├── .gitignore ├── LICENSE ├── README.md ├── ext ├── class_names │ ├── coco_4817_ids.py │ ├── lvis_ids.py │ └── lvis_list.py ├── meta │ └── sam_meta.py ├── open_clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── coca_model.py │ ├── constants.py │ ├── factory.py │ ├── generation_utils.py │ ├── hf_configs.py │ ├── hf_model.py │ ├── loss.py │ ├── model.py │ ├── model_configs │ │ ├── EVA01-g-14-plus.json │ │ ├── EVA01-g-14.json │ │ ├── EVA02-B-16.json │ │ ├── EVA02-E-14-plus.json │ │ ├── EVA02-E-14.json │ │ ├── EVA02-L-14-336.json │ │ ├── EVA02-L-14.json │ │ ├── RN101-quickgelu.json │ │ ├── RN101.json │ │ ├── RN50-quickgelu.json │ │ ├── RN50.json │ │ ├── RN50x16.json │ │ ├── RN50x4.json │ │ ├── RN50x64.json │ │ ├── ViT-B-16-plus-240.json │ │ ├── ViT-B-16-plus.json │ │ ├── ViT-B-16.json │ │ ├── ViT-B-32-plus-256.json │ │ ├── ViT-B-32-quickgelu.json │ │ ├── ViT-B-32.json │ │ ├── ViT-H-14.json │ │ ├── ViT-H-16.json │ │ ├── ViT-L-14-280.json │ │ ├── ViT-L-14-336.json │ │ ├── ViT-L-14.json │ │ ├── ViT-L-16-320.json │ │ ├── ViT-L-16.json │ │ ├── ViT-M-16-alt.json │ │ ├── ViT-M-16.json │ │ ├── ViT-M-32-alt.json │ │ ├── ViT-M-32.json │ │ ├── ViT-S-16-alt.json │ │ ├── ViT-S-16.json │ │ ├── ViT-S-32-alt.json │ │ ├── ViT-S-32.json │ │ ├── ViT-bigG-14.json │ │ ├── ViT-e-14.json │ │ ├── ViT-g-14.json │ │ ├── coca_ViT-B-32.json │ │ ├── coca_ViT-L-14.json │ │ ├── coca_base.json │ │ ├── coca_roberta-ViT-B-32.json │ │ ├── convnext_base.json │ │ ├── convnext_base_w.json │ │ ├── convnext_base_w_320.json │ │ ├── convnext_large.json │ │ ├── convnext_large_d.json │ │ ├── convnext_large_d_320.json │ │ ├── convnext_small.json │ │ ├── convnext_tiny.json │ │ ├── convnext_xlarge.json │ │ ├── convnext_xxlarge.json │ │ ├── convnext_xxlarge_320.json │ │ ├── mt5-base-ViT-B-32.json │ │ ├── mt5-xl-ViT-H-14.json │ │ ├── roberta-ViT-B-32.json │ │ ├── swin_base_patch4_window7_224.json │ │ ├── vit_medium_patch16_gap_256.json │ │ ├── vit_relpos_medium_patch16_cls_224.json │ │ ├── xlm-roberta-base-ViT-B-32.json │ │ └── xlm-roberta-large-ViT-H-14.json │ ├── modified_resnet.py │ ├── openai.py │ ├── pretrained.py │ ├── push_to_hf_hub.py │ ├── timm_model.py │ ├── tokenizer.py │ ├── transform.py │ ├── transformer.py │ ├── utils.py │ ├── version.py │ ├── zero_shot_classifier.py │ └── zero_shot_metadata.py ├── rwkv │ └── cls_backbones │ │ ├── __init__.py │ │ ├── backbones │ │ ├── __init__.py │ │ ├── cuda │ │ │ ├── wkv_cuda.cu │ │ │ └── wkv_op.cpp │ │ └── vrwkv.py │ │ └── utils │ │ ├── __init__.py │ │ ├── drop.py │ │ └── resize_pos.py ├── sam │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ └── transformer.py └── templates │ ├── __init__.py │ └── vild.py ├── projects └── rwkvsam │ ├── README.md │ ├── configs │ ├── _base_ │ │ ├── datasets │ │ │ ├── DIS │ │ │ │ └── dis_5k_1024.py │ │ │ ├── ade │ │ │ │ └── ade20k.py │ │ │ ├── coco │ │ │ │ ├── coco_detection.py │ │ │ │ ├── coco_instance.py │ │ │ │ ├── coco_instance_1024.py │ │ │ │ └── coco_instance_lsj.py │ │ │ ├── coconut │ │ │ │ └── coconut_b_instance_lsj.py │ │ │ ├── entity │ │ │ │ └── entity_lr_instance_lsj.py │ │ │ ├── hq_concat │ │ │ │ └── concat_coconutbpan_entity_dis5k_sam.py │ │ │ ├── imagenet │ │ │ │ └── imagenet_bs64_swin_224.py │ │ │ ├── sam │ │ │ │ ├── sam_001.py │ │ │ │ └── sam_distill.py │ │ │ └── thin_obj_det │ │ │ │ ├── coift_1024.py │ │ │ │ ├── hrsod_1024.py │ │ │ │ └── thin_obj_5k_1024.py │ │ ├── default_runtime.py │ │ ├── default_runtime_iterbased.py │ │ └── schedules │ │ │ ├── schedule_120e_bs1024_for_imagenet.py │ │ │ ├── schedule_12e_distillation.py │ │ │ ├── schedule_160k_seg.py │ │ │ ├── schedule_160k_seg_adam.py │ │ │ ├── schedule_1x.py │ │ │ ├── schedule_1x_adam.py │ │ │ ├── schedule_24e_distillation.py │ │ │ └── schedule_300e_bs1024_for_imagenet.py │ └── backbone_dist │ │ ├── rwkvsam1001_000_vith_vitamin_rwkv_small_mlp2.py │ │ └── sam_vith_dump.py │ ├── datasets │ ├── __init__.py │ ├── coconut_panoptic.py │ ├── concat_dataset.py │ ├── dis5k.py │ ├── entity_seg.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── loading.py │ │ └── optimization.py │ ├── sam.py │ └── thin_obj_det.py │ ├── evaluation │ ├── __init__.py │ ├── api_wrappers │ │ └── coco_api.py │ ├── biou_metric.py │ ├── coco_boundary_metric.py │ ├── iou_metric.py │ └── lvis_boundary_metric.py │ ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── sam_backbone.py │ │ └── vitamin.py │ ├── detectors │ │ ├── __init__.py │ │ ├── det_and_seg.py │ │ ├── feature_extraction.py │ │ ├── json_loader.py │ │ ├── sam_clip_distill.py │ │ ├── sam_dump.py │ │ └── sam_model.py │ ├── heads │ │ ├── __init__.py │ │ ├── sam_mask_decoder.py │ │ └── sam_mask_decoder_rwkv_mlpmerge.py │ ├── necks │ │ ├── __init__.py │ │ ├── gap.py │ │ ├── last_layer.py │ │ └── sam_pe.py │ └── preprocessors │ │ ├── __init__.py │ │ ├── data_preprocessors.py │ │ ├── ovsam_preprocessor.py │ │ └── sameval_preprocessor.py │ └── utils │ ├── __init__.py │ ├── boundary_iou.py │ └── load_checkpoint.py ├── seg ├── configs │ ├── _base_ │ │ ├── datasets │ │ │ ├── coco_ov_instance_lsj.py │ │ │ ├── lvis_norare.py │ │ │ ├── sam.py │ │ │ └── sam_img.py │ │ ├── default_runtime.py │ │ └── schedules │ │ │ ├── schedule_12e.py │ │ │ ├── schedule_24e.py │ │ │ └── schedule_distillation.py │ ├── clip2sam │ │ ├── clip2sam_coco_rn50x16.py │ │ └── clip2sam_lvis_rn50x16.py │ ├── ovsam │ │ ├── ovsam_coco_rn50x16_point.py │ │ └── ovsam_lvis_rn50x16_point.py │ └── sam2clip │ │ ├── sam2clip_vith_rn50x16.py │ │ └── sam_vith_dump.py ├── datasets │ ├── coco_ins_ov.py │ ├── concat_dataset.py │ ├── pipeliens │ │ ├── formatting.py │ │ ├── frame_copy.py │ │ ├── frame_sampling.py │ │ ├── loading.py │ │ └── transforms.py │ ├── sam.py │ └── samplers │ │ ├── batch_sampler.py │ │ └── multi_dataset_sampler.py ├── evaluation │ └── ins_cls_iou_metric.py └── models │ ├── backbones │ ├── __init__.py │ ├── openclip_backbone.py │ └── sam_backbone.py │ ├── data_preprocessor │ ├── __init__.py │ └── ovsam_preprocessor.py │ ├── detectors │ ├── __init__.py │ ├── clip2sam.py │ ├── ovsam.py │ ├── sam2clip_distill.py │ └── sam_dump.py │ ├── heads │ ├── __init__.py │ └── ovsam_head.py │ ├── necks │ ├── __init__.py │ ├── last_layer.py │ ├── sam_pe.py │ └── transformer_neck.py │ └── utils │ ├── __init__.py │ ├── class_overlapping.py │ ├── load_checkpoint.py │ ├── mask_pool.py │ ├── no_obj.py │ ├── offline_video_metrics.py │ ├── pan_seg_transform.py │ └── video_gt_preprocess.py └── tools ├── dist.sh ├── gen_cls.py ├── slurm.sh ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | /models/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2022 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and 6 | binary forms, with or without modification, are permitted provided 7 | that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in 14 | the documentation and/or other materials provided with the 15 | distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | In the event that redistribution and/or use for commercial purpose in 34 | source or binary forms, with or without modification is required, 35 | please contact the contributor(s) of the work. -------------------------------------------------------------------------------- /ext/class_names/coco_4817_ids.py: -------------------------------------------------------------------------------- 1 | COCO4817_BASE_IDS = [ 2 | 0, 1, 2, 3, 7, 8, 13, 14, 17, 18, 22, 23, 24, 26, 30, 33, 37, 39, 44, 45, 46, 47, 49, 50, 51, 53, 3 | 54, 56, 59, 62, 63, 65, 68, 69, 72, 73, 74, 75, 79, 6, 21, 28, 29, 42, 48, 61, 64, 70 4 | ] 5 | COCO4817_NOVEL_IDS = [ 6 | 5, 16, 19, 20, 25, 27, 36, 41, 43, 55, 57, 66, 71, 76, 4, 15, 31 7 | ] 8 | -------------------------------------------------------------------------------- /ext/meta/sam_meta.py: -------------------------------------------------------------------------------- 1 | meta_dict = { 2 | 'vit_h': dict( 3 | encoder_embed_dim=1280, 4 | encoder_depth=32, 5 | encoder_num_heads=16, 6 | encoder_global_attn_indexes=[7, 15, 23, 31], 7 | # common 8 | prompt_embed_dim=256, 9 | image_size=1024, 10 | vit_patch_size=16, 11 | image_embedding_size=64 12 | ), 13 | 'vit_l': dict( 14 | encoder_embed_dim=1024, 15 | encoder_depth=24, 16 | encoder_num_heads=16, 17 | encoder_global_attn_indexes=[5, 11, 17, 23], 18 | # common 19 | prompt_embed_dim=256, 20 | image_size=1024, 21 | vit_patch_size=16, 22 | image_embedding_size=64 23 | ), 24 | 'vit_b': dict( 25 | encoder_embed_dim=768, 26 | encoder_depth=12, 27 | encoder_num_heads=12, 28 | encoder_global_attn_indexes=[2, 5, 8, 11], 29 | # common 30 | prompt_embed_dim=256, 31 | image_size=1024, 32 | vit_patch_size=16, 33 | image_embedding_size=64 34 | ) 35 | } 36 | 37 | checkpoint_dict = { 38 | 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', 39 | 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', 40 | 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth', 41 | } 42 | -------------------------------------------------------------------------------- /ext/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .coca_model import CoCa 2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype 8 | from .openai import load_openai_model, list_openai_models 9 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 10 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 11 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 12 | from .tokenizer import SimpleTokenizer, tokenize, decode 13 | from .transform import image_transform, AugmentationCfg 14 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy 15 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES 16 | -------------------------------------------------------------------------------- /ext/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HarborYuan/ovsam/137d2c2e6daea060668cf50d7c966ed86e9c45ce/ext/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /ext/open_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /ext/open_clip/generation_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HarborYuan/ovsam/137d2c2e6daea060668cf50d7c966ed86e9c45ce/ext/open_clip/generation_utils.py -------------------------------------------------------------------------------- /ext/open_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | # https://huggingface.co/docs/transformers/model_doc/bert 46 | "bert": { 47 | "config_names": { 48 | "context_length": "max_position_embeddings", 49 | "vocab_size": "vocab_size", 50 | "width": "hidden_size", 51 | "heads": "num_attention_heads", 52 | "layers": "num_hidden_layers", 53 | }, 54 | "pooler": "cls_pooler", 55 | }, 56 | } 57 | -------------------------------------------------------------------------------- /ext/open_clip/model_configs/EVA01-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva_giant_patch14_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/EVA01-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva_giant_patch14_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/EVA02-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_base_patch16_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/EVA02-E-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_enormous_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1280, 14 | "heads": 20, 15 | "layers": 32 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/EVA02-E-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_enormous_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/EVA02-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "timm_model_name": "eva02_large_patch14_clip_336", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/EVA02-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_large_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /ext/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/RN50x64.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": [ 6 | 3, 7 | 15, 8 | 36, 9 | 10 10 | ], 11 | "width": 128, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 1024, 18 | "heads": 16, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-M-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16, 8 | "ls_init_value": 1e-4 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 384, 14 | "heads": 6, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-M-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-M-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-M-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-S-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-S-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-S-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-S-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-bigG-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 32 17 | } 18 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-e-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 56, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.5715, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 36 17 | } 18 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/coca_ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 512, 25 | "heads": 8, 26 | "layers": 12, 27 | "attn_pooler_heads": 8 28 | }, 29 | "custom_text": true 30 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/coca_ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 768, 25 | "heads": 12, 26 | "layers": 12, 27 | "attn_pooler_heads": 12 28 | }, 29 | "custom_text": true 30 | } 31 | -------------------------------------------------------------------------------- /ext/open_clip/model_configs/coca_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "multimodal_cfg": { 4 | "width": 768, 5 | "context_length": 76, 6 | "vocab_size": 64000, 7 | "mlp_ratio": 4, 8 | "layers": 12, 9 | "dim_head": 64, 10 | "heads": 12, 11 | "n_queries": 256, 12 | "attn_pooler_heads": 8 13 | }, 14 | "vision_cfg": { 15 | "image_size": 288, 16 | "layers": 12, 17 | "width": 768, 18 | "patch_size": 18, 19 | "output_tokens": true 20 | }, 21 | "text_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 64000, 24 | "layers": 12, 25 | "heads": 12, 26 | "width": 768, 27 | "embed_cls": true, 28 | "output_tokens": true 29 | }, 30 | "custom_text": true 31 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/coca_roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "output_tokens": true 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "proj": "linear", 14 | "width": 768, 15 | "output_tokens": true 16 | }, 17 | "multimodal_cfg": { 18 | "context_length": 76, 19 | "width": 768, 20 | "heads": 8, 21 | "layers": 12 22 | }, 23 | "custom_text": true 24 | } 25 | -------------------------------------------------------------------------------- /ext/open_clip/model_configs/convnext_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/convnext_base_w.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/convnext_base_w_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/convnext_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/convnext_large_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/convnext_large_d_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/convnext_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_small", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/convnext_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_tiny", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/convnext_xlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 20 18 | } 19 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/convnext_xxlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/convnext_xxlarge_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/mt5-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "google/mt5-base", 11 | "hf_tokenizer_name": "google/mt5-base", 12 | "proj": "mlp", 13 | "pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /ext/open_clip/model_configs/mt5-xl-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "google/mt5-xl", 12 | "hf_tokenizer_name": "google/mt5-xl", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /ext/open_clip/model_configs/roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /ext/open_clip/model_configs/swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 640, 14 | "heads": 10, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/vit_medium_patch16_gap_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_medium_patch16_gap_256", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 256 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_relpos_medium_patch16_cls_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /ext/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "xlm-roberta-base", 11 | "hf_tokenizer_name": "xlm-roberta-base", 12 | "proj": "mlp", 13 | "pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /ext/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "xlm-roberta-large", 12 | "hf_tokenizer_name": "xlm-roberta-large", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /ext/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 14 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 15 | 16 | __all__ = ["list_openai_models", "load_openai_model"] 17 | 18 | 19 | def list_openai_models() -> List[str]: 20 | """Returns the names of available CLIP models""" 21 | return list_pretrained_models_by_tag('openai') 22 | 23 | 24 | def load_openai_model( 25 | name: str, 26 | precision: Optional[str] = None, 27 | device: Optional[Union[str, torch.device]] = None, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | cache_dir : Optional[str] 41 | The directory to cache the downloaded model weights 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLIP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if device is None: 51 | device = "cuda" if torch.cuda.is_available() else "cpu" 52 | if precision is None: 53 | precision = 'fp32' if device == 'cpu' else 'fp16' 54 | 55 | if get_pretrained_url(name, 'openai'): 56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 57 | elif os.path.isfile(name): 58 | model_path = name 59 | else: 60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 61 | 62 | try: 63 | # loading JIT archive 64 | model = torch.jit.load(model_path, map_location="cpu").eval() 65 | state_dict = None 66 | except RuntimeError: 67 | # loading saved state dict 68 | state_dict = torch.load(model_path, map_location="cpu") 69 | 70 | # Build a non-jit model from the OpenAI jitted model state dict 71 | cast_dtype = get_cast_dtype(precision) 72 | try: 73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 74 | except KeyError: 75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 77 | 78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 79 | model = model.to(device) 80 | # FIXME support pure fp16/bf16 precision modes 81 | if precision != 'fp16': 82 | model.float() 83 | if precision == 'bf16': 84 | # for bf16, convert back to low-precision 85 | convert_weights_to_lp(model, dtype=torch.bfloat16) 86 | 87 | # add mean / std attributes for consistency with OpenCLIP models 88 | model.visual.image_mean = OPENAI_DATASET_MEAN 89 | model.visual.image_std = OPENAI_DATASET_STD 90 | return model 91 | -------------------------------------------------------------------------------- /ext/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | import torch 5 | from torch import nn as nn 6 | from torchvision.ops.misc import FrozenBatchNorm2d 7 | 8 | 9 | def freeze_batch_norm_2d(module, module_match={}, name=''): 10 | """ 11 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 12 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 13 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 14 | 15 | Args: 16 | module (torch.nn.Module): Any PyTorch module. 17 | module_match (dict): Dictionary of full module names to freeze (all if empty) 18 | name (str): Full module name (prefix) 19 | 20 | Returns: 21 | torch.nn.Module: Resulting module 22 | 23 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 24 | """ 25 | res = module 26 | is_match = True 27 | if module_match: 28 | is_match = name in module_match 29 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 30 | res = FrozenBatchNorm2d(module.num_features) 31 | res.num_features = module.num_features 32 | res.affine = module.affine 33 | if module.affine: 34 | res.weight.data = module.weight.data.clone().detach() 35 | res.bias.data = module.bias.data.clone().detach() 36 | res.running_mean.data = module.running_mean.data 37 | res.running_var.data = module.running_var.data 38 | res.eps = module.eps 39 | else: 40 | for child_name, child in module.named_children(): 41 | full_child_name = '.'.join([name, child_name]) if name else child_name 42 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 43 | if new_child is not child: 44 | res.add_module(child_name, new_child) 45 | return res 46 | 47 | 48 | # From PyTorch internals 49 | def _ntuple(n): 50 | def parse(x): 51 | if isinstance(x, collections.abc.Iterable): 52 | return x 53 | return tuple(repeat(x, n)) 54 | return parse 55 | 56 | 57 | to_1tuple = _ntuple(1) 58 | to_2tuple = _ntuple(2) 59 | to_3tuple = _ntuple(3) 60 | to_4tuple = _ntuple(4) 61 | to_ntuple = lambda n, x: _ntuple(n)(x) 62 | 63 | # Replaces all linear layers with linear_replacement 64 | # TODO: add int8 support for other linear layers including attn and convnets 65 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): 66 | for name, module in model.named_children(): 67 | if len(list(module.children())) > 0: 68 | replace_linear(module, linear_replacement, include_modules, copy_weights) 69 | 70 | if isinstance(module, torch.nn.Linear) and name in include_modules: 71 | old_module = model._modules[name] 72 | model._modules[name] = linear_replacement( 73 | module.in_features, 74 | module.out_features, 75 | module.bias is not None, 76 | ) 77 | if copy_weights: 78 | model._modules[name].weight.data.copy_(old_module.weight.data) 79 | if model._modules[name].bias is not None: 80 | model._modules[name].bias.data.copy_(old_module.bias) 81 | 82 | return model 83 | 84 | def convert_int8_model_to_inference_mode(model): 85 | for m in model.modules(): 86 | if hasattr(m, 'prepare_for_eval'): 87 | int8_original_dtype = m.weight.dtype 88 | m.prepare_for_eval() 89 | m.int8_original_dtype = int8_original_dtype -------------------------------------------------------------------------------- /ext/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.20.0' 2 | -------------------------------------------------------------------------------- /ext/rwkv/cls_backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * 2 | -------------------------------------------------------------------------------- /ext/rwkv/cls_backbones/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .vrwkv import VRWKV 2 | 3 | __all__ = ['VRWKV'] -------------------------------------------------------------------------------- /ext/rwkv/cls_backbones/backbones/cuda/wkv_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); 4 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); 5 | 6 | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { 7 | cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); 8 | } 9 | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { 10 | cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); 11 | } 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("forward", &forward, "wkv forward"); 15 | m.def("backward", &backward, "wkv backward"); 16 | } 17 | 18 | TORCH_LIBRARY(wkv, m) { 19 | m.def("forward", forward); 20 | m.def("backward", backward); 21 | } 22 | -------------------------------------------------------------------------------- /ext/rwkv/cls_backbones/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .drop import DropPath 2 | from .resize_pos import resize_pos_embed 3 | 4 | __all__ = ['DropPath', 'resize_pos_embed'] -------------------------------------------------------------------------------- /ext/rwkv/cls_backbones/utils/drop.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 4 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 5 | 6 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 7 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 8 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 9 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 10 | 'survival rate' as the argument. 11 | 12 | """ 13 | if drop_prob == 0. or not training: 14 | return x 15 | keep_prob = 1 - drop_prob 16 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 17 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 18 | if keep_prob > 0.0 and scale_by_keep: 19 | random_tensor.div_(keep_prob) 20 | return x * random_tensor 21 | 22 | 23 | class DropPath(nn.Module): 24 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 25 | """ 26 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 27 | super(DropPath, self).__init__() 28 | self.drop_prob = drop_prob 29 | self.scale_by_keep = scale_by_keep 30 | 31 | def forward(self, x): 32 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 33 | 34 | def extra_repr(self): 35 | return f'drop_prob={round(self.drop_prob,3):0.3f}' 36 | 37 | 38 | -------------------------------------------------------------------------------- /ext/rwkv/cls_backbones/utils/resize_pos.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def resize_pos_embed(pos_embed, 6 | src_shape, 7 | dst_shape, 8 | mode='bicubic', 9 | num_extra_tokens=1): 10 | """Resize pos_embed weights. 11 | 12 | Args: 13 | pos_embed (torch.Tensor): Position embedding weights with shape 14 | [1, L, C]. 15 | src_shape (tuple): The resolution of downsampled origin training 16 | image, in format (H, W). 17 | dst_shape (tuple): The resolution of downsampled new training 18 | image, in format (H, W). 19 | mode (str): Algorithm used for upsampling. Choose one from 'nearest', 20 | 'linear', 'bilinear', 'bicubic' and 'trilinear'. 21 | Defaults to 'bicubic'. 22 | num_extra_tokens (int): The number of extra tokens, such as cls_token. 23 | Defaults to 1. 24 | 25 | Returns: 26 | torch.Tensor: The resized pos_embed of shape [1, L_new, C] 27 | """ 28 | if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: 29 | return pos_embed 30 | assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]' 31 | _, L, C = pos_embed.shape 32 | src_h, src_w = src_shape 33 | assert L == src_h * src_w + num_extra_tokens, \ 34 | f"The length of `pos_embed` ({L}) doesn't match the expected " \ 35 | f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \ 36 | '`img_size` argument.' 37 | extra_tokens = pos_embed[:, :num_extra_tokens] 38 | 39 | src_weight = pos_embed[:, num_extra_tokens:] 40 | src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) 41 | 42 | # The cubic interpolate algorithm only accepts float32 43 | dst_weight = F.interpolate( 44 | src_weight.float(), size=dst_shape, align_corners=False, mode=mode) 45 | dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) 46 | dst_weight = dst_weight.to(src_weight.dtype) 47 | 48 | return torch.cat((extra_tokens, dst_weight), dim=1) 49 | -------------------------------------------------------------------------------- /ext/sam/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_encoder import ImageEncoderViT 2 | from .prompt_encoder import PromptEncoder 3 | from .mask_decoder import MaskDecoder 4 | -------------------------------------------------------------------------------- /ext/sam/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /ext/templates/__init__.py: -------------------------------------------------------------------------------- 1 | from .vild import VILD_PROMPT 2 | -------------------------------------------------------------------------------- /ext/templates/vild.py: -------------------------------------------------------------------------------- 1 | # https://github.com/bytedance/fc-clip/blob/93f3122518e8a3ef98926e5ea761a776d5050430/fcclip/fcclip.py#L26C1-L41C2 2 | VILD_PROMPT = [ 3 | "a photo of a {}.", 4 | "This is a photo of a {}", 5 | "There is a {} in the scene", 6 | "There is the {} in the scene", 7 | "a photo of a {} in the scene", 8 | "a photo of a small {}.", 9 | "a photo of a medium {}.", 10 | "a photo of a large {}.", 11 | "This is a photo of a small {}.", 12 | "This is a photo of a medium {}.", 13 | "This is a photo of a large {}.", 14 | "There is a small {} in the scene.", 15 | "There is a medium {} in the scene.", 16 | "There is a large {} in the scene.", 17 | ] 18 | -------------------------------------------------------------------------------- /projects/rwkvsam/README.md: -------------------------------------------------------------------------------- 1 | # RWKV-SAM 2 | 3 | [Haobo Yuan1](https://yuanhaobo.me), 4 | [Xiangtai Li1,2](https://lxtgh.github.io), 5 | [Lu Qi3](http://luqi.info), 6 | [Tao Zhang2](https://scholar.google.com.hk/citations?user=3xu4a5oAAAAJ&hl=zh-CN), 7 | [Ming-Hsuan Yang3](http://faculty.ucmerced.edu/mhyang/), 8 | [Shuicheng Yan2](https://yanshuicheng.info), 9 | [Chen Change Loy1](https://www.mmlab-ntu.com/person/ccloy/). 10 | 11 | [1S-Lab, Nanyang Technological University](https://www.mmlab-ntu.com/), 12 | [2Skywork AI](), 13 | [3UC Merced](https://www.ucmerced.edu) 14 | 15 | [![arXiv](https://img.shields.io/badge/arXiv-2406.19369-b31b1b.svg)](https://arxiv.org/abs/2406.19369) 16 | 17 | ## 📰 News 18 | * **` Jun. 28th, 2024`:** We release the arxiv paper and code for RWKV-SAM. The datasets, weights, and training scripts will be available soon. Please stay tuned. 19 | 20 | ## 👀 Overview 21 | We introduce RWKV-SAM, which includes an efficient segmentation backbone and a complete training pipeline to enable the high-quality segmentation capability for segment anything model. 22 | Compared with the same-scale transformer model, RWKV-SAM achieves more than 2× speedup and can achieve better segmentation performance on various datasets. 23 | 24 |

25 | RWKV-SAM overview 26 |

27 | 28 | ## 📸 Demo 29 | Our RWKV-SAM can achieve high-quality promptable segmentation with high efficiency at high resolution. 30 |

31 | RWKV-SAM overview 32 |

33 | 34 | ## 📚 Citation 35 | ```bibtex 36 | @article{yuan2024rwkvsam, 37 | title={Mamba or RWKV: Exploring High-Quality and High-Efficiency Segment Anything Model}, 38 | author={Yuan, Haobo and Li, Xiangtai and Qi, Lu and Zhang, Tao and Yang, Ming-Hsuan and Yan, Shuicheng and Loy, Chen Change}, 39 | journal={arXiv preprint}, 40 | year={2024} 41 | } 42 | ``` 43 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/DIS/dis_5k_1024.py: -------------------------------------------------------------------------------- 1 | from mmcv import LoadImageFromFile 2 | from mmdet.datasets.transforms import PackDetInputs 3 | from mmengine.dataset import DefaultSampler 4 | from projects.rwkvsam.datasets import DIS5KDataset 5 | from projects.rwkvsam.datasets.pipelines import LoadMaskFromFile 6 | 7 | from seg.datasets.pipeliens.transforms import ResizeSAM 8 | 9 | dataset_type = DIS5KDataset 10 | data_root = 'data/DIS5K' 11 | 12 | backend_args = None 13 | image_size = (1024, 1024) 14 | 15 | # dataset settings 16 | train_pipeline = [ 17 | dict(type=LoadImageFromFile, backend_args=backend_args), 18 | dict(type=LoadMaskFromFile,), 19 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 20 | dict( 21 | type=PackDetInputs, 22 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 23 | ), 24 | ] 25 | 26 | test_pipeline = [ 27 | dict(type=LoadImageFromFile, backend_args=backend_args), 28 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 29 | dict(type=LoadMaskFromFile, ), 30 | dict( 31 | type=PackDetInputs, 32 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 33 | ) 34 | ] 35 | 36 | # dataloader 37 | train_dataloader = dict( 38 | batch_size=2, 39 | num_workers=2, 40 | persistent_workers=True, 41 | sampler=dict(type=DefaultSampler, shuffle=True), 42 | batch_sampler=None, 43 | dataset=dict( 44 | type=dataset_type, 45 | data_root=data_root, 46 | ann_file='TR', 47 | data_prefix=dict(img=''), 48 | filter_cfg=None, 49 | pipeline=train_pipeline, 50 | backend_args=backend_args 51 | ) 52 | ) 53 | val_dataloader = dict( 54 | batch_size=1, 55 | num_workers=2, 56 | persistent_workers=True, 57 | drop_last=True, 58 | sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), 59 | batch_sampler=None, 60 | dataset=dict( 61 | type=dataset_type, 62 | data_root=data_root, 63 | ann_file='VD', 64 | data_prefix=dict(img=''), 65 | test_mode=True, 66 | pipeline=test_pipeline, 67 | backend_args=backend_args 68 | ) 69 | ) 70 | test_dataloader = val_dataloader 71 | 72 | val_evaluator = [] 73 | test_evaluator = val_evaluator 74 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/ade/ade20k.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | from mmcv import LoadImageFromFile, RandomResize, RandomFlip, Resize, TestTimeAug 3 | from mmengine.dataset import InfiniteSampler, DefaultSampler 4 | from mmseg.datasets import ADE20KDataset, LoadAnnotations, RandomCrop, PhotoMetricDistortion, PackSegInputs 5 | from mmseg.evaluation import IoUMetric 6 | 7 | dataset_type = ADE20KDataset 8 | data_root = 'data/ade/ADEChallengeData2016' 9 | crop_size = (512, 512) 10 | train_pipeline = [ 11 | dict(type=LoadImageFromFile), 12 | dict(type=LoadAnnotations, reduce_zero_label=True), 13 | dict( 14 | type=RandomResize, 15 | scale=(2048, 512), 16 | ratio_range=(0.5, 2.0), 17 | keep_ratio=True, 18 | resize_type=Resize, 19 | ), 20 | dict(type=RandomCrop, crop_size=crop_size, cat_max_ratio=0.75), 21 | dict(type=RandomFlip, prob=0.5), 22 | dict(type=PhotoMetricDistortion), 23 | dict(type=PackSegInputs) 24 | ] 25 | test_pipeline = [ 26 | dict(type=LoadImageFromFile), 27 | dict(type=Resize, scale=(2048, 512), keep_ratio=True), 28 | # add loading annotation after ``Resize`` because ground truth 29 | # does not need to do resize data transform 30 | dict(type=LoadAnnotations, reduce_zero_label=True), 31 | dict(type=PackSegInputs) 32 | ] 33 | img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] 34 | tta_pipeline = [ 35 | dict(type=LoadImageFromFile, backend_args=None), 36 | dict( 37 | type=TestTimeAug, 38 | transforms=[ 39 | [ 40 | dict(type=Resize, scale_factor=r, keep_ratio=True) 41 | for r in img_ratios 42 | ], 43 | [ 44 | dict(type=RandomFlip, prob=0., direction='horizontal'), 45 | dict(type=RandomFlip, prob=1., direction='horizontal') 46 | ], [dict(type=LoadAnnotations)], [dict(type='PackSegInputs')] 47 | ]) 48 | ] 49 | train_dataloader = dict( 50 | batch_size=4, 51 | num_workers=4, 52 | persistent_workers=True, 53 | sampler=dict(type=InfiniteSampler, shuffle=True), 54 | dataset=dict( 55 | type=dataset_type, 56 | data_root=data_root, 57 | data_prefix=dict( 58 | img_path='images/training', 59 | seg_map_path='annotations/training' 60 | ), 61 | pipeline=train_pipeline) 62 | ) 63 | val_dataloader = dict( 64 | batch_size=1, 65 | num_workers=4, 66 | persistent_workers=True, 67 | sampler=dict(type=DefaultSampler, shuffle=False), 68 | dataset=dict( 69 | type=dataset_type, 70 | data_root=data_root, 71 | data_prefix=dict( 72 | img_path='images/validation', 73 | seg_map_path='annotations/validation'), 74 | pipeline=test_pipeline) 75 | ) 76 | test_dataloader = val_dataloader 77 | 78 | val_evaluator = dict(type=IoUMetric, iou_metrics=['mIoU']) 79 | test_evaluator = val_evaluator 80 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/coco/coco_detection.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | from mmcv import LoadImageFromFile 3 | 4 | from mmdet.datasets import CocoDataset, AspectRatioBatchSampler 5 | from mmdet.datasets.transforms import LoadAnnotations, Resize, RandomFlip, PackDetInputs 6 | from mmdet.evaluation import CocoMetric 7 | from mmengine.dataset import DefaultSampler 8 | 9 | dataset_type = CocoDataset 10 | data_root = 'data/coco/' 11 | 12 | backend_args = None 13 | 14 | train_pipeline = [ 15 | dict(type=LoadImageFromFile, backend_args=backend_args), 16 | dict(type=LoadAnnotations, with_bbox=True), 17 | dict(type=Resize, scale=(1333, 800), keep_ratio=True), 18 | dict(type=RandomFlip, prob=0.5), 19 | dict(type=PackDetInputs) 20 | ] 21 | test_pipeline = [ 22 | dict(type=LoadImageFromFile, backend_args=backend_args), 23 | dict(type=Resize, scale=(1333, 800), keep_ratio=True), 24 | # If you don't have a gt annotation, delete the pipeline 25 | dict(type=LoadAnnotations, with_bbox=True), 26 | dict( 27 | type=PackDetInputs, 28 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')) 29 | ] 30 | train_dataloader = dict( 31 | batch_size=2, 32 | num_workers=2, 33 | persistent_workers=True, 34 | sampler=dict(type=DefaultSampler, shuffle=True), 35 | batch_sampler=dict(type=AspectRatioBatchSampler), 36 | dataset=dict( 37 | type=dataset_type, 38 | data_root=data_root, 39 | ann_file='annotations/instances_train2017.json', 40 | data_prefix=dict(img='train2017/'), 41 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 42 | pipeline=train_pipeline, 43 | backend_args=backend_args)) 44 | val_dataloader = dict( 45 | batch_size=1, 46 | num_workers=2, 47 | persistent_workers=True, 48 | drop_last=False, 49 | sampler=dict(type=DefaultSampler, shuffle=False), 50 | dataset=dict( 51 | type=dataset_type, 52 | data_root=data_root, 53 | ann_file='annotations/instances_val2017.json', 54 | data_prefix=dict(img='val2017/'), 55 | test_mode=True, 56 | pipeline=test_pipeline, 57 | backend_args=backend_args)) 58 | test_dataloader = val_dataloader 59 | 60 | val_evaluator = dict( 61 | type=CocoMetric, 62 | ann_file=data_root + 'annotations/instances_val2017.json', 63 | metric='bbox', 64 | format_only=False, 65 | backend_args=backend_args) 66 | test_evaluator = val_evaluator 67 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/coco/coco_instance.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | from mmcv import LoadImageFromFile 3 | from mmdet.datasets import AspectRatioBatchSampler, CocoDataset 4 | from mmdet.datasets.transforms import LoadAnnotations, Resize, RandomFlip, PackDetInputs 5 | from mmdet.evaluation import CocoMetric 6 | from mmengine.dataset import DefaultSampler 7 | 8 | data_root = 'data/coco/' 9 | backend_args = None 10 | dataset_type = CocoDataset 11 | 12 | image_size = (1024, 1024) 13 | 14 | train_pipeline = [ 15 | dict( 16 | type=LoadImageFromFile, 17 | to_float32=True, 18 | backend_args=backend_args), 19 | dict( 20 | type=LoadAnnotations, 21 | with_bbox=True, 22 | with_mask=True, 23 | backend_args=backend_args), 24 | dict(type=Resize, scale=(1333, 800), keep_ratio=True), 25 | dict(type=RandomFlip, prob=0.5), 26 | dict(type=PackDetInputs) 27 | ] 28 | train_dataloader = dict( 29 | batch_size=2, 30 | num_workers=2, 31 | persistent_workers=True, 32 | sampler=dict(type=DefaultSampler, shuffle=True), 33 | batch_sampler=dict(type=AspectRatioBatchSampler), 34 | dataset=dict( 35 | type=dataset_type, 36 | data_root=data_root, 37 | ann_file='annotations/instances_train2017.json', 38 | data_prefix=dict(img='train2017/'), 39 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 40 | pipeline=train_pipeline, 41 | backend_args=backend_args) 42 | ) 43 | 44 | test_pipeline = [ 45 | dict(type=LoadImageFromFile, backend_args=backend_args), 46 | dict(type=Resize, scale=(1333, 800), keep_ratio=True), 47 | dict(type=LoadAnnotations, with_bbox=True, with_mask=True), 48 | dict( 49 | type=PackDetInputs, 50 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')) 51 | ] 52 | val_dataloader = dict( 53 | batch_size=1, 54 | num_workers=2, 55 | persistent_workers=True, 56 | drop_last=False, 57 | sampler=dict(type=DefaultSampler, shuffle=False), 58 | dataset=dict( 59 | type=dataset_type, 60 | data_root=data_root, 61 | ann_file='annotations/instances_val2017.json', 62 | data_prefix=dict(img='val2017/'), 63 | test_mode=True, 64 | pipeline=test_pipeline, 65 | backend_args=backend_args) 66 | ) 67 | test_dataloader = val_dataloader 68 | 69 | val_evaluator = dict( 70 | type=CocoMetric, 71 | ann_file=data_root + 'annotations/instances_val2017.json', 72 | metric=['bbox', 'segm'], 73 | format_only=False, 74 | backend_args=backend_args 75 | ) 76 | test_evaluator = val_evaluator 77 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/coco/coco_instance_1024.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | from mmcv import LoadImageFromFile, RandomResize 3 | from mmdet.datasets import AspectRatioBatchSampler, CocoDataset 4 | from mmdet.datasets.transforms import LoadAnnotations, Resize, RandomFlip, PackDetInputs, RandomCrop 5 | from mmdet.evaluation import CocoMetric 6 | from mmengine.dataset import DefaultSampler 7 | 8 | from seg.datasets.pipeliens.frame_copy import AddSemSeg 9 | from seg.datasets.pipeliens.loading import FilterAnnotationsHB 10 | 11 | data_root = 'data/coco/' 12 | backend_args = None 13 | dataset_type = CocoDataset 14 | 15 | image_size = (1024, 1024) 16 | 17 | train_pipeline = [ 18 | dict( 19 | type=LoadImageFromFile, 20 | to_float32=True, 21 | backend_args=backend_args), 22 | dict( 23 | type=LoadAnnotations, 24 | with_bbox=True, 25 | with_mask=True, 26 | backend_args=backend_args), 27 | dict( 28 | type=AddSemSeg, 29 | ), 30 | dict(type=RandomFlip, prob=0.5), 31 | dict( 32 | type=RandomResize, 33 | resize_type=Resize, 34 | scale=image_size, 35 | ratio_range=(0.1, 2.0), 36 | keep_ratio=True, 37 | ), 38 | dict( 39 | type=RandomCrop, 40 | crop_size=image_size, 41 | crop_type='absolute', 42 | recompute_bbox=True, 43 | allow_negative_crop=True), 44 | dict( 45 | type=FilterAnnotationsHB, 46 | by_box=False, 47 | by_mask=True, 48 | min_gt_mask_area=32, 49 | ), 50 | dict(type=PackDetInputs) 51 | ] 52 | train_dataloader = dict( 53 | batch_size=2, 54 | num_workers=2, 55 | persistent_workers=True, 56 | sampler=dict(type=DefaultSampler, shuffle=True), 57 | batch_sampler=dict(type=AspectRatioBatchSampler), 58 | dataset=dict( 59 | type=dataset_type, 60 | data_root=data_root, 61 | ann_file='annotations/instances_train2017.json', 62 | data_prefix=dict(img='train2017/'), 63 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 64 | pipeline=train_pipeline, 65 | backend_args=backend_args) 66 | ) 67 | 68 | test_pipeline = [ 69 | dict(type=LoadImageFromFile, backend_args=backend_args), 70 | dict(type=Resize, scale=(1024, 1024), keep_ratio=True), 71 | dict(type=LoadAnnotations, with_bbox=True, with_mask=True), 72 | dict( 73 | type=PackDetInputs, 74 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')) 75 | ] 76 | val_dataloader = dict( 77 | batch_size=1, 78 | num_workers=2, 79 | persistent_workers=True, 80 | drop_last=False, 81 | sampler=dict(type=DefaultSampler, shuffle=False), 82 | dataset=dict( 83 | type=dataset_type, 84 | data_root=data_root, 85 | ann_file='annotations/instances_val2017.json', 86 | data_prefix=dict(img='val2017/'), 87 | test_mode=True, 88 | pipeline=test_pipeline, 89 | backend_args=backend_args) 90 | ) 91 | test_dataloader = val_dataloader 92 | 93 | val_evaluator = dict( 94 | type=CocoMetric, 95 | ann_file=data_root + 'annotations/instances_val2017.json', 96 | metric=['bbox', 'segm'], 97 | format_only=False, 98 | backend_args=backend_args 99 | ) 100 | test_evaluator = val_evaluator 101 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/coco/coco_instance_lsj.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | from mmcv import LoadImageFromFile, RandomResize 3 | from mmdet.datasets import AspectRatioBatchSampler, CocoDataset 4 | from mmdet.datasets.transforms import LoadAnnotations, Resize, RandomFlip, PackDetInputs, RandomCrop 5 | from mmdet.evaluation import CocoMetric 6 | from mmengine.dataset import DefaultSampler 7 | 8 | from seg.datasets.pipeliens.frame_copy import AddSemSeg 9 | from seg.datasets.pipeliens.loading import FilterAnnotationsHB 10 | 11 | data_root = 'data/coco/' 12 | backend_args = None 13 | dataset_type = CocoDataset 14 | 15 | image_size = (1024, 1024) 16 | 17 | train_pipeline = [ 18 | dict( 19 | type=LoadImageFromFile, 20 | to_float32=True, 21 | backend_args=backend_args), 22 | dict( 23 | type=LoadAnnotations, 24 | with_bbox=True, 25 | with_mask=True, 26 | backend_args=backend_args), 27 | dict( 28 | type=AddSemSeg, 29 | ), 30 | dict(type=RandomFlip, prob=0.5), 31 | dict( 32 | type=RandomResize, 33 | resize_type=Resize, 34 | scale=image_size, 35 | ratio_range=(0.1, 2.0), 36 | keep_ratio=True, 37 | ), 38 | dict( 39 | type=RandomCrop, 40 | crop_size=image_size, 41 | crop_type='absolute', 42 | recompute_bbox=True, 43 | allow_negative_crop=True), 44 | dict( 45 | type=FilterAnnotationsHB, 46 | by_box=False, 47 | by_mask=True, 48 | min_gt_mask_area=32, 49 | ), 50 | dict(type=PackDetInputs) 51 | ] 52 | train_dataloader = dict( 53 | batch_size=2, 54 | num_workers=2, 55 | persistent_workers=True, 56 | sampler=dict(type=DefaultSampler, shuffle=True), 57 | batch_sampler=dict(type=AspectRatioBatchSampler), 58 | dataset=dict( 59 | type=dataset_type, 60 | data_root=data_root, 61 | ann_file='annotations/instances_train2017.json', 62 | data_prefix=dict(img='train2017/'), 63 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 64 | pipeline=train_pipeline, 65 | backend_args=backend_args) 66 | ) 67 | 68 | test_pipeline = [ 69 | dict(type=LoadImageFromFile, backend_args=backend_args), 70 | dict(type=Resize, scale=(1333, 800), keep_ratio=True), 71 | dict(type=LoadAnnotations, with_bbox=True, with_mask=True), 72 | dict( 73 | type=PackDetInputs, 74 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')) 75 | ] 76 | val_dataloader = dict( 77 | batch_size=1, 78 | num_workers=2, 79 | persistent_workers=True, 80 | drop_last=False, 81 | sampler=dict(type=DefaultSampler, shuffle=False), 82 | dataset=dict( 83 | type=dataset_type, 84 | data_root=data_root, 85 | ann_file='annotations/instances_val2017.json', 86 | data_prefix=dict(img='val2017/'), 87 | test_mode=True, 88 | pipeline=test_pipeline, 89 | backend_args=backend_args) 90 | ) 91 | test_dataloader = val_dataloader 92 | 93 | val_evaluator = dict( 94 | type=CocoMetric, 95 | ann_file=data_root + 'annotations/instances_val2017.json', 96 | metric=['segm'], 97 | format_only=False, 98 | backend_args=backend_args 99 | ) 100 | test_evaluator = val_evaluator 101 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/coconut/coconut_b_instance_lsj.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | from mmcv import LoadImageFromFile, RandomResize 3 | from mmdet.datasets import AspectRatioBatchSampler 4 | from mmdet.datasets.transforms import Resize, RandomFlip, PackDetInputs, RandomCrop, \ 5 | LoadPanopticAnnotations 6 | from mmdet.evaluation import CocoMetric 7 | from mmengine.dataset import DefaultSampler 8 | 9 | from projects.rwkvsam.datasets import CocoNutPanopticDataset 10 | from seg.datasets.pipeliens.frame_copy import AddSemSeg 11 | from seg.datasets.pipeliens.loading import FilterAnnotationsHB 12 | 13 | data_root = 'data/coconut/' 14 | backend_args = None 15 | dataset_type = CocoNutPanopticDataset 16 | 17 | image_size = (1024, 1024) 18 | 19 | train_pipeline = [ 20 | dict( 21 | type=LoadImageFromFile, 22 | to_float32=True, 23 | backend_args=backend_args), 24 | dict( 25 | type=LoadPanopticAnnotations, 26 | with_bbox=True, 27 | with_mask=True, 28 | with_seg=False, 29 | backend_args=backend_args), 30 | dict( 31 | type=AddSemSeg, 32 | ), 33 | dict(type=RandomFlip, prob=0.5), 34 | dict( 35 | type=RandomResize, 36 | resize_type=Resize, 37 | scale=image_size, 38 | ratio_range=(0.1, 2.0), 39 | keep_ratio=True, 40 | ), 41 | dict( 42 | type=RandomCrop, 43 | crop_size=image_size, 44 | crop_type='absolute', 45 | recompute_bbox=True, 46 | allow_negative_crop=True), 47 | dict( 48 | type=FilterAnnotationsHB, 49 | by_box=False, 50 | by_mask=True, 51 | min_gt_mask_area=32, 52 | ), 53 | dict(type=PackDetInputs) 54 | ] 55 | train_dataloader = dict( 56 | batch_size=2, 57 | num_workers=2, 58 | persistent_workers=True, 59 | sampler=dict(type=DefaultSampler, shuffle=True), 60 | batch_sampler=dict(type=AspectRatioBatchSampler), 61 | dataset=dict( 62 | type=dataset_type, 63 | data_root=data_root, 64 | ann_file='annotations/coconut_b.json', 65 | data_prefix=dict(img='train2017/', seg='coconut_b/panoptic_coconut_b/'), 66 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 67 | pipeline=train_pipeline, 68 | backend_args=backend_args) 69 | ) 70 | 71 | test_pipeline = [ 72 | dict(type=LoadImageFromFile, backend_args=backend_args), 73 | dict(type=Resize, scale=(1333, 800), keep_ratio=True), 74 | dict(type=LoadPanopticAnnotations, with_bbox=True, with_mask=True, with_seg=False, backend_args=backend_args), 75 | dict( 76 | type=PackDetInputs, 77 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')) 78 | ] 79 | val_dataloader = dict( 80 | batch_size=1, 81 | num_workers=2, 82 | persistent_workers=True, 83 | drop_last=False, 84 | sampler=dict(type=DefaultSampler, shuffle=False), 85 | dataset=dict( 86 | type=dataset_type, 87 | data_root=data_root, 88 | ann_file='annotations/relabeled_coco_val.json', 89 | data_prefix=dict(img='val2017/', seg='relabeled_coco_val/relabeled_coco_val/'), 90 | test_mode=True, 91 | pipeline=test_pipeline, 92 | backend_args=backend_args) 93 | ) 94 | test_dataloader = val_dataloader 95 | 96 | val_evaluator = dict( 97 | type=CocoMetric, 98 | ann_file=data_root + 'annotations/instances_val2017.json', 99 | metric=['segm'], 100 | format_only=False, 101 | backend_args=backend_args 102 | ) 103 | test_evaluator = val_evaluator 104 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/entity/entity_lr_instance_lsj.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | from mmcv import LoadImageFromFile, RandomResize 3 | from mmdet.datasets import AspectRatioBatchSampler 4 | from mmdet.datasets.transforms import LoadAnnotations, Resize, RandomFlip, PackDetInputs, RandomCrop 5 | from mmdet.evaluation import CocoMetric 6 | from mmengine.dataset import DefaultSampler 7 | 8 | from projects.rwkvsam.datasets.entity_seg import EntitySegDataset 9 | from seg.datasets.pipeliens.loading import FilterAnnotationsHB 10 | 11 | data_root = 'data/entity_lr/' 12 | backend_args = None 13 | dataset_type = EntitySegDataset 14 | 15 | image_size = (1024, 1024) 16 | 17 | train_pipeline = [ 18 | dict( 19 | type=LoadImageFromFile, 20 | to_float32=True, 21 | ignore_empty=True, 22 | backend_args=backend_args), 23 | dict( 24 | type=LoadAnnotations, 25 | with_bbox=True, 26 | with_mask=True, 27 | backend_args=backend_args), 28 | dict(type=RandomFlip, prob=0.5), 29 | dict( 30 | type=RandomResize, 31 | resize_type=Resize, 32 | scale=image_size, 33 | ratio_range=(0.8, 2.0), 34 | keep_ratio=True, 35 | ), 36 | dict( 37 | type=RandomCrop, 38 | crop_size=image_size, 39 | crop_type='absolute', 40 | recompute_bbox=True, 41 | allow_negative_crop=True), 42 | dict( 43 | type=FilterAnnotationsHB, 44 | by_box=False, 45 | by_mask=True, 46 | min_gt_mask_area=32, 47 | ), 48 | dict(type=PackDetInputs), 49 | ] 50 | train_dataloader = dict( 51 | batch_size=2, 52 | num_workers=2, 53 | persistent_workers=True, 54 | sampler=dict(type=DefaultSampler, shuffle=True), 55 | batch_sampler=dict(type=AspectRatioBatchSampler), 56 | dataset=dict( 57 | type=dataset_type, 58 | data_root=data_root, 59 | ann_file='annotations/entityseg_train_lr.json', 60 | data_prefix=dict(img=''), 61 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 62 | pipeline=train_pipeline, 63 | backend_args=backend_args) 64 | ) 65 | 66 | test_pipeline = [ 67 | dict(type=LoadImageFromFile, backend_args=backend_args), 68 | dict(type=Resize, scale=(1333, 800), keep_ratio=True), 69 | dict(type=LoadAnnotations, with_bbox=True, with_mask=True), 70 | dict( 71 | type=PackDetInputs, 72 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')) 73 | ] 74 | val_dataloader = dict( 75 | batch_size=1, 76 | num_workers=2, 77 | persistent_workers=True, 78 | drop_last=False, 79 | sampler=dict(type=DefaultSampler, shuffle=False), 80 | dataset=dict( 81 | type=dataset_type, 82 | data_root=data_root, 83 | ann_file='annotations/entityseg_val_lr.json', 84 | data_prefix=dict(img=''), 85 | test_mode=True, 86 | pipeline=test_pipeline, 87 | backend_args=backend_args) 88 | ) 89 | test_dataloader = val_dataloader 90 | 91 | val_evaluator = dict( 92 | type=CocoMetric, 93 | ann_file=data_root + 'annotations/instances_val_lr.json', 94 | metric=['segm'], 95 | format_only=False, 96 | backend_args=backend_args 97 | ) 98 | test_evaluator = val_evaluator 99 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/hq_concat/concat_coconutbpan_entity_dis5k_sam.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | from mmengine import read_base 3 | from mmengine.dataset import RepeatDataset 4 | from projects.rwkvsam.datasets import AdvancedConcatDataset 5 | 6 | with read_base(): 7 | from ..coconut.coconut_b_instance_lsj import train_dataloader as _coconut 8 | from ..entity.entity_lr_instance_lsj import train_dataloader as _entity 9 | from ..DIS.dis_5k_1024 import train_dataloader as _dis 10 | from ..sam.sam_001 import train_dataloader as _sam 11 | 12 | # import tests 13 | from ..DIS.dis_5k_1024 import * 14 | 15 | train_dataloader = dict( 16 | batch_size=2, 17 | num_workers=2, 18 | persistent_workers=True, 19 | sampler=dict(type=DefaultSampler, shuffle=True), 20 | # batch_sampler=dict(type=AspectRatioBatchSampler), 21 | dataset=dict( 22 | type=AdvancedConcatDataset, # 233960 (5x; 2 : 1 : 1: 1) 23 | data_tag=['sam', 'sam', 'sam', 'sam'], 24 | datasets=[ 25 | dict( 26 | type=RepeatDataset, # 233960 (2x) 27 | dataset=_coconut.dataset, 28 | times=1, 29 | ), 30 | dict( 31 | type=RepeatDataset, # 31913 (0.27x) 32 | dataset=_entity.dataset, 33 | times=4, 34 | ), 35 | dict( 36 | type=RepeatDataset, # 3000 (0.025x) 37 | dataset=_dis.dataset, 38 | times=40, 39 | ), 40 | dict( 41 | type=RepeatDataset, # 111860 ~1x 42 | dataset=_sam.dataset, 43 | times=1, 44 | ), 45 | ], 46 | ) 47 | ) 48 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/imagenet/imagenet_bs64_swin_224.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | from mmcv import LoadImageFromFile, RandomFlip, CenterCrop 3 | from mmengine.dataset import DefaultSampler 4 | from mmpretrain.datasets import ImageNet, RandomResizedCrop, RandAugment, RandomErasing, PackInputs, ResizeEdge 5 | from mmpretrain.evaluation import Accuracy 6 | 7 | dataset_type = ImageNet 8 | data_preprocessor = dict( 9 | num_classes=1000, 10 | # RGB format normalization parameters 11 | mean=[123.675, 116.28, 103.53], 12 | std=[58.395, 57.12, 57.375], 13 | # convert image from BGR to RGB 14 | to_rgb=True, 15 | ) 16 | 17 | _bgr_mean = data_preprocessor['mean'][::-1] 18 | _bgr_std = data_preprocessor['std'][::-1] 19 | 20 | train_pipeline = [ 21 | dict(type=LoadImageFromFile), 22 | dict( 23 | type=RandomResizedCrop, 24 | scale=224, 25 | backend='pillow', 26 | interpolation='bicubic'), 27 | dict(type=RandomFlip, prob=0.5, direction='horizontal'), 28 | dict( 29 | type=RandAugment, 30 | policies='timm_increasing', 31 | num_policies=2, 32 | total_level=10, 33 | magnitude_level=9, 34 | magnitude_std=0.5, 35 | hparams=dict( 36 | pad_val=[round(x) for x in _bgr_mean], interpolation='bicubic')), 37 | dict( 38 | type=RandomErasing, 39 | erase_prob=0.25, 40 | mode='rand', 41 | min_area_ratio=0.02, 42 | max_area_ratio=1 / 3, 43 | fill_color=_bgr_mean, 44 | fill_std=_bgr_std), 45 | dict(type=PackInputs), 46 | ] 47 | 48 | test_pipeline = [ 49 | dict(type=LoadImageFromFile), 50 | dict( 51 | type=ResizeEdge, 52 | scale=256, 53 | edge='short', 54 | backend='pillow', 55 | interpolation='bicubic'), 56 | dict(type=CenterCrop, crop_size=224), 57 | dict(type=PackInputs), 58 | ] 59 | 60 | train_dataloader = dict( 61 | batch_size=64, 62 | num_workers=5, 63 | dataset=dict( 64 | type=dataset_type, 65 | data_root='data/imagenet', 66 | split='train', 67 | pipeline=train_pipeline), 68 | sampler=dict(type=DefaultSampler, shuffle=True), 69 | ) 70 | 71 | val_dataloader = dict( 72 | batch_size=64, 73 | num_workers=5, 74 | dataset=dict( 75 | type=dataset_type, 76 | data_root='data/imagenet', 77 | split='val', 78 | pipeline=test_pipeline), 79 | sampler=dict(type=DefaultSampler, shuffle=False), 80 | ) 81 | val_evaluator = dict(type=Accuracy, topk=(1, 5)) 82 | 83 | # If you want standard test, please manually configure the test dataset 84 | test_dataloader = val_dataloader 85 | test_evaluator = val_evaluator 86 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/sam/sam_001.py: -------------------------------------------------------------------------------- 1 | from mmcv import LoadImageFromFile, RandomResize 2 | from mmengine.dataset import DefaultSampler 3 | from mmdet.datasets.transforms import Resize, RandomFlip, PackDetInputs, RandomCrop 4 | 5 | from seg.datasets.pipeliens.formatting import PackSAMInputs 6 | from seg.datasets.pipeliens.loading import FilterAnnotationsHB, LoadJSONFromFile, LoadAnnotationsSAM 7 | from seg.datasets.pipeliens.transforms import ResizeSAM 8 | 9 | from projects.rwkvsam.datasets import SAMDataset 10 | 11 | 12 | dataset_type = SAMDataset 13 | data_root = 'data/sam' 14 | 15 | backend_args = None 16 | image_size = (1024, 1024) 17 | 18 | # dataset settings 19 | train_pipeline = [ 20 | dict( 21 | type=LoadImageFromFile, 22 | to_float32=True, 23 | backend_args=backend_args), 24 | dict(type=LoadJSONFromFile, backend_args=backend_args, limit=30, max_ratio=1/64), 25 | dict(type=LoadAnnotationsSAM, with_bbox=True, with_mask=True, with_point_coords=True), 26 | dict(type=RandomFlip, prob=0.5), 27 | dict( 28 | type=RandomResize, 29 | resize_type=Resize, 30 | scale=image_size, 31 | ratio_range=(1., 1.5), 32 | keep_ratio=True, 33 | ), 34 | dict( 35 | type=RandomCrop, 36 | crop_size=image_size, 37 | crop_type='absolute', 38 | recompute_bbox=True, 39 | allow_negative_crop=True), 40 | dict( 41 | type=FilterAnnotationsHB, 42 | by_box=False, 43 | by_mask=True, 44 | min_gt_mask_area=32, 45 | ), 46 | dict(type=PackDetInputs) 47 | ] 48 | 49 | test_pipeline = [ 50 | dict(type=LoadImageFromFile, backend_args=backend_args), 51 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 52 | dict(type=LoadJSONFromFile, backend_args=backend_args), 53 | dict(type=LoadAnnotationsSAM, with_bbox=True, with_mask=True, with_point_coords=True), 54 | dict( 55 | type=PackSAMInputs, 56 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 57 | ) 58 | ] 59 | 60 | # dataloader 61 | train_dataloader = dict( 62 | batch_size=2, 63 | num_workers=2, 64 | persistent_workers=True, 65 | sampler=dict(type=DefaultSampler, shuffle=True), 66 | batch_sampler=None, 67 | dataset=dict( 68 | type=dataset_type, 69 | data_root=data_root, 70 | ann_file='train.txt', 71 | data_prefix=dict(img=''), 72 | filter_cfg=None, 73 | pipeline=train_pipeline, 74 | backend_args=backend_args 75 | ) 76 | ) 77 | val_dataloader = dict( 78 | batch_size=1, 79 | num_workers=2, 80 | persistent_workers=True, 81 | drop_last=True, 82 | sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), 83 | batch_sampler=None, 84 | dataset=dict( 85 | type=dataset_type, 86 | data_root=data_root, 87 | ann_file='train.txt', 88 | data_prefix=dict(img=''), 89 | test_mode=True, 90 | pipeline=test_pipeline, 91 | backend_args=backend_args 92 | ) 93 | ) 94 | test_dataloader = val_dataloader 95 | 96 | val_evaluator = [] 97 | test_evaluator = val_evaluator 98 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/sam/sam_distill.py: -------------------------------------------------------------------------------- 1 | from mmcv import LoadImageFromFile 2 | from mmengine.dataset import DefaultSampler 3 | 4 | from seg.datasets.pipeliens.formatting import PackSAMInputs 5 | from seg.datasets.pipeliens.loading import LoadFeatFromFile 6 | from seg.datasets.pipeliens.transforms import ResizeSAM 7 | from seg.datasets.sam import SAMDataset 8 | 9 | dataset_type = SAMDataset 10 | data_root = 'data/sam' 11 | 12 | backend_args = None 13 | image_size = (1024, 1024) 14 | 15 | # dataset settings 16 | train_pipeline = [ 17 | dict(type=LoadImageFromFile, backend_args=backend_args), 18 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 19 | dict(type=LoadFeatFromFile), 20 | dict( 21 | type=PackSAMInputs, 22 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 23 | ), 24 | ] 25 | 26 | test_pipeline = [ 27 | dict(type=LoadImageFromFile, backend_args=backend_args), 28 | # dict(type=LoadJSONFromFile, backend_args=backend_args), 29 | # dict(type=LoadAnnotationsSAM, with_bbox=True, with_mask=True, with_point_coords=True), 30 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 31 | dict( 32 | type=PackSAMInputs, 33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 34 | ) 35 | ] 36 | 37 | # dataloader 38 | train_dataloader = dict( 39 | batch_size=2, 40 | num_workers=2, 41 | persistent_workers=True, 42 | sampler=dict(type=DefaultSampler, shuffle=True), 43 | batch_sampler=None, 44 | dataset=dict( 45 | type=dataset_type, 46 | data_root=data_root, 47 | ann_file='train.txt', 48 | data_prefix=dict(img=''), 49 | filter_cfg=None, 50 | pipeline=train_pipeline, 51 | backend_args=backend_args 52 | ) 53 | ) 54 | val_dataloader = dict( 55 | batch_size=1, 56 | num_workers=2, 57 | persistent_workers=True, 58 | drop_last=True, 59 | sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), 60 | batch_sampler=None, 61 | dataset=dict( 62 | type=dataset_type, 63 | data_root=data_root, 64 | ann_file='train.txt', 65 | data_prefix=dict(img=''), 66 | test_mode=True, 67 | pipeline=test_pipeline, 68 | backend_args=backend_args 69 | ) 70 | ) 71 | test_dataloader = val_dataloader 72 | 73 | val_evaluator = [] 74 | test_evaluator = val_evaluator 75 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/thin_obj_det/coift_1024.py: -------------------------------------------------------------------------------- 1 | from mmcv import LoadImageFromFile 2 | from mmdet.datasets.transforms import PackDetInputs 3 | from mmengine.dataset import DefaultSampler 4 | from projects.rwkvsam.datasets import ThinOBJDataset 5 | from projects.rwkvsam.datasets.pipelines import LoadMaskFromFile 6 | 7 | from seg.datasets.pipeliens.transforms import ResizeSAM 8 | 9 | dataset_type = ThinOBJDataset 10 | data_root = 'data/thin_object_detection/COIFT/' 11 | 12 | backend_args = None 13 | image_size = (1024, 1024) 14 | 15 | # dataset settings 16 | train_pipeline = [ 17 | dict(type=LoadImageFromFile, backend_args=backend_args), 18 | dict(type=LoadMaskFromFile,), 19 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 20 | dict( 21 | type=PackDetInputs, 22 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 23 | ), 24 | ] 25 | 26 | test_pipeline = [ 27 | dict(type=LoadImageFromFile, backend_args=backend_args), 28 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 29 | dict(type=LoadMaskFromFile, ), 30 | dict( 31 | type=PackDetInputs, 32 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 33 | ) 34 | ] 35 | 36 | # dataloader 37 | train_dataloader = dict( 38 | batch_size=2, 39 | num_workers=2, 40 | persistent_workers=True, 41 | sampler=dict(type=DefaultSampler, shuffle=True), 42 | batch_sampler=None, 43 | dataset=dict( 44 | type=dataset_type, 45 | data_root=data_root, 46 | data_prefix=dict(img='images', ann='masks'), 47 | filter_cfg=None, 48 | pipeline=train_pipeline, 49 | backend_args=backend_args 50 | ) 51 | ) 52 | val_dataloader = dict( 53 | batch_size=1, 54 | num_workers=2, 55 | persistent_workers=True, 56 | drop_last=True, 57 | sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), 58 | batch_sampler=None, 59 | dataset=dict( 60 | type=dataset_type, 61 | data_root=data_root, 62 | data_prefix=dict(img='images', ann='masks'), 63 | test_mode=True, 64 | pipeline=test_pipeline, 65 | backend_args=backend_args 66 | ) 67 | ) 68 | test_dataloader = val_dataloader 69 | 70 | val_evaluator = [] 71 | test_evaluator = val_evaluator 72 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/thin_obj_det/hrsod_1024.py: -------------------------------------------------------------------------------- 1 | from mmcv import LoadImageFromFile 2 | from mmdet.datasets.transforms import PackDetInputs 3 | from mmengine.dataset import DefaultSampler 4 | from projects.rwkvsam.datasets import ThinOBJDataset 5 | from projects.rwkvsam.datasets.pipelines import LoadMaskFromFile 6 | 7 | from seg.datasets.pipeliens.transforms import ResizeSAM 8 | 9 | dataset_type = ThinOBJDataset 10 | data_root = 'data/thin_object_detection/HRSOD/' 11 | 12 | backend_args = None 13 | image_size = (1024, 1024) 14 | 15 | # dataset settings 16 | train_pipeline = [ 17 | dict(type=LoadImageFromFile, backend_args=backend_args), 18 | dict(type=LoadMaskFromFile,), 19 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 20 | dict( 21 | type=PackDetInputs, 22 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 23 | ), 24 | ] 25 | 26 | test_pipeline = [ 27 | dict(type=LoadImageFromFile, backend_args=backend_args), 28 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 29 | dict(type=LoadMaskFromFile, ), 30 | dict( 31 | type=PackDetInputs, 32 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 33 | ) 34 | ] 35 | 36 | # dataloader 37 | train_dataloader = dict( 38 | batch_size=2, 39 | num_workers=2, 40 | persistent_workers=True, 41 | sampler=dict(type=DefaultSampler, shuffle=True), 42 | batch_sampler=None, 43 | dataset=dict( 44 | type=dataset_type, 45 | data_root=data_root, 46 | data_prefix=dict(img='images', ann='masks_max255'), 47 | filter_cfg=None, 48 | pipeline=train_pipeline, 49 | backend_args=backend_args 50 | ) 51 | ) 52 | val_dataloader = dict( 53 | batch_size=1, 54 | num_workers=2, 55 | persistent_workers=True, 56 | drop_last=True, 57 | sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), 58 | batch_sampler=None, 59 | dataset=dict( 60 | type=dataset_type, 61 | data_root=data_root, 62 | data_prefix=dict(img='images', ann='masks_max255'), 63 | test_mode=True, 64 | pipeline=test_pipeline, 65 | backend_args=backend_args 66 | ) 67 | ) 68 | test_dataloader = val_dataloader 69 | 70 | val_evaluator = [] 71 | test_evaluator = val_evaluator 72 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/datasets/thin_obj_det/thin_obj_5k_1024.py: -------------------------------------------------------------------------------- 1 | from mmcv import LoadImageFromFile 2 | from mmdet.datasets.transforms import PackDetInputs 3 | from mmengine.dataset import DefaultSampler 4 | from projects.rwkvsam.datasets import ThinOBJDataset 5 | from projects.rwkvsam.datasets.pipelines import LoadMaskFromFile 6 | 7 | from seg.datasets.pipeliens.transforms import ResizeSAM 8 | 9 | dataset_type = ThinOBJDataset 10 | data_root = 'data/thin_object_detection/ThinObject5K/' 11 | 12 | backend_args = None 13 | image_size = (1024, 1024) 14 | 15 | # dataset settings 16 | train_pipeline = [ 17 | dict(type=LoadImageFromFile, backend_args=backend_args), 18 | dict(type=LoadMaskFromFile,), 19 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 20 | dict( 21 | type=PackDetInputs, 22 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 23 | ), 24 | ] 25 | 26 | test_pipeline = [ 27 | dict(type=LoadImageFromFile, backend_args=backend_args), 28 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 29 | dict(type=LoadMaskFromFile, ), 30 | dict( 31 | type=PackDetInputs, 32 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 33 | ) 34 | ] 35 | 36 | # dataloader 37 | train_dataloader = dict( 38 | batch_size=2, 39 | num_workers=2, 40 | persistent_workers=True, 41 | sampler=dict(type=DefaultSampler, shuffle=True), 42 | batch_sampler=None, 43 | dataset=dict( 44 | type=dataset_type, 45 | data_root=data_root, 46 | data_prefix=dict(img='images_train', ann='masks_train'), 47 | filter_cfg=None, 48 | pipeline=train_pipeline, 49 | backend_args=backend_args 50 | ) 51 | ) 52 | val_dataloader = dict( 53 | batch_size=1, 54 | num_workers=2, 55 | persistent_workers=True, 56 | drop_last=True, 57 | sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), 58 | batch_sampler=None, 59 | dataset=dict( 60 | type=dataset_type, 61 | data_root=data_root, 62 | data_prefix=dict(img='images_test', ann='masks_test'), 63 | test_mode=True, 64 | pipeline=test_pipeline, 65 | backend_args=backend_args 66 | ) 67 | ) 68 | test_dataloader = val_dataloader 69 | 70 | val_evaluator = [] 71 | test_evaluator = val_evaluator 72 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | from mmengine.hooks import IterTimerHook, LoggerHook, ParamSchedulerHook, CheckpointHook, DistSamplerSeedHook 2 | from mmengine.runner import LogProcessor 3 | from mmengine.visualization import LocalVisBackend 4 | 5 | from mmdet.engine import DetVisualizationHook 6 | from mmdet.visualization import DetLocalVisualizer 7 | 8 | default_scope = None 9 | 10 | default_hooks = dict( 11 | timer=dict(type=IterTimerHook), 12 | logger=dict(type=LoggerHook, interval=50), 13 | param_scheduler=dict(type=ParamSchedulerHook), 14 | checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=1, save_last=True, save_best=['auto']), 15 | sampler_seed=dict(type=DistSamplerSeedHook), 16 | visualization=dict(type=DetVisualizationHook) 17 | ) 18 | 19 | env_cfg = dict( 20 | cudnn_benchmark=False, 21 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 22 | dist_cfg=dict(backend='nccl'), 23 | ) 24 | 25 | vis_backends = [dict(type=LocalVisBackend)] 26 | visualizer = dict( 27 | type=DetLocalVisualizer, vis_backends=vis_backends, name='visualizer') 28 | log_processor = dict(type=LogProcessor, window_size=50, by_epoch=True) 29 | 30 | log_level = 'INFO' 31 | load_from = None 32 | resume = False 33 | randomness = dict(seed=None, deterministic=False) 34 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/default_runtime_iterbased.py: -------------------------------------------------------------------------------- 1 | from mmengine.hooks import IterTimerHook, LoggerHook, ParamSchedulerHook, CheckpointHook, DistSamplerSeedHook 2 | from mmengine.runner import LogProcessor 3 | from mmengine.visualization import LocalVisBackend 4 | 5 | from mmdet.engine import DetVisualizationHook 6 | from mmdet.visualization import DetLocalVisualizer 7 | 8 | default_scope = None 9 | 10 | default_hooks = dict( 11 | timer=dict(type=IterTimerHook), 12 | logger=dict(type=LoggerHook, interval=50), 13 | param_scheduler=dict(type=ParamSchedulerHook), 14 | checkpoint=dict(type=CheckpointHook, by_epoch=False, interval=1000, max_keep_ckpts=1, save_last=True), 15 | sampler_seed=dict(type=DistSamplerSeedHook), 16 | visualization=dict(type=DetVisualizationHook) 17 | ) 18 | 19 | env_cfg = dict( 20 | cudnn_benchmark=False, 21 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 22 | dist_cfg=dict(backend='nccl'), 23 | ) 24 | 25 | vis_backends = [dict(type=LocalVisBackend)] 26 | visualizer = dict( 27 | type=DetLocalVisualizer, vis_backends=vis_backends, name='visualizer') 28 | log_processor = dict(type=LogProcessor, window_size=50, by_epoch=True) 29 | 30 | log_level = 'INFO' 31 | load_from = None 32 | resume = False 33 | randomness = dict(seed=None, deterministic=False) 34 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/schedules/schedule_120e_bs1024_for_imagenet.py: -------------------------------------------------------------------------------- 1 | # for batch in each gpu is 128, 8 gpu 2 | # lr = 5e-4 * 128 * 8 / 512 = 0.001 3 | from mmengine.optim import OptimWrapper, CosineAnnealingLR, LinearLR 4 | from mmengine.runner import ValLoop, TestLoop, EpochBasedTrainLoop 5 | from torch.optim import AdamW 6 | 7 | optim_wrapper = dict( 8 | type=OptimWrapper, 9 | optimizer=dict( 10 | type=AdamW, 11 | lr=5e-4 * 1024 / 512, 12 | weight_decay=0.05, 13 | eps=1e-8, 14 | betas=(0.9, 0.999)), 15 | paramwise_cfg=dict( 16 | norm_decay_mult=0.0, 17 | bias_decay_mult=0.0, 18 | flat_decay_mult=0.0, 19 | custom_keys={ 20 | '.absolute_pos_embed': dict(decay_mult=0.0), 21 | '.relative_position_bias_table': dict(decay_mult=0.0) 22 | }), 23 | clip_grad=dict(max_norm=5.0), 24 | ) 25 | 26 | # learning policy 27 | param_scheduler = [ 28 | # warm up learning rate scheduler 29 | dict( 30 | type=LinearLR, 31 | start_factor=1e-3, 32 | by_epoch=True, 33 | end=10, 34 | # update by iter 35 | convert_to_iter_based=True 36 | ), 37 | # main learning rate scheduler 38 | dict(type=CosineAnnealingLR, eta_min=1e-5, by_epoch=True, begin=10) 39 | ] 40 | 41 | # train, val, test setting 42 | train_cfg = dict( 43 | type=EpochBasedTrainLoop, 44 | max_epochs=120, 45 | val_interval=1, 46 | dynamic_intervals=[ 47 | (10, 10), 48 | (100, 1), 49 | ] 50 | ) 51 | val_cfg = dict(type=ValLoop) 52 | test_cfg = dict(type=TestLoop) 53 | 54 | # NOTE: `auto_scale_lr` is for automatically scaling LR, 55 | # based on the actual training batch size. 56 | auto_scale_lr = dict(enable=True, base_batch_size=1024) 57 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/schedules/schedule_12e_distillation.py: -------------------------------------------------------------------------------- 1 | from mmengine.optim import LinearLR, OptimWrapper, CosineAnnealingLR 2 | from mmengine.runner import EpochBasedTrainLoop, ValLoop, TestLoop 3 | from torch.optim import AdamW 4 | 5 | # training schedule for 12e 6 | train_cfg = dict( 7 | type=EpochBasedTrainLoop, 8 | max_epochs=12, 9 | val_interval=2, 10 | ) 11 | val_cfg = dict(type=ValLoop) 12 | test_cfg = dict(type=TestLoop) 13 | 14 | # learning rate 15 | param_scheduler = [ 16 | dict( 17 | type=LinearLR, 18 | start_factor=0.001, 19 | by_epoch=False, 20 | begin=0, 21 | end=500 22 | ), 23 | dict( 24 | type=CosineAnnealingLR, 25 | convert_to_iter_based=True, 26 | begin=0, 27 | end=12, 28 | by_epoch=True, 29 | eta_min_ratio=0.01, 30 | ) 31 | ] 32 | 33 | _embed_multi = dict(lr_mult=1.0, decay_mult=0.0) 34 | optim_wrapper = dict( 35 | type=OptimWrapper, 36 | optimizer=dict( 37 | type=AdamW, 38 | lr=0.0001, 39 | weight_decay=0.05, 40 | eps=1e-8, 41 | betas=(0.9, 0.999) 42 | ), 43 | paramwise_cfg=dict( 44 | norm_decay_mult=0.0 45 | ), 46 | clip_grad=dict(max_norm=5., norm_type=2) 47 | ) 48 | 49 | # Default setting for scaling LR automatically 50 | # - `enable` means enable scaling LR automatically 51 | # or not by default. 52 | # - `base_batch_size` = (8 GPUs) x (2 samples per GPU). 53 | auto_scale_lr = dict(enable=True, base_batch_size=16) 54 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/schedules/schedule_160k_seg.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | from mmengine.optim import OptimWrapper, PolyLR 3 | from mmengine.runner import IterBasedTrainLoop, ValLoop, TestLoop 4 | from torch.optim import SGD 5 | 6 | optimizer = dict(type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0005) 7 | optim_wrapper = dict(type=OptimWrapper, optimizer=optimizer, clip_grad=None) 8 | # learning policy 9 | param_scheduler = [ 10 | dict( 11 | type=PolyLR, 12 | eta_min=1e-4, 13 | power=0.9, 14 | begin=0, 15 | end=160000, 16 | by_epoch=False 17 | ) 18 | ] 19 | # training schedule for 160k 20 | train_cfg = dict( 21 | type=IterBasedTrainLoop, max_iters=160000, val_interval=16000 22 | ) 23 | val_cfg = dict(type=ValLoop) 24 | test_cfg = dict(type=TestLoop) 25 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/schedules/schedule_160k_seg_adam.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | from mmengine.optim import OptimWrapper, PolyLR, LinearLR 3 | from mmengine.runner import IterBasedTrainLoop, ValLoop, TestLoop 4 | from torch.optim import AdamW 5 | 6 | optimizer = dict(type=AdamW, lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01) 7 | optim_wrapper = dict( 8 | type=OptimWrapper, 9 | optimizer=optimizer, 10 | clip_grad=None, 11 | paramwise_cfg=dict( 12 | custom_keys={ 13 | 'pos_block': dict(decay_mult=0.), 14 | 'norm': dict(decay_mult=0.), 15 | 'head': dict(lr_mult=10.) 16 | } 17 | ) 18 | ) 19 | 20 | # learning policy 21 | param_scheduler = [ 22 | dict(type=LinearLR, start_factor=1e-6, by_epoch=False, begin=0, end=1500), 23 | dict( 24 | type=PolyLR, 25 | power=1.0, 26 | begin=1500, 27 | end=160000, 28 | eta_min=0.0, 29 | by_epoch=False, 30 | ) 31 | ] 32 | # training schedule for 160k 33 | train_cfg = dict( 34 | type=IterBasedTrainLoop, max_iters=160000, val_interval=16000 35 | ) 36 | val_cfg = dict(type=ValLoop) 37 | test_cfg = dict(type=TestLoop) 38 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/schedules/schedule_1x.py: -------------------------------------------------------------------------------- 1 | # training schedule for 1x 2 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1) 3 | val_cfg = dict(type='ValLoop') 4 | test_cfg = dict(type='TestLoop') 5 | 6 | # learning rate 7 | param_scheduler = [ 8 | dict( 9 | type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), 10 | dict( 11 | type='MultiStepLR', 12 | begin=0, 13 | end=12, 14 | by_epoch=True, 15 | milestones=[8, 11], 16 | gamma=0.1) 17 | ] 18 | 19 | # optimizer 20 | optim_wrapper = dict( 21 | type='OptimWrapper', 22 | optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)) 23 | 24 | # Default setting for scaling LR automatically 25 | # - `enable` means enable scaling LR automatically 26 | # or not by default. 27 | # - `base_batch_size` = (8 GPUs) x (2 samples per GPU). 28 | auto_scale_lr = dict(enable=False, base_batch_size=16) 29 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/schedules/schedule_1x_adam.py: -------------------------------------------------------------------------------- 1 | # training schedule for 1x 2 | from mmengine.optim import OptimWrapper 3 | from torch.optim import AdamW 4 | 5 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1) 6 | val_cfg = dict(type='ValLoop') 7 | test_cfg = dict(type='TestLoop') 8 | 9 | # learning rate 10 | param_scheduler = [ 11 | dict( 12 | type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=1000), 13 | dict( 14 | type='MultiStepLR', 15 | begin=0, 16 | end=12, 17 | by_epoch=True, 18 | milestones=[8, 11], 19 | gamma=0.1) 20 | ] 21 | 22 | # optimizer 23 | optim_wrapper = dict( 24 | type=OptimWrapper, 25 | optimizer=dict( 26 | type=AdamW, 27 | lr=0.0001 * 0.25, 28 | weight_decay=0.05, 29 | eps=1e-8, 30 | betas=(0.9, 0.999) 31 | ), 32 | paramwise_cfg=dict( 33 | norm_decay_mult=0.0 34 | ), 35 | # clip_grad=dict(max_norm=5., norm_type=2) 36 | ) 37 | 38 | # Default setting for scaling LR automatically 39 | # - `enable` means enable scaling LR automatically 40 | # or not by default. 41 | # - `base_batch_size` = (8 GPUs) x (2 samples per GPU). 42 | auto_scale_lr = dict(enable=True, base_batch_size=16) 43 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/schedules/schedule_24e_distillation.py: -------------------------------------------------------------------------------- 1 | from mmengine.optim import LinearLR, OptimWrapper, CosineAnnealingLR 2 | from mmengine.runner import EpochBasedTrainLoop, ValLoop, TestLoop 3 | from torch.optim import AdamW 4 | 5 | # training schedule for 50e 6 | train_cfg = dict( 7 | type=EpochBasedTrainLoop, 8 | max_epochs=24, 9 | val_interval=2, 10 | ) 11 | val_cfg = dict(type=ValLoop) 12 | test_cfg = dict(type=TestLoop) 13 | 14 | # learning rate 15 | param_scheduler = [ 16 | dict( 17 | type=LinearLR, 18 | start_factor=0.001, 19 | by_epoch=False, 20 | begin=0, 21 | end=500 22 | ), 23 | dict( 24 | type=CosineAnnealingLR, 25 | convert_to_iter_based=True, 26 | begin=0, 27 | end=24, 28 | by_epoch=True, 29 | eta_min_ratio=0.01, 30 | ) 31 | ] 32 | 33 | _embed_multi = dict(lr_mult=1.0, decay_mult=0.0) 34 | optim_wrapper = dict( 35 | type=OptimWrapper, 36 | optimizer=dict( 37 | type=AdamW, 38 | lr=0.0001, 39 | weight_decay=0.05, 40 | eps=1e-8, 41 | betas=(0.9, 0.999) 42 | ), 43 | paramwise_cfg=dict( 44 | norm_decay_mult=0.0 45 | ), 46 | clip_grad=dict(max_norm=5., norm_type=2) 47 | ) 48 | 49 | # Default setting for scaling LR automatically 50 | # - `enable` means enable scaling LR automatically 51 | # or not by default. 52 | # - `base_batch_size` = (8 GPUs) x (2 samples per GPU). 53 | auto_scale_lr = dict(enable=True, base_batch_size=16) 54 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/_base_/schedules/schedule_300e_bs1024_for_imagenet.py: -------------------------------------------------------------------------------- 1 | # for batch in each gpu is 128, 8 gpu 2 | # lr = 5e-4 * 128 * 8 / 512 = 0.001 3 | from mmengine.optim import OptimWrapper, CosineAnnealingLR, LinearLR 4 | from mmengine.runner import ValLoop, TestLoop, EpochBasedTrainLoop 5 | from torch.optim import AdamW 6 | 7 | optim_wrapper = dict( 8 | type=OptimWrapper, 9 | optimizer=dict( 10 | type=AdamW, 11 | lr=5e-4 * 1024 / 512, 12 | weight_decay=0.05, 13 | eps=1e-8, 14 | betas=(0.9, 0.999)), 15 | paramwise_cfg=dict( 16 | norm_decay_mult=0.0, 17 | bias_decay_mult=0.0, 18 | flat_decay_mult=0.0, 19 | custom_keys={ 20 | '.absolute_pos_embed': dict(decay_mult=0.0), 21 | '.relative_position_bias_table': dict(decay_mult=0.0) 22 | }), 23 | clip_grad=dict(max_norm=5.0), 24 | ) 25 | 26 | # learning policy 27 | param_scheduler = [ 28 | # warm up learning rate scheduler 29 | dict( 30 | type=LinearLR, 31 | start_factor=1e-3, 32 | by_epoch=True, 33 | end=20, 34 | # update by iter 35 | convert_to_iter_based=True 36 | ), 37 | # main learning rate scheduler 38 | dict(type=CosineAnnealingLR, eta_min=1e-5, by_epoch=True, begin=20) 39 | ] 40 | 41 | # train, val, test setting 42 | train_cfg = dict( 43 | type=EpochBasedTrainLoop, 44 | max_epochs=300, 45 | val_interval=1, 46 | dynamic_intervals=[ 47 | (20, 10), 48 | (250, 5), 49 | (280, 1) 50 | ] 51 | ) 52 | val_cfg = dict(type=ValLoop) 53 | test_cfg = dict(type=TestLoop) 54 | 55 | # NOTE: `auto_scale_lr` is for automatically scaling LR, 56 | # based on the actual training batch size. 57 | auto_scale_lr = dict(enable=True, base_batch_size=1024) 58 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/backbone_dist/rwkvsam1001_000_vith_vitamin_rwkv_small_mlp2.py: -------------------------------------------------------------------------------- 1 | from mmdet.models import BatchFixedSizePad, DetDataPreprocessor, MSELoss 2 | from mmengine.config import read_base 3 | 4 | from projects.rwkvsam.models import BackboneDistillation, SAMBackbone, LastLayerNeck, LastLayerProjNeck, VITAMINBackbone 5 | 6 | with read_base(): 7 | from .._base_.default_runtime import * 8 | from .._base_.datasets.sam.sam_distill import * 9 | from .._base_.schedules.schedule_24e_distillation import * 10 | 11 | batch_augments = [ 12 | dict( 13 | type=BatchFixedSizePad, 14 | size=image_size, 15 | img_pad_value=0, 16 | ) 17 | ] 18 | data_preprocessor = dict( 19 | type=DetDataPreprocessor, 20 | mean=[123.675, 116.28, 103.53], 21 | std=[58.395, 57.12, 57.375], 22 | bgr_to_rgb=True, 23 | pad_size_divisor=1, 24 | batch_augments=batch_augments 25 | ) 26 | 27 | model = dict( 28 | type=BackboneDistillation, 29 | use_cache=True, 30 | data_preprocessor=data_preprocessor, 31 | backbone_teacher=dict( 32 | type=SAMBackbone, 33 | model_name='vit_h', 34 | fix=True, 35 | init_cfg=dict( 36 | type='sam_pretrain', 37 | checkpoint='vit_h' 38 | ) 39 | ), 40 | backbone_student=dict( 41 | type=VITAMINBackbone, 42 | img_size=(224, 224), 43 | model_variant='small', 44 | attn_type='rwkv', 45 | attn_cfg=dict( 46 | mlp_ratio=2, 47 | ), 48 | with_pos_embd=False, 49 | init_cfg=dict( 50 | type='Pretrained', 51 | checkpoint='work_dirs/ckpt/ssmseg_pretrain_vitamin_rwkv_mlp2_pretrain_16xbs64_best_accuracy_top1_epoch_294.pth', 52 | prefix='backbone.', 53 | ) 54 | ), 55 | neck_teacher=dict(type=LastLayerNeck), 56 | neck_student=dict( 57 | type=LastLayerProjNeck, 58 | in_channels=384, 59 | out_channels=256, 60 | ), 61 | loss_distill=dict( 62 | type=MSELoss, 63 | reduction='mean', 64 | loss_weight=1. 65 | ) 66 | ) 67 | 68 | val_dataloader = None 69 | val_evaluator = None 70 | val_cfg = None 71 | -------------------------------------------------------------------------------- /projects/rwkvsam/configs/backbone_dist/sam_vith_dump.py: -------------------------------------------------------------------------------- 1 | from mmdet.models import BatchFixedSizePad, DetDataPreprocessor 2 | from mmengine.config import read_base 3 | 4 | from projects.rwkvsam.models import BackboneDump, LastLayerNeck, SAMBackbone 5 | 6 | with read_base(): 7 | from .._base_.default_runtime import * 8 | from .._base_.datasets.sam.sam_distill import * 9 | from .._base_.schedules.schedule_12e_distillation import * 10 | 11 | image_size = (1024, 1024) 12 | batch_augments = [ 13 | dict( 14 | type=BatchFixedSizePad, 15 | size=image_size, 16 | img_pad_value=0, 17 | pad_mask=False, 18 | mask_pad_value=0, 19 | pad_seg=False, 20 | ) 21 | ] 22 | data_preprocessor = dict( 23 | type=DetDataPreprocessor, 24 | mean=[123.675, 116.28, 103.53], 25 | std=[58.395, 57.12, 57.375], 26 | bgr_to_rgb=True, 27 | pad_size_divisor=1024, 28 | pad_mask=False, 29 | mask_pad_value=0, 30 | pad_seg=False, 31 | batch_augments=batch_augments 32 | ) 33 | 34 | model = dict( 35 | type=BackboneDump, 36 | data_preprocessor=data_preprocessor, 37 | backbone=dict( 38 | type=SAMBackbone, 39 | model_name='vit_h', 40 | fix=True, 41 | init_cfg=dict( 42 | type='sam_pretrain', 43 | checkpoint='vit_h' 44 | ) 45 | ), 46 | neck=dict( 47 | type=LastLayerNeck 48 | ) 49 | ) 50 | 51 | val_dataloader = None 52 | val_evaluator = None 53 | val_cfg = None 54 | -------------------------------------------------------------------------------- /projects/rwkvsam/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .sam import SAMDataset 2 | from .dis5k import DIS5KDataset 3 | from .entity_seg import EntitySegDataset 4 | from .concat_dataset import AdvancedConcatDataset 5 | from .thin_obj_det import ThinOBJDataset 6 | from .coconut_panoptic import CocoNutPanopticDataset 7 | -------------------------------------------------------------------------------- /projects/rwkvsam/datasets/dis5k.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from mmdet.registry import DATASETS 4 | from mmengine import print_log, join_path, list_dir_or_file 5 | from mmdet.datasets import BaseDetDataset 6 | 7 | @DATASETS.register_module() 8 | class DIS5KDataset(BaseDetDataset): 9 | METAINFO = { 10 | 'classes': None, 11 | 'palette': None, 12 | } 13 | 14 | def __init__(self, *args, img_map_suffix='.jpg', **kwargs): 15 | self.img_map_suffix = img_map_suffix 16 | self.id2folder = dict() 17 | super().__init__(*args, **kwargs) 18 | 19 | def load_data_list(self) -> List[dict]: 20 | print_log('Starting to load DIS5K dataset', 'current') 21 | folders = [] 22 | if 'TR' in self.ann_file: 23 | folders.append('DIS-TR') 24 | elif 'VD' in self.ann_file: 25 | folders.append('DIS-VD') 26 | 27 | img_ids_list = [] 28 | for folder in folders: 29 | folder_path = join_path(self.data_prefix['img'], folder) 30 | im_folder_path = join_path(folder_path, 'im') 31 | img_ids = sorted( 32 | list_dir_or_file(im_folder_path, recursive=False, list_dir=False, suffix=self.img_map_suffix) 33 | ) 34 | img_ids_list.extend(img_ids) 35 | for img_id in img_ids: 36 | self.id2folder[img_id] = folder 37 | 38 | img_ids = img_ids_list 39 | data_list = [] 40 | for img_id in img_ids: 41 | data_info = { 42 | 'img_id': img_id, 43 | 'img_path': join_path(self.data_prefix['img'], self.id2folder[img_id], 'im', img_id), 44 | 'ann_path': join_path(self.data_prefix['img'], self.id2folder[img_id], 'gt', 45 | img_id.replace('.jpg', '.png')), 46 | } 47 | data_list.append(data_info) 48 | print_log(f'Found {len(data_list)} in {len(folders)} folders.', 'current') 49 | return data_list 50 | -------------------------------------------------------------------------------- /projects/rwkvsam/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .loading import LoadMaskFromFile 2 | from .optimization import * 3 | -------------------------------------------------------------------------------- /projects/rwkvsam/datasets/pipelines/optimization.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from mmcv import BaseTransform 4 | from mmdet.registry import TRANSFORMS 5 | 6 | @TRANSFORMS.register_module() 7 | class FixPadOptimization(BaseTransform): 8 | 9 | def __init__( 10 | self, 11 | img_scale=None, 12 | mask_pad_val=0, 13 | num_proposals=0, 14 | ): 15 | self.img_scale = img_scale 16 | self.mask_pad_val = mask_pad_val 17 | self.num_proposals = num_proposals 18 | 19 | def transform(self, results: dict) -> Optional[dict]: 20 | if self.num_proposals > 0: 21 | results['data_samples'].gt_instances = results['data_samples'].gt_instances[:self.num_proposals] 22 | results['data_samples'].gt_instances.masks = results['data_samples'].gt_instances.masks.pad( 23 | self.img_scale, 24 | pad_val=self.mask_pad_val) 25 | return results 26 | 27 | def __repr__(self): 28 | repr_str = f'{self.__class__.__name__}' 29 | return repr_str 30 | -------------------------------------------------------------------------------- /projects/rwkvsam/datasets/sam.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from mmdet.registry import DATASETS 4 | from mmengine import get_local_path, list_from_file, join_path, list_dir_or_file, print_log 5 | 6 | from mmdet.datasets import BaseDetDataset 7 | 8 | 9 | @DATASETS.register_module() 10 | class SAMDataset(BaseDetDataset): 11 | 12 | def __init__(self, *args, img_map_suffix='.jpg', custom_structure=False, **kwargs): 13 | self.img_map_suffix = img_map_suffix 14 | self.id2folder = dict() 15 | self.custom_structure = custom_structure 16 | super().__init__(*args, **kwargs) 17 | 18 | def load_data_list(self) -> List[dict]: 19 | print_log('Starting to load sam dataset', 'current') 20 | if 'sa_1b' in self.ann_file: 21 | if 'sa_1b_one' in self.ann_file: 22 | folders = [f'sa_000{str(idx).zfill(3)}' for idx in range(1)] 23 | elif 'sa_1b_001' in self.ann_file: 24 | # sa_000000 -> sa_000009 25 | folders = [f'sa_000{str(idx).zfill(3)}' for idx in range(10)] 26 | elif 'sa_1b_01' in self.ann_file: 27 | folders = [f'sa_000{str(idx).zfill(3)}' for idx in range(100)] 28 | else: 29 | # sa_000000 -> sa_000999 30 | folders = [f'sa_000{str(idx).zfill(3)}' for idx in range(1000)] 31 | else: 32 | with get_local_path( 33 | self.ann_file, backend_args=self.backend_args) as local_path: 34 | folders = list_from_file(local_path) 35 | 36 | 37 | img_ids_list = [] 38 | for folder in folders: 39 | folder_path = join_path(self.data_prefix['img'], folder) 40 | img_ids = sorted(list(map( 41 | lambda x: int(x.split('.')[0].split('_')[-1]), 42 | list_dir_or_file(folder_path, recursive=False, list_dir=False, suffix='.jpg', backend_args=self.backend_args) 43 | ))) 44 | print_log(f'Found {len(img_ids)} in {folder}.', 'current') 45 | img_ids_list.extend(img_ids) 46 | for img_id in img_ids: 47 | self.id2folder[img_id] = folder 48 | 49 | img_ids = img_ids_list 50 | data_list = [] 51 | for img_id in img_ids: 52 | if self.custom_structure: 53 | # tt structure 54 | data_info = { 55 | 'img_id': img_id, 56 | 'img_path': join_path(self.data_prefix['img'], 57 | self.id2folder[img_id], 58 | 'img', 59 | self.id2folder[img_id], 60 | f"sa_{img_id}.jpg", 61 | backend_args=self.backend_args), 62 | 'info_path': join_path(self.data_prefix['img'], 63 | self.id2folder[img_id], 64 | 'label', 65 | self.id2folder[img_id], 66 | f"sa_{img_id}.json", 67 | backend_args=self.backend_args), 68 | } 69 | else: 70 | data_info = { 71 | 'img_id': img_id, 72 | 'img_path': join_path(self.data_prefix['img'], self.id2folder[img_id], f"sa_{img_id}.jpg", backend_args=self.backend_args), 73 | 'info_path': join_path(self.data_prefix['img'], self.id2folder[img_id], f"sa_{img_id}.json", backend_args=self.backend_args), 74 | } 75 | data_list.append(data_info) 76 | print_log(f'Found {len(data_list)} in {len(folders)} folders.', 'current') 77 | return data_list 78 | -------------------------------------------------------------------------------- /projects/rwkvsam/datasets/thin_obj_det.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from mmdet.registry import DATASETS 4 | from mmengine import print_log, join_path, list_dir_or_file 5 | from mmdet.datasets import BaseDetDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class ThinOBJDataset(BaseDetDataset): 10 | METAINFO = { 11 | 'classes': None, 12 | 'palette': None, 13 | } 14 | 15 | def __init__(self, *args, img_map_suffix='.jpg', **kwargs): 16 | self.img_map_suffix = img_map_suffix 17 | self.id2folder = dict() 18 | super().__init__(*args, **kwargs) 19 | 20 | def load_data_list(self) -> List[dict]: 21 | print_log(f'Starting to load Thin Obj Detection dataset from {self.data_root}', 'current') 22 | 23 | img_ids_list = [] 24 | img_ids = sorted( 25 | list_dir_or_file(self.data_prefix['img'], recursive=False, list_dir=False, suffix=self.img_map_suffix) 26 | ) 27 | img_ids_list.extend(img_ids) 28 | 29 | img_ids = img_ids_list 30 | data_list = [] 31 | for img_id in img_ids: 32 | data_info = { 33 | 'img_id': img_id, 34 | 'img_path': join_path(self.data_prefix['img'], img_id), 35 | 'ann_path': join_path(self.data_prefix['ann'], img_id.replace('.jpg', '.png')), 36 | } 37 | data_list.append(data_info) 38 | print_log(f'Found {len(data_list)} samples.', 'current') 39 | return data_list 40 | -------------------------------------------------------------------------------- /projects/rwkvsam/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .iou_metric import IoUMetric 2 | from .biou_metric import BIoUMetric 3 | -------------------------------------------------------------------------------- /projects/rwkvsam/evaluation/biou_metric.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Dict 2 | 3 | import numpy as np 4 | import torch 5 | from mmengine.dist import collect_results, broadcast_object_list, is_main_process 6 | 7 | from mmengine.evaluator import BaseMetric 8 | from mmdet.registry import METRICS 9 | from mmengine.evaluator.metric import _to_cpu 10 | 11 | from projects.rwkvsam.utils.boundary_iou import mask_to_boundary 12 | 13 | 14 | @METRICS.register_module() 15 | class BIoUMetric(BaseMetric): 16 | 17 | def __init__(self, collect_device: str = 'cpu', prefix: Optional[str] = None) -> None: 18 | super().__init__(collect_device=collect_device, prefix=prefix) 19 | 20 | self.ious = [] 21 | 22 | def get_iou(self, gt_masks, pred_masks): 23 | gt_masks = gt_masks 24 | n, h, w = gt_masks.shape 25 | intersection = (gt_masks & pred_masks).reshape(n, h * w).sum(dim=-1) 26 | union = (gt_masks | pred_masks).reshape(n, h * w).sum(dim=-1) 27 | ious = (intersection / (union + 1.e-8)) 28 | return ious 29 | 30 | def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: 31 | for data_sample in data_samples: 32 | pred_masks = data_sample['pred_instances']['masks'] 33 | gt_masks = data_sample['gt_instances']['masks'] 34 | device = pred_masks.device 35 | 36 | gt_masks = gt_masks.to_tensor(dtype=torch.bool, device=pred_masks.device) 37 | gt_boundary = mask_to_boundary(gt_masks.cpu().numpy()[0].astype(np.uint8)) 38 | pred_boundary = mask_to_boundary(pred_masks.cpu().numpy()[0].astype(np.uint8)) 39 | 40 | biou = self.get_iou( 41 | torch.tensor(gt_boundary[None]).to(device=device), 42 | torch.tensor(pred_boundary[None]).to(device=device) 43 | ) 44 | self.ious.append(biou) 45 | 46 | def compute_metrics(self, iou_list) -> Dict[str, float]: 47 | mean_iou = sum(iou_list) / len(iou_list) 48 | results = dict() 49 | results['biou'] = mean_iou * 100 50 | return results 51 | 52 | def evaluate(self, size: int) -> dict: 53 | _ious = collect_results(self.ious, size, self.collect_device) 54 | if is_main_process(): 55 | _ious = _to_cpu(_ious) 56 | ious = torch.cat(_ious) 57 | _metrics = self.compute_metrics(ious) 58 | metrics = [_metrics] 59 | else: 60 | metrics = [None] # type: ignore 61 | broadcast_object_list(metrics) 62 | return metrics[0] 63 | -------------------------------------------------------------------------------- /projects/rwkvsam/evaluation/iou_metric.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Dict 2 | 3 | import torch 4 | from mmengine.dist import collect_results, broadcast_object_list, is_main_process 5 | 6 | from mmengine.evaluator import BaseMetric 7 | from mmdet.registry import METRICS 8 | from mmengine.evaluator.metric import _to_cpu 9 | 10 | 11 | @METRICS.register_module() 12 | class IoUMetric(BaseMetric): 13 | 14 | def __init__(self, collect_device: str = 'cpu', prefix: Optional[str] = None) -> None: 15 | super().__init__(collect_device=collect_device, prefix=prefix) 16 | 17 | self.ious = [] 18 | 19 | def get_iou(self, gt_masks, pred_masks): 20 | gt_masks = gt_masks 21 | n, h, w = gt_masks.shape 22 | intersection = (gt_masks & pred_masks).reshape(n, h * w).sum(dim=-1) 23 | union = (gt_masks | pred_masks).reshape(n, h * w).sum(dim=-1) 24 | ious = (intersection / (union + 1.e-8)) 25 | return ious 26 | 27 | def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: 28 | for data_sample in data_samples: 29 | pred_masks = data_sample['pred_instances']['masks'] 30 | gt_masks = data_sample['gt_instances']['masks'] 31 | gt_masks = gt_masks.to_tensor(dtype=torch.bool, device=pred_masks.device) 32 | iou = self.get_iou(gt_masks, pred_masks) 33 | self.ious.append(iou) 34 | 35 | def compute_metrics(self, iou_list) -> Dict[str, float]: 36 | mean_iou = sum(iou_list) / len(iou_list) 37 | results = dict() 38 | results['miou'] = mean_iou * 100 39 | return results 40 | 41 | def evaluate(self, size: int) -> dict: 42 | _ious = collect_results(self.ious, size, self.collect_device) 43 | if is_main_process(): 44 | _ious = _to_cpu(_ious) 45 | ious = torch.cat(_ious) 46 | _metrics = self.compute_metrics(ious) 47 | metrics = [_metrics] 48 | else: 49 | metrics = [None] # type: ignore 50 | broadcast_object_list(metrics) 51 | return metrics[0] 52 | -------------------------------------------------------------------------------- /projects/rwkvsam/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * 2 | from .detectors import * 3 | from .necks import * 4 | from .heads import * 5 | from .preprocessors import * -------------------------------------------------------------------------------- /projects/rwkvsam/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .vitamin import VITAMINBackbone 2 | from .sam_backbone import SAMBackbone 3 | -------------------------------------------------------------------------------- /projects/rwkvsam/models/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | from .feature_extraction import FeatExtraction, MaskExtraction 2 | from .sam_clip_distill import BackboneDistillation 3 | from .sam_model import SAMModel 4 | from .det_and_seg import DetSeg 5 | from .json_loader import JsonLoader 6 | from .sam_dump import BackboneDump 7 | -------------------------------------------------------------------------------- /projects/rwkvsam/models/detectors/feature_extraction.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Dict 2 | 3 | import torch 4 | from mmdet.registry import MODELS 5 | from mmdet.utils import OptConfigType 6 | from mmengine.model import BaseModel 7 | from mmengine.structures import InstanceData 8 | 9 | 10 | @MODELS.register_module() 11 | class FeatExtraction(BaseModel): 12 | 13 | def __init__( 14 | self, 15 | data_preprocessor, 16 | backbone: OptConfigType = None, 17 | init_cfg=None, 18 | ): 19 | super(FeatExtraction, self).__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg) 20 | self.backbone = MODELS.build(backbone) 21 | 22 | def forward(self, 23 | inputs: torch.Tensor, 24 | data_samples: Optional[list] = None, 25 | mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: 26 | feats = self.backbone(inputs) 27 | return feats 28 | 29 | 30 | @MODELS.register_module() 31 | class MaskExtraction(BaseModel): 32 | def __init__( 33 | self, 34 | data_preprocessor, 35 | backbone: OptConfigType, 36 | neck: OptConfigType, 37 | prompt_encoder: OptConfigType, 38 | mask_decoder: OptConfigType, 39 | init_cfg=None, 40 | ): 41 | super(MaskExtraction, self).__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg) 42 | self.backbone = MODELS.build(backbone) 43 | self.neck = MODELS.build(neck) 44 | self.pe = MODELS.build(prompt_encoder) 45 | self.mask_decoder = MODELS.build(mask_decoder) 46 | 47 | self.add_extra_hr_feat = True 48 | self.num_ins = 1 49 | 50 | def forward(self, 51 | inputs: torch.Tensor, 52 | data_samples: Optional[list] = None, 53 | mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: 54 | backbone_feat = self.backbone(inputs) 55 | batch_feats = self.neck(backbone_feat) 56 | prompt_instances = InstanceData( 57 | bboxes=torch.tensor([[0, 0, 1, 1]], dtype=torch.float32, device=batch_feats.device) 58 | ) 59 | sparse_embed, dense_embed = self.pe( 60 | prompt_instances, 61 | image_size=data_samples[0].batch_input_shape, 62 | with_bboxes=True, 63 | ) 64 | 65 | kwargs = dict() 66 | if self.add_extra_hr_feat: 67 | kwargs['hr_feat'] = backbone_feat[0] 68 | kwargs['mr_feat'] = backbone_feat[1] 69 | 70 | low_res_masks, iou_predictions = self.mask_decoder( 71 | image_embeddings=batch_feats, 72 | image_pe=self.pe.get_dense_pe(), 73 | sparse_prompt_embeddings=sparse_embed, 74 | dense_prompt_embeddings=dense_embed, 75 | multi_mask_output=False, 76 | **kwargs 77 | ) 78 | return low_res_masks 79 | -------------------------------------------------------------------------------- /projects/rwkvsam/models/detectors/sam_dump.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from typing import Union, Tuple, Dict, List 3 | 4 | import mmengine 5 | import torch 6 | from mmdet.models.detectors.base import ForwardResults 7 | from mmengine import print_log 8 | from mmengine.model import BaseModel 9 | from torch import Tensor 10 | 11 | from mmdet.registry import MODELS 12 | from mmdet.structures import SampleList, OptSampleList 13 | from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig 14 | 15 | 16 | @MODELS.register_module() 17 | class BackboneDump(BaseModel): 18 | 19 | def __init__( 20 | self, 21 | backbone: ConfigType, 22 | neck: ConfigType, 23 | data_preprocessor: OptConfigType = None, 24 | init_cfg: OptMultiConfig = None, 25 | ) -> None: 26 | super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg) 27 | self.backbone = MODELS.build(backbone) 28 | self.neck = MODELS.build(neck) 29 | 30 | self.register_parameter('dummy', torch.nn.Parameter(torch.zeros(1))) 31 | 32 | def forward(self, 33 | inputs: torch.Tensor, 34 | data_samples: OptSampleList = None, 35 | mode: str = 'tensor') -> ForwardResults: 36 | if mode == 'loss': 37 | return self.loss(inputs, data_samples) 38 | elif mode == 'predict': 39 | return self.predict(inputs, data_samples) 40 | elif mode == 'tensor': 41 | return self._forward(inputs, data_samples) 42 | else: 43 | raise RuntimeError(f'Invalid mode "{mode}". ' 44 | 'Only supports loss, predict and tensor mode') 45 | 46 | def _forward(self, *args, **kwargs) -> Tuple[Tensor]: 47 | raise NotImplementedError 48 | 49 | def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor, Tensor]: 50 | return self.neck(self.backbone(batch_inputs)) 51 | 52 | def predict(self, batch_inputs: Tensor, 53 | batch_data_samples: SampleList) -> Union[Dict, List]: 54 | feat = self.extract_feat(batch_inputs) 55 | 56 | assert len(batch_data_samples) == 1 57 | dir_path = os.path.dirname(batch_data_samples[0].metainfo['img_path']) 58 | dir_path = os.path.dirname(dir_path) 59 | dir_path = dir_path.replace('sam', 'sam_feat') 60 | 61 | img_path = os.path.basename(batch_data_samples[0].metainfo['img_path']) 62 | img_path = img_path.replace('.jpg', f'_{self.backbone.model_name}_cache.pth') 63 | img_path = os.path.join(dir_path, img_path) 64 | 65 | if not mmengine.exists(img_path): 66 | feat = feat.to(device='cpu')[0] 67 | torch.save(feat.to(device='cpu'), img_path) 68 | else: 69 | print_log(f'{img_path} already exists, but still regenerate.') 70 | feat = feat.to(device='cpu')[0] 71 | torch.save(feat.to(device='cpu'), img_path) 72 | return {} 73 | 74 | def loss(self, batch_inputs: Tensor, 75 | batch_data_samples: SampleList) -> Union[Dict, List]: 76 | raise NotImplementedError 77 | -------------------------------------------------------------------------------- /projects/rwkvsam/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .sam_mask_decoder import SAMMaskDecoder 2 | # exps 3 | # option1 4 | from .sam_mask_decoder_rwkv_mlpmerge import SAMRWKVHRMaskDecoderMLPMerge 5 | -------------------------------------------------------------------------------- /projects/rwkvsam/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | from .last_layer import LastLayerNeck, LastLayerProjNeck, LastTwoLayerProjNeck 2 | from .sam_pe import SAMPromptEncoder 3 | from .gap import GlobalAveragePoolingBugFix 4 | # from .efficient_sam_pe import EfficientSAMPromptEncoder 5 | -------------------------------------------------------------------------------- /projects/rwkvsam/models/necks/gap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from mmpretrain.registry import MODELS 8 | 9 | 10 | @MODELS.register_module() 11 | class GlobalAveragePoolingBugFix(nn.Module): 12 | """Global Average Pooling neck. 13 | 14 | Note that we use `view` to remove extra channel after pooling. We do not 15 | use `squeeze` as it will also remove the batch dimension when the tensor 16 | has a batch dimension of size 1, which can lead to unexpected errors. 17 | 18 | Args: 19 | dim (int): Dimensions of each sample channel, can be one of {1, 2, 3}. 20 | Default: 2 21 | """ 22 | 23 | def __init__(self, dim=2): 24 | super().__init__() 25 | assert dim in [1, 2, 3], 'GlobalAveragePooling dim only support ' \ 26 | f'{1, 2, 3}, get {dim} instead.' 27 | if dim == 1: 28 | self.gap = nn.AdaptiveAvgPool1d(1) 29 | elif dim == 2: 30 | self.gap = nn.AdaptiveAvgPool2d((1, 1)) 31 | else: 32 | self.gap = nn.AdaptiveAvgPool3d((1, 1, 1)) 33 | 34 | def init_weights(self): 35 | pass 36 | 37 | def forward(self, inputs): 38 | if isinstance(inputs, tuple) or isinstance(inputs, List): 39 | outs = tuple([self.gap(x) for x in inputs]) 40 | outs = tuple( 41 | [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]) 42 | elif isinstance(inputs, torch.Tensor): 43 | outs = self.gap(inputs) 44 | outs = outs.view(inputs.size(0), -1) 45 | else: 46 | raise TypeError('neck inputs should be tuple or torch.tensor') 47 | return outs 48 | -------------------------------------------------------------------------------- /projects/rwkvsam/models/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_preprocessors import DetDataInferenceTimePreprocessor 2 | from .sameval_preprocessor import SAMEvalDataPreprocessor 3 | from .ovsam_preprocessor import OVSAMDataPreprocessor 4 | -------------------------------------------------------------------------------- /projects/rwkvsam/models/preprocessors/data_preprocessors.py: -------------------------------------------------------------------------------- 1 | from mmdet.models import DetDataPreprocessor 2 | from mmdet.registry import MODELS 3 | 4 | 5 | @MODELS.register_module() 6 | class DetDataInferenceTimePreprocessor(DetDataPreprocessor): 7 | def forward(self, data: dict, training: bool = False) -> dict: 8 | data = super().forward(data=data, training=True) 9 | return data 10 | -------------------------------------------------------------------------------- /projects/rwkvsam/models/preprocessors/sameval_preprocessor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mmdet.models import DetDataPreprocessor 4 | from mmdet.registry import MODELS 5 | from mmdet.structures.mask import BitmapMasks 6 | 7 | 8 | @MODELS.register_module() 9 | class SAMEvalDataPreprocessor(DetDataPreprocessor): 10 | 11 | def forward(self, data: dict, training: bool = False) -> dict: 12 | data = super().forward(data, training=True) 13 | for data_sample in data['data_samples']: 14 | pred_instances = data_sample.pred_instances 15 | bboxes = pred_instances.bboxes 16 | scale_factor = bboxes.new_tensor(data_sample.scale_factor).repeat(2) 17 | bboxes = bboxes * scale_factor 18 | pred_instances.bboxes = bboxes 19 | 20 | if 'masks' in pred_instances: 21 | masks = BitmapMasks(pred_instances.masks.to(device='cpu', dtype=torch.uint8).numpy(), 22 | *pred_instances.masks.shape[-2:]) 23 | masks = masks.resize(data_sample.img_shape) 24 | if self.pad_mask: 25 | masks = masks.pad(data_sample.batch_input_shape, pad_val=self.mask_pad_value) 26 | pred_instances.masks = masks 27 | 28 | return data 29 | -------------------------------------------------------------------------------- /projects/rwkvsam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_checkpoint import load_checkpoint_with_prefix 2 | -------------------------------------------------------------------------------- /projects/rwkvsam/utils/boundary_iou.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | import pycocotools.mask as mask_utils 7 | 8 | 9 | # General util function to get the boundary of a binary mask. 10 | def mask_to_boundary(mask, dilation_ratio=0.02): 11 | """ 12 | Convert binary mask to boundary mask. 13 | :param mask (numpy array, uint8): binary mask 14 | :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal 15 | :return: boundary mask (numpy array) 16 | """ 17 | h, w = mask.shape 18 | img_diag = np.sqrt(h ** 2 + w ** 2) 19 | dilation = int(round(dilation_ratio * img_diag)) 20 | if dilation < 1: 21 | dilation = 1 22 | # Pad image so mask truncated by the image border is also considered as boundary. 23 | new_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0) 24 | kernel = np.ones((3, 3), dtype=np.uint8) 25 | new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation) 26 | mask_erode = new_mask_erode[1: h + 1, 1: w + 1] 27 | # G_d intersects G in the paper. 28 | return mask - mask_erode 29 | 30 | 31 | # COCO/LVIS related util functions, to get the boundary for every annotations. 32 | def augment_annotations_with_boundary_single_core(proc_id, annotations, ann_to_mask, dilation_ratio=0.02): 33 | new_annotations = [] 34 | 35 | for ann in annotations: 36 | mask = ann_to_mask(ann) 37 | # Find mask boundary. 38 | boundary = mask_to_boundary(mask, dilation_ratio) 39 | # Add boundary to annotation in RLE format. 40 | ann['boundary'] = mask_utils.encode( 41 | np.array(boundary[:, :, None], order="F", dtype="uint8"))[0] 42 | new_annotations.append(ann) 43 | 44 | return new_annotations 45 | 46 | 47 | def augment_annotations_with_boundary_multi_core(annotations, ann_to_mask, dilation_ratio=0.02): 48 | cpu_num = multiprocessing.cpu_count() 49 | annotations_split = np.array_split(annotations, cpu_num) 50 | print("Number of cores: {}, annotations per core: {}".format(cpu_num, len(annotations_split[0]))) 51 | workers = multiprocessing.Pool(processes=cpu_num) 52 | processes = [] 53 | 54 | for proc_id, annotation_set in enumerate(annotations_split): 55 | p = workers.apply_async(augment_annotations_with_boundary_single_core, 56 | (proc_id, annotation_set, ann_to_mask, dilation_ratio)) 57 | processes.append(p) 58 | 59 | new_annotations = [] 60 | for p in processes: 61 | new_annotations.extend(p.get()) 62 | 63 | workers.close() 64 | workers.join() 65 | 66 | return new_annotations 67 | -------------------------------------------------------------------------------- /projects/rwkvsam/utils/load_checkpoint.py: -------------------------------------------------------------------------------- 1 | from mmengine.runner.checkpoint import CheckpointLoader 2 | from huggingface_hub import hf_hub_download 3 | 4 | HF_HUB_PREFIX = 'hf-hub:' 5 | 6 | def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'): 7 | """Load partial pretrained model with specific prefix. 8 | 9 | Args: 10 | prefix (str): The prefix of sub-module. 11 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, 12 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for 13 | details. 14 | map_location (str | None): Same as :func:`torch.load`. 15 | Defaults to None. 16 | logger: logger 17 | 18 | Returns: 19 | dict or OrderedDict: The loaded checkpoint. 20 | """ 21 | if filename.startswith('hf-hub:'): 22 | model_id = filename[len(HF_HUB_PREFIX):] 23 | filename = hf_hub_download(model_id, 'pytorch_model.bin') 24 | 25 | checkpoint = CheckpointLoader.load_checkpoint(filename, map_location=map_location, logger=logger) 26 | 27 | if 'state_dict' in checkpoint: 28 | state_dict = checkpoint['state_dict'] 29 | elif 'model' in checkpoint: 30 | state_dict = checkpoint['model'] 31 | else: 32 | state_dict = checkpoint 33 | if not prefix: 34 | return state_dict 35 | if not prefix.endswith('.'): 36 | prefix += '.' 37 | prefix_len = len(prefix) 38 | 39 | state_dict = { 40 | k[prefix_len:]: v 41 | for k, v in state_dict.items() if k.startswith(prefix) 42 | } 43 | 44 | assert state_dict, f'{prefix} is not in the pretrained model' 45 | return state_dict 46 | -------------------------------------------------------------------------------- /seg/configs/_base_/datasets/coco_ov_instance_lsj.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | from mmcv import LoadImageFromFile, RandomResize 3 | from mmdet.datasets import AspectRatioBatchSampler 4 | from mmdet.datasets.transforms import LoadAnnotations, Resize, RandomFlip, PackDetInputs, RandomCrop 5 | from mmengine.dataset import DefaultSampler 6 | 7 | from seg.datasets.coco_ins_ov import CocoOVDataset 8 | from seg.datasets.pipeliens.loading import FilterAnnotationsHB 9 | from seg.evaluation.ins_cls_iou_metric import InsClsIoUMetric 10 | 11 | from ext.class_names.coco_4817_ids import COCO4817_BASE_IDS, COCO4817_NOVEL_IDS 12 | 13 | data_root = 'data/coco/' 14 | backend_args = None 15 | dataset_type = CocoOVDataset 16 | 17 | image_size = (1024, 1024) 18 | 19 | train_pipeline = [ 20 | dict( 21 | type=LoadImageFromFile, 22 | to_float32=True, 23 | backend_args=backend_args), 24 | dict( 25 | type=LoadAnnotations, 26 | with_bbox=True, 27 | with_mask=True, 28 | backend_args=backend_args), 29 | dict(type=RandomFlip, prob=0.5), 30 | dict( 31 | type=RandomResize, 32 | resize_type=Resize, 33 | scale=image_size, 34 | ratio_range=(.9, 2.), 35 | keep_ratio=True, 36 | ), 37 | dict( 38 | type=RandomCrop, 39 | crop_size=image_size, 40 | crop_type='absolute', 41 | recompute_bbox=True, 42 | allow_negative_crop=True), 43 | dict( 44 | type=FilterAnnotationsHB, 45 | by_box=False, 46 | by_mask=True, 47 | min_gt_mask_area=32, 48 | ), 49 | dict(type=PackDetInputs) 50 | ] 51 | train_dataloader = dict( 52 | batch_size=2, 53 | num_workers=2, 54 | persistent_workers=True, 55 | sampler=dict(type=DefaultSampler, shuffle=True), 56 | batch_sampler=dict(type=AspectRatioBatchSampler), 57 | dataset=dict( 58 | type=dataset_type, 59 | data_root=data_root, 60 | ann_file='annotations/instances_train2017.json', 61 | data_prefix=dict(img='train2017/'), 62 | filter_cfg=dict(filter_empty_gt=True, min_size=32, sub_split='48_17'), 63 | pipeline=train_pipeline, 64 | backend_args=backend_args) 65 | ) 66 | 67 | 68 | test_pipeline = [ 69 | dict(type=LoadImageFromFile, backend_args=backend_args), 70 | dict(type=Resize, scale=(1024, 1024), keep_ratio=True), 71 | dict(type=LoadAnnotations, with_bbox=True, with_mask=True), 72 | dict( 73 | type=PackDetInputs, 74 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 75 | ) 76 | ] 77 | val_dataloader = dict( 78 | batch_size=1, 79 | num_workers=2, 80 | persistent_workers=True, 81 | drop_last=False, 82 | sampler=dict(type=DefaultSampler, shuffle=False), 83 | dataset=dict( 84 | type=dataset_type, 85 | data_root=data_root, 86 | ann_file='annotations/instances_val2017.json', 87 | data_prefix=dict(img='val2017/'), 88 | filter_cfg=dict(sub_split='48_17'), 89 | test_mode=True, 90 | return_classes=True, 91 | pipeline=test_pipeline, 92 | backend_args=backend_args) 93 | ) 94 | test_dataloader = val_dataloader 95 | 96 | val_evaluator = [ 97 | dict( 98 | type=InsClsIoUMetric, 99 | prefix='coco_ins', 100 | base_classes=COCO4817_BASE_IDS, 101 | novel_classes=COCO4817_NOVEL_IDS, 102 | ), 103 | ] 104 | test_evaluator = val_evaluator 105 | -------------------------------------------------------------------------------- /seg/configs/_base_/datasets/lvis_norare.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | from mmcv import LoadImageFromFile, RandomResize 3 | from mmdet.datasets import LVISV1Dataset, AspectRatioBatchSampler 4 | from mmdet.datasets.transforms import LoadAnnotations, Resize, RandomFlip, PackDetInputs, RandomCrop 5 | from mmengine.dataset import DefaultSampler 6 | 7 | from seg.datasets.pipeliens.loading import FilterAnnotationsHB 8 | from seg.evaluation.ins_cls_iou_metric import InsClsIoUMetric 9 | 10 | from ext.class_names.lvis_ids import LVIS_BASE_IDS, LVIS_RARE_IDS 11 | 12 | data_root = 'data/lvis/' 13 | backend_args = None 14 | dataset_type = LVISV1Dataset 15 | 16 | image_size = (1024, 1024) 17 | 18 | train_pipeline = [ 19 | dict( 20 | type=LoadImageFromFile, 21 | to_float32=True, 22 | backend_args=backend_args), 23 | dict( 24 | type=LoadAnnotations, 25 | with_bbox=True, 26 | with_mask=True, 27 | backend_args=backend_args), 28 | dict(type=RandomFlip, prob=0.5), 29 | dict( 30 | type=RandomResize, 31 | resize_type=Resize, 32 | scale=image_size, 33 | ratio_range=(.1, 2.), 34 | keep_ratio=True, 35 | ), 36 | dict( 37 | type=RandomCrop, 38 | crop_size=image_size, 39 | crop_type='absolute', 40 | recompute_bbox=True, 41 | allow_negative_crop=True), 42 | dict( 43 | type=FilterAnnotationsHB, 44 | by_box=False, 45 | by_mask=True, 46 | min_gt_mask_area=32, 47 | ), 48 | dict(type=PackDetInputs) 49 | ] 50 | train_dataloader = dict( 51 | batch_size=2, 52 | num_workers=2, 53 | persistent_workers=True, 54 | sampler=dict(type=DefaultSampler, shuffle=True), 55 | batch_sampler=dict(type=AspectRatioBatchSampler), 56 | dataset=dict( 57 | type=dataset_type, 58 | data_root=data_root, 59 | ann_file='annotations/lvis_v1_train_norare.json', 60 | data_prefix=dict(img=''), 61 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 62 | pipeline=train_pipeline, 63 | backend_args=backend_args) 64 | ) 65 | 66 | test_pipeline = [ 67 | dict(type=LoadImageFromFile, backend_args=backend_args), 68 | dict(type=Resize, scale=(1024, 1024), keep_ratio=True), 69 | dict(type=LoadAnnotations, with_bbox=True, with_mask=True), 70 | dict( 71 | type=PackDetInputs, 72 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', 'instances') 73 | ) 74 | ] 75 | val_dataloader = dict( 76 | batch_size=1, 77 | num_workers=2, 78 | persistent_workers=True, 79 | drop_last=False, 80 | sampler=dict(type=DefaultSampler, shuffle=False), 81 | dataset=dict( 82 | type=dataset_type, 83 | data_root=data_root, 84 | ann_file='annotations/lvis_v1_val.json', 85 | data_prefix=dict(img=''), 86 | test_mode=True, 87 | return_classes=True, 88 | pipeline=test_pipeline, 89 | backend_args=backend_args) 90 | ) 91 | test_dataloader = val_dataloader 92 | 93 | val_evaluator = dict( 94 | type=InsClsIoUMetric, 95 | prefix='lvis_ins', 96 | base_classes=LVIS_BASE_IDS, 97 | novel_classes=LVIS_RARE_IDS, 98 | ) 99 | test_evaluator = val_evaluator 100 | -------------------------------------------------------------------------------- /seg/configs/_base_/datasets/sam.py: -------------------------------------------------------------------------------- 1 | from mmcv import LoadImageFromFile 2 | from mmengine.dataset import DefaultSampler 3 | 4 | from seg.datasets.pipeliens.formatting import PackSAMInputs, GeneratePoint 5 | from seg.datasets.pipeliens.loading import LoadJSONFromFile, FilterAnnotationsHB 6 | from seg.datasets.pipeliens.transforms import ResizeSAM 7 | from seg.datasets.sam import SAMDataset 8 | from seg.datasets.pipeliens.loading import LoadAnnotationsSAM 9 | 10 | dataset_type = SAMDataset 11 | data_root = 'data/sam' 12 | 13 | backend_args = None 14 | image_size = (1024, 1024) 15 | 16 | # dataset settings 17 | train_pipeline = [ 18 | dict(type=LoadImageFromFile, backend_args=backend_args), 19 | dict(type=LoadJSONFromFile, backend_args=backend_args), 20 | dict(type=LoadAnnotationsSAM, with_bbox=True, with_mask=True, with_point_coords=True), 21 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 22 | dict( 23 | type=FilterAnnotationsHB, 24 | by_box=False, 25 | by_mask=True, 26 | min_gt_mask_area=256, 27 | ), 28 | dict( 29 | type=PackSAMInputs, 30 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 31 | ), 32 | dict(type=GeneratePoint) 33 | ] 34 | 35 | test_pipeline = [ 36 | dict(type=LoadImageFromFile, backend_args=backend_args), 37 | dict(type=LoadJSONFromFile, backend_args=backend_args), 38 | dict(type=LoadAnnotationsSAM, with_bbox=True, with_mask=True, with_point_coords=True), 39 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 40 | dict( 41 | type=PackSAMInputs, 42 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 43 | ) 44 | ] 45 | 46 | # dataloader 47 | train_dataloader = dict( 48 | batch_size=1, 49 | num_workers=2, 50 | persistent_workers=True, 51 | sampler=dict(type=DefaultSampler, shuffle=True), 52 | batch_sampler=None, 53 | dataset=dict( 54 | type=dataset_type, 55 | data_root=data_root, 56 | ann_file='train.txt', 57 | data_prefix=dict(img=''), 58 | filter_cfg=None, 59 | pipeline=train_pipeline, 60 | backend_args=backend_args 61 | ) 62 | ) 63 | val_dataloader = dict( 64 | batch_size=1, 65 | num_workers=2, 66 | persistent_workers=True, 67 | drop_last=True, 68 | sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), 69 | batch_sampler=None, 70 | dataset=dict( 71 | type=dataset_type, 72 | data_root=data_root, 73 | ann_file='val.txt', 74 | data_prefix=dict(img=''), 75 | test_mode=True, 76 | pipeline=test_pipeline, 77 | backend_args=backend_args 78 | ) 79 | ) 80 | test_dataloader = val_dataloader 81 | 82 | val_evaluator = [] 83 | test_evaluator = val_evaluator 84 | -------------------------------------------------------------------------------- /seg/configs/_base_/datasets/sam_img.py: -------------------------------------------------------------------------------- 1 | from mmcv import LoadImageFromFile 2 | from mmengine.dataset import DefaultSampler 3 | 4 | from seg.datasets.pipeliens.formatting import PackSAMInputs 5 | from seg.datasets.pipeliens.loading import LoadJSONFromFile, LoadFeatFromFile 6 | from seg.datasets.pipeliens.transforms import ResizeSAM 7 | from seg.datasets.sam import SAMDataset 8 | from seg.datasets.pipeliens.loading import LoadAnnotationsSAM 9 | 10 | dataset_type = SAMDataset 11 | data_root = 'data/sam' 12 | 13 | backend_args = None 14 | image_size = (1024, 1024) 15 | 16 | # dataset settings 17 | train_pipeline = [ 18 | dict(type=LoadImageFromFile, backend_args=backend_args), 19 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 20 | dict(type=LoadFeatFromFile), 21 | dict( 22 | type=PackSAMInputs, 23 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 24 | ), 25 | ] 26 | 27 | test_pipeline = [ 28 | dict(type=LoadImageFromFile, backend_args=backend_args), 29 | dict(type=LoadJSONFromFile, backend_args=backend_args), 30 | dict(type=LoadAnnotationsSAM, with_bbox=True, with_mask=True, with_point_coords=True), 31 | dict(type=ResizeSAM, scale=image_size, keep_ratio=True), 32 | dict( 33 | type=PackSAMInputs, 34 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') 35 | ) 36 | ] 37 | 38 | # dataloader 39 | train_dataloader = dict( 40 | batch_size=2, 41 | num_workers=2, 42 | persistent_workers=True, 43 | sampler=dict(type=DefaultSampler, shuffle=True), 44 | batch_sampler=None, 45 | dataset=dict( 46 | type=dataset_type, 47 | data_root=data_root, 48 | ann_file='train.txt', 49 | data_prefix=dict(img=''), 50 | filter_cfg=None, 51 | pipeline=train_pipeline, 52 | backend_args=backend_args 53 | ) 54 | ) 55 | val_dataloader = dict( 56 | batch_size=1, 57 | num_workers=2, 58 | persistent_workers=True, 59 | drop_last=True, 60 | sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), 61 | batch_sampler=None, 62 | dataset=dict( 63 | type=dataset_type, 64 | data_root=data_root, 65 | ann_file='val.txt', 66 | data_prefix=dict(img=''), 67 | test_mode=True, 68 | pipeline=test_pipeline, 69 | backend_args=backend_args 70 | ) 71 | ) 72 | test_dataloader = val_dataloader 73 | 74 | val_evaluator = [] 75 | test_evaluator = val_evaluator 76 | -------------------------------------------------------------------------------- /seg/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, 3 | LoggerHook, ParamSchedulerHook) 4 | from mmengine.runner import LogProcessor 5 | from mmengine.visualization import LocalVisBackend 6 | 7 | from mmdet.engine.hooks import DetVisualizationHook 8 | from mmdet.visualization import DetLocalVisualizer 9 | 10 | default_scope = None 11 | 12 | default_hooks = dict( 13 | timer=dict(type=IterTimerHook), 14 | logger=dict(type=LoggerHook, interval=50), 15 | param_scheduler=dict(type=ParamSchedulerHook), 16 | checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=1), 17 | sampler_seed=dict(type=DistSamplerSeedHook), 18 | visualization=dict(type=DetVisualizationHook)) 19 | 20 | env_cfg = dict( 21 | cudnn_benchmark=False, 22 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 23 | dist_cfg=dict(backend='nccl'), 24 | ) 25 | 26 | vis_backends = [dict(type=LocalVisBackend)] 27 | visualizer = dict( 28 | type=DetLocalVisualizer, vis_backends=vis_backends, name='visualizer') 29 | log_processor = dict(type=LogProcessor, window_size=50, by_epoch=True) 30 | 31 | log_level = 'INFO' 32 | load_from = None 33 | resume = False 34 | -------------------------------------------------------------------------------- /seg/configs/_base_/schedules/schedule_12e.py: -------------------------------------------------------------------------------- 1 | from mmengine.optim import LinearLR, MultiStepLR, OptimWrapper 2 | from mmengine.runner import EpochBasedTrainLoop, ValLoop, TestLoop 3 | from torch.optim import AdamW 4 | 5 | # training schedule for 50e 6 | train_cfg = dict( 7 | type=EpochBasedTrainLoop, 8 | max_epochs=12, 9 | val_interval=2, 10 | ) 11 | val_cfg = dict(type=ValLoop) 12 | test_cfg = dict(type=TestLoop) 13 | 14 | # learning rate 15 | param_scheduler = [ 16 | dict( 17 | type=LinearLR, 18 | start_factor=0.001, 19 | by_epoch=False, 20 | begin=0, 21 | end=500 22 | ), 23 | dict( 24 | type=MultiStepLR, 25 | begin=0, 26 | end=12, 27 | by_epoch=True, 28 | milestones=[8, 11], 29 | gamma=0.1 30 | ) 31 | ] 32 | 33 | _embed_multi = dict(lr_mult=1.0, decay_mult=0.0) 34 | optim_wrapper = dict( 35 | type=OptimWrapper, 36 | optimizer=dict( 37 | type=AdamW, 38 | lr=0.0001, 39 | weight_decay=0.05, 40 | eps=1e-8, 41 | betas=(0.9, 0.999) 42 | ), 43 | paramwise_cfg=dict( 44 | custom_keys={ 45 | 'backbone': dict(lr_mult=0.1, decay_mult=1.0), 46 | 'query_embed': _embed_multi, 47 | 'query_feat': _embed_multi, 48 | 'level_embed': _embed_multi, 49 | }, 50 | norm_decay_mult=0.0 51 | ), 52 | clip_grad=dict(max_norm=0.01, norm_type=2) 53 | ) 54 | 55 | # Default setting for scaling LR automatically 56 | # - `enable` means enable scaling LR automatically 57 | # or not by default. 58 | # - `base_batch_size` = (8 GPUs) x (2 samples per GPU). 59 | auto_scale_lr = dict(enable=True, base_batch_size=16) 60 | -------------------------------------------------------------------------------- /seg/configs/_base_/schedules/schedule_24e.py: -------------------------------------------------------------------------------- 1 | from mmengine.optim import LinearLR, MultiStepLR, OptimWrapper 2 | from mmengine.runner import EpochBasedTrainLoop, ValLoop, TestLoop 3 | from torch.optim import AdamW 4 | 5 | # training schedule for 50e 6 | train_cfg = dict( 7 | type=EpochBasedTrainLoop, 8 | max_epochs=24, 9 | val_interval=2, 10 | ) 11 | val_cfg = dict(type=ValLoop) 12 | test_cfg = dict(type=TestLoop) 13 | 14 | # learning rate 15 | param_scheduler = [ 16 | dict( 17 | type=LinearLR, 18 | start_factor=0.001, 19 | by_epoch=False, 20 | begin=0, 21 | end=500 22 | ), 23 | dict( 24 | type=MultiStepLR, 25 | begin=0, 26 | end=24, 27 | by_epoch=True, 28 | milestones=[16, 22], 29 | gamma=0.1 30 | ) 31 | ] 32 | 33 | _embed_multi = dict(lr_mult=1.0, decay_mult=0.0) 34 | optim_wrapper = dict( 35 | type=OptimWrapper, 36 | optimizer=dict( 37 | type=AdamW, 38 | lr=0.0001, 39 | weight_decay=0.05, 40 | eps=1e-8, 41 | betas=(0.9, 0.999) 42 | ), 43 | paramwise_cfg=dict( 44 | custom_keys={ 45 | 'backbone': dict(lr_mult=0.1, decay_mult=1.0), 46 | 'query_embed': _embed_multi, 47 | 'query_feat': _embed_multi, 48 | 'level_embed': _embed_multi, 49 | }, 50 | norm_decay_mult=0.0 51 | ), 52 | clip_grad=dict(max_norm=0.01, norm_type=2) 53 | ) 54 | 55 | # Default setting for scaling LR automatically 56 | # - `enable` means enable scaling LR automatically 57 | # or not by default. 58 | # - `base_batch_size` = (8 GPUs) x (2 samples per GPU). 59 | auto_scale_lr = dict(enable=True, base_batch_size=16) 60 | -------------------------------------------------------------------------------- /seg/configs/_base_/schedules/schedule_distillation.py: -------------------------------------------------------------------------------- 1 | from mmengine.optim import LinearLR, OptimWrapper, CosineAnnealingLR 2 | from mmengine.runner import EpochBasedTrainLoop, ValLoop, TestLoop 3 | from torch.optim import AdamW 4 | 5 | # training schedule for 50e 6 | train_cfg = dict( 7 | type=EpochBasedTrainLoop, 8 | max_epochs=24, 9 | val_interval=2, 10 | ) 11 | val_cfg = dict(type=ValLoop) 12 | test_cfg = dict(type=TestLoop) 13 | 14 | # learning rate 15 | param_scheduler = [ 16 | dict( 17 | type=LinearLR, 18 | start_factor=0.001, 19 | by_epoch=False, 20 | begin=0, 21 | end=500 22 | ), 23 | dict( 24 | type=CosineAnnealingLR, 25 | convert_to_iter_based=True, 26 | begin=0, 27 | end=24, 28 | by_epoch=True, 29 | eta_min_ratio=0.01, 30 | ) 31 | ] 32 | 33 | _embed_multi = dict(lr_mult=1.0, decay_mult=0.0) 34 | optim_wrapper = dict( 35 | type=OptimWrapper, 36 | optimizer=dict( 37 | type=AdamW, 38 | lr=0.0001, 39 | weight_decay=0.05, 40 | eps=1e-8, 41 | betas=(0.9, 0.999) 42 | ), 43 | paramwise_cfg=dict( 44 | norm_decay_mult=0.0 45 | ), 46 | clip_grad=dict(max_norm=5., norm_type=2) 47 | ) 48 | 49 | # Default setting for scaling LR automatically 50 | # - `enable` means enable scaling LR automatically 51 | # or not by default. 52 | # - `base_batch_size` = (8 GPUs) x (2 samples per GPU). 53 | auto_scale_lr = dict(enable=True, base_batch_size=16) 54 | -------------------------------------------------------------------------------- /seg/configs/clip2sam/clip2sam_coco_rn50x16.py: -------------------------------------------------------------------------------- 1 | from mmcv.ops import RoIAlign 2 | from mmdet.models import CrossEntropyLoss, DiceLoss, FPN, SingleRoIExtractor 3 | from mmengine.config import read_base 4 | 5 | from seg.models.detectors import CLIP2SAM 6 | from seg.models.backbones import OpenCLIPBackbone 7 | from seg.models.necks import MultiLayerTransformerNeck, SAMPromptEncoder 8 | from seg.models.heads import OVSAMHead 9 | from seg.models.data_preprocessor import OVSAMDataPreprocessor 10 | from seg.models.utils import NO_OBJ 11 | 12 | with read_base(): 13 | from .._base_.default_runtime import * 14 | from .._base_.datasets.coco_ov_instance_lsj import * 15 | from .._base_.schedules.schedule_12e import * 16 | 17 | image_size = (1024, 1024) 18 | data_preprocessor = dict( 19 | type=OVSAMDataPreprocessor, 20 | mean=[123.675, 116.28, 103.53], 21 | std=[58.395, 57.12, 57.375], 22 | bgr_to_rgb=True, 23 | pad_size_divisor=1024, 24 | pad_mask=True, 25 | mask_pad_value=0, 26 | pad_seg=False, 27 | seg_pad_value=NO_OBJ, 28 | batch_augments=None, 29 | use_point_det=True, 30 | num_proposals=40, 31 | ) 32 | 33 | model = dict( 34 | type=CLIP2SAM, 35 | data_preprocessor=data_preprocessor, 36 | with_box=True, 37 | with_points=True, 38 | backbone=dict( 39 | type=OpenCLIPBackbone, 40 | model_name='RN50x16', 41 | fix=True, 42 | init_cfg=dict( 43 | type='clip_pretrain', 44 | checkpoint='openai' 45 | ) 46 | ), 47 | neck=dict( 48 | type=MultiLayerTransformerNeck, 49 | input_size=(1024, 1024), 50 | in_channels=[384, 768, 1536, 3072], 51 | strides=[4, 8, 16, 32], 52 | layer_ids=(0, 1, 2, 3), 53 | embed_channels=1280, 54 | out_channels=256, 55 | fix=True, 56 | init_cfg=dict( 57 | type='Pretrained', 58 | checkpoint='./models/sam2clip_vith_rn50x16.pth', 59 | prefix='neck_student', 60 | ) 61 | ), 62 | fpn_neck=dict( 63 | type=FPN, 64 | in_channels=[384, 768, 1536, 3072], 65 | out_channels=256, 66 | num_outs=4, 67 | ), 68 | prompt_encoder=dict( 69 | type=SAMPromptEncoder, 70 | model_name='vit_h', 71 | fix=True, 72 | init_cfg=dict( 73 | type='sam_pretrain', 74 | checkpoint='vit_h' 75 | ) 76 | ), 77 | mask_decoder=dict( 78 | type=OVSAMHead, 79 | model_name='vit_h', 80 | with_label_token=True, 81 | ov_classifier_name='RN50x16_CocoOVDataset', 82 | roi_extractor=dict( 83 | type=SingleRoIExtractor, 84 | roi_layer=dict(type=RoIAlign, output_size=12, sampling_ratio=0), 85 | out_channels=256, 86 | featmap_strides=[4, 8, 16, 32] 87 | ), 88 | fix=False, 89 | init_cfg=dict( 90 | type='sam_pretrain', 91 | checkpoint='vit_h' 92 | ), 93 | loss_cls=dict( 94 | type=CrossEntropyLoss, 95 | use_sigmoid=False, 96 | loss_weight=2.0, 97 | reduction='mean' 98 | ), 99 | loss_mask=dict( 100 | type=CrossEntropyLoss, 101 | use_sigmoid=True, 102 | reduction='mean', 103 | loss_weight=5.0 104 | ), 105 | loss_dice=dict( 106 | type=DiceLoss, 107 | use_sigmoid=True, 108 | activate=True, 109 | reduction='mean', 110 | naive_dice=True, 111 | eps=1.0, 112 | loss_weight=5.0 113 | ) 114 | ) 115 | ) 116 | 117 | val_dataloader = None 118 | val_evaluator = None 119 | val_cfg = None 120 | test_dataloader = None 121 | test_evaluator = None 122 | test_cfg = None 123 | -------------------------------------------------------------------------------- /seg/configs/clip2sam/clip2sam_lvis_rn50x16.py: -------------------------------------------------------------------------------- 1 | from mmcv.ops import RoIAlign 2 | from mmdet.models import CrossEntropyLoss, DiceLoss, FPN, SingleRoIExtractor 3 | from mmengine.config import read_base 4 | 5 | from seg.models.detectors import CLIP2SAM 6 | from seg.models.backbones import OpenCLIPBackbone 7 | from seg.models.necks import MultiLayerTransformerNeck, SAMPromptEncoder 8 | from seg.models.heads import OVSAMHead 9 | from seg.models.data_preprocessor import OVSAMDataPreprocessor 10 | from seg.models.utils import NO_OBJ 11 | 12 | with read_base(): 13 | from .._base_.default_runtime import * 14 | from .._base_.datasets.lvis_norare import * 15 | from .._base_.schedules.schedule_12e import * 16 | 17 | image_size = (1024, 1024) 18 | data_preprocessor = dict( 19 | type=OVSAMDataPreprocessor, 20 | mean=[123.675, 116.28, 103.53], 21 | std=[58.395, 57.12, 57.375], 22 | bgr_to_rgb=True, 23 | pad_size_divisor=1024, 24 | pad_mask=True, 25 | mask_pad_value=0, 26 | pad_seg=False, 27 | seg_pad_value=NO_OBJ, 28 | batch_augments=None, 29 | use_point_det=True, 30 | num_proposals=40, 31 | ) 32 | 33 | model = dict( 34 | type=CLIP2SAM, 35 | data_preprocessor=data_preprocessor, 36 | with_box=True, 37 | with_points=True, 38 | backbone=dict( 39 | type=OpenCLIPBackbone, 40 | model_name='RN50x16', 41 | fix=True, 42 | init_cfg=dict( 43 | type='clip_pretrain', 44 | checkpoint='openai' 45 | ) 46 | ), 47 | neck=dict( 48 | type=MultiLayerTransformerNeck, 49 | input_size=(1024, 1024), 50 | in_channels=[384, 768, 1536, 3072], 51 | strides=[4, 8, 16, 32], 52 | layer_ids=(0, 1, 2, 3), 53 | embed_channels=1280, 54 | out_channels=256, 55 | fix=True, 56 | init_cfg=dict( 57 | type='Pretrained', 58 | checkpoint='./models/sam2clip_vith_rn50x16.pth', 59 | prefix='neck_student', 60 | ) 61 | ), 62 | fpn_neck=dict( 63 | type=FPN, 64 | in_channels=[384, 768, 1536, 3072], 65 | out_channels=256, 66 | num_outs=4, 67 | ), 68 | prompt_encoder=dict( 69 | type=SAMPromptEncoder, 70 | model_name='vit_h', 71 | fix=True, 72 | init_cfg=dict( 73 | type='sam_pretrain', 74 | checkpoint='vit_h' 75 | ) 76 | ), 77 | mask_decoder=dict( 78 | type=OVSAMHead, 79 | model_name='vit_h', 80 | with_label_token=True, 81 | ov_classifier_name='RN50x16_LVISV1Dataset', 82 | roi_extractor=dict( 83 | type=SingleRoIExtractor, 84 | roi_layer=dict(type=RoIAlign, output_size=12, sampling_ratio=0), 85 | out_channels=256, 86 | featmap_strides=[4, 8, 16, 32] 87 | ), 88 | fix=False, 89 | init_cfg=dict( 90 | type='sam_pretrain', 91 | checkpoint='vit_h' 92 | ), 93 | loss_cls=dict( 94 | type=CrossEntropyLoss, 95 | use_sigmoid=False, 96 | loss_weight=2.0, 97 | reduction='mean' 98 | ), 99 | loss_mask=dict( 100 | type=CrossEntropyLoss, 101 | use_sigmoid=True, 102 | reduction='mean', 103 | loss_weight=5.0 104 | ), 105 | loss_dice=dict( 106 | type=DiceLoss, 107 | use_sigmoid=True, 108 | activate=True, 109 | reduction='mean', 110 | naive_dice=True, 111 | eps=1.0, 112 | loss_weight=5.0 113 | ) 114 | ) 115 | ) 116 | 117 | val_dataloader = None 118 | val_evaluator = None 119 | val_cfg = None 120 | test_dataloader = None 121 | test_evaluator = None 122 | test_cfg = None 123 | -------------------------------------------------------------------------------- /seg/configs/ovsam/ovsam_coco_rn50x16_point.py: -------------------------------------------------------------------------------- 1 | from mmcv.ops import RoIAlign 2 | from mmdet.models import FPN, SingleRoIExtractor 3 | from mmengine.config import read_base 4 | 5 | from seg.models.data_preprocessor import OVSAMDataPreprocessor 6 | from seg.models.backbones import OpenCLIPBackbone 7 | from seg.models.detectors import OVSAM 8 | from seg.models.heads import OVSAMHead 9 | from seg.models.necks import SAMPromptEncoder, MultiLayerTransformerNeck 10 | 11 | with read_base(): 12 | from .._base_.default_runtime import * 13 | from .._base_.datasets.coco_ov_instance_lsj import * 14 | from .._base_.schedules.schedule_12e import * 15 | 16 | image_size = (1024, 1024) 17 | _data_preprocessor = dict( 18 | type=OVSAMDataPreprocessor, 19 | mean=[123.675, 116.28, 103.53], 20 | std=[58.395, 57.12, 57.375], 21 | bgr_to_rgb=True, 22 | pad_size_divisor=image_size[0], 23 | pad_mask=False, 24 | mask_pad_value=0, 25 | pad_seg=False, 26 | seg_pad_value=255, 27 | batch_augments=None, 28 | use_center_point=True 29 | ) 30 | model = dict( 31 | type=OVSAM, 32 | data_preprocessor=_data_preprocessor, 33 | use_gt_prompt=True, 34 | use_clip_feat=True, 35 | use_head_feat=True, 36 | use_point=True, 37 | num_classes=80, 38 | base_classes=COCO4817_BASE_IDS, 39 | novel_classes=COCO4817_NOVEL_IDS, 40 | backbone=dict( 41 | type=OpenCLIPBackbone, 42 | model_name='RN50x16', 43 | fix=True, 44 | init_cfg=dict( 45 | type='clip_pretrain', 46 | checkpoint='openai' 47 | ) 48 | ), 49 | neck=dict( 50 | type=MultiLayerTransformerNeck, 51 | input_size=(1024, 1024), 52 | in_channels=[384, 768, 1536, 3072], 53 | strides=[4, 8, 16, 32], 54 | layer_ids=(0, 1, 2, 3), 55 | embed_channels=1280, 56 | out_channels=256, 57 | fix=True, 58 | init_cfg=dict( 59 | type='Pretrained', 60 | checkpoint='./models/sam2clip_vith_rn50x16.pth', 61 | prefix='neck_student', 62 | ) 63 | ), 64 | fpn_neck=dict( 65 | type=FPN, 66 | in_channels=[384, 768, 1536, 3072], 67 | out_channels=256, 68 | num_outs=4, 69 | init_cfg=dict( 70 | type='Pretrained', 71 | checkpoint='./models/clip2sam_coco_rn50x16.pth', 72 | prefix='fpn_neck', 73 | ), 74 | ), 75 | prompt_encoder=dict( 76 | type=SAMPromptEncoder, 77 | model_name='vit_h', 78 | fix=True, 79 | init_cfg=dict( 80 | type='sam_pretrain', 81 | checkpoint='vit_h' 82 | ) 83 | ), 84 | mask_decoder=dict( 85 | type=OVSAMHead, 86 | gen_box=True, 87 | model_name='vit_h', 88 | with_label_token=True, 89 | fix=False, 90 | ov_classifier_name='RN50x16_CocoOVDataset', 91 | roi_extractor=dict( 92 | type=SingleRoIExtractor, 93 | roi_layer=dict(type=RoIAlign, output_size=12, sampling_ratio=0), 94 | out_channels=256, 95 | featmap_strides=[4, 8, 16, 32] 96 | ), 97 | init_cfg=dict( 98 | type='Pretrained', 99 | checkpoint='./models/clip2sam_coco_rn50x16.pth', 100 | prefix='mask_decoder', 101 | ) 102 | ) 103 | ) 104 | -------------------------------------------------------------------------------- /seg/configs/ovsam/ovsam_lvis_rn50x16_point.py: -------------------------------------------------------------------------------- 1 | from mmcv.ops import RoIAlign 2 | from mmdet.models import FPN, SingleRoIExtractor 3 | from mmengine.config import read_base 4 | 5 | from seg.models.data_preprocessor import OVSAMDataPreprocessor 6 | from seg.models.backbones import OpenCLIPBackbone 7 | from seg.models.detectors import OVSAM 8 | from seg.models.heads import OVSAMHead 9 | from seg.models.necks import SAMPromptEncoder, MultiLayerTransformerNeck 10 | 11 | 12 | with read_base(): 13 | from .._base_.default_runtime import * 14 | from .._base_.datasets.lvis_norare import * 15 | from .._base_.schedules.schedule_12e import * 16 | 17 | image_size = (1024, 1024) 18 | _data_preprocessor = dict( 19 | type=OVSAMDataPreprocessor, 20 | mean=[123.675, 116.28, 103.53], 21 | std=[58.395, 57.12, 57.375], 22 | bgr_to_rgb=True, 23 | pad_size_divisor=image_size[0], 24 | pad_mask=False, 25 | mask_pad_value=0, 26 | pad_seg=False, 27 | seg_pad_value=255, 28 | batch_augments=None, 29 | use_center_point=True 30 | ) 31 | model = dict( 32 | type=OVSAM, 33 | data_preprocessor=_data_preprocessor, 34 | use_gt_prompt=True, 35 | use_clip_feat=True, 36 | use_head_feat=True, 37 | use_point=True, 38 | num_classes=1203, 39 | base_classes=LVIS_BASE_IDS, 40 | novel_classes=LVIS_RARE_IDS, 41 | backbone=dict( 42 | type=OpenCLIPBackbone, 43 | model_name='RN50x16', 44 | fix=True, 45 | init_cfg=dict( 46 | type='clip_pretrain', 47 | checkpoint='openai' 48 | ) 49 | ), 50 | neck=dict( 51 | type=MultiLayerTransformerNeck, 52 | input_size=(1024, 1024), 53 | in_channels=[384, 768, 1536, 3072], 54 | strides=[4, 8, 16, 32], 55 | layer_ids=(0, 1, 2, 3), 56 | embed_channels=1280, 57 | out_channels=256, 58 | fix=True, 59 | init_cfg=dict( 60 | type='Pretrained', 61 | checkpoint='./models/sam2clip_vith_rn50x16.pth', 62 | prefix='neck_student', 63 | ) 64 | ), 65 | fpn_neck=dict( 66 | type=FPN, 67 | in_channels=[384, 768, 1536, 3072], 68 | out_channels=256, 69 | num_outs=4, 70 | init_cfg=dict( 71 | type='Pretrained', 72 | checkpoint='./models/clip2sam_lvis_rn50x16.pth', 73 | prefix='fpn_neck', 74 | ), 75 | ), 76 | prompt_encoder=dict( 77 | type=SAMPromptEncoder, 78 | model_name='vit_h', 79 | fix=True, 80 | init_cfg=dict( 81 | type='sam_pretrain', 82 | checkpoint='vit_h' 83 | ) 84 | ), 85 | mask_decoder=dict( 86 | type=OVSAMHead, 87 | gen_box=True, 88 | model_name='vit_h', 89 | with_label_token=True, 90 | fix=False, 91 | ov_classifier_name='RN50x16_LVISV1Dataset', 92 | roi_extractor=dict( 93 | type=SingleRoIExtractor, 94 | roi_layer=dict(type=RoIAlign, output_size=7, sampling_ratio=0), 95 | out_channels=256, 96 | featmap_strides=[4, 8, 16, 32] 97 | ), 98 | init_cfg=dict( 99 | type='Pretrained', 100 | checkpoint='./models/clip2sam_lvis_rn50x16.pth', 101 | prefix='mask_decoder', 102 | ) 103 | ) 104 | ) 105 | -------------------------------------------------------------------------------- /seg/configs/sam2clip/sam2clip_vith_rn50x16.py: -------------------------------------------------------------------------------- 1 | from mmdet.models import BatchFixedSizePad, DetDataPreprocessor, MSELoss 2 | from mmengine.config import read_base 3 | 4 | from seg.models.detectors import BackboneDistillation 5 | from seg.models.backbones import OpenCLIPBackbone, SAMBackbone 6 | from seg.models.necks import LastLayerNeck 7 | from seg.models.necks.transformer_neck import MultiLayerTransformerNeck 8 | from seg.models.utils import NO_OBJ 9 | 10 | with read_base(): 11 | from .._base_.default_runtime import * 12 | from .._base_.datasets.sam_img import * 13 | from .._base_.schedules.schedule_distillation import * 14 | 15 | image_size = (1024, 1024) 16 | batch_augments = [ 17 | dict( 18 | type=BatchFixedSizePad, 19 | size=image_size, 20 | img_pad_value=0, 21 | pad_mask=False, 22 | mask_pad_value=0, 23 | pad_seg=False, 24 | seg_pad_value=255 25 | ) 26 | ] 27 | data_preprocessor = dict( 28 | type=DetDataPreprocessor, 29 | mean=[123.675, 116.28, 103.53], 30 | std=[58.395, 57.12, 57.375], 31 | bgr_to_rgb=True, 32 | pad_size_divisor=1, 33 | pad_mask=False, 34 | mask_pad_value=0, 35 | pad_seg=False, 36 | seg_pad_value=NO_OBJ, 37 | batch_augments=batch_augments 38 | ) 39 | 40 | model = dict( 41 | type=BackboneDistillation, 42 | use_cache=True, 43 | data_preprocessor=data_preprocessor, 44 | backbone_teacher=dict( 45 | type=SAMBackbone, 46 | model_name='vit_h', 47 | fix=True, 48 | init_cfg=dict( 49 | type='sam_pretrain', 50 | checkpoint='vit_h' 51 | ) 52 | ), 53 | backbone_student=dict( 54 | type=OpenCLIPBackbone, 55 | model_name='RN50x16', 56 | fix=True, 57 | init_cfg=dict( 58 | type='clip_pretrain', 59 | checkpoint='openai' 60 | ) 61 | ), 62 | neck_teacher=dict(type=LastLayerNeck), 63 | neck_student=dict( 64 | type=MultiLayerTransformerNeck, 65 | input_size=(1024, 1024), 66 | in_channels=[384, 768, 1536, 3072], 67 | strides=[4, 8, 16, 32], 68 | layer_ids=(0, 1, 2, 3), 69 | embed_channels=1280, 70 | out_channels=256, 71 | embedding_path='sam_vit_h' 72 | ), 73 | loss_distill=dict( 74 | type=MSELoss, 75 | reduction='mean', 76 | loss_weight=1. 77 | ) 78 | ) 79 | 80 | val_dataloader = None 81 | val_evaluator = None 82 | val_cfg = None 83 | -------------------------------------------------------------------------------- /seg/configs/sam2clip/sam_vith_dump.py: -------------------------------------------------------------------------------- 1 | from mmdet.models import BatchFixedSizePad, DetDataPreprocessor 2 | from mmengine.config import read_base 3 | 4 | from seg.models.backbones import SAMBackbone 5 | from seg.models.detectors import BackboneDump 6 | from seg.models.necks import LastLayerNeck 7 | from seg.models.utils import NO_OBJ 8 | 9 | with read_base(): 10 | from .._base_.default_runtime import * 11 | from .._base_.datasets.sam import * 12 | from .._base_.schedules.schedule_distillation import * 13 | 14 | image_size = (1024, 1024) 15 | batch_augments = [ 16 | dict( 17 | type=BatchFixedSizePad, 18 | size=image_size, 19 | img_pad_value=0, 20 | pad_mask=False, 21 | mask_pad_value=0, 22 | pad_seg=False, 23 | seg_pad_value=NO_OBJ 24 | ) 25 | ] 26 | data_preprocessor = dict( 27 | type=DetDataPreprocessor, 28 | mean=[123.675, 116.28, 103.53], 29 | std=[58.395, 57.12, 57.375], 30 | bgr_to_rgb=True, 31 | pad_size_divisor=1024, 32 | pad_mask=False, 33 | mask_pad_value=0, 34 | pad_seg=False, 35 | seg_pad_value=NO_OBJ, 36 | batch_augments=batch_augments 37 | ) 38 | 39 | model = dict( 40 | type=BackboneDump, 41 | data_preprocessor=data_preprocessor, 42 | backbone=dict( 43 | type=SAMBackbone, 44 | model_name='vit_h', 45 | fix=True, 46 | init_cfg=dict( 47 | type='sam_pretrain', 48 | checkpoint='vit_h' 49 | ) 50 | ), 51 | neck=dict( 52 | type=LastLayerNeck 53 | ) 54 | ) 55 | 56 | val_dataloader = None 57 | val_evaluator = None 58 | val_cfg = None 59 | -------------------------------------------------------------------------------- /seg/datasets/pipeliens/frame_copy.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | from mmcv import BaseTransform 5 | from mmdet.registry import TRANSFORMS 6 | 7 | from seg.models.utils import NO_OBJ 8 | 9 | 10 | @TRANSFORMS.register_module() 11 | class ImageCopy(BaseTransform): 12 | """Copy an image several times to build a video seq. 13 | """ 14 | DIVISOR = 10000 15 | 16 | def __init__( 17 | self, 18 | num_frames: int = 1, 19 | ) -> None: 20 | assert num_frames > 1 21 | self.num_frames = num_frames 22 | 23 | def transform(self, results: dict) -> dict: 24 | for key in results: 25 | value = results[key] 26 | results[key] = [] 27 | for _ in range(self.num_frames): 28 | results[key].append(copy.deepcopy(value)) 29 | 30 | num_instances = len(results['gt_bboxes_labels'][0]) 31 | num_frames = len(results['gt_bboxes_labels']) 32 | gt_instance_ids = results['gt_bboxes_labels'][0] * self.DIVISOR + np.arange(num_instances) + 1 33 | results['gt_instances_ids'] = [copy.deepcopy(gt_instance_ids) for _ in range(num_frames)] 34 | return results 35 | 36 | def __repr__(self) -> str: 37 | repr_str = self.__class__.__name__ 38 | repr_str += f'(num_frames={self.num_frames})' 39 | return repr_str 40 | 41 | 42 | @TRANSFORMS.register_module() 43 | class AddSemSeg(BaseTransform): 44 | """Add dummy semantic segmentation map. 45 | """ 46 | 47 | def __init__(self, ) -> None: 48 | pass 49 | 50 | def transform(self, results: dict) -> dict: 51 | gt_seg = np.zeros(results['img'].shape[:2], dtype=np.int32) + NO_OBJ 52 | results['gt_seg_map'] = gt_seg 53 | return results 54 | 55 | def __repr__(self) -> str: 56 | repr_str = self.__class__.__name__ 57 | return repr_str 58 | -------------------------------------------------------------------------------- /seg/datasets/pipeliens/frame_sampling.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Dict, List, Optional 3 | 4 | import numpy as np 5 | from mmdet.registry import TRANSFORMS 6 | from mmdet.datasets.transforms import BaseFrameSample 7 | 8 | 9 | @TRANSFORMS.register_module() 10 | class VideoClipSample(BaseFrameSample): 11 | def __init__(self, 12 | num_selected: int = 1, 13 | interval: int = 1, 14 | collect_video_keys: List[str] = ['video_id', 'video_length']): 15 | self.num_selected = num_selected 16 | self.interval = interval 17 | super().__init__(collect_video_keys=collect_video_keys) 18 | 19 | def transform(self, video_infos: dict) -> Optional[Dict[str, List]]: 20 | """Transform the video information. 21 | 22 | Args: 23 | video_infos (dict): The whole video information. 24 | 25 | Returns: 26 | dict: The data information of the sampled frames. 27 | """ 28 | len_with_interval = self.num_selected + (self.num_selected - 1) * (self.interval - 1) 29 | len_video = video_infos['video_length'] 30 | if len_with_interval > len_video: 31 | return None 32 | 33 | first_frame_id = random.sample(range(len_video - len_with_interval + 1), 1)[0] 34 | 35 | sampled_frames_ids = first_frame_id + np.arange(self.num_selected) * self.interval 36 | results = self.prepare_data(video_infos, sampled_frames_ids) 37 | 38 | return results 39 | 40 | def __repr__(self) -> str: 41 | repr_str = self.__class__.__name__ 42 | repr_str += f'num_selected=({self.num_selected}' 43 | repr_str += f'interval={self.interval}' 44 | repr_str += f'collect_video_keys={self.collect_video_keys})' 45 | return repr_str 46 | -------------------------------------------------------------------------------- /seg/datasets/sam.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from mmengine import get_local_path, list_from_file, join_path, scandir, print_log 4 | 5 | from mmdet.datasets import BaseDetDataset 6 | 7 | 8 | class SAMDataset(BaseDetDataset): 9 | 10 | def __init__(self, *args, img_map_suffix='.jpg', **kwargs): 11 | self.img_map_suffix = img_map_suffix 12 | self.id2folder = dict() 13 | super().__init__(*args, **kwargs) 14 | 15 | def load_data_list(self) -> List[dict]: 16 | print_log('Starting to load sam dataset', 'current') 17 | with get_local_path( 18 | self.ann_file, backend_args=self.backend_args) as local_path: 19 | folders = list_from_file(local_path) 20 | 21 | img_ids_list = [] 22 | for folder in folders: 23 | folder_path = join_path(self.data_prefix['img'], folder) 24 | img_ids = sorted(list(map( 25 | lambda x: int(x.split('.')[0].split('_')[1]), 26 | scandir(folder_path, recursive=False, suffix='.jpg') 27 | ))) 28 | img_ids_list.extend(img_ids) 29 | for img_id in img_ids: 30 | self.id2folder[img_id] = folder 31 | 32 | img_ids = img_ids_list 33 | data_list = [] 34 | for img_id in img_ids: 35 | data_info = { 36 | 'img_id': img_id, 37 | 'img_path': join_path(self.data_prefix['img'], self.id2folder[img_id], f"sa_{img_id}.jpg"), 38 | 'info_path': join_path(self.data_prefix['img'], self.id2folder[img_id], f"sa_{img_id}.json"), 39 | } 40 | data_list.append(data_info) 41 | print_log(f'Found {len(data_list)} in {len(folders)} folders.', 'current') 42 | return data_list 43 | -------------------------------------------------------------------------------- /seg/datasets/samplers/multi_dataset_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | from typing import Iterator, Optional, Sequence, Sized 4 | 5 | import torch 6 | from mmengine.dist import get_dist_info, sync_random_seed 7 | from mmengine.registry import DATA_SAMPLERS 8 | from torch.utils.data import Sampler 9 | 10 | 11 | @DATA_SAMPLERS.register_module() 12 | class MultiDataSampler(Sampler): 13 | """The default data sampler for both distributed and non-distributed 14 | environment. 15 | 16 | It has several differences from the PyTorch ``DistributedSampler`` as 17 | below: 18 | 19 | 1. This sampler supports non-distributed environment. 20 | 21 | 2. The round up behaviors are a little different. 22 | 23 | - If ``round_up=True``, this sampler will add extra samples to make the 24 | number of samples is evenly divisible by the world size. And 25 | this behavior is the same as the ``DistributedSampler`` with 26 | ``drop_last=False``. 27 | - If ``round_up=False``, this sampler won't remove or add any samples 28 | while the ``DistributedSampler`` with ``drop_last=True`` will remove 29 | tail samples. 30 | 31 | Args: 32 | dataset (Sized): The dataset. 33 | dataset_ratio (Sequence(int)) The ratios of different datasets. 34 | seed (int, optional): Random seed used to shuffle the sampler if 35 | :attr:`shuffle=True`. This number should be identical across all 36 | processes in the distributed group. Defaults to None. 37 | round_up (bool): Whether to add extra samples to make the number of 38 | samples evenly divisible by the world size. Defaults to True. 39 | """ 40 | 41 | def __init__(self, 42 | dataset: Sized, 43 | dataset_ratio: Sequence[int], 44 | seed: Optional[int] = None, 45 | round_up: bool = True) -> None: 46 | rank, world_size = get_dist_info() 47 | self.rank = rank 48 | self.world_size = world_size 49 | 50 | self.dataset = dataset 51 | self.dataset_ratio = dataset_ratio 52 | 53 | if seed is None: 54 | seed = sync_random_seed() 55 | self.seed = seed 56 | self.epoch = 0 57 | self.round_up = round_up 58 | 59 | if self.round_up: 60 | self.num_samples = math.ceil(len(self.dataset) / world_size) 61 | self.total_size = self.num_samples * self.world_size 62 | else: 63 | self.num_samples = math.ceil( 64 | (len(self.dataset) - rank) / world_size) 65 | self.total_size = len(self.dataset) 66 | 67 | self.sizes = [len(dataset) for dataset in self.dataset.datasets] 68 | 69 | dataset_weight = [ 70 | torch.ones(s) * max(self.sizes) / s * r / sum(self.dataset_ratio) 71 | for i, (r, s) in enumerate(zip(self.dataset_ratio, self.sizes)) 72 | ] 73 | self.weights = torch.cat(dataset_weight) 74 | 75 | def __iter__(self) -> Iterator[int]: 76 | """Iterate the indices.""" 77 | # deterministically shuffle based on epoch and seed 78 | g = torch.Generator() 79 | g.manual_seed(self.seed + self.epoch) 80 | 81 | indices = torch.multinomial( 82 | self.weights, len(self.weights), generator=g, 83 | replacement=True).tolist() 84 | 85 | # add extra samples to make it evenly divisible 86 | if self.round_up: 87 | indices = ( 88 | indices * 89 | int(self.total_size / len(indices) + 1))[:self.total_size] 90 | 91 | # subsample 92 | indices = indices[self.rank:self.total_size:self.world_size] 93 | 94 | return iter(indices) 95 | 96 | def __len__(self) -> int: 97 | """The number of samples in this rank.""" 98 | return self.num_samples 99 | 100 | def set_epoch(self, epoch: int) -> None: 101 | """Sets the epoch for this sampler. 102 | 103 | When :attr:`shuffle=True`, this ensures all replicas use a different 104 | random ordering for each epoch. Otherwise, the next iteration of this 105 | sampler will yield the same ordering. 106 | 107 | Args: 108 | epoch (int): Epoch number. 109 | """ 110 | self.epoch = epoch 111 | -------------------------------------------------------------------------------- /seg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .openclip_backbone import OpenCLIPBackbone 2 | from .openclip_backbone import OpenCLIPBackboneText 3 | from .sam_backbone import SAMBackbone 4 | -------------------------------------------------------------------------------- /seg/models/data_preprocessor/__init__.py: -------------------------------------------------------------------------------- 1 | from .ovsam_preprocessor import OVSAMDataPreprocessor 2 | -------------------------------------------------------------------------------- /seg/models/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | from .sam2clip_distill import BackboneDistillation 2 | from .clip2sam import CLIP2SAM 3 | from .sam_dump import BackboneDump 4 | from .ovsam import OVSAM 5 | -------------------------------------------------------------------------------- /seg/models/detectors/sam_dump.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, Dict, List 2 | 3 | import mmengine 4 | import torch 5 | from mmdet.models.detectors.base import ForwardResults 6 | from mmengine import print_log 7 | from mmengine.model import BaseModel 8 | from torch import Tensor 9 | 10 | from mmdet.registry import MODELS 11 | from mmdet.structures import SampleList, OptSampleList 12 | from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig 13 | 14 | 15 | @MODELS.register_module() 16 | class BackboneDump(BaseModel): 17 | 18 | def __init__( 19 | self, 20 | backbone: ConfigType, 21 | neck: ConfigType, 22 | data_preprocessor: OptConfigType = None, 23 | init_cfg: OptMultiConfig = None, 24 | ) -> None: 25 | super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg) 26 | self.backbone = MODELS.build(backbone) 27 | self.neck = MODELS.build(neck) 28 | 29 | self.register_parameter('dummy', torch.nn.Parameter(torch.zeros(1))) 30 | 31 | def forward(self, 32 | inputs: torch.Tensor, 33 | data_samples: OptSampleList = None, 34 | mode: str = 'tensor') -> ForwardResults: 35 | if mode == 'loss': 36 | return self.loss(inputs, data_samples) 37 | elif mode == 'predict': 38 | return self.predict(inputs, data_samples) 39 | elif mode == 'tensor': 40 | return self._forward(inputs, data_samples) 41 | else: 42 | raise RuntimeError(f'Invalid mode "{mode}". ' 43 | 'Only supports loss, predict and tensor mode') 44 | 45 | def _forward(self, *args, **kwargs) -> Tuple[Tensor]: 46 | raise NotImplementedError 47 | 48 | def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor, Tensor]: 49 | return self.neck(self.backbone(batch_inputs)) 50 | 51 | def predict(self, batch_inputs: Tensor, 52 | batch_data_samples: SampleList) -> Union[Dict, List]: 53 | feat = self.extract_feat(batch_inputs) 54 | 55 | assert len(batch_data_samples) == 1 56 | img_path = batch_data_samples[0].metainfo['img_path'] 57 | img_path = img_path.replace('.jpg', f'_{self.backbone.model_name}_cache.pth') 58 | if not mmengine.exists(img_path): 59 | feat = feat.to(device='cpu')[0] 60 | torch.save(feat.to(device='cpu'), img_path) 61 | else: 62 | print_log(f'{img_path} already exists') 63 | return {} 64 | 65 | def loss(self, batch_inputs: Tensor, 66 | batch_data_samples: SampleList) -> Union[Dict, List]: 67 | raise NotImplementedError 68 | -------------------------------------------------------------------------------- /seg/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .ovsam_head import OVSAMHead 2 | -------------------------------------------------------------------------------- /seg/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | from .last_layer import LastLayerNeck, LastLayerProjNeck 2 | from .sam_pe import SAMPromptEncoder 3 | from .transformer_neck import SingleLayerTransformerNeck, MultiLayerTransformerNeck 4 | -------------------------------------------------------------------------------- /seg/models/necks/last_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from mmengine.model import BaseModule 4 | from torch import Tensor, nn 5 | 6 | from mmdet.registry import MODELS 7 | 8 | from ext.sam.common import LayerNorm2d 9 | from seg.models.utils.load_checkpoint import load_checkpoint_with_prefix 10 | 11 | 12 | @MODELS.register_module() 13 | class LastLayerNeck(BaseModule): 14 | r"""Last Layer Neck 15 | 16 | Return the last layer feature of the backbone. 17 | """ 18 | 19 | def __init__(self) -> None: 20 | super().__init__(init_cfg=None) 21 | 22 | def forward(self, inputs: Tuple[Tensor]) -> Tensor: 23 | return inputs[-1] 24 | 25 | 26 | @MODELS.register_module() 27 | class LastLayerProjNeck(BaseModule): 28 | 29 | def __init__( 30 | self, 31 | in_channels, 32 | out_channels, 33 | init_cfg=None 34 | ) -> None: 35 | super().__init__(init_cfg=None) 36 | self.out_proj = nn.Sequential( 37 | nn.Conv2d( 38 | in_channels, 39 | out_channels, 40 | kernel_size=1, 41 | bias=False, 42 | ), 43 | LayerNorm2d(out_channels), 44 | nn.Conv2d( 45 | out_channels, 46 | out_channels, 47 | kernel_size=3, 48 | padding=1, 49 | bias=False, 50 | ), 51 | LayerNorm2d(out_channels), 52 | ) 53 | 54 | if init_cfg is not None and init_cfg['type'] == 'Pretrained': 55 | checkpoint_path = init_cfg['checkpoint'] 56 | state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=init_cfg['prefix']) 57 | self.load_state_dict(state_dict, strict=True) 58 | self._is_init = True 59 | 60 | def init_weights(self): 61 | pass 62 | 63 | def forward(self, inputs: Tuple[Tensor]) -> Tensor: 64 | return self.out_proj(inputs[-1]) 65 | -------------------------------------------------------------------------------- /seg/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .video_gt_preprocess import preprocess_video_panoptic_gt 2 | from .mask_pool import mask_pool 3 | from .pan_seg_transform import INSTANCE_OFFSET_HB, mmpan2hbpan, mmgt2hbpan 4 | from .class_overlapping import calculate_class_overlapping 5 | from .no_obj import NO_OBJ 6 | from .offline_video_metrics import vpq_eval, stq 7 | -------------------------------------------------------------------------------- /seg/models/utils/class_overlapping.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | def calculate_class_overlapping(classes1: List[str], classes2: List[str]) -> List[bool]: 5 | words1 = [word for item in classes1 for word in item.split(',')] 6 | results = [] 7 | for item in classes2: 8 | flag: bool = False 9 | for word in item.split(','): 10 | if word in words1: 11 | flag = True 12 | break 13 | results.append(flag) 14 | return results 15 | -------------------------------------------------------------------------------- /seg/models/utils/load_checkpoint.py: -------------------------------------------------------------------------------- 1 | from mmengine.runner.checkpoint import CheckpointLoader 2 | 3 | 4 | def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'): 5 | """Load partial pretrained model with specific prefix. 6 | 7 | Args: 8 | prefix (str): The prefix of sub-module. 9 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, 10 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for 11 | details. 12 | map_location (str | None): Same as :func:`torch.load`. 13 | Defaults to None. 14 | logger: logger 15 | 16 | Returns: 17 | dict or OrderedDict: The loaded checkpoint. 18 | """ 19 | 20 | checkpoint = CheckpointLoader.load_checkpoint(filename, map_location=map_location, logger=logger) 21 | 22 | if 'state_dict' in checkpoint: 23 | state_dict = checkpoint['state_dict'] 24 | else: 25 | state_dict = checkpoint 26 | if not prefix: 27 | return state_dict 28 | if not prefix.endswith('.'): 29 | prefix += '.' 30 | prefix_len = len(prefix) 31 | 32 | state_dict = { 33 | k[prefix_len:]: v 34 | for k, v in state_dict.items() if k.startswith(prefix) 35 | } 36 | 37 | assert state_dict, f'{prefix} is not in the pretrained model' 38 | return state_dict 39 | -------------------------------------------------------------------------------- /seg/models/utils/mask_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # https://github.com/NVlabs/ODISE/blob/e97b06c424c575fec9fc5368dd4b3e050d91abc4/odise/modeling/meta_arch/odise.py#L923 6 | 7 | def mask_pool(x, mask): 8 | """ 9 | Args: 10 | x: [B, C, H, W] 11 | mask: [B, Q, H, W] 12 | """ 13 | if not x.shape[-2:] == mask.shape[-2:]: 14 | # reshape mask to x 15 | mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False) 16 | with torch.no_grad(): 17 | mask = mask.detach() 18 | mask = (mask > 0).to(mask.dtype) 19 | denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8 20 | 21 | mask_pooled_x = torch.einsum( 22 | "bchw,bqhw->bqc", 23 | x, 24 | mask / denorm, 25 | ) 26 | return mask_pooled_x 27 | 28 | -------------------------------------------------------------------------------- /seg/models/utils/no_obj.py: -------------------------------------------------------------------------------- 1 | NO_OBJ = 65535 2 | -------------------------------------------------------------------------------- /seg/models/utils/pan_seg_transform.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import numpy as np 5 | from mmdet.evaluation import INSTANCE_OFFSET 6 | 7 | INSTANCE_OFFSET_HB = 10000 8 | 9 | 10 | def mmpan2hbpan(pred_pan_map, num_classes): 11 | pan_seg_map = - np.ones_like(pred_pan_map) 12 | for itm in np.unique(pred_pan_map): 13 | if itm >= INSTANCE_OFFSET: 14 | # cls labels (from segmentation maps) 15 | cls = itm % INSTANCE_OFFSET 16 | # id labels (from tracking maps) 17 | ins = itm // INSTANCE_OFFSET 18 | pan_seg_map[pred_pan_map == itm] = cls * INSTANCE_OFFSET_HB + ins 19 | elif itm == num_classes: 20 | pan_seg_map[pred_pan_map == itm] = num_classes * INSTANCE_OFFSET_HB 21 | else: 22 | pan_seg_map[pred_pan_map == itm] = itm * INSTANCE_OFFSET_HB 23 | assert -1 not in pan_seg_map 24 | return pan_seg_map 25 | 26 | 27 | def mmgt2hbpan(data_samples): 28 | pan_map = copy.deepcopy(data_samples.gt_sem_seg.sem_seg[0]) 29 | pan_map = pan_map * INSTANCE_OFFSET_HB 30 | gt_instances = data_samples.gt_instances 31 | for idx in range(len(gt_instances)): 32 | mask = torch.tensor(gt_instances.masks.masks[idx], dtype=torch.bool) 33 | instance_id = gt_instances.instances_ids[idx].item() 34 | pan_map[mask] = instance_id 35 | 36 | return pan_map 37 | -------------------------------------------------------------------------------- /seg/models/utils/video_gt_preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def preprocess_video_panoptic_gt( 5 | gt_labels, 6 | gt_masks, 7 | gt_semantic_seg, 8 | gt_instance_ids, 9 | num_things, 10 | num_stuff, 11 | ): 12 | num_classes = num_things + num_stuff 13 | num_frames = len(gt_masks) 14 | mask_size = gt_masks[0].masks.shape[-2:] 15 | 16 | thing_masks_list = [] 17 | for frame_id in range(num_frames): 18 | thing_masks_list.append(gt_masks[frame_id].pad( 19 | mask_size, pad_val=0).to_tensor( 20 | dtype=torch.bool, device=gt_labels.device) 21 | ) 22 | instances = torch.unique(gt_instance_ids[:, 1]) 23 | things_masks = [] 24 | labels = [] 25 | for instance in instances: 26 | pos_ins = torch.nonzero(torch.eq(gt_instance_ids[:, 1], instance), as_tuple=True)[0] # 0 is for redundant tuple 27 | labels_instance = gt_labels[:, 1][pos_ins] 28 | assert torch.allclose(labels_instance, labels_instance[0]) 29 | labels.append(labels_instance[0]) 30 | instance_frame_ids = gt_instance_ids[:, 0][pos_ins].to(dtype=torch.int32).tolist() 31 | instance_masks = [] 32 | for frame_id in range(num_frames): 33 | frame_instance_ids = gt_instance_ids[gt_instance_ids[:, 0] == frame_id, 1] 34 | if frame_id not in instance_frame_ids: 35 | empty_mask = torch.zeros( 36 | mask_size, 37 | dtype=thing_masks_list[frame_id].dtype, device=thing_masks_list[frame_id].device 38 | ) 39 | instance_masks.append(empty_mask) 40 | else: 41 | pos_inner_frame = torch.nonzero(torch.eq(frame_instance_ids, instance), as_tuple=True)[0].item() 42 | frame_mask = thing_masks_list[frame_id][pos_inner_frame] 43 | instance_masks.append(frame_mask) 44 | things_masks.append(torch.stack(instance_masks)) 45 | 46 | if len(instances) == 0: 47 | things_masks = torch.stack(thing_masks_list, dim=1) 48 | labels = torch.empty_like(instances) 49 | else: 50 | things_masks = torch.stack(things_masks) 51 | labels = torch.stack(labels) 52 | assert torch.all(torch.less(labels, num_things)) 53 | 54 | if gt_semantic_seg is not None: 55 | things_labels = labels 56 | gt_semantic_seg = gt_semantic_seg.squeeze(1) 57 | 58 | semantic_labels = torch.unique( 59 | gt_semantic_seg, 60 | sorted=False, 61 | return_inverse=False, 62 | return_counts=False) 63 | stuff_masks_list = [] 64 | stuff_labels_list = [] 65 | for label in semantic_labels: 66 | if label < num_things or label >= num_classes: 67 | continue 68 | stuff_mask = gt_semantic_seg == label 69 | stuff_masks_list.append(stuff_mask) 70 | stuff_labels_list.append(label) 71 | 72 | if len(stuff_masks_list) > 0: 73 | stuff_masks = torch.stack(stuff_masks_list, dim=0) 74 | stuff_labels = torch.stack(stuff_labels_list, dim=0) 75 | assert torch.all(torch.ge(stuff_labels, num_things)) and torch.all(torch.less(stuff_labels, num_classes)) 76 | labels = torch.cat([things_labels, stuff_labels], dim=0) 77 | masks = torch.cat([things_masks, stuff_masks], dim=0) 78 | else: 79 | labels = things_labels 80 | masks = things_masks 81 | assert len(labels) == len(masks) 82 | else: 83 | masks = things_masks 84 | 85 | labels = labels.to(dtype=torch.long) 86 | masks = masks.to(dtype=torch.long) 87 | return labels, masks 88 | -------------------------------------------------------------------------------- /tools/dist.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | FILE=$1 4 | CONFIG=$2 5 | GPUS=$3 6 | NNODES=${NNODES:-1} 7 | NODE_RANK=${NODE_RANK:-0} 8 | PORT=${PORT:-$((28500 + $RANDOM % 2000))} 9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 10 | 11 | 12 | if command -v torchrun &> /dev/null 13 | then 14 | echo "Using torchrun mode." 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ 16 | torchrun --nnodes=${NNODES} \ 17 | --nnodes=${NNODES} \ 18 | --node_rank=${NODE_RANK} \ 19 | --master_addr=${MASTER_ADDR} \ 20 | --master_port=${PORT} \ 21 | --nproc_per_node=${GPUS} \ 22 | $(dirname "$0")/${FILE}.py ${CONFIG} --launcher pytorch ${@:4} 23 | else 24 | echo "Using launch mode." 25 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ 26 | python -m torch.distributed.launch \ 27 | --nnodes=${NNODES} \ 28 | --node_rank=${NODE_RANK} \ 29 | --master_addr=${MASTER_ADDR} \ 30 | --master_port=${PORT} \ 31 | --nproc_per_node=${GPUS} \ 32 | $(dirname "$0")/${FILE}.py ${CONFIG} --launcher pytorch ${@:4} 33 | fi 34 | -------------------------------------------------------------------------------- /tools/slurm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | FILE=$1 6 | CONFIG=$2 7 | GPUS=${GPUS:-8} 8 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 9 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 10 | MASTER_PORT=${MASTER_PORT:-$((28500 + $RANDOM % 2000))} 11 | PARTITION=${PARTITION:-DUMMY} 12 | JOB_NAME=${JOB_NAME:-DUMMY} 13 | QUOTATYPE=${QUOTATYPE:-auto} 14 | SRUN_ARGS=${SRUN_ARGS:-""} 15 | PY_ARGS=${@:3} 16 | 17 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ 18 | CUDA_HOME=$(dirname $(dirname $(which nvcc))) \ 19 | MASTER_PORT=$MASTER_PORT \ 20 | srun -p ${PARTITION} \ 21 | --job-name=${JOB_NAME} \ 22 | --gres=gpu:${GPUS_PER_NODE} \ 23 | --ntasks=${GPUS} \ 24 | --ntasks-per-node=${GPUS_PER_NODE} \ 25 | --cpus-per-task=${CPUS_PER_TASK} \ 26 | --kill-on-bad-exit=1 \ 27 | --quotatype=${QUOTATYPE} \ 28 | ${SRUN_ARGS} \ 29 | python -u tools/${FILE}.py ${CONFIG} --launcher="slurm" ${PY_ARGS} 30 | --------------------------------------------------------------------------------