├── .gitattributes ├── .gitignore ├── LICENSE.txt ├── README.md ├── SegTracker.py ├── aot ├── LICENSE ├── MODEL_ZOO.md ├── README.md ├── __init__.py ├── configs │ ├── default.py │ ├── models │ │ ├── aotb.py │ │ ├── aotl.py │ │ ├── aots.py │ │ ├── aott.py │ │ ├── deaotb.py │ │ ├── deaotl.py │ │ ├── deaots.py │ │ ├── deaott.py │ │ ├── default.py │ │ ├── default_deaot.py │ │ ├── r101_aotl.py │ │ ├── r50_aotl.py │ │ ├── r50_deaotl.py │ │ ├── rs101_aotl.py │ │ ├── swinb_aotl.py │ │ └── swinb_deaotl.py │ ├── pre.py │ ├── pre_dav.py │ ├── pre_ytb.py │ ├── pre_ytb_dav.py │ └── ytb.py ├── dataloaders │ ├── __init__.py │ ├── eval_datasets.py │ ├── image_transforms.py │ ├── train_datasets.py │ └── video_transforms.py ├── datasets │ ├── .DS_Store │ ├── DAVIS │ │ └── README.md │ ├── Static │ │ └── README.md │ └── YTB │ │ ├── 2018 │ │ ├── train │ │ │ └── README.md │ │ ├── valid │ │ │ └── README.md │ │ └── valid_all_frames │ │ │ └── README.md │ │ └── 2019 │ │ ├── train │ │ └── README.md │ │ ├── valid │ │ └── README.md │ │ └── valid_all_frames │ │ └── README.md ├── networks │ ├── .DS_Store │ ├── __init__.py │ ├── decoders │ │ ├── __init__.py │ │ └── fpn.py │ ├── encoders │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── mobilenetv2.py │ │ ├── mobilenetv3.py │ │ ├── resnest │ │ │ ├── __init__.py │ │ │ ├── resnest.py │ │ │ ├── resnet.py │ │ │ └── splat.py │ │ ├── resnet.py │ │ └── swin │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ └── swin_transformer.py │ ├── engines │ │ ├── __init__.py │ │ ├── aot_engine.py │ │ └── deaot_engine.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── basic.py │ │ ├── loss.py │ │ ├── normalization.py │ │ ├── position.py │ │ └── transformer.py │ ├── managers │ │ ├── evaluator.py │ │ └── trainer.py │ └── models │ │ ├── __init__.py │ │ ├── aot.py │ │ └── deaot.py ├── pretrain_models │ └── README.md ├── source │ ├── .DS_Store │ ├── overview.png │ └── overview_deaot.png ├── tools │ ├── demo.py │ ├── eval.py │ └── train.py ├── train_eval.sh └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── cp_ckpt.py │ ├── ema.py │ ├── eval.py │ ├── image.py │ ├── learning.py │ ├── math.py │ ├── meters.py │ └── metric.py ├── aot_tracker.py ├── app.py ├── assets ├── 840_iSXIa0hE8Ek.zip ├── blackswan.mp4 ├── cars.mp4 ├── cell.mp4 ├── demo_3x2.gif ├── gradio.jpg ├── interactive_webui.jpg └── top.gif ├── demo.ipynb ├── demo_instseg.ipynb ├── img2vid.py ├── licenses.md ├── model_args.py ├── prepare.py ├── sam ├── .flake8 ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── assets │ ├── masks1.png │ ├── masks2.jpg │ ├── model_diagram.png │ ├── notebook1.png │ └── notebook2.png ├── linter.sh ├── notebooks │ ├── automatic_mask_generator_example.ipynb │ ├── images │ │ ├── dog.jpg │ │ ├── groceries.jpg │ │ └── truck.jpg │ ├── onnx_model_example.ipynb │ └── predictor_example.ipynb ├── scripts │ ├── amg.py │ └── export_onnx_model.py ├── segment_anything │ ├── .DS_Store │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ │ ├── __init__.py │ │ ├── common.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ ├── sam.py │ │ └── transformer.py │ ├── predictor.py │ └── utils │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── onnx.py │ │ └── transforms.py ├── setup.cfg └── setup.py ├── script ├── download_ckpt.sh └── install.sh ├── seg_track_anything.py ├── tool ├── detector.py ├── segmentor.py └── transfer_tools.py └── tutorial ├── img ├── Drawing_board.jpg ├── add_positive_base_on_everything.jpg ├── add_positive_base_on_everything_cxk.jpg ├── add_positive_points.jpg ├── add_positive_points_2.jpg ├── audio_tab.jpg ├── click_input_video.jpg ├── click_segment.jpg ├── click_segment_everything.jpg ├── detect_result.jpg ├── enter_text.jpg ├── grounding-tab.jpg ├── input_video.jpg ├── new_object.jpg ├── second_object.jpg ├── segment_everything_blackswan.jpg ├── select_fps.jpg ├── start_tracking.jpg ├── switch2ImgSeq.jpg ├── switch2textT.jpg ├── upload_Image_seq.jpg └── use_exa4ImgSeq.jpg ├── tutorial for Image-Sequence input.md ├── tutorial for WebUI-1.0-Version.md ├── tutorial for WebUI-1.5-Version.md └── tutorial for WebUI-1.6-Version.md /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | ckpt/* 3 | assets/*masks 4 | assets/*mp4 5 | # assets/*zip 6 | assets/*gif 7 | *.pyc 8 | debug 9 | cym_utils 10 | /src 11 | /tracking_results 12 | /aot/results 13 | /aot/pretrain_models 14 | /aot/datasets 15 | /ast_master 16 | -------------------------------------------------------------------------------- /aot/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, z-x-yang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /aot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/aot/__init__.py -------------------------------------------------------------------------------- /aot/configs/default.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | 5 | class DefaultEngineConfig(): 6 | def __init__(self, exp_name='default', model='aott'): 7 | model_cfg = importlib.import_module('configs.models.' + 8 | model).ModelConfig() 9 | self.__dict__.update(model_cfg.__dict__) # add model config 10 | 11 | self.EXP_NAME = exp_name + '_' + self.MODEL_NAME 12 | 13 | self.STAGE_NAME = 'YTB' 14 | 15 | self.DATASETS = ['youtubevos'] 16 | self.DATA_WORKERS = 8 17 | self.DATA_RANDOMCROP = (465, 18 | 465) if self.MODEL_ALIGN_CORNERS else (464, 19 | 464) 20 | self.DATA_RANDOMFLIP = 0.5 21 | self.DATA_MAX_CROP_STEPS = 10 22 | self.DATA_SHORT_EDGE_LEN = 480 23 | self.DATA_MIN_SCALE_FACTOR = 0.7 24 | self.DATA_MAX_SCALE_FACTOR = 1.3 25 | self.DATA_RANDOM_REVERSE_SEQ = True 26 | self.DATA_SEQ_LEN = 5 27 | self.DATA_DAVIS_REPEAT = 5 28 | self.DATA_RANDOM_GAP_DAVIS = 12 # max frame interval between two sampled frames for DAVIS (24fps) 29 | self.DATA_RANDOM_GAP_YTB = 3 # max frame interval between two sampled frames for YouTube-VOS (6fps) 30 | self.DATA_DYNAMIC_MERGE_PROB = 0.3 31 | 32 | self.PRETRAIN = True 33 | self.PRETRAIN_FULL = False # if False, load encoder only 34 | self.PRETRAIN_MODEL = './data_wd/pretrain_model/mobilenet_v2.pth' 35 | # self.PRETRAIN_MODEL = './pretrain_models/mobilenet_v2-b0353104.pth' 36 | 37 | self.TRAIN_TOTAL_STEPS = 100000 38 | self.TRAIN_START_STEP = 0 39 | self.TRAIN_WEIGHT_DECAY = 0.07 40 | self.TRAIN_WEIGHT_DECAY_EXCLUSIVE = { 41 | # 'encoder.': 0.01 42 | } 43 | self.TRAIN_WEIGHT_DECAY_EXEMPTION = [ 44 | 'absolute_pos_embed', 'relative_position_bias_table', 45 | 'relative_emb_v', 'conv_out' 46 | ] 47 | self.TRAIN_LR = 2e-4 48 | self.TRAIN_LR_MIN = 2e-5 if 'mobilenetv2' in self.MODEL_ENCODER else 1e-5 49 | self.TRAIN_LR_POWER = 0.9 50 | self.TRAIN_LR_ENCODER_RATIO = 0.1 51 | self.TRAIN_LR_WARM_UP_RATIO = 0.05 52 | self.TRAIN_LR_COSINE_DECAY = False 53 | self.TRAIN_LR_RESTART = 1 54 | self.TRAIN_LR_UPDATE_STEP = 1 55 | self.TRAIN_AUX_LOSS_WEIGHT = 1.0 56 | self.TRAIN_AUX_LOSS_RATIO = 1.0 57 | self.TRAIN_OPT = 'adamw' 58 | self.TRAIN_SGD_MOMENTUM = 0.9 59 | self.TRAIN_GPUS = 4 60 | self.TRAIN_BATCH_SIZE = 16 61 | self.TRAIN_TBLOG = False 62 | self.TRAIN_TBLOG_STEP = 50 63 | self.TRAIN_LOG_STEP = 20 64 | self.TRAIN_IMG_LOG = True 65 | self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15 66 | self.TRAIN_SEQ_TRAINING_FREEZE_PARAMS = ['patch_wise_id_bank'] 67 | self.TRAIN_SEQ_TRAINING_START_RATIO = 0.5 68 | self.TRAIN_HARD_MINING_RATIO = 0.5 69 | self.TRAIN_EMA_RATIO = 0.1 70 | self.TRAIN_CLIP_GRAD_NORM = 5. 71 | self.TRAIN_SAVE_STEP = 5000 72 | self.TRAIN_MAX_KEEP_CKPT = 8 73 | self.TRAIN_RESUME = False 74 | self.TRAIN_RESUME_CKPT = None 75 | self.TRAIN_RESUME_STEP = 0 76 | self.TRAIN_AUTO_RESUME = True 77 | self.TRAIN_DATASET_FULL_RESOLUTION = False 78 | self.TRAIN_ENABLE_PREV_FRAME = False 79 | self.TRAIN_ENCODER_FREEZE_AT = 2 80 | self.TRAIN_LSTT_EMB_DROPOUT = 0. 81 | self.TRAIN_LSTT_ID_DROPOUT = 0. 82 | self.TRAIN_LSTT_DROPPATH = 0.1 83 | self.TRAIN_LSTT_DROPPATH_SCALING = False 84 | self.TRAIN_LSTT_DROPPATH_LST = False 85 | self.TRAIN_LSTT_LT_DROPOUT = 0. 86 | self.TRAIN_LSTT_ST_DROPOUT = 0. 87 | 88 | self.TEST_GPU_ID = 0 89 | self.TEST_GPU_NUM = 1 90 | self.TEST_FRAME_LOG = False 91 | self.TEST_DATASET = 'youtubevos' 92 | self.TEST_DATASET_FULL_RESOLUTION = False 93 | self.TEST_DATASET_SPLIT = 'val' 94 | self.TEST_CKPT_PATH = None 95 | # if "None", evaluate the latest checkpoint. 96 | self.TEST_CKPT_STEP = None 97 | self.TEST_FLIP = False 98 | self.TEST_MULTISCALE = [1] 99 | self.TEST_MAX_SHORT_EDGE = None 100 | self.TEST_MAX_LONG_EDGE = 800 * 1.3 101 | self.TEST_WORKERS = 4 102 | 103 | # GPU distribution 104 | self.DIST_ENABLE = True 105 | self.DIST_BACKEND = "nccl" # "gloo" 106 | self.DIST_URL = "tcp://127.0.0.1:13241" 107 | self.DIST_START_GPU = 0 108 | 109 | def init_dir(self): 110 | self.DIR_DATA = '../VOS02/datasets'#'./datasets' 111 | self.DIR_DAVIS = os.path.join(self.DIR_DATA, 'DAVIS') 112 | self.DIR_YTB = os.path.join(self.DIR_DATA, 'YTB') 113 | self.DIR_STATIC = os.path.join(self.DIR_DATA, 'Static') 114 | 115 | self.DIR_ROOT = './'#'./data_wd/youtube_vos_jobs' 116 | 117 | self.DIR_RESULT = os.path.join(self.DIR_ROOT, 'result', self.EXP_NAME, 118 | self.STAGE_NAME) 119 | self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt') 120 | self.DIR_EMA_CKPT = os.path.join(self.DIR_RESULT, 'ema_ckpt') 121 | self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log') 122 | self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard') 123 | # self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img') 124 | # self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval') 125 | self.DIR_IMG_LOG = './img_logs' 126 | self.DIR_EVALUATION = './results' 127 | 128 | for path in [ 129 | self.DIR_RESULT, self.DIR_CKPT, self.DIR_EMA_CKPT, 130 | self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG, 131 | self.DIR_TB_LOG 132 | ]: 133 | if not os.path.isdir(path): 134 | try: 135 | os.makedirs(path) 136 | except Exception as inst: 137 | print(inst) 138 | print('Failed to make dir: {}.'.format(path)) 139 | -------------------------------------------------------------------------------- /aot/configs/models/aotb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTB' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | -------------------------------------------------------------------------------- /aot/configs/models/aotl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTL' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | 11 | self.TRAIN_LONG_TERM_MEM_GAP = 2 12 | 13 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot/configs/models/aots.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTS' 8 | 9 | self.MODEL_LSTT_NUM = 2 10 | -------------------------------------------------------------------------------- /aot/configs/models/aott.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTT' 8 | -------------------------------------------------------------------------------- /aot/configs/models/deaotb.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTB' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | -------------------------------------------------------------------------------- /aot/configs/models/deaotl.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTL' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | 11 | self.TRAIN_LONG_TERM_MEM_GAP = 2 12 | 13 | self.TEST_LONG_TERM_MEM_GAP = 5 14 | -------------------------------------------------------------------------------- /aot/configs/models/deaots.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTS' 8 | 9 | self.MODEL_LSTT_NUM = 2 10 | -------------------------------------------------------------------------------- /aot/configs/models/deaott.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTT' 8 | -------------------------------------------------------------------------------- /aot/configs/models/default.py: -------------------------------------------------------------------------------- 1 | class DefaultModelConfig(): 2 | def __init__(self): 3 | self.MODEL_NAME = 'AOTDefault' 4 | 5 | self.MODEL_VOS = 'aot' 6 | self.MODEL_ENGINE = 'aotengine' 7 | self.MODEL_ALIGN_CORNERS = True 8 | self.MODEL_ENCODER = 'mobilenetv2' 9 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/mobilenet_v2-b0353104.pth' 10 | self.MODEL_ENCODER_DIM = [24, 32, 96, 1280] # 4x, 8x, 16x, 16x 11 | self.MODEL_ENCODER_EMBEDDING_DIM = 256 12 | self.MODEL_DECODER_INTERMEDIATE_LSTT = True 13 | self.MODEL_FREEZE_BN = True 14 | self.MODEL_FREEZE_BACKBONE = False 15 | self.MODEL_MAX_OBJ_NUM = 10 16 | self.MODEL_SELF_HEADS = 8 17 | self.MODEL_ATT_HEADS = 8 18 | self.MODEL_LSTT_NUM = 1 19 | self.MODEL_EPSILON = 1e-5 20 | self.MODEL_USE_PREV_PROB = False 21 | 22 | self.TRAIN_LONG_TERM_MEM_GAP = 9999 23 | self.TRAIN_AUG_TYPE = 'v1' 24 | 25 | self.TEST_LONG_TERM_MEM_GAP = 9999 26 | 27 | self.TEST_SHORT_TERM_MEM_SKIP = 1 28 | -------------------------------------------------------------------------------- /aot/configs/models/default_deaot.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig as BaseConfig 2 | 3 | 4 | class DefaultModelConfig(BaseConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTDefault' 8 | 9 | self.MODEL_VOS = 'deaot' 10 | self.MODEL_ENGINE = 'deaotengine' 11 | 12 | self.MODEL_DECODER_INTERMEDIATE_LSTT = False 13 | 14 | self.MODEL_SELF_HEADS = 1 15 | self.MODEL_ATT_HEADS = 1 16 | 17 | self.TRAIN_AUG_TYPE = 'v2' 18 | -------------------------------------------------------------------------------- /aot/configs/models/r101_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R101_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnet101' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet101-63fe2227.pth' # https://download.pytorch.org/models/resnet101-63fe2227.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot/configs/models/r50_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R50_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnet50' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet50-0676ba61.pth' # https://download.pytorch.org/models/resnet50-0676ba61.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot/configs/models/r50_deaotl.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R50_DeAOTL' 8 | 9 | self.MODEL_ENCODER = 'resnet50' 10 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 11 | 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 17 | -------------------------------------------------------------------------------- /aot/configs/models/rs101_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R101_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnest101' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnest101-22405ba7.pth' # https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot/configs/models/swinb_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'SwinB_AOTL' 8 | 9 | self.MODEL_ENCODER = 'swin_base' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/swin_base_patch4_window7_224_22k.pth' # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth 11 | self.MODEL_ALIGN_CORNERS = False 12 | self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x 13 | self.MODEL_LSTT_NUM = 3 14 | 15 | self.TRAIN_LONG_TERM_MEM_GAP = 2 16 | 17 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot/configs/models/swinb_deaotl.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'SwinB_DeAOTL' 8 | 9 | self.MODEL_ENCODER = 'swin_base' 10 | self.MODEL_ALIGN_CORNERS = False 11 | self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x 12 | 13 | self.MODEL_LSTT_NUM = 3 14 | 15 | self.TRAIN_LONG_TERM_MEM_GAP = 2 16 | 17 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot/configs/pre.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultEngineConfig 2 | 3 | 4 | class EngineConfig(DefaultEngineConfig): 5 | def __init__(self, exp_name='default', model='AOTT'): 6 | super().__init__(exp_name, model) 7 | self.STAGE_NAME = 'PRE' 8 | 9 | self.init_dir() 10 | 11 | self.DATASETS = ['static'] 12 | 13 | self.DATA_DYNAMIC_MERGE_PROB = 1.0 14 | 15 | self.TRAIN_LR = 4e-4 16 | self.TRAIN_LR_MIN = 2e-5 17 | self.TRAIN_WEIGHT_DECAY = 0.03 18 | self.TRAIN_SEQ_TRAINING_START_RATIO = 1.0 19 | self.TRAIN_AUX_LOSS_RATIO = 0.1 20 | -------------------------------------------------------------------------------- /aot/configs/pre_dav.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'PRE_DAV' 9 | 10 | self.init_dir() 11 | 12 | self.DATASETS = ['davis2017'] 13 | 14 | self.TRAIN_TOTAL_STEPS = 50000 15 | 16 | pretrain_stage = 'PRE' 17 | pretrain_ckpt = 'save_step_100000.pth' 18 | self.PRETRAIN_FULL = True # if False, load encoder only 19 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 20 | self.EXP_NAME, pretrain_stage, 21 | 'ema_ckpt', pretrain_ckpt) 22 | -------------------------------------------------------------------------------- /aot/configs/pre_ytb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'PRE_YTB' 9 | 10 | self.init_dir() 11 | 12 | pretrain_stage = 'PRE' 13 | pretrain_ckpt = 'save_step_100000.pth' 14 | self.PRETRAIN_FULL = True # if False, load encoder only 15 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 16 | self.EXP_NAME, pretrain_stage, 17 | 'ema_ckpt', pretrain_ckpt) 18 | -------------------------------------------------------------------------------- /aot/configs/pre_ytb_dav.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'PRE_YTB_DAV' 9 | 10 | self.init_dir() 11 | 12 | self.DATASETS = ['youtubevos', 'davis2017'] 13 | 14 | pretrain_stage = 'PRE' 15 | pretrain_ckpt = 'save_step_100000.pth' 16 | self.PRETRAIN_FULL = True # if False, load encoder only 17 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 18 | self.EXP_NAME, pretrain_stage, 19 | 'ema_ckpt', pretrain_ckpt) 20 | -------------------------------------------------------------------------------- /aot/configs/ytb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'YTB' 9 | 10 | self.init_dir() 11 | -------------------------------------------------------------------------------- /aot/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/aot/dataloaders/__init__.py -------------------------------------------------------------------------------- /aot/datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/aot/datasets/.DS_Store -------------------------------------------------------------------------------- /aot/datasets/DAVIS/README.md: -------------------------------------------------------------------------------- 1 | Put DAVIS 2017 here. -------------------------------------------------------------------------------- /aot/datasets/Static/README.md: -------------------------------------------------------------------------------- 1 | Put the static dataset here. Guidance can be found in [AFB-URR](https://github.com/xmlyqing00/AFB-URR), which we referred to in the implementation of the pre-training. 2 | -------------------------------------------------------------------------------- /aot/datasets/YTB/2018/train/README.md: -------------------------------------------------------------------------------- 1 | Put the training split of YouTube-VOS 2018 here. -------------------------------------------------------------------------------- /aot/datasets/YTB/2018/valid/README.md: -------------------------------------------------------------------------------- 1 | Put the validation split of YouTube-VOS 2018 here. -------------------------------------------------------------------------------- /aot/datasets/YTB/2018/valid_all_frames/README.md: -------------------------------------------------------------------------------- 1 | Put the all-frame validation split of YouTube-VOS 2018 here. -------------------------------------------------------------------------------- /aot/datasets/YTB/2019/train/README.md: -------------------------------------------------------------------------------- 1 | Put the training split of YouTube-VOS 2019 here. -------------------------------------------------------------------------------- /aot/datasets/YTB/2019/valid/README.md: -------------------------------------------------------------------------------- 1 | Put the validation split of YouTube-VOS 2019 here. -------------------------------------------------------------------------------- /aot/datasets/YTB/2019/valid_all_frames/README.md: -------------------------------------------------------------------------------- 1 | Put the all-frame validation split of YouTube-VOS 2018 here. -------------------------------------------------------------------------------- /aot/networks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/aot/networks/.DS_Store -------------------------------------------------------------------------------- /aot/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/aot/networks/__init__.py -------------------------------------------------------------------------------- /aot/networks/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.decoders.fpn import FPNSegmentationHead 2 | 3 | 4 | def build_decoder(name, **kwargs): 5 | 6 | if name == 'fpn': 7 | return FPNSegmentationHead(**kwargs) 8 | else: 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /aot/networks/decoders/fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.layers.basic import ConvGN 5 | 6 | 7 | class FPNSegmentationHead(nn.Module): 8 | def __init__(self, 9 | in_dim, 10 | out_dim, 11 | decode_intermediate_input=True, 12 | hidden_dim=256, 13 | shortcut_dims=[24, 32, 96, 1280], 14 | align_corners=True): 15 | super().__init__() 16 | self.align_corners = align_corners 17 | 18 | self.decode_intermediate_input = decode_intermediate_input 19 | 20 | self.conv_in = ConvGN(in_dim, hidden_dim, 1) 21 | 22 | self.conv_16x = ConvGN(hidden_dim, hidden_dim, 3) 23 | self.conv_8x = ConvGN(hidden_dim, hidden_dim // 2, 3) 24 | self.conv_4x = ConvGN(hidden_dim // 2, hidden_dim // 2, 3) 25 | 26 | self.adapter_16x = nn.Conv2d(shortcut_dims[-2], hidden_dim, 1) 27 | self.adapter_8x = nn.Conv2d(shortcut_dims[-3], hidden_dim, 1) 28 | self.adapter_4x = nn.Conv2d(shortcut_dims[-4], hidden_dim // 2, 1) 29 | 30 | self.conv_out = nn.Conv2d(hidden_dim // 2, out_dim, 1) 31 | 32 | self._init_weight() 33 | 34 | def forward(self, inputs, shortcuts): 35 | 36 | if self.decode_intermediate_input: 37 | x = torch.cat(inputs, dim=1) 38 | else: 39 | x = inputs[-1] 40 | 41 | x = F.relu_(self.conv_in(x)) 42 | x = F.relu_(self.conv_16x(self.adapter_16x(shortcuts[-2]) + x)) 43 | 44 | x = F.interpolate(x, 45 | size=shortcuts[-3].size()[-2:], 46 | mode="bilinear", 47 | align_corners=self.align_corners) 48 | x = F.relu_(self.conv_8x(self.adapter_8x(shortcuts[-3]) + x)) 49 | 50 | x = F.interpolate(x, 51 | size=shortcuts[-4].size()[-2:], 52 | mode="bilinear", 53 | align_corners=self.align_corners) 54 | x = F.relu_(self.conv_4x(self.adapter_4x(shortcuts[-4]) + x)) 55 | 56 | x = self.conv_out(x) 57 | 58 | return x 59 | 60 | def _init_weight(self): 61 | for p in self.parameters(): 62 | if p.dim() > 1: 63 | nn.init.xavier_uniform_(p) 64 | -------------------------------------------------------------------------------- /aot/networks/encoders/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/aot/networks/encoders/.DS_Store -------------------------------------------------------------------------------- /aot/networks/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.encoders.mobilenetv2 import MobileNetV2 2 | from networks.encoders.mobilenetv3 import MobileNetV3Large 3 | from networks.encoders.resnet import ResNet101, ResNet50 4 | from networks.encoders.resnest import resnest 5 | from networks.encoders.swin import build_swin_model 6 | from networks.layers.normalization import FrozenBatchNorm2d 7 | from torch import nn 8 | 9 | 10 | def build_encoder(name, frozen_bn=True, freeze_at=-1): 11 | if frozen_bn: 12 | BatchNorm = FrozenBatchNorm2d 13 | else: 14 | BatchNorm = nn.BatchNorm2d 15 | 16 | if name == 'mobilenetv2': 17 | return MobileNetV2(16, BatchNorm, freeze_at=freeze_at) 18 | elif name == 'mobilenetv3': 19 | return MobileNetV3Large(16, BatchNorm, freeze_at=freeze_at) 20 | elif name == 'resnet50': 21 | return ResNet50(16, BatchNorm, freeze_at=freeze_at) 22 | elif name == 'resnet101': 23 | return ResNet101(16, BatchNorm, freeze_at=freeze_at) 24 | elif name == 'resnest50': 25 | return resnest.resnest50(norm_layer=BatchNorm, 26 | dilation=2, 27 | freeze_at=freeze_at) 28 | elif name == 'resnest101': 29 | return resnest.resnest101(norm_layer=BatchNorm, 30 | dilation=2, 31 | freeze_at=freeze_at) 32 | elif 'swin' in name: 33 | return build_swin_model(name, freeze_at=freeze_at) 34 | else: 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /aot/networks/encoders/resnest/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnest import * 2 | -------------------------------------------------------------------------------- /aot/networks/encoders/resnest/resnest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .resnet import ResNet, Bottleneck 3 | 4 | __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269'] 5 | 6 | _url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' 7 | 8 | _model_sha256 = { 9 | name: checksum 10 | for checksum, name in [ 11 | ('528c19ca', 'resnest50'), 12 | ('22405ba7', 'resnest101'), 13 | ('75117900', 'resnest200'), 14 | ('0cc87c48', 'resnest269'), 15 | ] 16 | } 17 | 18 | 19 | def short_hash(name): 20 | if name not in _model_sha256: 21 | raise ValueError( 22 | 'Pretrained model for {name} is not available.'.format(name=name)) 23 | return _model_sha256[name][:8] 24 | 25 | 26 | resnest_model_urls = { 27 | name: _url_format.format(name, short_hash(name)) 28 | for name in _model_sha256.keys() 29 | } 30 | 31 | 32 | def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): 33 | model = ResNet(Bottleneck, [3, 4, 6, 3], 34 | radix=2, 35 | groups=1, 36 | bottleneck_width=64, 37 | deep_stem=True, 38 | stem_width=32, 39 | avg_down=True, 40 | avd=True, 41 | avd_first=False, 42 | **kwargs) 43 | if pretrained: 44 | model.load_state_dict( 45 | torch.hub.load_state_dict_from_url(resnest_model_urls['resnest50'], 46 | progress=True, 47 | check_hash=True)) 48 | return model 49 | 50 | 51 | def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): 52 | model = ResNet(Bottleneck, [3, 4, 23, 3], 53 | radix=2, 54 | groups=1, 55 | bottleneck_width=64, 56 | deep_stem=True, 57 | stem_width=64, 58 | avg_down=True, 59 | avd=True, 60 | avd_first=False, 61 | **kwargs) 62 | if pretrained: 63 | model.load_state_dict( 64 | torch.hub.load_state_dict_from_url( 65 | resnest_model_urls['resnest101'], 66 | progress=True, 67 | check_hash=True)) 68 | return model 69 | 70 | 71 | def resnest200(pretrained=False, root='~/.encoding/models', **kwargs): 72 | model = ResNet(Bottleneck, [3, 24, 36, 3], 73 | radix=2, 74 | groups=1, 75 | bottleneck_width=64, 76 | deep_stem=True, 77 | stem_width=64, 78 | avg_down=True, 79 | avd=True, 80 | avd_first=False, 81 | **kwargs) 82 | if pretrained: 83 | model.load_state_dict( 84 | torch.hub.load_state_dict_from_url( 85 | resnest_model_urls['resnest200'], 86 | progress=True, 87 | check_hash=True)) 88 | return model 89 | 90 | 91 | def resnest269(pretrained=False, root='~/.encoding/models', **kwargs): 92 | model = ResNet(Bottleneck, [3, 30, 48, 8], 93 | radix=2, 94 | groups=1, 95 | bottleneck_width=64, 96 | deep_stem=True, 97 | stem_width=64, 98 | avg_down=True, 99 | avd=True, 100 | avd_first=False, 101 | **kwargs) 102 | if pretrained: 103 | model.load_state_dict( 104 | torch.hub.load_state_dict_from_url( 105 | resnest_model_urls['resnest269'], 106 | progress=True, 107 | check_hash=True)) 108 | return model 109 | -------------------------------------------------------------------------------- /aot/networks/encoders/resnest/splat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv2d, Module, ReLU 5 | from torch.nn.modules.utils import _pair 6 | 7 | __all__ = ['SplAtConv2d', 'DropBlock2D'] 8 | 9 | 10 | class DropBlock2D(object): 11 | def __init__(self, *args, **kwargs): 12 | raise NotImplementedError 13 | 14 | 15 | class SplAtConv2d(Module): 16 | """Split-Attention Conv2d 17 | """ 18 | def __init__(self, 19 | in_channels, 20 | channels, 21 | kernel_size, 22 | stride=(1, 1), 23 | padding=(0, 0), 24 | dilation=(1, 1), 25 | groups=1, 26 | bias=True, 27 | radix=2, 28 | reduction_factor=4, 29 | rectify=False, 30 | rectify_avg=False, 31 | norm_layer=None, 32 | dropblock_prob=0.0, 33 | **kwargs): 34 | super(SplAtConv2d, self).__init__() 35 | padding = _pair(padding) 36 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 37 | self.rectify_avg = rectify_avg 38 | inter_channels = max(in_channels * radix // reduction_factor, 32) 39 | self.radix = radix 40 | self.cardinality = groups 41 | self.channels = channels 42 | self.dropblock_prob = dropblock_prob 43 | if self.rectify: 44 | from rfconv import RFConv2d 45 | self.conv = RFConv2d(in_channels, 46 | channels * radix, 47 | kernel_size, 48 | stride, 49 | padding, 50 | dilation, 51 | groups=groups * radix, 52 | bias=bias, 53 | average_mode=rectify_avg, 54 | **kwargs) 55 | else: 56 | self.conv = Conv2d(in_channels, 57 | channels * radix, 58 | kernel_size, 59 | stride, 60 | padding, 61 | dilation, 62 | groups=groups * radix, 63 | bias=bias, 64 | **kwargs) 65 | self.use_bn = norm_layer is not None 66 | if self.use_bn: 67 | self.bn0 = norm_layer(channels * radix) 68 | self.relu = ReLU(inplace=True) 69 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 70 | if self.use_bn: 71 | self.bn1 = norm_layer(inter_channels) 72 | self.fc2 = Conv2d(inter_channels, 73 | channels * radix, 74 | 1, 75 | groups=self.cardinality) 76 | if dropblock_prob > 0.0: 77 | self.dropblock = DropBlock2D(dropblock_prob, 3) 78 | self.rsoftmax = rSoftMax(radix, groups) 79 | 80 | def forward(self, x): 81 | x = self.conv(x) 82 | if self.use_bn: 83 | x = self.bn0(x) 84 | if self.dropblock_prob > 0.0: 85 | x = self.dropblock(x) 86 | x = self.relu(x) 87 | 88 | batch, rchannel = x.shape[:2] 89 | if self.radix > 1: 90 | if torch.__version__ < '1.5': 91 | splited = torch.split(x, int(rchannel // self.radix), dim=1) 92 | else: 93 | splited = torch.split(x, rchannel // self.radix, dim=1) 94 | gap = sum(splited) 95 | else: 96 | gap = x 97 | gap = F.adaptive_avg_pool2d(gap, 1) 98 | gap = self.fc1(gap) 99 | 100 | if self.use_bn: 101 | gap = self.bn1(gap) 102 | gap = self.relu(gap) 103 | 104 | atten = self.fc2(gap) 105 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 106 | 107 | if self.radix > 1: 108 | if torch.__version__ < '1.5': 109 | attens = torch.split(atten, int(rchannel // self.radix), dim=1) 110 | else: 111 | attens = torch.split(atten, rchannel // self.radix, dim=1) 112 | out = sum([att * split for (att, split) in zip(attens, splited)]) 113 | else: 114 | out = atten * x 115 | return out.contiguous() 116 | 117 | 118 | class rSoftMax(nn.Module): 119 | def __init__(self, radix, cardinality): 120 | super().__init__() 121 | self.radix = radix 122 | self.cardinality = cardinality 123 | 124 | def forward(self, x): 125 | batch = x.size(0) 126 | if self.radix > 1: 127 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 128 | x = F.softmax(x, dim=1) 129 | x = x.reshape(batch, -1) 130 | else: 131 | x = torch.sigmoid(x) 132 | return x 133 | -------------------------------------------------------------------------------- /aot/networks/encoders/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | from utils.learning import freeze_params 4 | 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, 10 | inplanes, 11 | planes, 12 | stride=1, 13 | dilation=1, 14 | downsample=None, 15 | BatchNorm=None): 16 | super(Bottleneck, self).__init__() 17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 18 | self.bn1 = BatchNorm(planes) 19 | self.conv2 = nn.Conv2d(planes, 20 | planes, 21 | kernel_size=3, 22 | stride=stride, 23 | dilation=dilation, 24 | padding=dilation, 25 | bias=False) 26 | self.bn2 = BatchNorm(planes) 27 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 28 | self.bn3 = BatchNorm(planes * 4) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = downsample 31 | self.stride = stride 32 | self.dilation = dilation 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv3(out) 46 | out = self.bn3(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class ResNet(nn.Module): 58 | def __init__(self, block, layers, output_stride, BatchNorm, freeze_at=0): 59 | self.inplanes = 64 60 | super(ResNet, self).__init__() 61 | 62 | if output_stride == 16: 63 | strides = [1, 2, 2, 1] 64 | dilations = [1, 1, 1, 2] 65 | elif output_stride == 8: 66 | strides = [1, 2, 1, 1] 67 | dilations = [1, 1, 2, 4] 68 | else: 69 | raise NotImplementedError 70 | 71 | # Modules 72 | self.conv1 = nn.Conv2d(3, 73 | 64, 74 | kernel_size=7, 75 | stride=2, 76 | padding=3, 77 | bias=False) 78 | self.bn1 = BatchNorm(64) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 81 | 82 | self.layer1 = self._make_layer(block, 83 | 64, 84 | layers[0], 85 | stride=strides[0], 86 | dilation=dilations[0], 87 | BatchNorm=BatchNorm) 88 | self.layer2 = self._make_layer(block, 89 | 128, 90 | layers[1], 91 | stride=strides[1], 92 | dilation=dilations[1], 93 | BatchNorm=BatchNorm) 94 | self.layer3 = self._make_layer(block, 95 | 256, 96 | layers[2], 97 | stride=strides[2], 98 | dilation=dilations[2], 99 | BatchNorm=BatchNorm) 100 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 101 | 102 | self.stem = [self.conv1, self.bn1] 103 | self.stages = [self.layer1, self.layer2, self.layer3] 104 | 105 | self._init_weight() 106 | self.freeze(freeze_at) 107 | 108 | def _make_layer(self, 109 | block, 110 | planes, 111 | blocks, 112 | stride=1, 113 | dilation=1, 114 | BatchNorm=None): 115 | downsample = None 116 | if stride != 1 or self.inplanes != planes * block.expansion: 117 | downsample = nn.Sequential( 118 | nn.Conv2d(self.inplanes, 119 | planes * block.expansion, 120 | kernel_size=1, 121 | stride=stride, 122 | bias=False), 123 | BatchNorm(planes * block.expansion), 124 | ) 125 | 126 | layers = [] 127 | layers.append( 128 | block(self.inplanes, planes, stride, max(dilation // 2, 1), 129 | downsample, BatchNorm)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append( 133 | block(self.inplanes, 134 | planes, 135 | dilation=dilation, 136 | BatchNorm=BatchNorm)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, input): 141 | x = self.conv1(input) 142 | x = self.bn1(x) 143 | x = self.relu(x) 144 | x = self.maxpool(x) 145 | 146 | xs = [] 147 | 148 | x = self.layer1(x) 149 | xs.append(x) # 4X 150 | x = self.layer2(x) 151 | xs.append(x) # 8X 152 | x = self.layer3(x) 153 | xs.append(x) # 16X 154 | # Following STMVOS, we drop stage 5. 155 | xs.append(x) # 16X 156 | 157 | return xs 158 | 159 | def _init_weight(self): 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | elif isinstance(m, nn.BatchNorm2d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | 168 | def freeze(self, freeze_at): 169 | if freeze_at >= 1: 170 | for m in self.stem: 171 | freeze_params(m) 172 | 173 | for idx, stage in enumerate(self.stages, start=2): 174 | if freeze_at >= idx: 175 | freeze_params(stage) 176 | 177 | 178 | def ResNet50(output_stride, BatchNorm, freeze_at=0): 179 | """Constructs a ResNet-50 model. 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | model = ResNet(Bottleneck, [3, 4, 6, 3], 184 | output_stride, 185 | BatchNorm, 186 | freeze_at=freeze_at) 187 | return model 188 | 189 | 190 | def ResNet101(output_stride, BatchNorm, freeze_at=0): 191 | """Constructs a ResNet-101 model. 192 | Args: 193 | pretrained (bool): If True, returns a model pre-trained on ImageNet 194 | """ 195 | model = ResNet(Bottleneck, [3, 4, 23, 3], 196 | output_stride, 197 | BatchNorm, 198 | freeze_at=freeze_at) 199 | return model 200 | 201 | 202 | if __name__ == "__main__": 203 | import torch 204 | model = ResNet101(BatchNorm=nn.BatchNorm2d, output_stride=8) 205 | input = torch.rand(1, 3, 512, 512) 206 | output, low_level_feat = model(input) 207 | print(output.size()) 208 | print(low_level_feat.size()) 209 | -------------------------------------------------------------------------------- /aot/networks/encoders/swin/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_swin_model -------------------------------------------------------------------------------- /aot/networks/encoders/swin/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .swin_transformer import SwinTransformer 9 | 10 | 11 | def build_swin_model(model_type, freeze_at=0): 12 | if model_type == 'swin_base': 13 | model = SwinTransformer(embed_dim=128, 14 | depths=[2, 2, 18, 2], 15 | num_heads=[4, 8, 16, 32], 16 | window_size=7, 17 | drop_path_rate=0.3, 18 | out_indices=(0, 1, 2), 19 | ape=False, 20 | patch_norm=True, 21 | frozen_stages=freeze_at, 22 | use_checkpoint=False) 23 | 24 | else: 25 | raise NotImplementedError(f"Unkown model: {model_type}") 26 | 27 | return model 28 | -------------------------------------------------------------------------------- /aot/networks/engines/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.engines.aot_engine import AOTEngine, AOTInferEngine 2 | from networks.engines.deaot_engine import DeAOTEngine, DeAOTInferEngine 3 | 4 | 5 | def build_engine(name, phase='train', **kwargs): 6 | if name == 'aotengine': 7 | if phase == 'train': 8 | return AOTEngine(**kwargs) 9 | elif phase == 'eval': 10 | return AOTInferEngine(**kwargs) 11 | else: 12 | raise NotImplementedError 13 | elif name == 'deaotengine': 14 | if phase == 'train': 15 | return DeAOTEngine(**kwargs) 16 | elif phase == 'eval': 17 | return DeAOTInferEngine(**kwargs) 18 | else: 19 | raise NotImplementedError 20 | else: 21 | raise NotImplementedError 22 | -------------------------------------------------------------------------------- /aot/networks/engines/deaot_engine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from utils.image import one_hot_mask 4 | 5 | from networks.layers.basic import seq_to_2d 6 | from networks.engines.aot_engine import AOTEngine, AOTInferEngine 7 | 8 | 9 | class DeAOTEngine(AOTEngine): 10 | def __init__(self, 11 | aot_model, 12 | gpu_id=0, 13 | long_term_mem_gap=9999, 14 | short_term_mem_skip=1, 15 | layer_loss_scaling_ratio=2., 16 | max_len_long_term=9999): 17 | super().__init__(aot_model, gpu_id, long_term_mem_gap, 18 | short_term_mem_skip, max_len_long_term) 19 | self.layer_loss_scaling_ratio = layer_loss_scaling_ratio 20 | def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False): 21 | 22 | if curr_id_emb is None: 23 | if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1: 24 | curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num) 25 | else: 26 | curr_one_hot_mask = curr_mask 27 | curr_id_emb = self.assign_identity(curr_one_hot_mask) 28 | 29 | lstt_curr_memories = self.curr_lstt_output[1] 30 | lstt_curr_memories_2d = [] 31 | for layer_idx in range(len(lstt_curr_memories)): 32 | curr_k, curr_v, curr_id_k, curr_id_v = lstt_curr_memories[ 33 | layer_idx] 34 | curr_id_k, curr_id_v = self.AOT.LSTT.layers[ 35 | layer_idx].fuse_key_value_id(curr_id_k, curr_id_v, curr_id_emb) 36 | lstt_curr_memories[layer_idx][2], lstt_curr_memories[layer_idx][ 37 | 3] = curr_id_k, curr_id_v 38 | local_curr_id_k = seq_to_2d( 39 | curr_id_k, self.enc_size_2d) if curr_id_k is not None else None 40 | local_curr_id_v = seq_to_2d(curr_id_v, self.enc_size_2d) 41 | lstt_curr_memories_2d.append([ 42 | seq_to_2d(curr_k, self.enc_size_2d), 43 | seq_to_2d(curr_v, self.enc_size_2d), local_curr_id_k, 44 | local_curr_id_v 45 | ]) 46 | 47 | self.short_term_memories_list.append(lstt_curr_memories_2d) 48 | self.short_term_memories_list = self.short_term_memories_list[ 49 | -self.short_term_mem_skip:] 50 | self.short_term_memories = self.short_term_memories_list[0] 51 | 52 | if self.frame_step - self.last_mem_step >= self.long_term_mem_gap: 53 | # skip the update of long-term memory or not 54 | if not skip_long_term_update: 55 | self.update_long_term_memory(lstt_curr_memories) 56 | self.last_mem_step = self.frame_step 57 | 58 | 59 | class DeAOTInferEngine(AOTInferEngine): 60 | def __init__(self, 61 | aot_model, 62 | gpu_id=0, 63 | long_term_mem_gap=9999, 64 | short_term_mem_skip=1, 65 | max_aot_obj_num=None, 66 | max_len_long_term=9999): 67 | super().__init__(aot_model, gpu_id, long_term_mem_gap, 68 | short_term_mem_skip, max_aot_obj_num, max_len_long_term) 69 | def add_reference_frame(self, img, mask, obj_nums, frame_step=-1): 70 | if isinstance(obj_nums, list): 71 | obj_nums = obj_nums[0] 72 | self.obj_nums = obj_nums 73 | aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1) 74 | while (aot_num > len(self.aot_engines)): 75 | new_engine = DeAOTEngine(self.AOT, self.gpu_id, 76 | self.long_term_mem_gap, 77 | self.short_term_mem_skip, 78 | max_len_long_term = self.max_len_long_term) 79 | new_engine.eval() 80 | self.aot_engines.append(new_engine) 81 | 82 | separated_masks, separated_obj_nums = self.separate_mask( 83 | mask, obj_nums) 84 | img_embs = None 85 | for aot_engine, separated_mask, separated_obj_num in zip( 86 | self.aot_engines, separated_masks, separated_obj_nums): 87 | if aot_engine.obj_nums is None or aot_engine.obj_nums[0] < separated_obj_num: 88 | aot_engine.add_reference_frame(img, 89 | separated_mask, 90 | obj_nums=[separated_obj_num], 91 | frame_step=frame_step, 92 | img_embs=img_embs) 93 | else: 94 | aot_engine.update_short_term_memory(separated_mask) 95 | if img_embs is None: # reuse image embeddings 96 | img_embs = aot_engine.curr_enc_embs 97 | 98 | self.update_size() 99 | -------------------------------------------------------------------------------- /aot/networks/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/aot/networks/layers/__init__.py -------------------------------------------------------------------------------- /aot/networks/layers/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class GroupNorm1D(nn.Module): 7 | def __init__(self, indim, groups=8): 8 | super().__init__() 9 | self.gn = nn.GroupNorm(groups, indim) 10 | 11 | def forward(self, x): 12 | return self.gn(x.permute(1, 2, 0)).permute(2, 0, 1) 13 | 14 | 15 | class GNActDWConv2d(nn.Module): 16 | def __init__(self, indim, gn_groups=32): 17 | super().__init__() 18 | self.gn = nn.GroupNorm(gn_groups, indim) 19 | self.conv = nn.Conv2d(indim, 20 | indim, 21 | 5, 22 | dilation=1, 23 | padding=2, 24 | groups=indim, 25 | bias=False) 26 | 27 | def forward(self, x, size_2d): 28 | h, w = size_2d 29 | _, bs, c = x.size() 30 | x = x.view(h, w, bs, c).permute(2, 3, 0, 1) 31 | x = self.gn(x) 32 | x = F.gelu(x) 33 | x = self.conv(x) 34 | x = x.view(bs, c, h * w).permute(2, 0, 1) 35 | return x 36 | 37 | 38 | class DWConv2d(nn.Module): 39 | def __init__(self, indim, dropout=0.1): 40 | super().__init__() 41 | self.conv = nn.Conv2d(indim, 42 | indim, 43 | 5, 44 | dilation=1, 45 | padding=2, 46 | groups=indim, 47 | bias=False) 48 | self.dropout = nn.Dropout2d(p=dropout, inplace=True) 49 | 50 | def forward(self, x, size_2d): 51 | h, w = size_2d 52 | _, bs, c = x.size() 53 | x = x.view(h, w, bs, c).permute(2, 3, 0, 1) 54 | x = self.conv(x) 55 | x = self.dropout(x) 56 | x = x.view(bs, c, h * w).permute(2, 0, 1) 57 | return x 58 | 59 | 60 | class ScaleOffset(nn.Module): 61 | def __init__(self, indim): 62 | super().__init__() 63 | self.gamma = nn.Parameter(torch.ones(indim)) 64 | # torch.nn.init.normal_(self.gamma, std=0.02) 65 | self.beta = nn.Parameter(torch.zeros(indim)) 66 | 67 | def forward(self, x): 68 | if len(x.size()) == 3: 69 | return x * self.gamma + self.beta 70 | else: 71 | return x * self.gamma.view(1, -1, 1, 1) + self.beta.view( 72 | 1, -1, 1, 1) 73 | 74 | 75 | class ConvGN(nn.Module): 76 | def __init__(self, indim, outdim, kernel_size, gn_groups=8): 77 | super().__init__() 78 | self.conv = nn.Conv2d(indim, 79 | outdim, 80 | kernel_size, 81 | padding=kernel_size // 2) 82 | self.gn = nn.GroupNorm(gn_groups, outdim) 83 | 84 | def forward(self, x): 85 | return self.gn(self.conv(x)) 86 | 87 | 88 | def seq_to_2d(tensor, size_2d): 89 | h, w = size_2d 90 | _, n, c = tensor.size() 91 | tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous() 92 | return tensor 93 | 94 | 95 | def drop_path(x, drop_prob: float = 0., training: bool = False): 96 | if drop_prob == 0. or not training: 97 | return x 98 | keep_prob = 1 - drop_prob 99 | shape = ( 100 | x.shape[0], 101 | x.shape[1], 102 | ) + (1, ) * (x.ndim - 2 103 | ) # work with diff dim tensors, not just 2D ConvNets 104 | random_tensor = keep_prob + torch.rand( 105 | shape, dtype=x.dtype, device=x.device) 106 | random_tensor.floor_() # binarize 107 | output = x.div(keep_prob) * random_tensor 108 | return output 109 | 110 | 111 | def mask_out(x, y, mask_rate=0.15, training=False): 112 | if mask_rate == 0. or not training: 113 | return x 114 | 115 | keep_prob = 1 - mask_rate 116 | shape = ( 117 | x.shape[0], 118 | x.shape[1], 119 | ) + (1, ) * (x.ndim - 2 120 | ) # work with diff dim tensors, not just 2D ConvNets 121 | random_tensor = keep_prob + torch.rand( 122 | shape, dtype=x.dtype, device=x.device) 123 | random_tensor.floor_() # binarize 124 | output = x * random_tensor + y * (1 - random_tensor) 125 | 126 | return output 127 | 128 | 129 | class DropPath(nn.Module): 130 | def __init__(self, drop_prob=None, batch_dim=0): 131 | super(DropPath, self).__init__() 132 | self.drop_prob = drop_prob 133 | self.batch_dim = batch_dim 134 | 135 | def forward(self, x): 136 | return self.drop_path(x, self.drop_prob) 137 | 138 | def drop_path(self, x, drop_prob): 139 | if drop_prob == 0. or not self.training: 140 | return x 141 | keep_prob = 1 - drop_prob 142 | shape = [1 for _ in range(x.ndim)] 143 | shape[self.batch_dim] = x.shape[self.batch_dim] 144 | random_tensor = keep_prob + torch.rand( 145 | shape, dtype=x.dtype, device=x.device) 146 | random_tensor.floor_() # binarize 147 | output = x.div(keep_prob) * random_tensor 148 | return output 149 | 150 | 151 | class DropOutLogit(nn.Module): 152 | def __init__(self, drop_prob=None): 153 | super(DropOutLogit, self).__init__() 154 | self.drop_prob = drop_prob 155 | 156 | def forward(self, x): 157 | return self.drop_logit(x, self.drop_prob) 158 | 159 | def drop_logit(self, x, drop_prob): 160 | if drop_prob == 0. or not self.training: 161 | return x 162 | random_tensor = drop_prob + torch.rand( 163 | x.shape, dtype=x.dtype, device=x.device) 164 | random_tensor.floor_() # binarize 165 | mask = random_tensor * 1e+8 if ( 166 | x.dtype == torch.float32) else random_tensor * 1e+4 167 | output = x - mask 168 | return output 169 | -------------------------------------------------------------------------------- /aot/networks/layers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | try: 6 | from itertools import ifilterfalse 7 | except ImportError: # py3k 8 | from itertools import filterfalse as ifilterfalse 9 | 10 | 11 | def dice_loss(probas, labels, smooth=1): 12 | 13 | C = probas.size(1) 14 | losses = [] 15 | for c in list(range(C)): 16 | fg = (labels == c).float() 17 | if fg.sum() == 0: 18 | continue 19 | class_pred = probas[:, c] 20 | p0 = class_pred 21 | g0 = fg 22 | numerator = 2 * torch.sum(p0 * g0) + smooth 23 | denominator = torch.sum(p0) + torch.sum(g0) + smooth 24 | losses.append(1 - ((numerator) / (denominator))) 25 | return mean(losses) 26 | 27 | 28 | def tversky_loss(probas, labels, alpha=0.5, beta=0.5, epsilon=1e-6): 29 | ''' 30 | Tversky loss function. 31 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 32 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 33 | 34 | Same as soft dice loss when alpha=beta=0.5. 35 | Same as Jaccord loss when alpha=beta=1.0. 36 | See `Tversky loss function for image segmentation using 3D fully convolutional deep networks` 37 | https://arxiv.org/pdf/1706.05721.pdf 38 | ''' 39 | C = probas.size(1) 40 | losses = [] 41 | for c in list(range(C)): 42 | fg = (labels == c).float() 43 | if fg.sum() == 0: 44 | continue 45 | class_pred = probas[:, c] 46 | p0 = class_pred 47 | p1 = 1 - class_pred 48 | g0 = fg 49 | g1 = 1 - fg 50 | numerator = torch.sum(p0 * g0) 51 | denominator = numerator + alpha * \ 52 | torch.sum(p0*g1) + beta*torch.sum(p1*g0) 53 | losses.append(1 - ((numerator) / (denominator + epsilon))) 54 | return mean(losses) 55 | 56 | 57 | def flatten_probas(probas, labels, ignore=255): 58 | """ 59 | Flattens predictions in the batch 60 | """ 61 | B, C, H, W = probas.size() 62 | probas = probas.permute(0, 2, 3, 63 | 1).contiguous().view(-1, C) # B * H * W, C = P, C 64 | labels = labels.view(-1) 65 | if ignore is None: 66 | return probas, labels 67 | valid = (labels != ignore) 68 | vprobas = probas[valid.view(-1, 1).expand(-1, C)].reshape(-1, C) 69 | # vprobas = probas[torch.nonzero(valid).squeeze()] 70 | vlabels = labels[valid] 71 | return vprobas, vlabels 72 | 73 | 74 | def isnan(x): 75 | return x != x 76 | 77 | 78 | def mean(l, ignore_nan=False, empty=0): 79 | """ 80 | nanmean compatible with generators. 81 | """ 82 | l = iter(l) 83 | if ignore_nan: 84 | l = ifilterfalse(isnan, l) 85 | try: 86 | n = 1 87 | acc = next(l) 88 | except StopIteration: 89 | if empty == 'raise': 90 | raise ValueError('Empty mean') 91 | return empty 92 | for n, v in enumerate(l, 2): 93 | acc += v 94 | if n == 1: 95 | return acc 96 | return acc / n 97 | 98 | 99 | class DiceLoss(nn.Module): 100 | def __init__(self, ignore_index=255): 101 | super(DiceLoss, self).__init__() 102 | self.ignore_index = ignore_index 103 | 104 | def forward(self, tmp_dic, label_dic, step=None): 105 | total_loss = [] 106 | for idx in range(len(tmp_dic)): 107 | pred = tmp_dic[idx] 108 | label = label_dic[idx] 109 | pred = F.softmax(pred, dim=1) 110 | label = label.view(1, 1, pred.size()[2], pred.size()[3]) 111 | loss = dice_loss( 112 | *flatten_probas(pred, label, ignore=self.ignore_index)) 113 | total_loss.append(loss.unsqueeze(0)) 114 | total_loss = torch.cat(total_loss, dim=0) 115 | return total_loss 116 | 117 | 118 | class SoftJaccordLoss(nn.Module): 119 | def __init__(self, ignore_index=255): 120 | super(SoftJaccordLoss, self).__init__() 121 | self.ignore_index = ignore_index 122 | 123 | def forward(self, tmp_dic, label_dic, step=None): 124 | total_loss = [] 125 | for idx in range(len(tmp_dic)): 126 | pred = tmp_dic[idx] 127 | label = label_dic[idx] 128 | pred = F.softmax(pred, dim=1) 129 | label = label.view(1, 1, pred.size()[2], pred.size()[3]) 130 | loss = tversky_loss(*flatten_probas(pred, 131 | label, 132 | ignore=self.ignore_index), 133 | alpha=1.0, 134 | beta=1.0) 135 | total_loss.append(loss.unsqueeze(0)) 136 | total_loss = torch.cat(total_loss, dim=0) 137 | return total_loss 138 | 139 | 140 | class CrossEntropyLoss(nn.Module): 141 | def __init__(self, 142 | top_k_percent_pixels=None, 143 | hard_example_mining_step=100000): 144 | super(CrossEntropyLoss, self).__init__() 145 | self.top_k_percent_pixels = top_k_percent_pixels 146 | if top_k_percent_pixels is not None: 147 | assert (top_k_percent_pixels > 0 and top_k_percent_pixels < 1) 148 | self.hard_example_mining_step = hard_example_mining_step + 1e-5 149 | if self.top_k_percent_pixels is None: 150 | self.celoss = nn.CrossEntropyLoss(ignore_index=255, 151 | reduction='mean') 152 | else: 153 | self.celoss = nn.CrossEntropyLoss(ignore_index=255, 154 | reduction='none') 155 | 156 | def forward(self, dic_tmp, y, step): 157 | total_loss = [] 158 | for i in range(len(dic_tmp)): 159 | pred_logits = dic_tmp[i] 160 | gts = y[i] 161 | if self.top_k_percent_pixels is None: 162 | final_loss = self.celoss(pred_logits, gts) 163 | else: 164 | # Only compute the loss for top k percent pixels. 165 | # First, compute the loss for all pixels. Note we do not put the loss 166 | # to loss_collection and set reduction = None to keep the shape. 167 | num_pixels = float(pred_logits.size(2) * pred_logits.size(3)) 168 | pred_logits = pred_logits.view( 169 | -1, pred_logits.size(1), 170 | pred_logits.size(2) * pred_logits.size(3)) 171 | gts = gts.view(-1, gts.size(1) * gts.size(2)) 172 | pixel_losses = self.celoss(pred_logits, gts) 173 | if self.hard_example_mining_step == 0: 174 | top_k_pixels = int(self.top_k_percent_pixels * num_pixels) 175 | else: 176 | ratio = min(1.0, 177 | step / float(self.hard_example_mining_step)) 178 | top_k_pixels = int((ratio * self.top_k_percent_pixels + 179 | (1.0 - ratio)) * num_pixels) 180 | top_k_loss, top_k_indices = torch.topk(pixel_losses, 181 | k=top_k_pixels, 182 | dim=1) 183 | 184 | final_loss = torch.mean(top_k_loss) 185 | final_loss = final_loss.unsqueeze(0) 186 | total_loss.append(final_loss) 187 | total_loss = torch.cat(total_loss, dim=0) 188 | return total_loss 189 | -------------------------------------------------------------------------------- /aot/networks/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FrozenBatchNorm2d(nn.Module): 7 | """ 8 | BatchNorm2d where the batch statistics and the affine parameters 9 | are fixed 10 | """ 11 | def __init__(self, n, epsilon=1e-5): 12 | super(FrozenBatchNorm2d, self).__init__() 13 | self.register_buffer("weight", torch.ones(n)) 14 | self.register_buffer("bias", torch.zeros(n)) 15 | self.register_buffer("running_mean", torch.zeros(n)) 16 | self.register_buffer("running_var", torch.ones(n) - epsilon) 17 | self.epsilon = epsilon 18 | 19 | def forward(self, x): 20 | """ 21 | Refer to Detectron2 (https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py) 22 | """ 23 | if x.requires_grad: 24 | # When gradients are needed, F.batch_norm will use extra memory 25 | # because its backward op computes gradients for weight/bias as well. 26 | scale = self.weight * (self.running_var + self.epsilon).rsqrt() 27 | bias = self.bias - self.running_mean * scale 28 | scale = scale.reshape(1, -1, 1, 1) 29 | bias = bias.reshape(1, -1, 1, 1) 30 | out_dtype = x.dtype # may be half 31 | return x * scale.to(out_dtype) + bias.to(out_dtype) 32 | else: 33 | # When gradients are not needed, F.batch_norm is a single fused op 34 | # and provide more optimization opportunities. 35 | return F.batch_norm( 36 | x, 37 | self.running_mean, 38 | self.running_var, 39 | self.weight, 40 | self.bias, 41 | training=False, 42 | eps=self.epsilon, 43 | ) 44 | -------------------------------------------------------------------------------- /aot/networks/layers/position.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from utils.math import truncated_normal_ 8 | 9 | 10 | class Downsample2D(nn.Module): 11 | def __init__(self, mode='nearest', scale=4): 12 | super().__init__() 13 | self.mode = mode 14 | self.scale = scale 15 | 16 | def forward(self, x): 17 | n, c, h, w = x.size() 18 | x = F.interpolate(x, 19 | size=(h // self.scale + 1, w // self.scale + 1), 20 | mode=self.mode) 21 | return x 22 | 23 | 24 | def generate_coord(x): 25 | _, _, h, w = x.size() 26 | device = x.device 27 | col = torch.arange(0, h, device=device) 28 | row = torch.arange(0, w, device=device) 29 | grid_h, grid_w = torch.meshgrid(col, row) 30 | return grid_h, grid_w 31 | 32 | 33 | class PositionEmbeddingSine(nn.Module): 34 | def __init__(self, 35 | num_pos_feats=64, 36 | temperature=10000, 37 | normalize=False, 38 | scale=None): 39 | super().__init__() 40 | self.num_pos_feats = num_pos_feats 41 | self.temperature = temperature 42 | self.normalize = normalize 43 | if scale is not None and normalize is False: 44 | raise ValueError("normalize should be True if scale is passed") 45 | if scale is None: 46 | scale = 2 * math.pi 47 | self.scale = scale 48 | 49 | def forward(self, x): 50 | grid_y, grid_x = generate_coord(x) 51 | 52 | y_embed = grid_y.unsqueeze(0).float() 53 | x_embed = grid_x.unsqueeze(0).float() 54 | 55 | if self.normalize: 56 | eps = 1e-6 57 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 58 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 59 | 60 | dim_t = torch.arange(self.num_pos_feats, 61 | dtype=torch.float32, 62 | device=x.device) 63 | dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) 64 | 65 | pos_x = x_embed[:, :, :, None] / dim_t 66 | pos_y = y_embed[:, :, :, None] / dim_t 67 | pos_x = torch.stack( 68 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), 69 | dim=4).flatten(3) 70 | pos_y = torch.stack( 71 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), 72 | dim=4).flatten(3) 73 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 74 | return pos 75 | 76 | 77 | class PositionEmbeddingLearned(nn.Module): 78 | def __init__(self, num_pos_feats=64, H=30, W=30): 79 | super().__init__() 80 | self.H = H 81 | self.W = W 82 | self.pos_emb = nn.Parameter( 83 | truncated_normal_(torch.zeros(1, num_pos_feats, H, W))) 84 | 85 | def forward(self, x): 86 | bs, _, h, w = x.size() 87 | pos_emb = self.pos_emb 88 | if h != self.H or w != self.W: 89 | pos_emb = F.interpolate(pos_emb, size=(h, w), mode="bilinear") 90 | return pos_emb 91 | -------------------------------------------------------------------------------- /aot/networks/models/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.models.aot import AOT 2 | from networks.models.deaot import DeAOT 3 | 4 | 5 | def build_vos_model(name, cfg, **kwargs): 6 | if name == 'aot': 7 | return AOT(cfg, encoder=cfg.MODEL_ENCODER, **kwargs) 8 | elif name == 'deaot': 9 | return DeAOT(cfg, encoder=cfg.MODEL_ENCODER, **kwargs) 10 | else: 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /aot/networks/models/aot.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from networks.encoders import build_encoder 4 | from networks.layers.transformer import LongShortTermTransformer 5 | from networks.decoders import build_decoder 6 | from networks.layers.position import PositionEmbeddingSine 7 | 8 | 9 | class AOT(nn.Module): 10 | def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): 11 | super().__init__() 12 | self.cfg = cfg 13 | self.max_obj_num = cfg.MODEL_MAX_OBJ_NUM 14 | self.epsilon = cfg.MODEL_EPSILON 15 | 16 | self.encoder = build_encoder(encoder, 17 | frozen_bn=cfg.MODEL_FREEZE_BN, 18 | freeze_at=cfg.TRAIN_ENCODER_FREEZE_AT) 19 | self.encoder_projector = nn.Conv2d(cfg.MODEL_ENCODER_DIM[-1], 20 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 21 | kernel_size=1) 22 | 23 | self.LSTT = LongShortTermTransformer( 24 | cfg.MODEL_LSTT_NUM, 25 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 26 | cfg.MODEL_SELF_HEADS, 27 | cfg.MODEL_ATT_HEADS, 28 | emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, 29 | droppath=cfg.TRAIN_LSTT_DROPPATH, 30 | lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, 31 | st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, 32 | droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, 33 | droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, 34 | intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 35 | return_intermediate=True) 36 | 37 | decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ 38 | (cfg.MODEL_LSTT_NUM + 39 | 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM 40 | 41 | self.decoder = build_decoder( 42 | decoder, 43 | in_dim=decoder_indim, 44 | out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, 45 | decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 46 | hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, 47 | shortcut_dims=cfg.MODEL_ENCODER_DIM, 48 | align_corners=cfg.MODEL_ALIGN_CORNERS) 49 | 50 | if cfg.MODEL_ALIGN_CORNERS: 51 | self.patch_wise_id_bank = nn.Conv2d( 52 | cfg.MODEL_MAX_OBJ_NUM + 1, 53 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 54 | kernel_size=17, 55 | stride=16, 56 | padding=8) 57 | else: 58 | self.patch_wise_id_bank = nn.Conv2d( 59 | cfg.MODEL_MAX_OBJ_NUM + 1, 60 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 61 | kernel_size=16, 62 | stride=16, 63 | padding=0) 64 | 65 | self.id_dropout = nn.Dropout(cfg.TRAIN_LSTT_ID_DROPOUT, True) 66 | 67 | self.pos_generator = PositionEmbeddingSine( 68 | cfg.MODEL_ENCODER_EMBEDDING_DIM // 2, normalize=True) 69 | 70 | self._init_weight() 71 | 72 | def get_pos_emb(self, x): 73 | pos_emb = self.pos_generator(x) 74 | return pos_emb 75 | 76 | def get_id_emb(self, x): 77 | id_emb = self.patch_wise_id_bank(x) 78 | id_emb = self.id_dropout(id_emb) 79 | return id_emb 80 | 81 | def encode_image(self, img): 82 | xs = self.encoder(img) 83 | xs[-1] = self.encoder_projector(xs[-1]) 84 | return xs 85 | 86 | def decode_id_logits(self, lstt_emb, shortcuts): 87 | n, c, h, w = shortcuts[-1].size() 88 | decoder_inputs = [shortcuts[-1]] 89 | for emb in lstt_emb: 90 | decoder_inputs.append(emb.view(h, w, n, c).permute(2, 3, 0, 1)) 91 | pred_logit = self.decoder(decoder_inputs, shortcuts) 92 | return pred_logit 93 | 94 | def LSTT_forward(self, 95 | curr_embs, 96 | long_term_memories, 97 | short_term_memories, 98 | curr_id_emb=None, 99 | pos_emb=None, 100 | size_2d=(30, 30)): 101 | n, c, h, w = curr_embs[-1].size() 102 | curr_emb = curr_embs[-1].view(n, c, h * w).permute(2, 0, 1) 103 | lstt_embs, lstt_memories = self.LSTT(curr_emb, long_term_memories, 104 | short_term_memories, curr_id_emb, 105 | pos_emb, size_2d) 106 | lstt_curr_memories, lstt_long_memories, lstt_short_memories = zip( 107 | *lstt_memories) 108 | return lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories 109 | 110 | def _init_weight(self): 111 | nn.init.xavier_uniform_(self.encoder_projector.weight) 112 | nn.init.orthogonal_( 113 | self.patch_wise_id_bank.weight.view( 114 | self.cfg.MODEL_ENCODER_EMBEDDING_DIM, -1).permute(0, 1), 115 | gain=17**-2 if self.cfg.MODEL_ALIGN_CORNERS else 16**-2) 116 | -------------------------------------------------------------------------------- /aot/networks/models/deaot.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from networks.layers.transformer import DualBranchGPM 4 | from networks.models.aot import AOT 5 | from networks.decoders import build_decoder 6 | 7 | 8 | class DeAOT(AOT): 9 | def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): 10 | super().__init__(cfg, encoder, decoder) 11 | 12 | self.LSTT = DualBranchGPM( 13 | cfg.MODEL_LSTT_NUM, 14 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 15 | cfg.MODEL_SELF_HEADS, 16 | cfg.MODEL_ATT_HEADS, 17 | emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, 18 | droppath=cfg.TRAIN_LSTT_DROPPATH, 19 | lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, 20 | st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, 21 | droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, 22 | droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, 23 | intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 24 | return_intermediate=True) 25 | 26 | decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ 27 | (cfg.MODEL_LSTT_NUM * 2 + 28 | 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM * 2 29 | 30 | self.decoder = build_decoder( 31 | decoder, 32 | in_dim=decoder_indim, 33 | out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, 34 | decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 35 | hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, 36 | shortcut_dims=cfg.MODEL_ENCODER_DIM, 37 | align_corners=cfg.MODEL_ALIGN_CORNERS) 38 | 39 | self.id_norm = nn.LayerNorm(cfg.MODEL_ENCODER_EMBEDDING_DIM) 40 | 41 | self._init_weight() 42 | 43 | def decode_id_logits(self, lstt_emb, shortcuts): 44 | n, c, h, w = shortcuts[-1].size() 45 | decoder_inputs = [shortcuts[-1]] 46 | for emb in lstt_emb: 47 | decoder_inputs.append(emb.view(h, w, n, -1).permute(2, 3, 0, 1)) 48 | pred_logit = self.decoder(decoder_inputs, shortcuts) 49 | return pred_logit 50 | 51 | def get_id_emb(self, x): 52 | id_emb = self.patch_wise_id_bank(x) 53 | id_emb = self.id_norm(id_emb.permute(2, 3, 0, 1)).permute(2, 3, 0, 1) 54 | id_emb = self.id_dropout(id_emb) 55 | return id_emb 56 | -------------------------------------------------------------------------------- /aot/pretrain_models/README.md: -------------------------------------------------------------------------------- 1 | Put pretrained models here. -------------------------------------------------------------------------------- /aot/source/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/aot/source/.DS_Store -------------------------------------------------------------------------------- /aot/source/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/aot/source/overview.png -------------------------------------------------------------------------------- /aot/source/overview_deaot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/aot/source/overview_deaot.png -------------------------------------------------------------------------------- /aot/tools/eval.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | 4 | sys.path.append('.') 5 | sys.path.append('..') 6 | 7 | import torch 8 | import torch.multiprocessing as mp 9 | 10 | from networks.managers.evaluator import Evaluator 11 | 12 | 13 | def main_worker(gpu, cfg, seq_queue=None, info_queue=None, enable_amp=False): 14 | # Initiate a evaluating manager 15 | evaluator = Evaluator(rank=gpu, 16 | cfg=cfg, 17 | seq_queue=seq_queue, 18 | info_queue=info_queue) 19 | # Start evaluation 20 | if enable_amp: 21 | with torch.cuda.amp.autocast(enabled=True): 22 | evaluator.evaluating() 23 | else: 24 | evaluator.evaluating() 25 | 26 | 27 | def main(): 28 | import argparse 29 | parser = argparse.ArgumentParser(description="Eval VOS") 30 | parser.add_argument('--exp_name', type=str, default='default') 31 | 32 | parser.add_argument('--stage', type=str, default='pre') 33 | parser.add_argument('--model', type=str, default='aott') 34 | parser.add_argument('--lstt_num', type=int, default=-1) 35 | parser.add_argument('--lt_gap', type=int, default=-1) 36 | parser.add_argument('--st_skip', type=int, default=-1) 37 | parser.add_argument('--max_id_num', type=int, default='-1') 38 | 39 | parser.add_argument('--gpu_id', type=int, default=0) 40 | parser.add_argument('--gpu_num', type=int, default=1) 41 | 42 | parser.add_argument('--ckpt_path', type=str, default='') 43 | parser.add_argument('--ckpt_step', type=int, default=-1) 44 | 45 | parser.add_argument('--dataset', type=str, default='') 46 | parser.add_argument('--split', type=str, default='') 47 | 48 | parser.add_argument('--ema', action='store_true') 49 | parser.set_defaults(ema=False) 50 | 51 | parser.add_argument('--flip', action='store_true') 52 | parser.set_defaults(flip=False) 53 | parser.add_argument('--ms', nargs='+', type=float, default=[1.]) 54 | 55 | parser.add_argument('--max_resolution', type=float, default=480 * 1.3) 56 | 57 | parser.add_argument('--amp', action='store_true') 58 | parser.set_defaults(amp=False) 59 | 60 | args = parser.parse_args() 61 | 62 | engine_config = importlib.import_module('configs.' + args.stage) 63 | cfg = engine_config.EngineConfig(args.exp_name, args.model) 64 | 65 | cfg.TEST_EMA = args.ema 66 | 67 | cfg.TEST_GPU_ID = args.gpu_id 68 | cfg.TEST_GPU_NUM = args.gpu_num 69 | 70 | if args.lstt_num > 0: 71 | cfg.MODEL_LSTT_NUM = args.lstt_num 72 | if args.lt_gap > 0: 73 | cfg.TEST_LONG_TERM_MEM_GAP = args.lt_gap 74 | if args.st_skip > 0: 75 | cfg.TEST_SHORT_TERM_MEM_SKIP = args.st_skip 76 | 77 | if args.max_id_num > 0: 78 | cfg.MODEL_MAX_OBJ_NUM = args.max_id_num 79 | 80 | if args.ckpt_path != '': 81 | cfg.TEST_CKPT_PATH = args.ckpt_path 82 | if args.ckpt_step > 0: 83 | cfg.TEST_CKPT_STEP = args.ckpt_step 84 | 85 | if args.dataset != '': 86 | cfg.TEST_DATASET = args.dataset 87 | 88 | if args.split != '': 89 | cfg.TEST_DATASET_SPLIT = args.split 90 | 91 | cfg.TEST_FLIP = args.flip 92 | cfg.TEST_MULTISCALE = args.ms 93 | 94 | if cfg.TEST_MULTISCALE != [1.]: 95 | cfg.TEST_MAX_SHORT_EDGE = args.max_resolution # for preventing OOM 96 | else: 97 | cfg.TEST_MAX_SHORT_EDGE = None # the default resolution setting of CFBI and AOT 98 | cfg.TEST_MAX_LONG_EDGE = args.max_resolution * 800. / 480. 99 | 100 | if args.gpu_num > 1: 101 | mp.set_start_method('spawn') 102 | seq_queue = mp.Queue() 103 | info_queue = mp.Queue() 104 | mp.spawn(main_worker, 105 | nprocs=cfg.TEST_GPU_NUM, 106 | args=(cfg, seq_queue, info_queue, args.amp)) 107 | else: 108 | main_worker(0, cfg, enable_amp=args.amp) 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /aot/tools/train.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import random 3 | import sys 4 | 5 | sys.setrecursionlimit(10000) 6 | sys.path.append('.') 7 | sys.path.append('..') 8 | 9 | import torch.multiprocessing as mp 10 | 11 | from networks.managers.trainer import Trainer 12 | 13 | 14 | def main_worker(gpu, cfg, enable_amp=True): 15 | # Initiate a training manager 16 | trainer = Trainer(rank=gpu, cfg=cfg, enable_amp=enable_amp) 17 | # Start Training 18 | trainer.sequential_training() 19 | 20 | 21 | def main(): 22 | import argparse 23 | parser = argparse.ArgumentParser(description="Train VOS") 24 | parser.add_argument('--exp_name', type=str, default='') 25 | parser.add_argument('--stage', type=str, default='pre') 26 | parser.add_argument('--model', type=str, default='aott') 27 | parser.add_argument('--max_id_num', type=int, default='-1') 28 | 29 | parser.add_argument('--start_gpu', type=int, default=0) 30 | parser.add_argument('--gpu_num', type=int, default=-1) 31 | parser.add_argument('--batch_size', type=int, default=-1) 32 | parser.add_argument('--dist_url', type=str, default='') 33 | parser.add_argument('--amp', action='store_true') 34 | parser.set_defaults(amp=False) 35 | 36 | parser.add_argument('--pretrained_path', type=str, default='') 37 | 38 | parser.add_argument('--datasets', nargs='+', type=str, default=[]) 39 | parser.add_argument('--lr', type=float, default=-1.) 40 | parser.add_argument('--total_step', type=int, default=-1.) 41 | parser.add_argument('--start_step', type=int, default=-1.) 42 | 43 | args = parser.parse_args() 44 | 45 | engine_config = importlib.import_module('configs.' + args.stage) 46 | 47 | cfg = engine_config.EngineConfig(args.exp_name, args.model) 48 | 49 | if len(args.datasets) > 0: 50 | cfg.DATASETS = args.datasets 51 | 52 | cfg.DIST_START_GPU = args.start_gpu 53 | if args.gpu_num > 0: 54 | cfg.TRAIN_GPUS = args.gpu_num 55 | if args.batch_size > 0: 56 | cfg.TRAIN_BATCH_SIZE = args.batch_size 57 | 58 | if args.pretrained_path != '': 59 | cfg.PRETRAIN_MODEL = args.pretrained_path 60 | 61 | if args.max_id_num > 0: 62 | cfg.MODEL_MAX_OBJ_NUM = args.max_id_num 63 | 64 | if args.lr > 0: 65 | cfg.TRAIN_LR = args.lr 66 | 67 | if args.total_step > 0: 68 | cfg.TRAIN_TOTAL_STEPS = args.total_step 69 | 70 | if args.start_step > 0: 71 | cfg.TRAIN_START_STEP = args.start_step 72 | 73 | if args.dist_url == '': 74 | cfg.DIST_URL = 'tcp://127.0.0.1:123' + str(random.randint(0, 9)) + str( 75 | random.randint(0, 9)) 76 | else: 77 | cfg.DIST_URL = args.dist_url 78 | 79 | if cfg.TRAIN_GPUS > 1: 80 | # Use torch.multiprocessing.spawn to launch distributed processes 81 | mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg, args.amp)) 82 | else: 83 | cfg.TRAIN_GPUS = 1 84 | main_worker(0, cfg, args.amp) 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /aot/train_eval.sh: -------------------------------------------------------------------------------- 1 | exp="default" 2 | gpu_num="4" 3 | 4 | model="aott" 5 | # model="aots" 6 | # model="aotb" 7 | # model="aotl" 8 | # model="r50_deaotl" 9 | # model="swinb_aotl" 10 | 11 | ## Training ## 12 | stage="pre" 13 | python tools/train.py --amp \ 14 | --exp_name ${exp} \ 15 | --stage ${stage} \ 16 | --model ${model} \ 17 | --gpu_num ${gpu_num} 18 | 19 | stage="pre_ytb_dav" 20 | python tools/train.py --amp \ 21 | --exp_name ${exp} \ 22 | --stage ${stage} \ 23 | --model ${model} \ 24 | --gpu_num ${gpu_num} 25 | 26 | ## Evaluation ## 27 | dataset="davis2017" 28 | split="test" 29 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 30 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} 31 | 32 | dataset="davis2017" 33 | split="val" 34 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 35 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} 36 | 37 | dataset="davis2016" 38 | split="val" 39 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 40 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} 41 | 42 | dataset="youtubevos2018" 43 | split="val" # or "val_all_frames" 44 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 45 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} 46 | 47 | dataset="youtubevos2019" 48 | split="val" # or "val_all_frames" 49 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 50 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} -------------------------------------------------------------------------------- /aot/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/aot/utils/__init__.py -------------------------------------------------------------------------------- /aot/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import shutil 4 | import numpy as np 5 | 6 | 7 | def load_network_and_optimizer(net, opt, pretrained_dir, gpu, scaler=None): 8 | pretrained = torch.load(pretrained_dir, 9 | map_location=torch.device("cuda:" + str(gpu))) 10 | pretrained_dict = pretrained['state_dict'] 11 | model_dict = net.state_dict() 12 | pretrained_dict_update = {} 13 | pretrained_dict_remove = [] 14 | for k, v in pretrained_dict.items(): 15 | if k in model_dict: 16 | pretrained_dict_update[k] = v 17 | elif k[:7] == 'module.': 18 | if k[7:] in model_dict: 19 | pretrained_dict_update[k[7:]] = v 20 | else: 21 | pretrained_dict_remove.append(k) 22 | model_dict.update(pretrained_dict_update) 23 | net.load_state_dict(model_dict) 24 | opt.load_state_dict(pretrained['optimizer']) 25 | if scaler is not None and 'scaler' in pretrained.keys(): 26 | scaler.load_state_dict(pretrained['scaler']) 27 | del (pretrained) 28 | return net.cuda(gpu), opt, pretrained_dict_remove 29 | 30 | 31 | def load_network_and_optimizer_v2(net, opt, pretrained_dir, gpu, scaler=None): 32 | pretrained = torch.load(pretrained_dir, 33 | map_location=torch.device("cuda:" + str(gpu))) 34 | # load model 35 | pretrained_dict = pretrained['state_dict'] 36 | model_dict = net.state_dict() 37 | pretrained_dict_update = {} 38 | pretrained_dict_remove = [] 39 | for k, v in pretrained_dict.items(): 40 | if k in model_dict: 41 | pretrained_dict_update[k] = v 42 | elif k[:7] == 'module.': 43 | if k[7:] in model_dict: 44 | pretrained_dict_update[k[7:]] = v 45 | else: 46 | pretrained_dict_remove.append(k) 47 | model_dict.update(pretrained_dict_update) 48 | net.load_state_dict(model_dict) 49 | 50 | # load optimizer 51 | opt_dict = opt.state_dict() 52 | all_params = { 53 | param_group['name']: param_group['params'][0] 54 | for param_group in opt_dict['param_groups'] 55 | } 56 | pretrained_opt_dict = {'state': {}, 'param_groups': []} 57 | for idx in range(len(pretrained['optimizer']['param_groups'])): 58 | param_group = pretrained['optimizer']['param_groups'][idx] 59 | if param_group['name'] in all_params.keys(): 60 | pretrained_opt_dict['state'][all_params[ 61 | param_group['name']]] = pretrained['optimizer']['state'][ 62 | param_group['params'][0]] 63 | param_group['params'][0] = all_params[param_group['name']] 64 | pretrained_opt_dict['param_groups'].append(param_group) 65 | 66 | opt_dict.update(pretrained_opt_dict) 67 | opt.load_state_dict(opt_dict) 68 | 69 | # load scaler 70 | if scaler is not None and 'scaler' in pretrained.keys(): 71 | scaler.load_state_dict(pretrained['scaler']) 72 | del (pretrained) 73 | return net.cuda(gpu), opt, pretrained_dict_remove 74 | 75 | 76 | def load_network(net, pretrained_dir, gpu): 77 | pretrained = torch.load(pretrained_dir, 78 | map_location=torch.device("cuda:" + str(gpu))) 79 | if 'state_dict' in pretrained.keys(): 80 | pretrained_dict = pretrained['state_dict'] 81 | elif 'model' in pretrained.keys(): 82 | pretrained_dict = pretrained['model'] 83 | else: 84 | pretrained_dict = pretrained 85 | model_dict = net.state_dict() 86 | pretrained_dict_update = {} 87 | pretrained_dict_remove = [] 88 | for k, v in pretrained_dict.items(): 89 | if k in model_dict: 90 | pretrained_dict_update[k] = v 91 | elif k[:7] == 'module.': 92 | if k[7:] in model_dict: 93 | pretrained_dict_update[k[7:]] = v 94 | else: 95 | pretrained_dict_remove.append(k) 96 | model_dict.update(pretrained_dict_update) 97 | net.load_state_dict(model_dict) 98 | del (pretrained) 99 | return net.cuda(gpu), pretrained_dict_remove 100 | 101 | 102 | def save_network(net, 103 | opt, 104 | step, 105 | save_path, 106 | max_keep=8, 107 | backup_dir='./saved_models', 108 | scaler=None): 109 | ckpt = {'state_dict': net.state_dict(), 'optimizer': opt.state_dict()} 110 | if scaler is not None: 111 | ckpt['scaler'] = scaler.state_dict() 112 | 113 | try: 114 | if not os.path.exists(save_path): 115 | os.makedirs(save_path) 116 | save_file = 'save_step_%s.pth' % (step) 117 | save_dir = os.path.join(save_path, save_file) 118 | torch.save(ckpt, save_dir) 119 | except: 120 | save_path = backup_dir 121 | if not os.path.exists(save_path): 122 | os.makedirs(save_path) 123 | save_file = 'save_step_%s.pth' % (step) 124 | save_dir = os.path.join(save_path, save_file) 125 | torch.save(ckpt, save_dir) 126 | 127 | all_ckpt = os.listdir(save_path) 128 | if len(all_ckpt) > max_keep: 129 | all_step = [] 130 | for ckpt_name in all_ckpt: 131 | step = int(ckpt_name.split('_')[-1].split('.')[0]) 132 | all_step.append(step) 133 | all_step = list(np.sort(all_step))[:-max_keep] 134 | for step in all_step: 135 | ckpt_path = os.path.join(save_path, 'save_step_%s.pth' % (step)) 136 | os.system('rm {}'.format(ckpt_path)) 137 | 138 | 139 | def cp_ckpt(remote_dir="data_wd/youtube_vos_jobs/result", curr_dir="backup"): 140 | exps = os.listdir(curr_dir) 141 | for exp in exps: 142 | exp_dir = os.path.join(curr_dir, exp) 143 | stages = os.listdir(exp_dir) 144 | for stage in stages: 145 | stage_dir = os.path.join(exp_dir, stage) 146 | finals = ["ema_ckpt", "ckpt"] 147 | for final in finals: 148 | final_dir = os.path.join(stage_dir, final) 149 | ckpts = os.listdir(final_dir) 150 | for ckpt in ckpts: 151 | if '.pth' not in ckpt: 152 | continue 153 | curr_ckpt_path = os.path.join(final_dir, ckpt) 154 | remote_ckpt_path = os.path.join(remote_dir, exp, stage, 155 | final, ckpt) 156 | if os.path.exists(remote_ckpt_path): 157 | os.system('rm {}'.format(remote_ckpt_path)) 158 | try: 159 | shutil.copy(curr_ckpt_path, remote_ckpt_path) 160 | print("Copy {} to {}.".format(curr_ckpt_path, 161 | remote_ckpt_path)) 162 | except OSError as Inst: 163 | return 164 | -------------------------------------------------------------------------------- /aot/utils/cp_ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | def cp_ckpt(remote_dir="data_wd/youtube_vos_jobs/result", curr_dir="backup"): 6 | exps = os.listdir(curr_dir) 7 | for exp in exps: 8 | print("Exp: ", exp) 9 | exp_dir = os.path.join(curr_dir, exp) 10 | stages = os.listdir(exp_dir) 11 | for stage in stages: 12 | print("Stage: ", stage) 13 | stage_dir = os.path.join(exp_dir, stage) 14 | finals = ["ema_ckpt", "ckpt"] 15 | for final in finals: 16 | print("Final: ", final) 17 | final_dir = os.path.join(stage_dir, final) 18 | ckpts = os.listdir(final_dir) 19 | for ckpt in ckpts: 20 | if '.pth' not in ckpt: 21 | continue 22 | curr_ckpt_path = os.path.join(final_dir, ckpt) 23 | remote_ckpt_path = os.path.join(remote_dir, exp, stage, 24 | final, ckpt) 25 | if os.path.exists(remote_ckpt_path): 26 | os.system('rm {}'.format(remote_ckpt_path)) 27 | try: 28 | shutil.copy(curr_ckpt_path, remote_ckpt_path) 29 | print(ckpt, ': OK') 30 | except OSError as Inst: 31 | print(Inst) 32 | print(ckpt, ': Fail') 33 | 34 | 35 | if __name__ == "__main__": 36 | cp_ckpt() 37 | -------------------------------------------------------------------------------- /aot/utils/ema.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import unicode_literals 3 | 4 | import torch 5 | 6 | 7 | def get_param_buffer_for_ema(model, 8 | update_buffer=False, 9 | required_buffers=['running_mean', 'running_var']): 10 | params = model.parameters() 11 | all_param_buffer = [p for p in params if p.requires_grad] 12 | if update_buffer: 13 | named_buffers = model.named_buffers() 14 | for key, value in named_buffers: 15 | for buffer_name in required_buffers: 16 | if buffer_name in key: 17 | all_param_buffer.append(value) 18 | break 19 | return all_param_buffer 20 | 21 | 22 | class ExponentialMovingAverage: 23 | """ 24 | Maintains (exponential) moving average of a set of parameters. 25 | """ 26 | def __init__(self, parameters, decay, use_num_updates=True): 27 | """ 28 | Args: 29 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 30 | `model.parameters()`. 31 | decay: The exponential decay. 32 | use_num_updates: Whether to use number of updates when computing 33 | averages. 34 | """ 35 | if decay < 0.0 or decay > 1.0: 36 | raise ValueError('Decay must be between 0 and 1') 37 | self.decay = decay 38 | self.num_updates = 0 if use_num_updates else None 39 | self.shadow_params = [p.clone().detach() for p in parameters] 40 | self.collected_params = [] 41 | 42 | def update(self, parameters): 43 | """ 44 | Update currently maintained parameters. 45 | Call this every time the parameters are updated, such as the result of 46 | the `optimizer.step()` call. 47 | Args: 48 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 49 | parameters used to initialize this object. 50 | """ 51 | decay = self.decay 52 | if self.num_updates is not None: 53 | self.num_updates += 1 54 | decay = min(decay, 55 | (1 + self.num_updates) / (10 + self.num_updates)) 56 | one_minus_decay = 1.0 - decay 57 | with torch.no_grad(): 58 | for s_param, param in zip(self.shadow_params, parameters): 59 | s_param.sub_(one_minus_decay * (s_param - param)) 60 | 61 | def copy_to(self, parameters): 62 | """ 63 | Copy current parameters into given collection of parameters. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | updated with the stored moving averages. 67 | """ 68 | for s_param, param in zip(self.shadow_params, parameters): 69 | param.data.copy_(s_param.data) 70 | 71 | def store(self, parameters): 72 | """ 73 | Save the current parameters for restoring later. 74 | Args: 75 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 76 | temporarily stored. 77 | """ 78 | self.collected_params = [param.clone() for param in parameters] 79 | 80 | def restore(self, parameters): 81 | """ 82 | Restore the parameters stored with the `store` method. 83 | Useful to validate the model with EMA parameters without affecting the 84 | original optimization process. Store the parameters before the 85 | `copy_to` method. After validation (or model saving), use this to 86 | restore the former parameters. 87 | Args: 88 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 89 | updated with the stored parameters. 90 | """ 91 | for c_param, param in zip(self.collected_params, parameters): 92 | param.data.copy_(c_param.data) 93 | del (self.collected_params) 94 | -------------------------------------------------------------------------------- /aot/utils/eval.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import os 3 | 4 | 5 | def zip_folder(source_folder, zip_dir): 6 | f = zipfile.ZipFile(zip_dir, 'w', zipfile.ZIP_DEFLATED) 7 | pre_len = len(os.path.dirname(source_folder)) 8 | for dirpath, dirnames, filenames in os.walk(source_folder): 9 | for filename in filenames: 10 | pathfile = os.path.join(dirpath, filename) 11 | arcname = pathfile[pre_len:].strip(os.path.sep) 12 | f.write(pathfile, arcname) 13 | f.close() -------------------------------------------------------------------------------- /aot/utils/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import threading 5 | 6 | _palette = [ 7 | 0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 8 | 128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0, 9 | 128, 191, 0, 128, 64, 128, 128, 191, 128, 128, 0, 64, 0, 128, 64, 0, 0, 10 | 191, 0, 128, 191, 0, 0, 64, 128, 128, 64, 128, 22, 22, 22, 23, 23, 23, 24, 11 | 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30, 12 | 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36, 13 | 37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 14 | 43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 15 | 49, 50, 50, 50, 51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 16 | 56, 56, 56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 17 | 62, 62, 63, 63, 63, 64, 64, 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 18 | 68, 69, 69, 69, 70, 70, 70, 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 19 | 75, 75, 75, 76, 76, 76, 77, 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 20 | 81, 81, 82, 82, 82, 83, 83, 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 21 | 87, 88, 88, 88, 89, 89, 89, 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 22 | 94, 94, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 23 | 100, 100, 100, 101, 101, 101, 102, 102, 102, 103, 103, 103, 104, 104, 104, 24 | 105, 105, 105, 106, 106, 106, 107, 107, 107, 108, 108, 108, 109, 109, 109, 25 | 110, 110, 110, 111, 111, 111, 112, 112, 112, 113, 113, 113, 114, 114, 114, 26 | 115, 115, 115, 116, 116, 116, 117, 117, 117, 118, 118, 118, 119, 119, 119, 27 | 120, 120, 120, 121, 121, 121, 122, 122, 122, 123, 123, 123, 124, 124, 124, 28 | 125, 125, 125, 126, 126, 126, 127, 127, 127, 128, 128, 128, 129, 129, 129, 29 | 130, 130, 130, 131, 131, 131, 132, 132, 132, 133, 133, 133, 134, 134, 134, 30 | 135, 135, 135, 136, 136, 136, 137, 137, 137, 138, 138, 138, 139, 139, 139, 31 | 140, 140, 140, 141, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144, 32 | 145, 145, 145, 146, 146, 146, 147, 147, 147, 148, 148, 148, 149, 149, 149, 33 | 150, 150, 150, 151, 151, 151, 152, 152, 152, 153, 153, 153, 154, 154, 154, 34 | 155, 155, 155, 156, 156, 156, 157, 157, 157, 158, 158, 158, 159, 159, 159, 35 | 160, 160, 160, 161, 161, 161, 162, 162, 162, 163, 163, 163, 164, 164, 164, 36 | 165, 165, 165, 166, 166, 166, 167, 167, 167, 168, 168, 168, 169, 169, 169, 37 | 170, 170, 170, 171, 171, 171, 172, 172, 172, 173, 173, 173, 174, 174, 174, 38 | 175, 175, 175, 176, 176, 176, 177, 177, 177, 178, 178, 178, 179, 179, 179, 39 | 180, 180, 180, 181, 181, 181, 182, 182, 182, 183, 183, 183, 184, 184, 184, 40 | 185, 185, 185, 186, 186, 186, 187, 187, 187, 188, 188, 188, 189, 189, 189, 41 | 190, 190, 190, 191, 191, 191, 192, 192, 192, 193, 193, 193, 194, 194, 194, 42 | 195, 195, 195, 196, 196, 196, 197, 197, 197, 198, 198, 198, 199, 199, 199, 43 | 200, 200, 200, 201, 201, 201, 202, 202, 202, 203, 203, 203, 204, 204, 204, 44 | 205, 205, 205, 206, 206, 206, 207, 207, 207, 208, 208, 208, 209, 209, 209, 45 | 210, 210, 210, 211, 211, 211, 212, 212, 212, 213, 213, 213, 214, 214, 214, 46 | 215, 215, 215, 216, 216, 216, 217, 217, 217, 218, 218, 218, 219, 219, 219, 47 | 220, 220, 220, 221, 221, 221, 222, 222, 222, 223, 223, 223, 224, 224, 224, 48 | 225, 225, 225, 226, 226, 226, 227, 227, 227, 228, 228, 228, 229, 229, 229, 49 | 230, 230, 230, 231, 231, 231, 232, 232, 232, 233, 233, 233, 234, 234, 234, 50 | 235, 235, 235, 236, 236, 236, 237, 237, 237, 238, 238, 238, 239, 239, 239, 51 | 240, 240, 240, 241, 241, 241, 242, 242, 242, 243, 243, 243, 244, 244, 244, 52 | 245, 245, 245, 246, 246, 246, 247, 247, 247, 248, 248, 248, 249, 249, 249, 53 | 250, 250, 250, 251, 251, 251, 252, 252, 252, 253, 253, 253, 254, 254, 254, 54 | 255, 255, 255 55 | ] 56 | 57 | 58 | def label2colormap(label): 59 | 60 | m = label.astype(np.uint8) 61 | r, c = m.shape 62 | cmap = np.zeros((r, c, 3), dtype=np.uint8) 63 | cmap[:, :, 0] = (m & 1) << 7 | (m & 8) << 3 | (m & 64) >> 1 64 | cmap[:, :, 1] = (m & 2) << 6 | (m & 16) << 2 | (m & 128) >> 2 65 | cmap[:, :, 2] = (m & 4) << 5 | (m & 32) << 1 66 | return cmap 67 | 68 | 69 | def one_hot_mask(mask, cls_num): 70 | if len(mask.size()) == 3: 71 | mask = mask.unsqueeze(1) 72 | indices = torch.arange(0, cls_num + 1, 73 | device=mask.device).view(1, -1, 1, 1) 74 | return (mask == indices).float() 75 | 76 | 77 | def masked_image(image, colored_mask, mask, alpha=0.7): 78 | mask = np.expand_dims(mask > 0, axis=0) 79 | mask = np.repeat(mask, 3, axis=0) 80 | show_img = (image * alpha + colored_mask * 81 | (1 - alpha)) * mask + image * (1 - mask) 82 | return show_img 83 | 84 | 85 | def save_image(image, path): 86 | im = Image.fromarray(np.uint8(image * 255.).transpose((1, 2, 0))) 87 | im.save(path) 88 | 89 | 90 | def _save_mask(mask, path, squeeze_idx=None): 91 | if squeeze_idx is not None: 92 | unsqueezed_mask = mask * 0 93 | for idx in range(1, len(squeeze_idx)): 94 | obj_id = squeeze_idx[idx] 95 | mask_i = mask == idx 96 | unsqueezed_mask += (mask_i * obj_id).astype(np.uint8) 97 | mask = unsqueezed_mask 98 | mask = Image.fromarray(mask).convert('P') 99 | mask.putpalette(_palette) 100 | mask.save(path) 101 | 102 | 103 | def save_mask(mask_tensor, path, squeeze_idx=None): 104 | mask = mask_tensor.cpu().numpy().astype('uint8') 105 | threading.Thread(target=_save_mask, args=[mask, path, squeeze_idx]).start() 106 | 107 | 108 | def flip_tensor(tensor, dim=0): 109 | inv_idx = torch.arange(tensor.size(dim) - 1, -1, -1, 110 | device=tensor.device).long() 111 | tensor = tensor.index_select(dim, inv_idx) 112 | return tensor 113 | 114 | 115 | def shuffle_obj_mask(mask): 116 | 117 | bs, obj_num, _, _ = mask.size() 118 | new_masks = [] 119 | for idx in range(bs): 120 | now_mask = mask[idx] 121 | random_matrix = torch.eye(obj_num, device=mask.device) 122 | fg = random_matrix[1:][torch.randperm(obj_num - 1)] 123 | random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) 124 | now_mask = torch.einsum('nm,nhw->mhw', random_matrix, now_mask) 125 | new_masks.append(now_mask) 126 | 127 | return torch.stack(new_masks, dim=0) 128 | -------------------------------------------------------------------------------- /aot/utils/learning.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def adjust_learning_rate(optimizer, 5 | base_lr, 6 | p, 7 | itr, 8 | max_itr, 9 | restart=1, 10 | warm_up_steps=1000, 11 | is_cosine_decay=False, 12 | min_lr=1e-5, 13 | encoder_lr_ratio=1.0, 14 | freeze_params=[]): 15 | 16 | if restart > 1: 17 | each_max_itr = int(math.ceil(float(max_itr) / restart)) 18 | itr = itr % each_max_itr 19 | warm_up_steps /= restart 20 | max_itr = each_max_itr 21 | 22 | if itr < warm_up_steps: 23 | now_lr = min_lr + (base_lr - min_lr) * itr / warm_up_steps 24 | else: 25 | itr = itr - warm_up_steps 26 | max_itr = max_itr - warm_up_steps 27 | if is_cosine_decay: 28 | now_lr = min_lr + (base_lr - min_lr) * (math.cos(math.pi * itr / 29 | (max_itr + 1)) + 30 | 1.) * 0.5 31 | else: 32 | now_lr = min_lr + (base_lr - min_lr) * (1 - itr / (max_itr + 1))**p 33 | 34 | for param_group in optimizer.param_groups: 35 | if encoder_lr_ratio != 1.0 and "encoder." in param_group["name"]: 36 | param_group['lr'] = (now_lr - min_lr) * encoder_lr_ratio + min_lr 37 | else: 38 | param_group['lr'] = now_lr 39 | 40 | for freeze_param in freeze_params: 41 | if freeze_param in param_group["name"]: 42 | param_group['lr'] = 0 43 | param_group['weight_decay'] = 0 44 | break 45 | 46 | return now_lr 47 | 48 | 49 | def get_trainable_params(model, 50 | base_lr, 51 | weight_decay, 52 | use_frozen_bn=False, 53 | exclusive_wd_dict={}, 54 | no_wd_keys=[]): 55 | params = [] 56 | memo = set() 57 | total_param = 0 58 | for key, value in model.named_parameters(): 59 | if value in memo: 60 | continue 61 | total_param += value.numel() 62 | if not value.requires_grad: 63 | continue 64 | memo.add(value) 65 | wd = weight_decay 66 | for exclusive_key in exclusive_wd_dict.keys(): 67 | if exclusive_key in key: 68 | wd = exclusive_wd_dict[exclusive_key] 69 | break 70 | if len(value.shape) == 1: # normalization layers 71 | if 'bias' in key: # bias requires no weight decay 72 | wd = 0. 73 | elif not use_frozen_bn: # if not use frozen BN, apply zero weight decay 74 | wd = 0. 75 | elif 'encoder.' not in key: # if use frozen BN, apply weight decay to all frozen BNs in the encoder 76 | wd = 0. 77 | else: 78 | for no_wd_key in no_wd_keys: 79 | if no_wd_key in key: 80 | wd = 0. 81 | break 82 | params += [{ 83 | "params": [value], 84 | "lr": base_lr, 85 | "weight_decay": wd, 86 | "name": key 87 | }] 88 | 89 | print('Total Param: {:.2f}M'.format(total_param / 1e6)) 90 | return params 91 | 92 | 93 | def freeze_params(module): 94 | for p in module.parameters(): 95 | p.requires_grad = False 96 | 97 | 98 | def calculate_params(state_dict): 99 | memo = set() 100 | total_param = 0 101 | for key, value in state_dict.items(): 102 | if value in memo: 103 | continue 104 | memo.add(value) 105 | total_param += value.numel() 106 | print('Total Param: {:.2f}M'.format(total_param / 1e6)) 107 | -------------------------------------------------------------------------------- /aot/utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def generate_permute_matrix(dim, num, keep_first=True, gpu_id=0): 5 | all_matrix = [] 6 | for idx in range(num): 7 | random_matrix = torch.eye(dim, device=torch.device('cuda', gpu_id)) 8 | if keep_first: 9 | fg = random_matrix[1:][torch.randperm(dim - 1)] 10 | random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) 11 | else: 12 | random_matrix = random_matrix[torch.randperm(dim)] 13 | all_matrix.append(random_matrix) 14 | return torch.stack(all_matrix, dim=0) 15 | 16 | 17 | def truncated_normal_(tensor, mean=0, std=.02): 18 | size = tensor.shape 19 | tmp = tensor.new_empty(size + (4, )).normal_() 20 | valid = (tmp < 2) & (tmp > -2) 21 | ind = valid.max(-1, keepdim=True)[1] 22 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 23 | tensor.data.mul_(std).add_(mean) 24 | return tensor 25 | -------------------------------------------------------------------------------- /aot/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | def __init__(self, momentum=0.999): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | self.long_count = 0 12 | self.momentum = momentum 13 | self.moving_avg = 0 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | if self.long_count == 0: 23 | self.moving_avg = val 24 | else: 25 | momentum = min(self.momentum, 1. - 1. / self.long_count) 26 | self.moving_avg = self.moving_avg * momentum + val * (1 - momentum) 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.long_count += n 31 | self.avg = self.sum / self.count 32 | -------------------------------------------------------------------------------- /aot/utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pytorch_iou(pred, target, obj_num, epsilon=1e-6): 5 | ''' 6 | pred: [bs, h, w] 7 | target: [bs, h, w] 8 | obj_num: [bs] 9 | ''' 10 | bs = pred.size(0) 11 | all_iou = [] 12 | for idx in range(bs): 13 | now_pred = pred[idx].unsqueeze(0) 14 | now_target = target[idx].unsqueeze(0) 15 | now_obj_num = obj_num[idx] 16 | 17 | obj_ids = torch.arange(0, now_obj_num + 1, 18 | device=now_pred.device).int().view(-1, 1, 1) 19 | if obj_ids.size(0) == 1: # only contain background 20 | continue 21 | else: 22 | obj_ids = obj_ids[1:] 23 | now_pred = (now_pred == obj_ids).float() 24 | now_target = (now_target == obj_ids).float() 25 | 26 | intersection = (now_pred * now_target).sum((1, 2)) 27 | union = ((now_pred + now_target) > 0).float().sum((1, 2)) 28 | 29 | now_iou = (intersection + epsilon) / (union + epsilon) 30 | 31 | all_iou.append(now_iou.mean()) 32 | if len(all_iou) > 0: 33 | all_iou = torch.stack(all_iou).mean() 34 | else: 35 | all_iou = torch.ones((1), device=pred.device) 36 | return all_iou 37 | -------------------------------------------------------------------------------- /aot_tracker.py: -------------------------------------------------------------------------------- 1 | from statistics import mode 2 | import torch 3 | import torch.nn.functional as F 4 | import os 5 | import sys 6 | sys.path.append("./aot") 7 | from aot.networks.engines.aot_engine import AOTEngine,AOTInferEngine 8 | from aot.networks.engines.deaot_engine import DeAOTEngine,DeAOTInferEngine 9 | import importlib 10 | import numpy as np 11 | from PIL import Image 12 | from skimage.morphology.binary import binary_dilation 13 | 14 | 15 | np.random.seed(200) 16 | _palette = ((np.random.random((3*255))*0.7+0.3)*255).astype(np.uint8).tolist() 17 | _palette = [0,0,0]+_palette 18 | 19 | import aot.dataloaders.video_transforms as tr 20 | from aot.utils.checkpoint import load_network 21 | from aot.networks.models import build_vos_model 22 | from aot.networks.engines import build_engine 23 | from torchvision import transforms 24 | 25 | class AOTTracker(object): 26 | def __init__(self, cfg, gpu_id=0): 27 | self.gpu_id = gpu_id 28 | self.model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(gpu_id) 29 | self.model, _ = load_network(self.model, cfg.TEST_CKPT_PATH, gpu_id) 30 | # self.engine = self.build_tracker_engine(cfg.MODEL_ENGINE, 31 | # aot_model=self.model, 32 | # gpu_id=gpu_id, 33 | # short_term_mem_skip=4, 34 | # long_term_mem_gap=cfg.TEST_LONG_TERM_MEM_GAP) 35 | self.engine = build_engine(cfg.MODEL_ENGINE, 36 | phase='eval', 37 | aot_model=self.model, 38 | gpu_id=gpu_id, 39 | short_term_mem_skip=1, 40 | long_term_mem_gap=cfg.TEST_LONG_TERM_MEM_GAP, 41 | max_len_long_term=cfg.MAX_LEN_LONG_TERM) 42 | 43 | self.transform = transforms.Compose([ 44 | tr.MultiRestrictSize(cfg.TEST_MAX_SHORT_EDGE, 45 | cfg.TEST_MAX_LONG_EDGE, cfg.TEST_FLIP, 46 | cfg.TEST_MULTISCALE, cfg.MODEL_ALIGN_CORNERS), 47 | tr.MultiToTensor() 48 | ]) 49 | 50 | self.model.eval() 51 | 52 | @torch.no_grad() 53 | def add_reference_frame(self, frame, mask, obj_nums, frame_step, incremental=False): 54 | # mask = cv2.resize(mask, frame.shape[:2][::-1], interpolation = cv2.INTER_NEAREST) 55 | 56 | sample = { 57 | 'current_img': frame, 58 | 'current_label': mask, 59 | } 60 | 61 | sample = self.transform(sample) 62 | frame = sample[0]['current_img'].unsqueeze(0).float().cuda(self.gpu_id) 63 | mask = sample[0]['current_label'].unsqueeze(0).float().cuda(self.gpu_id) 64 | _mask = F.interpolate(mask,size=frame.shape[-2:],mode='nearest') 65 | 66 | if incremental: 67 | self.engine.add_reference_frame_incremental(frame, _mask, obj_nums=obj_nums, frame_step=frame_step) 68 | else: 69 | self.engine.add_reference_frame(frame, _mask, obj_nums=obj_nums, frame_step=frame_step) 70 | 71 | 72 | 73 | @torch.no_grad() 74 | def track(self, image): 75 | output_height, output_width = image.shape[0], image.shape[1] 76 | sample = {'current_img': image} 77 | sample = self.transform(sample) 78 | image = sample[0]['current_img'].unsqueeze(0).float().cuda(self.gpu_id) 79 | self.engine.match_propogate_one_frame(image) 80 | pred_logit = self.engine.decode_current_logits((output_height, output_width)) 81 | 82 | # pred_prob = torch.softmax(pred_logit, dim=1) 83 | pred_label = torch.argmax(pred_logit, dim=1, 84 | keepdim=True).float() 85 | 86 | return pred_label 87 | 88 | @torch.no_grad() 89 | def update_memory(self, pred_label): 90 | self.engine.update_memory(pred_label) 91 | 92 | @torch.no_grad() 93 | def restart(self): 94 | self.engine.restart_engine() 95 | 96 | @torch.no_grad() 97 | def build_tracker_engine(self, name, **kwargs): 98 | if name == 'aotengine': 99 | return AOTTrackerInferEngine(**kwargs) 100 | elif name == 'deaotengine': 101 | return DeAOTTrackerInferEngine(**kwargs) 102 | else: 103 | raise NotImplementedError 104 | 105 | 106 | class AOTTrackerInferEngine(AOTInferEngine): 107 | def __init__(self, aot_model, gpu_id=0, long_term_mem_gap=9999, short_term_mem_skip=1, max_aot_obj_num=None): 108 | super().__init__(aot_model, gpu_id, long_term_mem_gap, short_term_mem_skip, max_aot_obj_num) 109 | def add_reference_frame_incremental(self, img, mask, obj_nums, frame_step=-1): 110 | if isinstance(obj_nums, list): 111 | obj_nums = obj_nums[0] 112 | self.obj_nums = obj_nums 113 | aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1) 114 | while (aot_num > len(self.aot_engines)): 115 | new_engine = AOTEngine(self.AOT, self.gpu_id, 116 | self.long_term_mem_gap, 117 | self.short_term_mem_skip) 118 | new_engine.eval() 119 | self.aot_engines.append(new_engine) 120 | 121 | separated_masks, separated_obj_nums = self.separate_mask( 122 | mask, obj_nums) 123 | img_embs = None 124 | for aot_engine, separated_mask, separated_obj_num in zip( 125 | self.aot_engines, separated_masks, separated_obj_nums): 126 | if aot_engine.obj_nums is None or aot_engine.obj_nums[0] < separated_obj_num: 127 | aot_engine.add_reference_frame(img, 128 | separated_mask, 129 | obj_nums=[separated_obj_num], 130 | frame_step=frame_step, 131 | img_embs=img_embs) 132 | else: 133 | aot_engine.update_short_term_memory(separated_mask) 134 | 135 | if img_embs is None: # reuse image embeddings 136 | img_embs = aot_engine.curr_enc_embs 137 | 138 | self.update_size() 139 | 140 | 141 | 142 | class DeAOTTrackerInferEngine(DeAOTInferEngine): 143 | def __init__(self, aot_model, gpu_id=0, long_term_mem_gap=9999, short_term_mem_skip=1, max_aot_obj_num=None): 144 | super().__init__(aot_model, gpu_id, long_term_mem_gap, short_term_mem_skip, max_aot_obj_num) 145 | def add_reference_frame_incremental(self, img, mask, obj_nums, frame_step=-1): 146 | if isinstance(obj_nums, list): 147 | obj_nums = obj_nums[0] 148 | self.obj_nums = obj_nums 149 | aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1) 150 | while (aot_num > len(self.aot_engines)): 151 | new_engine = DeAOTEngine(self.AOT, self.gpu_id, 152 | self.long_term_mem_gap, 153 | self.short_term_mem_skip) 154 | new_engine.eval() 155 | self.aot_engines.append(new_engine) 156 | 157 | separated_masks, separated_obj_nums = self.separate_mask( 158 | mask, obj_nums) 159 | img_embs = None 160 | for aot_engine, separated_mask, separated_obj_num in zip( 161 | self.aot_engines, separated_masks, separated_obj_nums): 162 | if aot_engine.obj_nums is None or aot_engine.obj_nums[0] < separated_obj_num: 163 | aot_engine.add_reference_frame(img, 164 | separated_mask, 165 | obj_nums=[separated_obj_num], 166 | frame_step=frame_step, 167 | img_embs=img_embs) 168 | else: 169 | aot_engine.update_short_term_memory(separated_mask) 170 | 171 | if img_embs is None: # reuse image embeddings 172 | img_embs = aot_engine.curr_enc_embs 173 | 174 | self.update_size() 175 | 176 | 177 | def get_aot(args): 178 | # build vos engine 179 | engine_config = importlib.import_module('configs.' + 'pre_ytb_dav') 180 | cfg = engine_config.EngineConfig(args['phase'], args['model']) 181 | cfg.TEST_CKPT_PATH = args['model_path'] 182 | cfg.TEST_LONG_TERM_MEM_GAP = args['long_term_mem_gap'] 183 | cfg.MAX_LEN_LONG_TERM = args['max_len_long_term'] 184 | # init AOTTracker 185 | tracker = AOTTracker(cfg, args['gpu_id']) 186 | return tracker 187 | -------------------------------------------------------------------------------- /assets/840_iSXIa0hE8Ek.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/assets/840_iSXIa0hE8Ek.zip -------------------------------------------------------------------------------- /assets/blackswan.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/assets/blackswan.mp4 -------------------------------------------------------------------------------- /assets/cars.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/assets/cars.mp4 -------------------------------------------------------------------------------- /assets/cell.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/assets/cell.mp4 -------------------------------------------------------------------------------- /assets/demo_3x2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/assets/demo_3x2.gif -------------------------------------------------------------------------------- /assets/gradio.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/assets/gradio.jpg -------------------------------------------------------------------------------- /assets/interactive_webui.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/assets/interactive_webui.jpg -------------------------------------------------------------------------------- /assets/top.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/assets/top.gif -------------------------------------------------------------------------------- /img2vid.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | 4 | # set the directory containing the images 5 | img_dir = './assets/840_iSXIa0hE8Ek' 6 | 7 | # set the output video file name and codec 8 | out_file = './assets/840_iSXIa0hE8Ek.mp4' 9 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 10 | 11 | # get the dimensions of the first image 12 | img_path = os.path.join(img_dir, os.listdir(img_dir)[0]) 13 | img = cv2.imread(img_path) 14 | height, width, channels = img.shape 15 | 16 | # create the VideoWriter object 17 | out = cv2.VideoWriter(out_file, fourcc, 10, (width, height)) 18 | 19 | # loop through the images and write them to the video 20 | for img_name in sorted(os.listdir(img_dir)): 21 | img_path = os.path.join(img_dir, img_name) 22 | img = cv2.imread(img_path) 23 | out.write(img) 24 | 25 | # release the VideoWriter object and close the video file 26 | out.release() 27 | -------------------------------------------------------------------------------- /model_args.py: -------------------------------------------------------------------------------- 1 | # Explanation of generator_args is in sam/segment_anything/automatic_mask_generator.py: SamAutomaticMaskGenerator 2 | sam_args = { 3 | 'sam_checkpoint': "ckpt/sam_vit_b_01ec64.pth", 4 | 'model_type': "vit_b", 5 | 'generator_args':{ 6 | 'points_per_side': 16, 7 | 'pred_iou_thresh': 0.8, 8 | 'stability_score_thresh': 0.9, 9 | 'crop_n_layers': 1, 10 | 'crop_n_points_downscale_factor': 2, 11 | 'min_mask_region_area': 200, 12 | }, 13 | 'gpu_id': 0, 14 | } 15 | aot_args = { 16 | 'phase': 'PRE_YTB_DAV', 17 | 'model': 'r50_deaotl', 18 | 'model_path': 'ckpt/R50_DeAOTL_PRE_YTB_DAV.pth', 19 | 'long_term_mem_gap': 9999, 20 | 'max_len_long_term': 9999, 21 | 'gpu_id': 0, 22 | } 23 | segtracker_args = { 24 | 'sam_gap': 10, # the interval to run sam to segment new objects 25 | 'min_area': 200, # minimal mask area to add a new mask as a new object 26 | 'max_obj_num': 255, # maximal object number to track in a video 27 | 'min_new_obj_iou': 0.8, # the background area ratio of a new object should > 80% 28 | } -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | import os, csv, argparse 2 | import sys 3 | import torch, torchaudio, timm 4 | import numpy as np 5 | from torch.cuda.amp import autocast 6 | import IPython 7 | current_directory = os.path.dirname(os.path.abspath(__file__)) 8 | sys.path.append(current_directory) 9 | from src.models import ASTModel 10 | 11 | # Create a new class that inherits the original ASTModel class 12 | class ASTModelVis(ASTModel): 13 | def get_att_map(self, block, x): 14 | qkv = block.attn.qkv 15 | num_heads = block.attn.num_heads 16 | scale = block.attn.scale 17 | B, N, C = x.shape 18 | qkv = qkv(x).reshape(B, N, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) 19 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 20 | attn = (q @ k.transpose(-2, -1)) * scale 21 | attn = attn.softmax(dim=-1) 22 | return attn 23 | 24 | def forward_visualization(self, x): 25 | # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128) 26 | x = x.unsqueeze(1) 27 | x = x.transpose(2, 3) 28 | 29 | B = x.shape[0] 30 | x = self.v.patch_embed(x) 31 | cls_tokens = self.v.cls_token.expand(B, -1, -1) 32 | dist_token = self.v.dist_token.expand(B, -1, -1) 33 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 34 | x = x + self.v.pos_embed 35 | x = self.v.pos_drop(x) 36 | # save the attention map of each of 12 Transformer layer 37 | att_list = [] 38 | for blk in self.v.blocks: 39 | cur_att = self.get_att_map(blk, x) 40 | att_list.append(cur_att) 41 | x = blk(x) 42 | return att_list 43 | 44 | def make_features(wav_name, mel_bins, target_length=1024): 45 | waveform, sr = torchaudio.load(wav_name) 46 | # assert sr == 16000, 'input audio sampling rate must be 16kHz' 47 | 48 | fbank = torchaudio.compliance.kaldi.fbank( 49 | waveform, htk_compat=True, sample_frequency=sr, use_energy=False, 50 | window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10) 51 | 52 | n_frames = fbank.shape[0] 53 | 54 | p = target_length - n_frames 55 | if p > 0: 56 | m = torch.nn.ZeroPad2d((0, 0, 0, p)) 57 | fbank = m(fbank) 58 | elif p < 0: 59 | fbank = fbank[0:target_length, :] 60 | 61 | fbank = (fbank - (-4.2677393)) / (4.5689974 * 2) 62 | return fbank 63 | 64 | 65 | def load_label(label_csv): 66 | with open(label_csv, 'r') as f: 67 | reader = csv.reader(f, delimiter=',') 68 | lines = list(reader) 69 | labels = [] 70 | ids = [] # Each label has a unique id such as "/m/068hy" 71 | for i1 in range(1, len(lines)): 72 | id = lines[i1][1] 73 | label = lines[i1][2] 74 | ids.append(id) 75 | labels.append(label) 76 | return labels 77 | 78 | def ASTpredict(): 79 | # Assume each input spectrogram has 1024 time frames 80 | input_tdim = 1024 81 | checkpoint_path = './ast_master/pretrained_models/audio_mdl.pth' 82 | # now load the visualization model 83 | ast_mdl = ASTModelVis(label_dim=527, input_tdim=input_tdim, imagenet_pretrain=False, audioset_pretrain=False) 84 | print(f'[*INFO] load checkpoint: {checkpoint_path}') 85 | checkpoint = torch.load(checkpoint_path, map_location='cuda') 86 | audio_model = torch.nn.DataParallel(ast_mdl, device_ids=[0]) 87 | audio_model.load_state_dict(checkpoint) 88 | audio_model = audio_model.to(torch.device("cuda:0")) 89 | audio_model.eval() 90 | 91 | # Load the AudioSet label set 92 | label_csv = './ast_master/egs/audioset/data/class_labels_indices.csv' # label and indices for audioset data 93 | labels = load_label(label_csv) 94 | 95 | feats = make_features("./audio.flac", mel_bins=128) # shape(1024, 128) 96 | feats_data = feats.expand(1, input_tdim, 128) # reshape the feature 97 | feats_data = feats_data.to(torch.device("cuda:0")) 98 | # do some masking of the input 99 | #feats_data[:, :512, :] = 0. 100 | 101 | # Make the prediction 102 | with torch.no_grad(): 103 | with autocast(): 104 | output = audio_model.forward(feats_data) 105 | output = torch.sigmoid(output) 106 | result_output = output.data.cpu().numpy()[0] 107 | sorted_indexes = np.argsort(result_output)[::-1] 108 | 109 | # Print audio tagging top probabilities 110 | print('Predice results:') 111 | for k in range(10): 112 | print('- {}: {:.4f}'.format(np.array(labels)[sorted_indexes[k]], result_output[sorted_indexes[k]])) 113 | #return the top 10 labels and their probabilities 114 | top_labels_probs = {} 115 | top_labels = {} 116 | for k in range(10): 117 | label = np.array(labels)[sorted_indexes[k]] 118 | prob = result_output[sorted_indexes[k]] 119 | top_labels[k]= label 120 | top_labels_probs[k]= prob 121 | return top_labels, top_labels_probs 122 | -------------------------------------------------------------------------------- /sam/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811, C101, EXE001, EXE002 3 | max-line-length = 100 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 6 | per-file-ignores = 7 | **/__init__.py:F401,F403,E402 8 | -------------------------------------------------------------------------------- /sam/.gitignore: -------------------------------------------------------------------------------- 1 | .nfs* 2 | 3 | # compilation and distribution 4 | __pycache__ 5 | _ext 6 | *.pyc 7 | *.pyd 8 | *.so 9 | *.dll 10 | *.egg-info/ 11 | build/ 12 | dist/ 13 | wheels/ 14 | 15 | # pytorch/python/numpy formats 16 | *.pth 17 | *.pkl 18 | *.npy 19 | *.ts 20 | model_ts*.txt 21 | 22 | # onnx models 23 | *.onnx 24 | 25 | # ipython/jupyter notebooks 26 | **/.ipynb_checkpoints/ 27 | 28 | # Editor temporaries 29 | *.swn 30 | *.swo 31 | *.swp 32 | *~ 33 | 34 | # editor settings 35 | .idea 36 | .vscode 37 | _darcs 38 | -------------------------------------------------------------------------------- /sam/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /sam/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to segment-anything 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints, using the `linter.sh` script in the project's root directory. Linting requires `black==23.*`, `isort==5.12.0`, `flake8`, and `mypy`. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to segment-anything, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /sam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/sam/__init__.py -------------------------------------------------------------------------------- /sam/assets/masks1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/sam/assets/masks1.png -------------------------------------------------------------------------------- /sam/assets/masks2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/sam/assets/masks2.jpg -------------------------------------------------------------------------------- /sam/assets/model_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/sam/assets/model_diagram.png -------------------------------------------------------------------------------- /sam/assets/notebook1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/sam/assets/notebook1.png -------------------------------------------------------------------------------- /sam/assets/notebook2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/sam/assets/notebook2.png -------------------------------------------------------------------------------- /sam/linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | { 5 | black --version | grep -E "23\." > /dev/null 6 | } || { 7 | echo "Linter requires 'black==23.*' !" 8 | exit 1 9 | } 10 | 11 | ISORT_VERSION=$(isort --version-number) 12 | if [[ "$ISORT_VERSION" != 5.12* ]]; then 13 | echo "Linter requires isort==5.12.0 !" 14 | exit 1 15 | fi 16 | 17 | echo "Running isort ..." 18 | isort . --atomic 19 | 20 | echo "Running black ..." 21 | black -l 100 . 22 | 23 | echo "Running flake8 ..." 24 | if [ -x "$(command -v flake8)" ]; then 25 | flake8 . 26 | else 27 | python3 -m flake8 . 28 | fi 29 | 30 | echo "Running mypy..." 31 | 32 | mypy --exclude 'setup.py|notebooks' . 33 | -------------------------------------------------------------------------------- /sam/notebooks/images/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/sam/notebooks/images/dog.jpg -------------------------------------------------------------------------------- /sam/notebooks/images/groceries.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/sam/notebooks/images/groceries.jpg -------------------------------------------------------------------------------- /sam/notebooks/images/truck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/sam/notebooks/images/truck.jpg -------------------------------------------------------------------------------- /sam/scripts/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import cv2 # type: ignore 8 | 9 | from segment_anything import SamAutomaticMaskGenerator, sam_model_registry 10 | 11 | import argparse 12 | import json 13 | import os 14 | from typing import Any, Dict, List 15 | 16 | parser = argparse.ArgumentParser( 17 | description=( 18 | "Runs automatic mask generation on an input image or directory of images, " 19 | "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " 20 | "as well as pycocotools if saving in RLE format." 21 | ) 22 | ) 23 | 24 | parser.add_argument( 25 | "--input", 26 | type=str, 27 | required=True, 28 | help="Path to either a single input image or folder of images.", 29 | ) 30 | 31 | parser.add_argument( 32 | "--output", 33 | type=str, 34 | required=True, 35 | help=( 36 | "Path to the directory where masks will be output. Output will be either a folder " 37 | "of PNGs per image or a single json with COCO-style masks." 38 | ), 39 | ) 40 | 41 | parser.add_argument( 42 | "--model-type", 43 | type=str, 44 | required=True, 45 | help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']", 46 | ) 47 | 48 | parser.add_argument( 49 | "--checkpoint", 50 | type=str, 51 | required=True, 52 | help="The path to the SAM checkpoint to use for mask generation.", 53 | ) 54 | 55 | parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.") 56 | 57 | parser.add_argument( 58 | "--convert-to-rle", 59 | action="store_true", 60 | help=( 61 | "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. " 62 | "Requires pycocotools." 63 | ), 64 | ) 65 | 66 | amg_settings = parser.add_argument_group("AMG Settings") 67 | 68 | amg_settings.add_argument( 69 | "--points-per-side", 70 | type=int, 71 | default=None, 72 | help="Generate masks by sampling a grid over the image with this many points to a side.", 73 | ) 74 | 75 | amg_settings.add_argument( 76 | "--points-per-batch", 77 | type=int, 78 | default=None, 79 | help="How many input points to process simultaneously in one batch.", 80 | ) 81 | 82 | amg_settings.add_argument( 83 | "--pred-iou-thresh", 84 | type=float, 85 | default=None, 86 | help="Exclude masks with a predicted score from the model that is lower than this threshold.", 87 | ) 88 | 89 | amg_settings.add_argument( 90 | "--stability-score-thresh", 91 | type=float, 92 | default=None, 93 | help="Exclude masks with a stability score lower than this threshold.", 94 | ) 95 | 96 | amg_settings.add_argument( 97 | "--stability-score-offset", 98 | type=float, 99 | default=None, 100 | help="Larger values perturb the mask more when measuring stability score.", 101 | ) 102 | 103 | amg_settings.add_argument( 104 | "--box-nms-thresh", 105 | type=float, 106 | default=None, 107 | help="The overlap threshold for excluding a duplicate mask.", 108 | ) 109 | 110 | amg_settings.add_argument( 111 | "--crop-n-layers", 112 | type=int, 113 | default=None, 114 | help=( 115 | "If >0, mask generation is run on smaller crops of the image to generate more masks. " 116 | "The value sets how many different scales to crop at." 117 | ), 118 | ) 119 | 120 | amg_settings.add_argument( 121 | "--crop-nms-thresh", 122 | type=float, 123 | default=None, 124 | help="The overlap threshold for excluding duplicate masks across different crops.", 125 | ) 126 | 127 | amg_settings.add_argument( 128 | "--crop-overlap-ratio", 129 | type=int, 130 | default=None, 131 | help="Larger numbers mean image crops will overlap more.", 132 | ) 133 | 134 | amg_settings.add_argument( 135 | "--crop-n-points-downscale-factor", 136 | type=int, 137 | default=None, 138 | help="The number of points-per-side in each layer of crop is reduced by this factor.", 139 | ) 140 | 141 | amg_settings.add_argument( 142 | "--min-mask-region-area", 143 | type=int, 144 | default=None, 145 | help=( 146 | "Disconnected mask regions or holes with area smaller than this value " 147 | "in pixels are removed by postprocessing." 148 | ), 149 | ) 150 | 151 | 152 | def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None: 153 | header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa 154 | metadata = [header] 155 | for i, mask_data in enumerate(masks): 156 | mask = mask_data["segmentation"] 157 | filename = f"{i}.png" 158 | cv2.imwrite(os.path.join(path, filename), mask * 255) 159 | mask_metadata = [ 160 | str(i), 161 | str(mask_data["area"]), 162 | *[str(x) for x in mask_data["bbox"]], 163 | *[str(x) for x in mask_data["point_coords"][0]], 164 | str(mask_data["predicted_iou"]), 165 | str(mask_data["stability_score"]), 166 | *[str(x) for x in mask_data["crop_box"]], 167 | ] 168 | row = ",".join(mask_metadata) 169 | metadata.append(row) 170 | metadata_path = os.path.join(path, "metadata.csv") 171 | with open(metadata_path, "w") as f: 172 | f.write("\n".join(metadata)) 173 | 174 | return 175 | 176 | 177 | def get_amg_kwargs(args): 178 | amg_kwargs = { 179 | "points_per_side": args.points_per_side, 180 | "points_per_batch": args.points_per_batch, 181 | "pred_iou_thresh": args.pred_iou_thresh, 182 | "stability_score_thresh": args.stability_score_thresh, 183 | "stability_score_offset": args.stability_score_offset, 184 | "box_nms_thresh": args.box_nms_thresh, 185 | "crop_n_layers": args.crop_n_layers, 186 | "crop_nms_thresh": args.crop_nms_thresh, 187 | "crop_overlap_ratio": args.crop_overlap_ratio, 188 | "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor, 189 | "min_mask_region_area": args.min_mask_region_area, 190 | } 191 | amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} 192 | return amg_kwargs 193 | 194 | 195 | def main(args: argparse.Namespace) -> None: 196 | print("Loading model...") 197 | sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) 198 | _ = sam.to(device=args.device) 199 | output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" 200 | amg_kwargs = get_amg_kwargs(args) 201 | generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) 202 | 203 | if not os.path.isdir(args.input): 204 | targets = [args.input] 205 | else: 206 | targets = [ 207 | f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f)) 208 | ] 209 | targets = [os.path.join(args.input, f) for f in targets] 210 | 211 | os.makedirs(args.output, exist_ok=True) 212 | 213 | for t in targets: 214 | print(f"Processing '{t}'...") 215 | image = cv2.imread(t) 216 | if image is None: 217 | print(f"Could not load '{t}' as an image, skipping...") 218 | continue 219 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 220 | 221 | masks = generator.generate(image) 222 | 223 | base = os.path.basename(t) 224 | base = os.path.splitext(base)[0] 225 | save_base = os.path.join(args.output, base) 226 | if output_mode == "binary_mask": 227 | os.makedirs(save_base, exist_ok=False) 228 | write_masks_to_folder(masks, save_base) 229 | else: 230 | save_file = save_base + ".json" 231 | with open(save_file, "w") as f: 232 | json.dump(masks, f) 233 | print("Done!") 234 | 235 | 236 | if __name__ == "__main__": 237 | args = parser.parse_args() 238 | main(args) 239 | -------------------------------------------------------------------------------- /sam/scripts/export_onnx_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from segment_anything import sam_model_registry 10 | from segment_anything.utils.onnx import SamOnnxModel 11 | 12 | import argparse 13 | import warnings 14 | 15 | try: 16 | import onnxruntime # type: ignore 17 | 18 | onnxruntime_exists = True 19 | except ImportError: 20 | onnxruntime_exists = False 21 | 22 | parser = argparse.ArgumentParser( 23 | description="Export the SAM prompt encoder and mask decoder to an ONNX model." 24 | ) 25 | 26 | parser.add_argument( 27 | "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint." 28 | ) 29 | 30 | parser.add_argument( 31 | "--output", type=str, required=True, help="The filename to save the ONNX model to." 32 | ) 33 | 34 | parser.add_argument( 35 | "--model-type", 36 | type=str, 37 | required=True, 38 | help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.", 39 | ) 40 | 41 | parser.add_argument( 42 | "--return-single-mask", 43 | action="store_true", 44 | help=( 45 | "If true, the exported ONNX model will only return the best mask, " 46 | "instead of returning multiple masks. For high resolution images " 47 | "this can improve runtime when upscaling masks is expensive." 48 | ), 49 | ) 50 | 51 | parser.add_argument( 52 | "--opset", 53 | type=int, 54 | default=17, 55 | help="The ONNX opset version to use. Must be >=11", 56 | ) 57 | 58 | parser.add_argument( 59 | "--quantize-out", 60 | type=str, 61 | default=None, 62 | help=( 63 | "If set, will quantize the model and save it with this name. " 64 | "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." 65 | ), 66 | ) 67 | 68 | parser.add_argument( 69 | "--gelu-approximate", 70 | action="store_true", 71 | help=( 72 | "Replace GELU operations with approximations using tanh. Useful " 73 | "for some runtimes that have slow or unimplemented erf ops, used in GELU." 74 | ), 75 | ) 76 | 77 | parser.add_argument( 78 | "--use-stability-score", 79 | action="store_true", 80 | help=( 81 | "Replaces the model's predicted mask quality score with the stability " 82 | "score calculated on the low resolution masks using an offset of 1.0. " 83 | ), 84 | ) 85 | 86 | parser.add_argument( 87 | "--return-extra-metrics", 88 | action="store_true", 89 | help=( 90 | "The model will return five results: (masks, scores, stability_scores, " 91 | "areas, low_res_logits) instead of the usual three. This can be " 92 | "significantly slower for high resolution outputs." 93 | ), 94 | ) 95 | 96 | 97 | def run_export( 98 | model_type: str, 99 | checkpoint: str, 100 | output: str, 101 | opset: int, 102 | return_single_mask: bool, 103 | gelu_approximate: bool = False, 104 | use_stability_score: bool = False, 105 | return_extra_metrics=False, 106 | ): 107 | print("Loading model...") 108 | sam = sam_model_registry[model_type](checkpoint=checkpoint) 109 | 110 | onnx_model = SamOnnxModel( 111 | model=sam, 112 | return_single_mask=return_single_mask, 113 | use_stability_score=use_stability_score, 114 | return_extra_metrics=return_extra_metrics, 115 | ) 116 | 117 | if gelu_approximate: 118 | for n, m in onnx_model.named_modules(): 119 | if isinstance(m, torch.nn.GELU): 120 | m.approximate = "tanh" 121 | 122 | dynamic_axes = { 123 | "point_coords": {1: "num_points"}, 124 | "point_labels": {1: "num_points"}, 125 | } 126 | 127 | embed_dim = sam.prompt_encoder.embed_dim 128 | embed_size = sam.prompt_encoder.image_embedding_size 129 | mask_input_size = [4 * x for x in embed_size] 130 | dummy_inputs = { 131 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), 132 | "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), 133 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), 134 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), 135 | "has_mask_input": torch.tensor([1], dtype=torch.float), 136 | "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), 137 | } 138 | 139 | _ = onnx_model(**dummy_inputs) 140 | 141 | output_names = ["masks", "iou_predictions", "low_res_masks"] 142 | 143 | with warnings.catch_warnings(): 144 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 145 | warnings.filterwarnings("ignore", category=UserWarning) 146 | with open(output, "wb") as f: 147 | print(f"Exporing onnx model to {output}...") 148 | torch.onnx.export( 149 | onnx_model, 150 | tuple(dummy_inputs.values()), 151 | f, 152 | export_params=True, 153 | verbose=False, 154 | opset_version=opset, 155 | do_constant_folding=True, 156 | input_names=list(dummy_inputs.keys()), 157 | output_names=output_names, 158 | dynamic_axes=dynamic_axes, 159 | ) 160 | 161 | if onnxruntime_exists: 162 | ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()} 163 | ort_session = onnxruntime.InferenceSession(output) 164 | _ = ort_session.run(None, ort_inputs) 165 | print("Model has successfully been run with ONNXRuntime.") 166 | 167 | 168 | def to_numpy(tensor): 169 | return tensor.cpu().numpy() 170 | 171 | 172 | if __name__ == "__main__": 173 | args = parser.parse_args() 174 | run_export( 175 | model_type=args.model_type, 176 | checkpoint=args.checkpoint, 177 | output=args.output, 178 | opset=args.opset, 179 | return_single_mask=args.return_single_mask, 180 | gelu_approximate=args.gelu_approximate, 181 | use_stability_score=args.use_stability_score, 182 | return_extra_metrics=args.return_extra_metrics, 183 | ) 184 | 185 | if args.quantize_out is not None: 186 | assert onnxruntime_exists, "onnxruntime is required to quantize the model." 187 | from onnxruntime.quantization import QuantType # type: ignore 188 | from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore 189 | 190 | print(f"Quantizing model and writing to {args.quantize_out}...") 191 | quantize_dynamic( 192 | model_input=args.output, 193 | model_output=args.quantize_out, 194 | optimize_model=True, 195 | per_channel=False, 196 | reduce_range=False, 197 | weight_type=QuantType.QUInt8, 198 | ) 199 | print("Done!") 200 | -------------------------------------------------------------------------------- /sam/segment_anything/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/sam/segment_anything/.DS_Store -------------------------------------------------------------------------------- /sam/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /sam/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /sam/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /sam/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /sam/segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | tranformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for outptu 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /sam/segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input promts, 89 | C is determiend by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /sam/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam/segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size) 85 | masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])] 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /sam/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /sam/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=3 4 | include_trailing_comma=True 5 | known_standard_library=numpy,setuptools 6 | skip_glob=*/__init__.py 7 | known_myself=segment_anything 8 | known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort 9 | no_lines_before=STDLIB,THIRDPARTY 10 | sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER 11 | default_section=FIRSTPARTY 12 | -------------------------------------------------------------------------------- /sam/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | setup( 10 | name="segment_anything", 11 | version="1.0", 12 | install_requires=[], 13 | packages=find_packages(exclude="notebooks"), 14 | extras_require={ 15 | "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"], 16 | "dev": ["flake8", "isort", "black", "mypy"], 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /script/download_ckpt.sh: -------------------------------------------------------------------------------- 1 | # download aot-ckpt 2 | gdown --id '1QoChMkTVxdYZ_eBlZhK2acq9KMQZccPJ' --output ./ckpt/R50_DeAOTL_PRE_YTB_DAV.pth 3 | 4 | # download sam-ckpt 5 | wget -P ./ckpt https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth 6 | 7 | # download grounding-dino ckpt 8 | wget -P ./ckpt https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth 9 | 10 | # download mit-ast-finetuned ckpt 11 | wget -O ./ast_master/pretrained_models/audio_mdl.pth https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1 -------------------------------------------------------------------------------- /script/install.sh: -------------------------------------------------------------------------------- 1 | # Install SAM 2 | cd sam; pip install -e . 3 | cd - 4 | 5 | # Install Grounding-Dino 6 | pip install -e git+https://github.com/IDEA-Research/GroundingDINO.git@main#egg=GroundingDINO 7 | 8 | # Install other lib 9 | pip install numpy opencv-python pycocotools matplotlib Pillow==9.2.0 scikit-image 10 | pip install gradio==3.39.0 zip gdown ffmpeg==1.4 11 | pip install timm==0.4.5 12 | pip install wget 13 | 14 | # Install Pytorch Correlation 15 | git clone https://github.com/ClementPinard/Pytorch-Correlation-extension.git 16 | cd Pytorch-Correlation-extension 17 | python setup.py install 18 | cd - 19 | 20 | # Install AST 21 | git clone https://github.com/YuanGongND/ast.git ast_master 22 | cp ./prepare.py ./ast_master 23 | 24 | -------------------------------------------------------------------------------- /tool/detector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | import PIL 5 | 6 | from groundingdino.models import build_model as build_grounding_dino 7 | from groundingdino.util.slconfig import SLConfig 8 | from groundingdino.util.utils import clean_state_dict 9 | from groundingdino.util.inference import annotate, load_image, predict 10 | import groundingdino.datasets.transforms as T 11 | 12 | from torchvision.ops import box_convert 13 | 14 | class Detector: 15 | def __init__(self, device): 16 | config_file = "src/groundingdino/groundingdino/config/GroundingDINO_SwinT_OGC.py" 17 | grounding_dino_ckpt = './ckpt/groundingdino_swint_ogc.pth' 18 | args = SLConfig.fromfile(config_file) 19 | args.device = device 20 | self.deivce = device 21 | self.gd = build_grounding_dino(args) 22 | 23 | checkpoint = torch.load(grounding_dino_ckpt, map_location='cpu') 24 | log = self.gd.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) 25 | print("Model loaded from {} \n => {}".format(grounding_dino_ckpt, log)) 26 | self.gd.eval() 27 | 28 | def image_transform_grounding(self, init_image): 29 | transform = T.Compose([ 30 | T.RandomResize([800], max_size=1333), 31 | T.ToTensor(), 32 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 33 | ]) 34 | image, _ = transform(init_image, None) # 3, h, w 35 | return init_image, image 36 | 37 | def image_transform_grounding_for_vis(self, init_image): 38 | transform = T.Compose([ 39 | T.RandomResize([800], max_size=1333), 40 | ]) 41 | image, _ = transform(init_image, None) # 3, h, w 42 | return image 43 | 44 | def transfer_boxes_format(self, boxes, height, width): 45 | boxes = boxes * torch.Tensor([width, height, width, height]) 46 | boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy") 47 | 48 | transfered_boxes = [] 49 | for i in range(len(boxes)): 50 | box = boxes[i] 51 | transfered_box = [[int(box[0]), int(box[1])], [int(box[2]), int(box[3])]] 52 | transfered_boxes.append(transfered_box) 53 | 54 | transfered_boxes = np.array(transfered_boxes) 55 | return transfered_boxes 56 | 57 | @torch.no_grad() 58 | def run_grounding(self, origin_frame, grounding_caption, box_threshold, text_threshold): 59 | ''' 60 | return: 61 | annotated_frame:nd.array 62 | transfered_boxes: nd.array [N, 4]: [[x0, y0], [x1, y1]] 63 | ''' 64 | height, width, _ = origin_frame.shape 65 | img_pil = PIL.Image.fromarray(origin_frame) 66 | re_width, re_height = img_pil.size 67 | _, image_tensor = self.image_transform_grounding(img_pil) 68 | # img_pil = self.image_transform_grounding_for_vis(img_pil) 69 | 70 | # run grounidng 71 | boxes, logits, phrases = predict(self.gd, image_tensor, grounding_caption, box_threshold, text_threshold, device=self.deivce) 72 | annotated_frame = annotate(image_source=np.asarray(img_pil), boxes=boxes, logits=logits, phrases=phrases)[:, :, ::-1] 73 | annotated_frame = cv2.resize(annotated_frame, (width, height), interpolation=cv2.INTER_LINEAR) 74 | 75 | # transfer boxes to sam-format 76 | transfered_boxes = self.transfer_boxes_format(boxes, re_height, re_width) 77 | return annotated_frame, transfered_boxes 78 | 79 | if __name__ == "__main__": 80 | detector = Detector("cuda") 81 | origin_frame = cv2.imread('./debug/point.png') 82 | origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_BGR2RGB) 83 | grounding_caption = "swan.water" 84 | box_threshold = 0.25 85 | text_threshold = 0.25 86 | 87 | annotated_frame, boxes = detector.run_grounding(origin_frame, grounding_caption, box_threshold, text_threshold) 88 | cv2.imwrite('./debug/x.png', annotated_frame) 89 | 90 | for i in range(len(boxes)): 91 | bbox = boxes[i] 92 | origin_frame = cv2.rectangle(origin_frame, bbox[0], bbox[1], (0, 0, 255)) 93 | cv2.imwrite('./debug/bbox_frame.png', origin_frame) -------------------------------------------------------------------------------- /tool/segmentor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | from sam.segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator 5 | 6 | class Segmentor: 7 | def __init__(self, sam_args): 8 | """ 9 | sam_args: 10 | sam_checkpoint: path of SAM checkpoint 11 | generator_args: args for everything_generator 12 | gpu_id: device 13 | """ 14 | self.device = sam_args["gpu_id"] 15 | self.sam = sam_model_registry[sam_args["model_type"]](checkpoint=sam_args["sam_checkpoint"]) 16 | self.sam.to(device=self.device) 17 | self.everything_generator = SamAutomaticMaskGenerator(model=self.sam, **sam_args['generator_args']) 18 | self.interactive_predictor = self.everything_generator.predictor 19 | self.have_embedded = False 20 | 21 | @torch.no_grad() 22 | def set_image(self, image): 23 | # calculate the embedding only once per frame. 24 | if not self.have_embedded: 25 | self.interactive_predictor.set_image(image) 26 | self.have_embedded = True 27 | @torch.no_grad() 28 | def interactive_predict(self, prompts, mode, multimask=True): 29 | assert self.have_embedded, 'image embedding for sam need be set before predict.' 30 | 31 | if mode == 'point': 32 | masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'], 33 | point_labels=prompts['point_modes'], 34 | multimask_output=multimask) 35 | elif mode == 'mask': 36 | masks, scores, logits = self.interactive_predictor.predict(mask_input=prompts['mask_prompt'], 37 | multimask_output=multimask) 38 | elif mode == 'point_mask': 39 | masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'], 40 | point_labels=prompts['point_modes'], 41 | mask_input=prompts['mask_prompt'], 42 | multimask_output=multimask) 43 | 44 | return masks, scores, logits 45 | 46 | @torch.no_grad() 47 | def segment_with_click(self, origin_frame, coords, modes, multimask=True): 48 | ''' 49 | 50 | return: 51 | mask: one-hot 52 | ''' 53 | self.set_image(origin_frame) 54 | 55 | prompts = { 56 | 'point_coords': coords, 57 | 'point_modes': modes, 58 | } 59 | masks, scores, logits = self.interactive_predict(prompts, 'point', multimask) 60 | mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] 61 | prompts = { 62 | 'point_coords': coords, 63 | 'point_modes': modes, 64 | 'mask_prompt': logit[None, :, :] 65 | } 66 | masks, scores, logits = self.interactive_predict(prompts, 'point_mask', multimask) 67 | mask = masks[np.argmax(scores)] 68 | 69 | return mask.astype(np.uint8) 70 | 71 | def segment_with_box(self, origin_frame, bbox, reset_image=False): 72 | if reset_image: 73 | self.interactive_predictor.set_image(origin_frame) 74 | else: 75 | self.set_image(origin_frame) 76 | # coord = np.array([[int((bbox[1][0] - bbox[0][0]) / 2.), int((bbox[1][1] - bbox[0][1]) / 2)]]) 77 | # point_label = np.array([1]) 78 | 79 | masks, scores, logits = self.interactive_predictor.predict( 80 | point_coords=None, 81 | point_labels=None, 82 | box=np.array([bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]]), 83 | multimask_output=True 84 | ) 85 | mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] 86 | 87 | masks, scores, logits = self.interactive_predictor.predict( 88 | point_coords=None, 89 | point_labels=None, 90 | box=np.array([[bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]]]), 91 | mask_input=logit[None, :, :], 92 | multimask_output=True 93 | ) 94 | mask = masks[np.argmax(scores)] 95 | 96 | return [mask] 97 | -------------------------------------------------------------------------------- /tool/transfer_tools.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | def mask2bbox(mask): 5 | if len(np.where(mask > 0)[0]) == 0: 6 | print(f'not mask') 7 | return np.array([[0, 0], [0, 0]]).astype(np.int64) 8 | 9 | x_ = np.sum(mask, axis=0) 10 | y_ = np.sum(mask, axis=1) 11 | 12 | x0 = np.min(np.nonzero(x_)[0]) 13 | x1 = np.max(np.nonzero(x_)[0]) 14 | y0 = np.min(np.nonzero(y_)[0]) 15 | y1 = np.max(np.nonzero(y_)[0]) 16 | 17 | return np.array([[x0, y0], [x1, y1]]).astype(np.int64) 18 | 19 | def draw_outline(mask, frame): 20 | _, binary_mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY) 21 | 22 | contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 23 | 24 | cv2.drawContours(frame, contours, -1, (0, 0, 255), 2) 25 | 26 | return frame 27 | 28 | def draw_points(points, modes, frame): 29 | neg_points = points[np.argwhere(modes==0)[:, 0]] 30 | pos_points = points[np.argwhere(modes==1)[:, 0]] 31 | 32 | for i in range(len(neg_points)): 33 | point = neg_points[i] 34 | cv2.circle(frame, (point[0], point[1]), 8, (255, 80, 80), -1) 35 | 36 | for i in range(len(pos_points)): 37 | point = pos_points[i] 38 | cv2.circle(frame, (point[0], point[1]), 8, (0, 153, 255), -1) 39 | 40 | return frame 41 | 42 | if __name__ == '__main__': 43 | mask = cv2.imread('./debug/mask.jpg', cv2.IMREAD_GRAYSCALE) 44 | frame = cv2.imread('./debug/frame.jpg') 45 | draw_frame = draw_outline(mask, frame) 46 | 47 | cv2.imwrite('./debug/outline.jpg', draw_frame) 48 | 49 | # bbox = mask2bbox(mask) 50 | # draw_0 = cv2.rectangle(mask, bbox[0], bbox[1], (0, 0, 255)) 51 | # cv2.imwrite('./debug/rect.png', draw_0) -------------------------------------------------------------------------------- /tutorial/img/Drawing_board.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/Drawing_board.jpg -------------------------------------------------------------------------------- /tutorial/img/add_positive_base_on_everything.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/add_positive_base_on_everything.jpg -------------------------------------------------------------------------------- /tutorial/img/add_positive_base_on_everything_cxk.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/add_positive_base_on_everything_cxk.jpg -------------------------------------------------------------------------------- /tutorial/img/add_positive_points.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/add_positive_points.jpg -------------------------------------------------------------------------------- /tutorial/img/add_positive_points_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/add_positive_points_2.jpg -------------------------------------------------------------------------------- /tutorial/img/audio_tab.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/audio_tab.jpg -------------------------------------------------------------------------------- /tutorial/img/click_input_video.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/click_input_video.jpg -------------------------------------------------------------------------------- /tutorial/img/click_segment.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/click_segment.jpg -------------------------------------------------------------------------------- /tutorial/img/click_segment_everything.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/click_segment_everything.jpg -------------------------------------------------------------------------------- /tutorial/img/detect_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/detect_result.jpg -------------------------------------------------------------------------------- /tutorial/img/enter_text.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/enter_text.jpg -------------------------------------------------------------------------------- /tutorial/img/grounding-tab.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/grounding-tab.jpg -------------------------------------------------------------------------------- /tutorial/img/input_video.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/input_video.jpg -------------------------------------------------------------------------------- /tutorial/img/new_object.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/new_object.jpg -------------------------------------------------------------------------------- /tutorial/img/second_object.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/second_object.jpg -------------------------------------------------------------------------------- /tutorial/img/segment_everything_blackswan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/segment_everything_blackswan.jpg -------------------------------------------------------------------------------- /tutorial/img/select_fps.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/select_fps.jpg -------------------------------------------------------------------------------- /tutorial/img/start_tracking.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/start_tracking.jpg -------------------------------------------------------------------------------- /tutorial/img/switch2ImgSeq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/switch2ImgSeq.jpg -------------------------------------------------------------------------------- /tutorial/img/switch2textT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/switch2textT.jpg -------------------------------------------------------------------------------- /tutorial/img/upload_Image_seq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/upload_Image_seq.jpg -------------------------------------------------------------------------------- /tutorial/img/use_exa4ImgSeq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/Segment-and-Track-Anything/228b6f6ba6961e3618c9f27684746efcaed99a74/tutorial/img/use_exa4ImgSeq.jpg -------------------------------------------------------------------------------- /tutorial/tutorial for Image-Sequence input.md: -------------------------------------------------------------------------------- 1 | # Tutorial for Image-Sequence input 2 | 3 | ## Zip the Image-Sequence as input for the WebUI. 4 | **The structure of test-data-seq.zip must be like this. Please confirm that the image names are in ascending order.** 5 | ``` 6 | - test-data-seq 7 | - 000000.png 8 | - 000001.png 9 | - 000002.png 10 | - 000003.png 11 | .... 12 | - 0000xx.png 13 | ``` 14 | **Note: Please ensure that the image naming method is in ascending alphabetical order.** 15 | 16 | ## Use WebUI get test Image-Sequence data 17 | ### 1. Switch to the `Image-Seq type input` tab. 18 | 19 |

switch2ImgSeq

20 | 21 | ### 2. Upload the test dataset or use the provided examples directly. 22 | - Once the test dataset has finished uploading, the WebUI will automatically extract the first frame and display it in the `Segment result of first frame` component. 23 | - If you use the provided examples, you may need to manually extract the results by clicking the `extract` button. 24 | - Below are examples of how to upload an Image-sequence data. 25 | 26 |

27 | 28 | ### 3. Select fps for the output video 29 | 30 |

31 | 32 | ### 4. You can follow the [tutorial for WebUI-1.0-Version](./tutorial%20for%20WebUI-1.0-Version.md) to obtain your result. -------------------------------------------------------------------------------- /tutorial/tutorial for WebUI-1.0-Version.md: -------------------------------------------------------------------------------- 1 | # Tutorial for WebUI 1.0 Version 2 | 3 | ## Note: 4 | - We recommend reinitializing SegTracker by clicking the `Reset button` after processing each video to avoid encountering bugs. 5 | - If the `SegTracker-Args` are changed, the SegTracker needs to be reinitialized by clicking the Reset button. 6 | - If the `Drawing board` does not display the image properly, you can refresh the Drawing board by clicking on the `refresh icon` located in the upper right corner of the Drawing board. 7 | - A video tutorial will be released in the next few days. 8 | ## 1. About Components 9 | - `input video`: where the uploaded video is displayed for the user to view. 10 | - `Segment result of first frame`: where the segmentation result of the first frame is displayed for the user to view. Under the `Everything-Tab` and `Click-Tab`, users can interactively add a mask by clicking on the displayed result. 11 | - `Drawing board`: where users can circle the object they want to track. This component is only visible under the `Stroke-Tab`. 12 | - `SegTracker-Args`: used to adjust the parameters for initializing SegTracker. 13 | - `Undo`: used to undo a previously added point prompt or segment-everything operation. 14 | - `Reset`: used to reset all components and reinitialize SegTracker. 15 | - `Start Tracking`: used to begin tracking the objects selected by automatic/interactive methods in the video using SegTracker. 16 | - `Output video`: where the tracking results of the video are displayed for the user to view. 17 | - `Predicted masks`: show the predicted masks for each frame of the video. 18 | 19 | ## 2. Upload your video 20 | - To upload a video, click on the `input video` component. Once uploaded, the `segment result of first frame` component will display the first frame of the video automatically. 21 | - The examples for uploading a video are shown below. 22 | 23 |

click_input_video input_video

24 | 25 | ## 3. Adjust the SegTracker-Args to suit your needs 26 | - **aot_model**: used to select which version of DeAOT/AOT to use for tracking and propagation. 27 | - **sam_gap**: used to control how often SAM is used to add newly appearing objects at specified frame intervals. Increase to decrease the frequency of discovering new targets, but significantly improve speed of inference. 28 | - **points_per_side**: used to control the number of points per side used for generating masks by sampling a grid over the image. Increasing the size enhances the ability to detect small objects, but larger targets may be segmented into finer granularity. 29 | - **max_obj_num**: used to limit the maximum number of objects that SegTracker can detect and track. A larger number of objects necessitates a greater utilization of memory, with approximately 16GB of memory capable of processing a maximum of 255 objects. 30 | 31 | ## 4. Interactively modify single-object mask for first frame of video 32 | ### 4.1 Interactively add single-object based on segment-everything(`Everything-Tab`) 33 | - `Segment everything for first frame`: By clicking the button, SegTracker will be initialized based on the `SegTracker-Args`, and `Segment-everything` will be performed on the first frame of the video. 34 | - The example of the `segment-everything` approach are shown below. 35 | 36 |

click_segment_everything segment_everything_blackswan

37 | 38 | - `Point Prompt`: After applying the Segment-everything function, you can click on the image to add objects that were ignored by segment-everything or assign a separate ID to an object by doing this. 39 | - Two examples are provided below: one involves adding water which was previously ignored by the `segment-everything` approach, and the other involves assigning a separate ID to the face of a man. 40 | 41 |

add_positive_base_on_everything 42 | add_positive_base_on_everything_cxk

43 | 44 | - `Note`: The current version only supports adding a mask of the single-object(The added objects are assigned the same ID) on top of the segment everything. We will update the operation of adding multi-objects-mask(The added objects are assigned different IDs) in the feature. 45 | 46 | ### 4.2 Interactively add object by click(`Click-Tab`) 47 | - `Point Prompt`: you can select objects to track by clicking on the image with positive and negative points. 48 | - SegTracker will segment objects according to the specified prompt-points, as demonstrated in the example below. 49 | 50 |

add_positive_points add_positive_points_2

51 | 52 | ### 4.3 Interactively add object by stroke(`Stroke-Tab`) 53 | - `Drawing board`: You can circle the object you want to track on it. 54 | - `Undo`: To undo a stroke on the `Drawing board`, click the `Undo button` located in the upper right corner of the `Drawing board`. 55 | - `Reset`: Click on the `Reset button` in the upper right corner of the `Drawing board` to reset the `Drawing board`. 56 | - `Segment`: SegTracker will receive the mask you draw and display the segmentation results. 57 | - Below is an example demonstrating how to circle and segment an object using strokes. 58 |

Drawing_board Drawing_board

59 | 60 | - `Note`: 61 | - The current version only supports adding a mask for a single-object(The added objects are assigned the same ID). 62 | - We do not recommend adding a mask by clicking on `Segment result of first frame` under the `Stroke-Tab`, as this may result in bugs. 63 | 64 | ## 5. Segment and Track in Video 65 | - Once the object to be tracked in the video is identified, you can begin tracking by clicking on the `Start Tracking` button. 66 | - The results are displayed on the `output video` and `predicted masks`.You can download them. 67 | 68 |

Drawing_board

69 | -------------------------------------------------------------------------------- /tutorial/tutorial for WebUI-1.5-Version.md: -------------------------------------------------------------------------------- 1 | # Tutorial for WebUI 1.5 Version 2 | ## We have added two new features 3 | - We have added text prompts to allow for interactive selection of objects that will be tracked in the video. 4 | - We can now interactively add multiple objects for tracking in the video. 5 | 6 | 7 | ## Text-Prompts 8 | ### 1. Clone Grounding-DINO to `./src` 9 | ``` 10 | pip install -e git+https://github.com/IDEA-Research/GroundingDINO.git@main#egg=GroundingDINO 11 | ``` 12 | 13 | ### 2. Switch to Text-Tab by clicking `Text` Tab 14 | 15 |

16 | 17 |

18 | 19 | ### 3. Upload video or use example dicectly 20 | 21 | ### 4. Enter text to select the objects you are interested in 22 | - The `.` is used to split text, just like in the original Grounding-Dino setting. 23 | 24 |

25 | 26 |

27 | 28 | ### 5. Get mask of selected object by clicking `Detect` button 29 | - SAMTrack initialization may take some time. 30 | 31 |

32 | 33 |

34 | 35 | ### 6. Track in video 36 | 37 | ## Multi-Objects select 38 | ### 1. Once we interactively add an object mask, we can click the `Add new object button` to prepare to add a new object. 39 | 40 |

41 | 42 |

43 | 44 | ### 2. Add a new object by clicking object 45 | 46 |

47 | 48 |

49 | 50 | ### 3. You can add as many objects as you want by clicking `Add new object` button. -------------------------------------------------------------------------------- /tutorial/tutorial for WebUI-1.6-Version.md: -------------------------------------------------------------------------------- 1 | # Tutorial for WebUI 1.6 Version 2 | ## We have added one new feature 3 | - We have added an audio-grounding feature that tracks the sound-making object within the video's soundtrack. 4 | 5 | 6 | ## audio-grounding 7 | ### 1. Clone the audio-spectrum transformer (AST) model to `./ast_master` and download the pretrained model 8 | ``` 9 | git clone https://github.com/YuanGongND/ast.git ast_master 10 | wget -O ./ast_master/pretrained_models/audio_mdl.pth https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1 11 | ``` 12 | 13 | ### 2. Switch to Audio-Tab by clicking the `audio-grounding` Tab 14 | 15 |

16 | 17 |

18 | 19 | ### 3. Upload video or use example directly 20 | 21 | ### 4. Alter the number of labels and the confidence interval of the resulting labels according to your personal preference 22 | 23 | step 1: detect the label of the sound making object 24 | 25 | step 2: ground the sound making object 26 | 27 |

28 | 29 |

30 | 31 | ### 5. Get mask of selected object by clicking `Detect` button 32 | 33 | ### 6. Track in video --------------------------------------------------------------------------------