├── LICENSE ├── README.md ├── assets ├── arch.png └── teaser.png ├── model ├── backbone │ ├── __init__.py │ ├── custom_clip.py │ ├── custom_deit.py │ ├── custom_dino.py │ ├── custom_sam.py │ ├── src_dift │ │ ├── models │ │ │ └── dift_sd.py │ │ └── utils │ │ │ └── visualization.py │ └── src_dino │ │ ├── dinov1.py │ │ └── hubconf.py ├── segic.py └── segment_anything_training │ ├── __init__.py │ ├── build_sam.py │ ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ ├── transformer.py │ └── transformer.py.bak │ └── utils │ ├── __init__.py │ ├── coco_eval.py │ ├── transforms.py │ └── transforms_gdino.py ├── requirements.txt ├── scripts └── segic_dist.sh ├── train.py ├── utils ├── coco80.txt ├── dataloader.py ├── dataset.py ├── dataset │ ├── ade20k_classes.json │ ├── ade20k_icl.pth │ ├── ade847_classes.json │ ├── ade_icl.json │ ├── ade_icl.pth │ ├── pc459_classes.json │ ├── sd_ade847_classes.json │ ├── train_ade20k_icl.pth │ ├── val_ade20k_icl.pth │ ├── val_ade847_icl.pth │ ├── val_pc459_icl.pth │ ├── val_sd_ade20k_icl.pth │ └── val_sd_ade847_icl.pth ├── fss.py ├── fss_inst.py ├── inst_aug.py ├── instance_evaluation.py ├── logger.py ├── loss_mask.py ├── lr_sched.py ├── meter.py ├── misc.py ├── register_seginw_dataset.py ├── seginw_data_mapper.py └── vos_dataset.py └── vos_benchmark ├── __init__.py ├── benchmark.py ├── evaluator.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Lingchen Meng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SEGIC: Unleashing the Emergent Correspondence for In-Context Segmentation 2 | 3 | ### [Paper (ArXiv)](https://arxiv.org/abs/2311.14671) 4 | 5 | 6 | We introduce SEGIC, an end-to-end segment-in-context framework built upon a single frozen vision foundation model. 7 | 8 | ![teaser](assets/teaser.png) 9 | 10 | ## Model ZOO 11 | 12 | | Model | Backbone | Iters | Config | Download | 13 | | ------ | -------- | ------- | ----- | ----- | 14 | | SEGIC | DINOv2-l | 80k*12e | [config](scripts/segic_dist.sh) | [model](https://huggingface.co/menglc/SEGIC/blob/main/segic_dinov2_l_80kx12e.pth) 15 | | SEGIC | DINOv2-l | 160k*12e | [config](scripts/segic_dist.sh) | [model](https://huggingface.co/menglc/SEGIC/blob/main/segic_dinov2_l_160kx12e.pth) 16 | 17 | 18 | ## Environment Setup 19 | ``` 20 | conda create --name segic python=3.10 -y 21 | conda activate segic 22 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ## Train SEGIC 27 | ``` 28 | bash scripts/segic_dist.sh 8 dinov2 OUTPUT/all_exps/abs_backbone/dinov2_l --dinov2_model l --samples_per_epoch 80000 29 | ``` 30 | 31 | ## Evaluate SEGIC 32 | 33 | ### Download Datasets 34 | 35 | 36 | The dataset should be organized as: 37 | ``` 38 | data 39 | ├── COCO2014 40 | │ ├── annotations 41 | │ ├── train2014 42 | │ └── val2014 43 | ├── DAVIS 44 | │ ├── 2016 45 | │ └── 2017 46 | ├── FSS-1000 47 | │ ├── abacus 48 | │ ├── abe's_flyingfish 49 | │ ├── ab_wheel 50 | │ ├── ... 51 | └── ytbvos18 52 | └── val 53 | 54 | ``` 55 | 56 | ### Evaluate One-shot Segmentation 57 | ``` 58 | # coco 59 | bash scripts/segic_dist.sh 8 dinov2 OUTPUT/all_exps/abs_backbone/dinov2_l --eval --restore-model /your/ckpt/path --eval_datasets coco 60 | 61 | # fss 62 | bash scripts/segic_dist.sh 8 dinov2 OUTPUT/all_exps/abs_backbone/dinov2_l --eval --restore-model /your/ckpt/path --eval_datasets fss 63 | ``` 64 | 65 | ### Evaluate Zero-shot Video Object Segmentation 66 | ``` 67 | # davis-17 68 | bash scripts/segic_dist.sh 8 dinov2 OUTPUT/all_exps/abs_backbone/dinov2_l --eval_vos --vos_data davis17 --restore-model /your/ckpt/path 69 | 70 | # youtubevos-18 71 | bash scripts/segic_dist.sh 8 dinov2 OUTPUT/all_exps/abs_backbone/dinov2_l --eval_vos --vos_data youtube --restore-model /your/ckpt/path 72 | ``` 73 | 74 | ### Custom Inference 75 | ``` 76 | bash scripts/segic_dist.sh 1 dinov2 OUTPUT/all_exps/abs_backbone/dinov2_l --custom_eval --restore-model /your/ckpt/path 77 | ``` 78 | 79 | ## Acknowledgement 80 | Many thanks to these excellent opensource projects 81 | * [Segment Anything](https://github.com/facebookresearch/segment-anything) 82 | * [SAM-HQ](https://github.com/SysCV/sam-hq) 83 | 84 | ## Citation 85 | If you find this project useful for your research, please use the following BibTeX entry. 86 | ```bibtex 87 | @inproceedings{meng2023segic, 88 | title={SEGIC: Unleashing the Emergent Correspondence for In-Context Segmentation}, 89 | author={Meng, Lingchen and Lan, Shiyi and Li, Hengduo and Alvarez, Jose M and Wu, Zuxuan and Jiang, Yu-Gang}, 90 | journal={ECCV}, 91 | year={2024} 92 | } -------------------------------------------------------------------------------- /assets/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/SEGIC/73bbcc340825e89f18e1b607cdbaaa09adc9619a/assets/arch.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/SEGIC/73bbcc340825e89f18e1b607cdbaaa09adc9619a/assets/teaser.png -------------------------------------------------------------------------------- /model/backbone/custom_clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import open_clip 5 | from copy import deepcopy 6 | from transformers import CLIPVisionModel 7 | from transformers.models.clip.modeling_clip import (CLIPPreTrainedModel, CLIPVisionConfig, CLIPVisionTransformer, CLIPVisionEmbeddings, CLIPEncoder, 8 | CLIPConfig, CLIPTextConfig, CLIPTextModel) 9 | from transformers import AutoTokenizer 10 | 11 | class CustomCLIPVisionEmbeddings(CLIPVisionEmbeddings): 12 | 13 | def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: 14 | batch_size = pixel_values.shape[0] 15 | patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] 16 | patch_embeds = patch_embeds.flatten(2).transpose(1, 2) 17 | 18 | class_embeds = self.class_embedding.expand(batch_size, 1, -1) 19 | embeddings = torch.cat([class_embeds, patch_embeds], dim=1) 20 | 21 | h, w = pixel_values.shape[-2:] 22 | pos_embeddings = self.position_embedding(self.position_ids) 23 | pos_embeddings = self.interpolate_pos_encoding(embeddings, w, h) 24 | # pos_embeddings = self.interpolate_pos_encoding(embeddings, pos_embeddings[0], w, h) 25 | 26 | embeddings = embeddings + pos_embeddings 27 | return embeddings 28 | 29 | def interpolate_pos_encoding(self, x, w, h): 30 | previous_dtype = x.dtype 31 | npatch = x.shape[1] - 1 32 | N = self.position_embedding.weight.shape[0] - 1 33 | if npatch == N and w == h: 34 | return self.position_embedding.weight 35 | pos_embed = self.position_embedding.weight.float() 36 | class_pos_embed = pos_embed[:1, :] 37 | patch_pos_embed = pos_embed[1:, :] 38 | dim = x.shape[-1] 39 | w0 = w // self.patch_size 40 | h0 = h // self.patch_size 41 | # we add a small number to avoid floating point error in the interpolation 42 | # see discussion at https://github.com/facebookresearch/dino/issues/8 43 | w0, h0 = w0 + 0.1, h0 + 0.1 44 | 45 | patch_pos_embed = nn.functional.interpolate( 46 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 47 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 48 | mode="bicubic", 49 | ) 50 | 51 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 52 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 53 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) 54 | 55 | class CustomCLIPVisionTransformer(CLIPVisionTransformer): 56 | def __init__(self, config: CLIPVisionConfig): 57 | super(CLIPVisionTransformer, self).__init__() 58 | self.config = config 59 | embed_dim = config.hidden_size 60 | 61 | self.embeddings = CustomCLIPVisionEmbeddings(config) 62 | self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) 63 | self.encoder = CLIPEncoder(config) 64 | self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) 65 | 66 | class CustomCLIPVisionModel(CLIPVisionModel): 67 | def __init__(self, config: CLIPVisionConfig): 68 | super(CLIPVisionModel, self).__init__(config) 69 | self.vision_model = CustomCLIPVisionTransformer(config) 70 | # Initialize weights and apply final processing 71 | self.post_init() 72 | 73 | 74 | class CLIPModel(CLIPPreTrainedModel): 75 | config_class = CLIPConfig 76 | 77 | def __init__(self, config: CLIPConfig): 78 | super().__init__(config) 79 | 80 | if not isinstance(config.text_config, CLIPTextConfig): 81 | raise ValueError( 82 | "config.text_config is expected to be of type CLIPTextConfig but is of type" 83 | f" {type(config.text_config)}." 84 | ) 85 | 86 | if not isinstance(config.vision_config, CLIPVisionConfig): 87 | raise ValueError( 88 | "config.vision_config is expected to be of type CLIPVisionConfig but is of type" 89 | f" {type(config.vision_config)}." 90 | ) 91 | 92 | text_config = config.text_config 93 | vision_config = config.vision_config 94 | 95 | self.projection_dim = config.projection_dim 96 | self.text_embed_dim = text_config.hidden_size 97 | self.vision_embed_dim = vision_config.hidden_size 98 | 99 | self.text_model = CLIPTextModel(text_config) 100 | self.vision_model = CustomCLIPVisionModel(vision_config) 101 | self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) 102 | 103 | def get_prompt_features(self, prompt): 104 | device = next(self.text_model.parameters()).device 105 | text_inputs = self.tokenizer( 106 | prompt, 107 | padding="max_length", 108 | max_length=self.tokenizer.model_max_length, 109 | truncation=True, 110 | return_tensors="pt", 111 | ) 112 | 113 | for k, x in text_inputs.items(): 114 | if hasattr(x, 'to'): text_inputs[k] = text_inputs[k].to(device) 115 | 116 | return self.text_model(**text_inputs).pooler_output 117 | 118 | 119 | 120 | class CLIPModelConv(nn.Module): 121 | def __init__(self) -> None: 122 | super().__init__() 123 | # self.model, *_ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg') 124 | # self.tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg') 125 | self.model, *_ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg') 126 | self.tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg') 127 | del self.model.visual.trunk.head 128 | self.model.visual.trunk.head = nn.Identity() 129 | self.requires_grad_(False) 130 | if True : 131 | stages = deepcopy(self.model.visual.trunk.stages) 132 | del self.model.visual.trunk.stages 133 | self.model.visual.trunk.stages = stages[:3] 134 | 135 | def forward(self, x): 136 | return self.model.visual.trunk(x) 137 | 138 | @property 139 | def device(self): 140 | return next(self.parameters()).device 141 | 142 | def get_prompt_features(self, prompt): 143 | text = self.tokenizer(prompt).to(self.device) 144 | return self.model.encode_text(text) 145 | 146 | 147 | class PyramidCLIPModel(nn.Module): 148 | def __init__(self) -> None: 149 | super().__init__() 150 | from models.models import build_model 151 | from models.simple_tokenizer import tokenize 152 | 153 | self.tokenizer = tokenize 154 | model = build_model('RN50') 155 | ckpt_path = 'pretrained_checkpoint/PyramidCLIP-YFCC15MV2-RN50.pth' # specify path of checkpoint 156 | if ckpt_path: 157 | model.load_state_dict(torch.load(ckpt_path, map_location='cpu')['state_dict']) 158 | self.model = model 159 | 160 | self.requires_grad_(False) 161 | 162 | def forward(self, x): 163 | return self.model.encode_image(x, extract_dense_feature=True) 164 | 165 | @property 166 | def device(self): 167 | return next(self.parameters()).device 168 | 169 | def get_prompt_features(self, prompt): 170 | text = self.tokenizer(prompt).to(self.device) 171 | return self.model.encode_text(text) 172 | 173 | 174 | class DeCLIPModel(nn.Module): 175 | def __init__(self) -> None: 176 | super().__init__() 177 | from prototype.model import model_entry 178 | from prototype.utils.misc import parse_config 179 | from models.simple_tokenizer import tokenize 180 | 181 | config_file = 'DeCLIP/experiments/declip_experiments/declip88m/declip88m_r50_declip/config.yaml' 182 | config = parse_config(config_file) 183 | model = model_entry(config.model) 184 | 185 | import torch 186 | from collections import OrderedDict 187 | ckpt = torch.load('DeCLIP/r50.pth.tar', map_location='cpu') 188 | new_ckpt = OrderedDict() 189 | model_dict = model.state_dict() 190 | for k,v in ckpt['model'].items(): 191 | k = k.replace('module.', '') 192 | if k in model_dict and model_dict[k].shape == v.shape: 193 | new_ckpt[k]= v 194 | else: 195 | print(k) 196 | 197 | model.load_state_dict(new_ckpt, strict=False) 198 | self.tokenizer = tokenize 199 | self.model = model 200 | 201 | self.requires_grad_(False) 202 | 203 | def forward(self, x): 204 | return self.model.encode_image(x, return_dense=True) 205 | 206 | @property 207 | def device(self): 208 | return next(self.parameters()).device 209 | 210 | def get_prompt_features(self, prompt): 211 | return self.model.encode_text(prompt) 212 | 213 | -------------------------------------------------------------------------------- /model/backbone/custom_deit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | from timm.models.vision_transformer import VisionTransformer, _cfg 6 | from timm.models.registry import register_model 7 | from timm.models.layers import trunc_normal_ 8 | from transformers import AutoTokenizer, CLIPTextModel 9 | import math 10 | 11 | 12 | __all__ = [ 13 | 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224', 14 | 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224', 15 | 'deit_base_distilled_patch16_224', 'deit_base_patch16_384', 16 | 'deit_base_distilled_patch16_384', 17 | ] 18 | 19 | 20 | class DistilledVisionTransformer(VisionTransformer): 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 24 | num_patches = self.patch_embed.num_patches 25 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 26 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 27 | 28 | trunc_normal_(self.dist_token, std=.02) 29 | trunc_normal_(self.pos_embed, std=.02) 30 | self.head_dist.apply(self._init_weights) 31 | 32 | def forward_features(self, x): 33 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 34 | # with slight modifications to add the dist_token 35 | B = x.shape[0] 36 | x = self.patch_embed(x) 37 | 38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 39 | dist_token = self.dist_token.expand(B, -1, -1) 40 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 41 | 42 | x = x + self.pos_embed 43 | x = self.pos_drop(x) 44 | 45 | for blk in self.blocks: 46 | x = blk(x) 47 | 48 | x = self.norm(x) 49 | return x[:, 0], x[:, 1] 50 | 51 | def forward(self, x): 52 | x, x_dist = self.forward_features(x) 53 | x = self.head(x) 54 | x_dist = self.head_dist(x_dist) 55 | if self.training: 56 | return x, x_dist 57 | else: 58 | # during inference, return the average of both classifier predictions 59 | return (x + x_dist) / 2 60 | 61 | 62 | # class CustomViT(VisionTransformer): 63 | 64 | 65 | @register_model 66 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 67 | model = VisionTransformer( 68 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 69 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 70 | model.default_cfg = _cfg() 71 | if pretrained: 72 | checkpoint = torch.hub.load_state_dict_from_url( 73 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 74 | map_location="cpu", check_hash=True 75 | ) 76 | model.load_state_dict(checkpoint["model"]) 77 | return model 78 | 79 | 80 | @register_model 81 | def deit_small_patch16_224(pretrained=False, **kwargs): 82 | model = VisionTransformer( 83 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 84 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 85 | model.default_cfg = _cfg() 86 | if pretrained: 87 | checkpoint = torch.hub.load_state_dict_from_url( 88 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 89 | map_location="cpu", check_hash=True 90 | ) 91 | model.load_state_dict(checkpoint["model"]) 92 | return model 93 | 94 | 95 | @register_model 96 | def deit_base_patch16_224(pretrained=False, **kwargs): 97 | model = VisionTransformer( 98 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 99 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 100 | model.default_cfg = _cfg() 101 | if pretrained: 102 | checkpoint = torch.hub.load_state_dict_from_url( 103 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 104 | map_location="cpu", check_hash=True 105 | ) 106 | model.load_state_dict(checkpoint["model"]) 107 | return model 108 | 109 | 110 | @register_model 111 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): 112 | model = DistilledVisionTransformer( 113 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 114 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 115 | model.default_cfg = _cfg() 116 | if pretrained: 117 | checkpoint = torch.hub.load_state_dict_from_url( 118 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth", 119 | map_location="cpu", check_hash=True 120 | ) 121 | model.load_state_dict(checkpoint["model"]) 122 | return model 123 | 124 | 125 | @register_model 126 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs): 127 | model = DistilledVisionTransformer( 128 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 129 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 130 | model.default_cfg = _cfg() 131 | if pretrained: 132 | checkpoint = torch.hub.load_state_dict_from_url( 133 | url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth", 134 | map_location="cpu", check_hash=True 135 | ) 136 | model.load_state_dict(checkpoint["model"]) 137 | return model 138 | 139 | 140 | @register_model 141 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs): 142 | model = DistilledVisionTransformer( 143 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 144 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 145 | model.default_cfg = _cfg() 146 | if pretrained: 147 | checkpoint = torch.hub.load_state_dict_from_url( 148 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth", 149 | map_location="cpu", check_hash=True 150 | ) 151 | model.load_state_dict(checkpoint["model"]) 152 | return model 153 | 154 | 155 | @register_model 156 | def deit_base_patch16_384(pretrained=False, **kwargs): 157 | model = VisionTransformer( 158 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 159 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 160 | model.default_cfg = _cfg() 161 | if pretrained: 162 | checkpoint = torch.hub.load_state_dict_from_url( 163 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth", 164 | map_location="cpu", check_hash=True 165 | ) 166 | model.load_state_dict(checkpoint["model"]) 167 | return model 168 | 169 | 170 | @register_model 171 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs): 172 | model = DistilledVisionTransformer( 173 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 174 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 175 | model.default_cfg = _cfg() 176 | if pretrained: 177 | checkpoint = torch.hub.load_state_dict_from_url( 178 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth", 179 | map_location="cpu", check_hash=True 180 | ) 181 | model.load_state_dict(checkpoint["model"]) 182 | return model 183 | 184 | 185 | class CustomDeiT(nn.Module): 186 | def __init__(self) -> None: 187 | super().__init__() 188 | model = deit_base_patch16_224(True) 189 | model = deit_base_patch16_224(True) 190 | model.patch_embed.img_size = None 191 | model.patch_embed.flatten = False 192 | model.patch_embed.output_fmt = 'NHWC' 193 | model.dynamic_img_size = True 194 | self.model = model 195 | self.text_model = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14') 196 | self.tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-large-patch14') 197 | 198 | self.requires_grad_(False) 199 | 200 | def forward(self, x): 201 | image_embeddings = self.model.forward_features(x)[:,1:] 202 | bs, l, c = image_embeddings.shape 203 | ft_list = image_embeddings.reshape(bs, int(math.sqrt(l)), int(math.sqrt(l)), c).permute(0,3,1,2).contiguous() 204 | return ft_list 205 | 206 | @property 207 | def device(self): 208 | return next(self.parameters()).device 209 | 210 | def get_prompt_features(self, prompt): 211 | device = next(self.parameters()).device 212 | text_inputs = self.tokenizer( 213 | prompt, 214 | padding="max_length", 215 | max_length=self.tokenizer.model_max_length, 216 | truncation=True, 217 | return_tensors="pt", 218 | ) 219 | 220 | for k, x in text_inputs.items(): 221 | if hasattr(x, 'to'): text_inputs[k] = text_inputs[k].to(device) 222 | 223 | return self.text_model(**text_inputs).pooler_output -------------------------------------------------------------------------------- /model/backbone/custom_dino.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from transformers import AutoTokenizer, CLIPTextModel, ConvNextModel 5 | from transformers.models.vit_mae.modeling_vit_mae import ViTMAEEmbeddings, ViTMAEEncoder, ViTMAEPatchEmbeddings, ViTMAEModel 6 | from copy import deepcopy 7 | from .src_dino.hubconf import dino_vitb16 8 | 9 | class CustomDINOv2(nn.Module): 10 | def __init__(self, dinov2_model) -> None: 11 | super().__init__() 12 | assert dinov2_model in ('b', 'l', 'g') 13 | # import pdb; pdb.set_trace() 14 | if dinov2_model == 'b': 15 | self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') 16 | elif dinov2_model == 'l': 17 | self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') 18 | elif dinov2_model == 'g': 19 | self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') 20 | elif dinov2_model == 'b_reg': 21 | self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg') 22 | elif dinov2_model == 'l_reg': 23 | self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg') 24 | elif dinov2_model == 'g_reg': 25 | self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg') 26 | 27 | else : 28 | raise NotImplementedError 29 | 30 | self.text_model = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14') 31 | self.tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-large-patch14') 32 | 33 | def forward(self, x): 34 | h, w = [xx// 14 for xx in x.shape[-2:]] 35 | output = self.dinov2.forward_features(x) 36 | output = output['x_norm_patchtokens'] #(bs,l,c) 37 | bs, l, c = output.shape 38 | assert h*w == l 39 | return output.view(bs, h, w, c).permute(0,3,1,2).contiguous() #(bs, h, w ,c) 40 | 41 | def get_prompt_features(self, prompt): 42 | device = next(self.parameters()).device 43 | text_inputs = self.tokenizer( 44 | prompt, 45 | padding="max_length", 46 | max_length=self.tokenizer.model_max_length, 47 | truncation=True, 48 | return_tensors="pt", 49 | ) 50 | 51 | for k, x in text_inputs.items(): 52 | if hasattr(x, 'to'): text_inputs[k] = text_inputs[k].to(device) 53 | 54 | return self.text_model(**text_inputs).pooler_output 55 | 56 | 57 | class CustomDINOv1(nn.Module): 58 | def __init__(self) -> None: 59 | super().__init__() 60 | # self.dinov1 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') 61 | self.dinov1 = dino_vitb16() 62 | 63 | self.text_model = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14') 64 | self.tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-large-patch14') 65 | 66 | def forward_features(self, x, masks=None): 67 | 68 | x = self.dinov1.prepare_tokens(x) 69 | 70 | for blk in self.dinov1.blocks: 71 | x = blk(x) 72 | 73 | x_norm = self.dinov1.norm(x) 74 | return x_norm[:, 1:] 75 | 76 | def forward(self, x): 77 | output = self.forward_features(x) #(bs,l,c) 78 | bs, l, c = output.shape 79 | h = w = int(math.sqrt(l)) 80 | assert h*w == l 81 | return output.view(bs, h, w, c).permute(0,3,1,2).contiguous() #(bs, h, w ,c) 82 | 83 | def get_prompt_features(self, prompt): 84 | device = next(self.parameters()).device 85 | text_inputs = self.tokenizer( 86 | prompt, 87 | padding="max_length", 88 | max_length=self.tokenizer.model_max_length, 89 | truncation=True, 90 | return_tensors="pt", 91 | ) 92 | 93 | for k, x in text_inputs.items(): 94 | if hasattr(x, 'to'): text_inputs[k] = text_inputs[k].to(device) 95 | 96 | return self.text_model(**text_inputs).pooler_output 97 | 98 | 99 | 100 | class CustomEncoder(nn.Module): 101 | def __init__(self, img_encoder) -> None: 102 | super().__init__() 103 | self.img_encoder = img_encoder 104 | self.text_model = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14') 105 | self.tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-large-patch14') 106 | 107 | def forward(self, x): 108 | return self.img_encoder(x) 109 | 110 | def get_prompt_features(self, prompt): 111 | device = next(self.parameters()).device 112 | text_inputs = self.tokenizer( 113 | prompt, 114 | padding="max_length", 115 | max_length=self.tokenizer.model_max_length, 116 | truncation=True, 117 | return_tensors="pt", 118 | ) 119 | 120 | for k, x in text_inputs.items(): 121 | if hasattr(x, 'to'): text_inputs[k] = text_inputs[k].to(device) 122 | 123 | return self.text_model(**text_inputs).pooler_output 124 | 125 | class CustomViTMAEPatchEmbeddings(ViTMAEPatchEmbeddings): 126 | def forward(self, pixel_values): 127 | x = self.projection(pixel_values).flatten(2).transpose(1, 2) 128 | return x 129 | 130 | class CustomViTMAEEmbeddings(ViTMAEEmbeddings): 131 | def __init__(self, config): 132 | nn.Module.__init__(self) 133 | 134 | self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) 135 | self.patch_embeddings = CustomViTMAEPatchEmbeddings(config) 136 | self.num_patches = self.patch_embeddings.num_patches 137 | # fixed sin-cos embedding 138 | self.position_embeddings = nn.Parameter( 139 | torch.zeros(1, self.num_patches + 1, config.hidden_size), requires_grad=False 140 | ) 141 | self.config = config 142 | self.initialize_weights() 143 | 144 | def interpolate_pos_encoding(self, x, w, h): 145 | npatch = x.shape[1] - 1 146 | N = self.position_embeddings.shape[1] - 1 147 | if npatch == N and w == h: 148 | return self.position_embeddings 149 | class_pos_embed = self.position_embeddings[:, 0] 150 | patch_pos_embed = self.position_embeddings[:, 1:] 151 | dim = x.shape[-1] 152 | # patch_size = self.patch_embeddings.patch_size 153 | w0 = w // self.patch_embeddings.patch_size[0] 154 | h0 = h // self.patch_embeddings.patch_size[1] 155 | # we add a small number to avoid floating point error in the interpolation 156 | # see discussion at https://github.com/facebookresearch/dino/issues/8 157 | w0, h0 = w0 + 0.1, h0 + 0.1 158 | patch_pos_embed = nn.functional.interpolate( 159 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 160 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 161 | mode='bicubic', 162 | ) 163 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 164 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 165 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 166 | 167 | def forward(self, pixel_values, noise=None): 168 | batch_size, num_channels, height, width = pixel_values.shape 169 | embeddings = self.patch_embeddings(pixel_values) 170 | 171 | 172 | # add position embeddings w/o cls token 173 | # import pdb; pdb.set_trace() 174 | # aa = torch.rand(1, 1025, 768) 175 | B, nc, h, w = pixel_values.shape 176 | position_embeddings = self.interpolate_pos_encoding(embeddings, w, h) 177 | embeddings = embeddings + position_embeddings[:, 1:, :] 178 | 179 | # masking: length -> length * config.mask_ratio 180 | embeddings, mask, ids_restore = self.random_masking(embeddings, noise) 181 | 182 | # append cls token 183 | cls_token = self.cls_token + self.position_embeddings[:, :1, :] 184 | cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1) 185 | embeddings = torch.cat((cls_tokens, embeddings), dim=1) 186 | 187 | return embeddings, mask, ids_restore 188 | 189 | class CustomMAEEncoder(ViTMAEModel): 190 | def __init__(self, config): 191 | super().__init__(config) 192 | self.config = config 193 | 194 | self.embeddings = CustomViTMAEEmbeddings(config) 195 | self.encoder = ViTMAEEncoder(config) 196 | 197 | self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 198 | 199 | # Initialize weights and apply final processing 200 | self.post_init() 201 | 202 | 203 | class CustomMAE(nn.Module): 204 | def __init__(self, config='facebook/vit-mae-base') -> None: 205 | super().__init__() 206 | self.img_encoder = CustomMAEEncoder.from_pretrained(config) 207 | self.text_model = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14') 208 | self.tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-large-patch14') 209 | 210 | def forward(self, x): 211 | image_embeddings = self.img_encoder(x).last_hidden_state[:,1:] 212 | bs, l, c = image_embeddings.shape 213 | ft_list = image_embeddings.reshape(bs, int(math.sqrt(l)), int(math.sqrt(l)), c).permute(0,3,1,2).contiguous() 214 | return ft_list 215 | 216 | def get_prompt_features(self, prompt): 217 | device = next(self.parameters()).device 218 | text_inputs = self.tokenizer( 219 | prompt, 220 | padding="max_length", 221 | max_length=self.tokenizer.model_max_length, 222 | truncation=True, 223 | return_tensors="pt", 224 | ) 225 | 226 | for k, x in text_inputs.items(): 227 | if hasattr(x, 'to'): text_inputs[k] = text_inputs[k].to(device) 228 | 229 | return self.text_model(**text_inputs).pooler_output 230 | 231 | class CustomConvNext(nn.Module): 232 | def __init__(self) -> None: 233 | super().__init__() 234 | # self.model, *_ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg') 235 | # self.tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg') 236 | self.model = ConvNextModel.from_pretrained("facebook/convnext-base-224") 237 | self.text_model = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14') 238 | self.tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-large-patch14') 239 | 240 | del self.model.layernorm 241 | self.model.layernorm = nn.Identity() 242 | 243 | if True : 244 | stages = deepcopy(self.model.encoder.stages) 245 | del self.model.encoder.stages 246 | self.model.encoder.stages = stages[:3] 247 | self.requires_grad_(False) 248 | 249 | def forward(self, x): 250 | return self.model(x).last_hidden_state 251 | 252 | @property 253 | def device(self): 254 | return next(self.parameters()).device 255 | 256 | def get_prompt_features(self, prompt): 257 | device = next(self.parameters()).device 258 | text_inputs = self.tokenizer( 259 | prompt, 260 | padding="max_length", 261 | max_length=self.tokenizer.model_max_length, 262 | truncation=True, 263 | return_tensors="pt", 264 | ) 265 | 266 | for k, x in text_inputs.items(): 267 | if hasattr(x, 'to'): text_inputs[k] = text_inputs[k].to(device) 268 | 269 | return self.text_model(**text_inputs).pooler_output -------------------------------------------------------------------------------- /model/backbone/custom_sam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import copy 5 | import utils.misc as misc 6 | from utils.loss_mask import loss_masks 7 | import torch.nn.functional as F 8 | from detectron2.layers import ROIAlign 9 | import random 10 | from segment_anything_training.modeling.transformer import TwoWayAttentionBlock 11 | 12 | from detectron2.modeling.poolers import ROIPooler 13 | from detectron2.structures import Boxes 14 | 15 | class MLP(nn.Module): 16 | def __init__( 17 | self, 18 | input_dim: int, 19 | hidden_dim: int, 20 | output_dim: int, 21 | num_layers: int, 22 | sigmoid_output: bool = False, 23 | ) -> None: 24 | super().__init__() 25 | self.num_layers = num_layers 26 | h = [hidden_dim] * (num_layers - 1) 27 | self.layers = nn.ModuleList( 28 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 29 | ) 30 | self.sigmoid_output = sigmoid_output 31 | 32 | def forward(self, x): 33 | for i, layer in enumerate(self.layers): 34 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 35 | if self.sigmoid_output: 36 | x = F.sigmoid(x) 37 | return x 38 | 39 | class CorrespondenceEncoder(nn.Module): 40 | def __init__(self, 41 | depth, 42 | num_heads: int, 43 | mlp_dim: int, 44 | keypoint_size=7, 45 | embed_dim=256, 46 | activation = nn.ReLU, 47 | attention_downsample_rate: int = 2, 48 | ): 49 | super().__init__() 50 | self.keypoint_size = keypoint_size 51 | self.keypoint_embedding = nn.Embedding(keypoint_size*keypoint_size, embed_dim) 52 | self.fgbg_embedding = nn.Embedding(2, embed_dim) 53 | self.layers = nn.ModuleList() 54 | 55 | for i in range(depth): 56 | self.layers.append( 57 | TwoWayAttentionBlock( 58 | embedding_dim=embed_dim, 59 | num_heads=num_heads, 60 | mlp_dim=mlp_dim, 61 | activation=activation, 62 | attention_downsample_rate=attention_downsample_rate, 63 | skip_first_layer_pe=(i == 0), 64 | ) 65 | ) 66 | 67 | self.mlp_text_prompt = MLP(embed_dim, mlp_dim, embed_dim, 3) 68 | 69 | 70 | def pool_mask_ref(self, mask_ref): 71 | h = w = int(math.sqrt(mask_ref.shape[1])) 72 | assert h*w == mask_ref.shape[1] 73 | mask_ref = mask_ref.view(mask_ref.shape[0], 1, h, w) 74 | mask_ref_resize = F.adaptive_avg_pool2d(mask_ref, self.keypoint_size).flatten(1) 75 | return mask_ref_resize 76 | 77 | 78 | def forward(self, img_ref_feat, mask_ref): 79 | ''' 80 | img_ref_feat: (bs, l, c) 81 | mask_ref: (bs, l) 82 | ''' 83 | queries = self.keypoint_embedding.weight 84 | queries = queries[None].expand(img_ref_feat.shape[0], -1, -1) 85 | mask_ref = self.pool_mask_ref(mask_ref) 86 | mask_ref_embed = self.keypoint_embedding(mask_ref) 87 | queries = queries + mask_ref_embed 88 | keys = img_ref_feat 89 | for i, layer in enumerate(self.layers): 90 | queries, keys = layer( 91 | queries=queries, 92 | keys=keys, 93 | query_pe=0, 94 | key_pe=0, 95 | ) 96 | 97 | sparse_embedding = self.mlp_text_prompt(queries) 98 | return queries, sparse_embedding 99 | 100 | 101 | class CustomSam(nn.Module): 102 | def __init__(self, sam, net, 103 | keypoint_size = 7, 104 | args=None 105 | ): 106 | super().__init__() 107 | self.args = args 108 | self.sam = sam 109 | self.net = net 110 | 111 | # self.keypoint_embedding = nn.Embedding(keypoint_size*keypoint_size, 256) 112 | 113 | def forward(self, data, inference=False): 114 | args = self.args 115 | sam, net = self.sam, self.net 116 | inputs, labels = data['image'], data['label'] 117 | if torch.cuda.is_available(): 118 | inputs = inputs.cuda() 119 | labels = labels.cuda() 120 | if 'image_dual' in data: 121 | inputs_dual = data['image_dual'].cuda() 122 | labels_dual = data['label_dual'].cuda() 123 | def custom_roi(input, bbox, resolution=1024): 124 | bbox = bbox.int() 125 | return F.interpolate(input[:, bbox[1]:bbox[3], bbox[0]:bbox[2]][:, None], resolution)[:, 0] 126 | if 1 : 127 | pooler = ROIPooler((1024,1024), [1], 0, "ROIAlignV2") 128 | inst_bbox = misc.masks_to_boxes(labels_dual[:,0,:,:]) 129 | inst_bbox = [Boxes(x) for x in inst_bbox[:, None]] 130 | inputs_dual = pooler([inputs_dual], inst_bbox) 131 | labels_dual = pooler([labels_dual], inst_bbox) 132 | # inputs_dual = torch.stack([custom_roi(input, bbox)] for input, bbox in zip(inputs_dual, inst_bbox)) 133 | # labels_dual = torch.stack([custom_roi(input, bbox)] for input, bbox in zip(labels_dual, inst_bbox)) 134 | # import pdb; pdb.set_trace() 135 | else: 136 | inputs_dual, labels_dual = None, None 137 | 138 | bs = len(inputs) 139 | 140 | if not inference : 141 | input_keys = copy.deepcopy(args.input_keys) 142 | else : 143 | input_keys = copy.deepcopy(args.eval_keys) 144 | labels_box = misc.masks_to_boxes(labels[:,0,:,:]) 145 | try: 146 | labels_points = misc.masks_sample_points(labels[:,0,:,:]) 147 | except: 148 | # less than 10 points 149 | # input_keys = ['box','noise_mask'] 150 | if 'point' in input_keys: 151 | input_keys.remove('point') 152 | labels_256 = F.interpolate(labels, size=(256, 256), mode='bilinear') 153 | labels_noisemask = misc.masks_noise(labels_256) 154 | 155 | batched_input = [{'image':x} for x in inputs] 156 | if inputs_dual is not None : 157 | batched_input.extend([{'image':x} for x in inputs_dual]) 158 | 159 | with torch.no_grad(): 160 | image_embed_output = sam(batched_input, only_forward_img=True) 161 | inst_image_embeddings = image_embeddings = image_embed_output[0] 162 | 163 | if True : 164 | image_embeddings = image_embed_output[0] 165 | inst_image_embeddings = image_embeddings 166 | if inputs_dual is not None : 167 | image_embeddings = image_embeddings[:bs] 168 | inst_image_embeddings = image_embeddings[-bs:] 169 | def proccess_image_embed_output(image_embed_output): 170 | a, b = image_embed_output 171 | a = a[:bs] 172 | b = [x[:bs] for x in b] 173 | return a, b 174 | image_embed_output = proccess_image_embed_output(image_embed_output) 175 | 176 | inst_label = labels_dual if labels_dual is not None else labels 177 | if args.noised_inst : 178 | inst_label = misc.masks_noise(inst_label, apply_incoherent=True) 179 | inst_labels_64 = F.interpolate(inst_label, size=(64, 64), mode='bilinear') / 255 180 | inst_embedding = torch.einsum('nchw,nhw->nc', inst_image_embeddings, inst_labels_64.squeeze(1)) / inst_labels_64.sum((-1,2)).clamp(min=1) 181 | try: 182 | labels_points_inst = misc.masks_sample_points(inst_label[:,0,:,:]) 183 | except : 184 | labels_points_inst = torch.zeros((bs, 0, 2), device=inputs.device) 185 | 186 | if args.use_ref_keypoint : 187 | pooler = ROIAlign(32, 1/16, 0) 188 | pooler_label = ROIAlign(32, 1, 0) 189 | inst_bbox = misc.masks_to_boxes(inst_label[:,0,:,:]) 190 | bid_bbox = torch.tensor(range(len(inst_bbox)), dtype=inst_bbox.dtype, device=inst_bbox.device)[:, None] 191 | inst_bbox_roi = torch.cat([bid_bbox, inst_bbox], dim=-1) 192 | inst_roi_features = pooler(inst_image_embeddings, inst_bbox_roi) 193 | inst_roi_mask = pooler_label(inst_label/255, inst_bbox_roi) 194 | inst_roi_masked_features = inst_roi_features * inst_roi_mask 195 | else : 196 | inst_roi_masked_features = None 197 | 198 | labels_pionts_labels_inst = torch.ones(labels_points_inst.shape[:2], device=labels_points_inst.device) 199 | point_embeddings_inst = sam.prompt_encoder._embed_points(labels_points_inst, labels_pionts_labels_inst, pad=True) 200 | 201 | if args.use_ref_keypoint : 202 | sim = misc.cal_sim(image_embeddings, inst_roi_masked_features) 203 | else : 204 | sim = misc.cal_sim(image_embeddings, inst_embedding).unsqueeze(1) 205 | 206 | sim = F.interpolate(sim, size=(256, 256), mode='bilinear') 207 | 208 | batched_input = [{'image':x, 'original_size':x.shape[-2:]} for x in inputs] 209 | for b_i in range(bs): 210 | dict_input = batched_input[b_i] 211 | input_type = random.choice(input_keys) 212 | if input_type == 'box': 213 | dict_input['boxes'] = labels_box[b_i:b_i+1] 214 | elif input_type == 'point': 215 | point_coords = labels_points[b_i:b_i+1] 216 | dict_input['point_coords'] = point_coords 217 | dict_input['point_labels'] = torch.ones(point_coords.shape[1], device=point_coords.device)[None,:] 218 | elif input_type == 'sem_corr': 219 | # dict_input['mask_inputs'] = sim[b_i:b_i+1] 220 | def get_max(sim, target_hw=1024): 221 | h, w = sim.shape[-2:] 222 | sim = sim.squeeze(1).flatten(1) 223 | _, idx = sim.topk(args.n_point, dim=1) 224 | idx_h = idx // w * target_hw / h 225 | idx_w = idx % w * target_hw / w 226 | idx = torch.stack([idx_w, idx_h], dim=-1) 227 | return idx 228 | point_coords = get_max(sim[b_i:b_i+1]) 229 | dict_input['point_coords'] = point_coords 230 | dict_input['point_labels'] = torch.ones(point_coords.shape[1], device=point_coords.device)[None,:] 231 | elif input_type == 'noise_mask': 232 | dict_input['mask_inputs'] = labels_noisemask[b_i:b_i+1] 233 | else: 234 | raise NotImplementedError 235 | 236 | batched_output, interm_embeddings = sam(batched_input, multimask_output=False, 237 | image_embed_output=image_embed_output) 238 | 239 | batch_len = len(batched_output) 240 | encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0) 241 | image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)] 242 | sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)] 243 | dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)] 244 | 245 | masks_hq, bbox_preds = net( 246 | image_embeddings=encoder_embedding, 247 | image_pe=image_pe, 248 | sparse_prompt_embeddings=sparse_embeddings, 249 | dense_prompt_embeddings=dense_embeddings, 250 | multimask_output=False, 251 | hq_token_only=True, 252 | interm_embeddings=interm_embeddings, 253 | simm_input = sim, 254 | image_embedding_ref = inst_image_embeddings, 255 | point_embedding_ref = point_embeddings_inst, 256 | inst_roi_masked_features=inst_roi_masked_features 257 | ) 258 | 259 | if inference and bbox_preds is not None: 260 | point_coords = torch.cat([x['point_coords'] for x in batched_input]) 261 | point_labels = torch.cat([x['point_labels'] for x in batched_input]) 262 | bbox_preds_xyxy = misc.box_cxcywh_to_xyxy(bbox_preds) * 1024 263 | sparse_embeddings, dense_embeddings = sam.prompt_encoder( 264 | points=(point_coords, point_labels), 265 | boxes=bbox_preds_xyxy, 266 | masks=None 267 | ) 268 | masks_hq, bbox_preds = net( 269 | image_embeddings=encoder_embedding, 270 | image_pe=image_pe, 271 | sparse_prompt_embeddings=sparse_embeddings.unsqueeze(1), 272 | dense_prompt_embeddings=dense_embeddings.unsqueeze(1), 273 | multimask_output=False, 274 | hq_token_only=True, 275 | interm_embeddings=interm_embeddings, 276 | simm_input = sim, 277 | image_embedding_ref = inst_image_embeddings, 278 | point_embedding_ref = point_embeddings_inst, 279 | 280 | ) 281 | 282 | loss_mask, loss_dice = loss_masks(masks_hq, labels/255.0, len(masks_hq)) 283 | loss = loss_mask + loss_dice 284 | loss_dict = {"loss_mask": loss_mask, "loss_dice":loss_dice} 285 | if args.use_bbox_head : 286 | labels_box_xywh = misc.box_xyxy_to_cxcywh(labels_box) / 1024 287 | bbox_preds = bbox_preds 288 | num_boxes = bs 289 | 290 | loss_bbox = F.l1_loss(bbox_preds, labels_box_xywh, reduction='none') 291 | loss_giou = 1 - torch.diag(misc.generalized_box_iou( 292 | misc.box_cxcywh_to_xyxy(bbox_preds), 293 | misc.box_cxcywh_to_xyxy(labels_box_xywh))) 294 | 295 | loss_dict['loss_giou'] = loss_giou = loss_giou.sum() / num_boxes 296 | loss_dict['loss_bbox'] = loss_bbox = loss_bbox.sum() / num_boxes 297 | loss = loss + loss_bbox*5 + loss_giou*2 298 | 299 | return masks_hq, bbox_preds, loss, loss_dict -------------------------------------------------------------------------------- /model/backbone/src_dift/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | class Demo: 8 | 9 | def __init__(self, imgs, ft, img_size): 10 | self.ft = ft # NCHW 11 | self.imgs = imgs 12 | self.num_imgs = len(imgs) 13 | self.img_size = img_size 14 | 15 | def plot_img_pairs(self, fig_size=3, alpha=0.45, scatter_size=70): 16 | 17 | fig, axes = plt.subplots(1, self.num_imgs, figsize=(fig_size*self.num_imgs, fig_size)) 18 | 19 | plt.tight_layout() 20 | 21 | for i in range(self.num_imgs): 22 | axes[i].imshow(self.imgs[i]) 23 | axes[i].axis('off') 24 | if i == 0: 25 | axes[i].set_title('source image') 26 | else: 27 | axes[i].set_title('target image') 28 | 29 | num_channel = self.ft.size(1) 30 | cos = nn.CosineSimilarity(dim=1) 31 | 32 | def onclick(event): 33 | if event.inaxes == axes[0]: 34 | with torch.no_grad(): 35 | 36 | x, y = int(np.round(event.xdata)), int(np.round(event.ydata)) 37 | 38 | src_ft = self.ft[0].unsqueeze(0) 39 | src_ft = nn.Upsample(size=(self.img_size, self.img_size), mode='bilinear')(src_ft) 40 | src_vec = src_ft[0, :, y, x].view(1, num_channel, 1, 1) # 1, C, 1, 1 41 | 42 | del src_ft 43 | gc.collect() 44 | torch.cuda.empty_cache() 45 | 46 | trg_ft = nn.Upsample(size=(self.img_size, self.img_size), mode='bilinear')(self.ft[1:]) 47 | cos_map = cos(src_vec, trg_ft).cpu().numpy() # N, H, W 48 | 49 | del trg_ft 50 | gc.collect() 51 | torch.cuda.empty_cache() 52 | 53 | axes[0].clear() 54 | axes[0].imshow(self.imgs[0]) 55 | axes[0].axis('off') 56 | axes[0].scatter(x, y, c='r', s=scatter_size) 57 | axes[0].set_title('source image') 58 | 59 | for i in range(1, self.num_imgs): 60 | max_yx = np.unravel_index(cos_map[i-1].argmax(), cos_map[i-1].shape) 61 | axes[i].clear() 62 | 63 | heatmap = cos_map[i-1] 64 | heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap)) # Normalize to [0, 1] 65 | axes[i].imshow(self.imgs[i]) 66 | axes[i].imshow(255 * heatmap, alpha=alpha, cmap='viridis') 67 | axes[i].axis('off') 68 | axes[i].scatter(max_yx[1].item(), max_yx[0].item(), c='r', s=scatter_size) 69 | axes[i].set_title('target image') 70 | 71 | del cos_map 72 | del heatmap 73 | gc.collect() 74 | 75 | fig.canvas.mpl_connect('button_press_event', onclick) 76 | plt.show() 77 | 78 | -------------------------------------------------------------------------------- /model/backbone/src_dino/hubconf.py: -------------------------------------------------------------------------------- 1 | # import vision_transformer as vits 2 | import torch 3 | from . import dinov1 as vits 4 | 5 | def dino_vitb16(pretrained=True, **kwargs): 6 | """ 7 | ViT-Base/16x16 pre-trained with DINO. 8 | Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification. 9 | """ 10 | model = vits.__dict__["vit_base"](patch_size=16, num_classes=0, **kwargs) 11 | if pretrained: 12 | state_dict = torch.hub.load_state_dict_from_url( 13 | url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth", 14 | map_location="cpu", 15 | ) 16 | model.load_state_dict(state_dict, strict=True) 17 | return model -------------------------------------------------------------------------------- /model/segment_anything_training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | -------------------------------------------------------------------------------- /model/segment_anything_training/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, ThreeWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None, use_ref_decoder=False): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | use_ref_decoder=use_ref_decoder 22 | ) 23 | 24 | 25 | build_sam = build_sam_vit_h 26 | 27 | 28 | def build_sam_vit_l(checkpoint=None, use_ref_decoder=False): 29 | return _build_sam( 30 | encoder_embed_dim=1024, 31 | encoder_depth=24, 32 | encoder_num_heads=16, 33 | encoder_global_attn_indexes=[5, 11, 17, 23], 34 | checkpoint=checkpoint, 35 | use_ref_decoder=use_ref_decoder 36 | ) 37 | 38 | 39 | def build_sam_vit_b(checkpoint=None, use_ref_decoder=False): 40 | return _build_sam( 41 | encoder_embed_dim=768, 42 | encoder_depth=12, 43 | encoder_num_heads=12, 44 | encoder_global_attn_indexes=[2, 5, 8, 11], 45 | checkpoint=checkpoint, 46 | use_ref_decoder=use_ref_decoder 47 | ) 48 | 49 | 50 | sam_model_registry = { 51 | "default": build_sam, 52 | "vit_h": build_sam, 53 | "vit_l": build_sam_vit_l, 54 | "vit_b": build_sam_vit_b, 55 | } 56 | 57 | 58 | def _build_sam( 59 | encoder_embed_dim, 60 | encoder_depth, 61 | encoder_num_heads, 62 | encoder_global_attn_indexes, 63 | checkpoint=None, 64 | use_ref_decoder=False 65 | ): 66 | prompt_embed_dim = 256 67 | image_size = 1024 68 | vit_patch_size = 16 69 | image_embedding_size = image_size // vit_patch_size 70 | mask_decoder_transformer = TwoWayTransformer if not use_ref_decoder else ThreeWayTransformer 71 | sam = Sam( 72 | image_encoder=ImageEncoderViT( 73 | depth=encoder_depth, 74 | embed_dim=encoder_embed_dim, 75 | img_size=image_size, 76 | mlp_ratio=4, 77 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 78 | num_heads=encoder_num_heads, 79 | patch_size=vit_patch_size, 80 | qkv_bias=True, 81 | use_rel_pos=True, 82 | global_attn_indexes=encoder_global_attn_indexes, 83 | window_size=14, 84 | out_chans=prompt_embed_dim, 85 | ), 86 | prompt_encoder=PromptEncoder( 87 | embed_dim=prompt_embed_dim, 88 | image_embedding_size=(image_embedding_size, image_embedding_size), 89 | input_image_size=(image_size, image_size), 90 | mask_in_chans=16, 91 | ), 92 | mask_decoder=MaskDecoder( 93 | num_multimask_outputs=3, 94 | transformer=mask_decoder_transformer( 95 | depth=2, 96 | embedding_dim=prompt_embed_dim, 97 | mlp_dim=2048, 98 | num_heads=8, 99 | ), 100 | transformer_dim=prompt_embed_dim, 101 | iou_head_depth=3, 102 | iou_head_hidden_dim=256, 103 | ), 104 | pixel_mean=[123.675, 116.28, 103.53], 105 | pixel_std=[58.395, 57.12, 57.375], 106 | ) 107 | sam.eval() 108 | if checkpoint is not None: 109 | with open(checkpoint, "rb") as f: 110 | state_dict = torch.load(f) 111 | sam.load_state_dict(state_dict) 112 | return sam 113 | -------------------------------------------------------------------------------- /model/segment_anything_training/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer, ThreeWayTransformer 12 | -------------------------------------------------------------------------------- /model/segment_anything_training/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /model/segment_anything_training/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | tranformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | task_indicator: torch.Tensor = None, 79 | ) -> Tuple[torch.Tensor, torch.Tensor]: 80 | """ 81 | Predict masks given image and prompt embeddings. 82 | 83 | Arguments: 84 | image_embeddings (torch.Tensor): the embeddings from the image encoder 85 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 86 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 87 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 88 | multimask_output (bool): Whether to return multiple masks or a single 89 | mask. 90 | 91 | Returns: 92 | torch.Tensor: batched predicted masks 93 | torch.Tensor: batched predictions of mask quality 94 | """ 95 | masks, iou_pred = self.predict_masks( 96 | image_embeddings=image_embeddings, 97 | image_pe=image_pe, 98 | sparse_prompt_embeddings=sparse_prompt_embeddings, 99 | dense_prompt_embeddings=dense_prompt_embeddings, 100 | ) 101 | 102 | # Select the correct mask or masks for outptu 103 | if task_indicator is not None : 104 | # 0 for sem, 1 for inst 105 | assert task_indicator.dim() == 1 106 | mask_slice = task_indicator.long() 107 | assert len(mask_slice) == len(masks) 108 | bs = len(masks) 109 | masks = masks[range(bs), mask_slice, :, :][:, None] 110 | iou_pred = iou_pred[range(bs), mask_slice][:, None] 111 | else : 112 | if multimask_output: 113 | mask_slice = slice(1, None) 114 | else: 115 | mask_slice = slice(0, 1) 116 | masks = masks[:, mask_slice, :, :] 117 | iou_pred = iou_pred[:, mask_slice] 118 | 119 | # Prepare output 120 | return masks, iou_pred 121 | 122 | def predict_masks( 123 | self, 124 | image_embeddings: torch.Tensor, 125 | image_pe: torch.Tensor, 126 | sparse_prompt_embeddings: torch.Tensor, 127 | dense_prompt_embeddings: torch.Tensor, 128 | ) -> Tuple[torch.Tensor, torch.Tensor]: 129 | """Predicts masks. See 'forward' for more details.""" 130 | # Concatenate output tokens 131 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 132 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 133 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 134 | 135 | # Expand per-image data in batch direction to be per-mask 136 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 137 | src = src + dense_prompt_embeddings 138 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 139 | b, c, h, w = src.shape 140 | 141 | # Run the transformer 142 | hs, src = self.transformer(src, pos_src, tokens) 143 | iou_token_out = hs[:, 0, :] 144 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 145 | 146 | # Upscale mask embeddings and predict masks using the mask tokens 147 | src = src.transpose(1, 2).view(b, c, h, w) 148 | upscaled_embedding = self.output_upscaling(src) 149 | hyper_in_list: List[torch.Tensor] = [] 150 | for i in range(self.num_mask_tokens): 151 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 152 | hyper_in = torch.stack(hyper_in_list, dim=1) 153 | b, c, h, w = upscaled_embedding.shape 154 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 155 | 156 | # Generate mask quality predictions 157 | iou_pred = self.iou_prediction_head(iou_token_out) 158 | 159 | return masks, iou_pred 160 | 161 | 162 | # Lightly adapted from 163 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 164 | class MLP(nn.Module): 165 | def __init__( 166 | self, 167 | input_dim: int, 168 | hidden_dim: int, 169 | output_dim: int, 170 | num_layers: int, 171 | sigmoid_output: bool = False, 172 | ) -> None: 173 | super().__init__() 174 | self.num_layers = num_layers 175 | h = [hidden_dim] * (num_layers - 1) 176 | self.layers = nn.ModuleList( 177 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 178 | ) 179 | self.sigmoid_output = sigmoid_output 180 | 181 | def forward(self, x): 182 | for i, layer in enumerate(self.layers): 183 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 184 | if self.sigmoid_output: 185 | x = F.sigmoid(x) 186 | return x 187 | -------------------------------------------------------------------------------- /model/segment_anything_training/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | texts: Optional[torch.Tensor] = None, 113 | ) -> int: 114 | """ 115 | Gets the batch size of the output given the batch size of the input prompts. 116 | """ 117 | if points is not None: 118 | return points[0].shape[0] 119 | elif boxes is not None: 120 | return boxes.shape[0] 121 | elif masks is not None: 122 | return masks.shape[0] 123 | elif texts is not None: 124 | return texts.shape[0] 125 | else: 126 | return 1 127 | 128 | def _get_device(self) -> torch.device: 129 | return self.point_embeddings[0].weight.device 130 | 131 | def forward( 132 | self, 133 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 134 | boxes: Optional[torch.Tensor], 135 | masks: Optional[torch.Tensor], 136 | texts = None 137 | ) -> Tuple[torch.Tensor, torch.Tensor]: 138 | """ 139 | Embeds different types of prompts, returning both sparse and dense 140 | embeddings. 141 | 142 | Arguments: 143 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 144 | and labels to embed. 145 | boxes (torch.Tensor or none): boxes to embed 146 | masks (torch.Tensor or none): masks to embed 147 | 148 | Returns: 149 | torch.Tensor: sparse embeddings for the points and boxes, with shape 150 | BxNx(embed_dim), where N is determined by the number of input points 151 | and boxes. 152 | torch.Tensor: dense embeddings for the masks, in the shape 153 | Bx(embed_dim)x(embed_H)x(embed_W) 154 | """ 155 | bs = self._get_batch_size(points, boxes, masks, texts) 156 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 157 | if points is not None: 158 | coords, labels = points 159 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 160 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 161 | if boxes is not None: 162 | box_embeddings = self._embed_boxes(boxes) 163 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 164 | if texts is not None : 165 | sparse_embeddings = torch.cat([sparse_embeddings, texts], dim=1) 166 | 167 | if masks is not None: 168 | dense_embeddings = self._embed_masks(masks) 169 | else: 170 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 171 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 172 | ) 173 | 174 | return sparse_embeddings, dense_embeddings 175 | 176 | 177 | class PositionEmbeddingRandom(nn.Module): 178 | """ 179 | Positional encoding using random spatial frequencies. 180 | """ 181 | 182 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 183 | super().__init__() 184 | if scale is None or scale <= 0.0: 185 | scale = 1.0 186 | self.register_buffer( 187 | "positional_encoding_gaussian_matrix", 188 | scale * torch.randn((2, num_pos_feats)), 189 | ) 190 | 191 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 192 | """Positionally encode points that are normalized to [0,1].""" 193 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 194 | coords = 2 * coords - 1 195 | coords = coords @ self.positional_encoding_gaussian_matrix 196 | coords = 2 * np.pi * coords 197 | # outputs d_1 x ... x d_n x C shape 198 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 199 | 200 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 201 | """Generate positional encoding for a grid of the specified size.""" 202 | h, w = size 203 | device: Any = self.positional_encoding_gaussian_matrix.device 204 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 205 | y_embed = grid.cumsum(dim=0) - 0.5 206 | x_embed = grid.cumsum(dim=1) - 0.5 207 | y_embed = y_embed / h 208 | x_embed = x_embed / w 209 | 210 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 211 | return pe.permute(2, 0, 1) # C x H x W 212 | 213 | def forward_with_coords( 214 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 215 | ) -> torch.Tensor: 216 | """Positionally encode points that are not normalized to [0,1].""" 217 | coords = coords_input.clone() 218 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 219 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 220 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 221 | -------------------------------------------------------------------------------- /model/segment_anything_training/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | def forward( 54 | self, 55 | batched_input: List[Dict[str, Any]], 56 | multimask_output: bool = False, 57 | only_forward_img: bool = False, 58 | image_embed_output = None, 59 | not_mask_decode = False 60 | ) -> List[Dict[str, torch.Tensor]]: 61 | """ 62 | Predicts masks end-to-end from provided images and prompts. 63 | If prompts are not known in advance, using SamPredictor is 64 | recommended over calling the model directly. 65 | 66 | Arguments: 67 | batched_input (list(dict)): A list over input images, each a 68 | dictionary with the following keys. A prompt key can be 69 | excluded if it is not present. 70 | 'image': The image as a torch tensor in 3xHxW format, 71 | already transformed for input to the model. 72 | 'original_size': (tuple(int, int)) The original size of 73 | the image before transformation, as (H, W). 74 | 'point_coords': (torch.Tensor) Batched point prompts for 75 | this image, with shape BxNx2. Already transformed to the 76 | input frame of the model. 77 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 78 | with shape BxN. 79 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 80 | Already transformed to the input frame of the model. 81 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 82 | in the form Bx1xHxW. 83 | multimask_output (bool): Whether the model should predict multiple 84 | disambiguating masks, or return a single mask. 85 | 86 | Returns: 87 | (list(dict)): A list over input images, where each element is 88 | as dictionary with the following keys. 89 | 'masks': (torch.Tensor) Batched binary mask predictions, 90 | with shape BxCxHxW, where B is the number of input promts, 91 | C is determiend by multimask_output, and (H, W) is the 92 | original size of the image. 93 | 'iou_predictions': (torch.Tensor) The model's predictions 94 | of mask quality, in shape BxC. 95 | 'low_res_logits': (torch.Tensor) Low resolution logits with 96 | shape BxCxHxW, where H=W=256. Can be passed as mask input 97 | to subsequent iterations of prediction. 98 | """ 99 | 100 | if image_embed_output is None : 101 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 102 | image_embeddings, interm_embeddings = self.image_encoder(input_images) 103 | else : 104 | image_embeddings, interm_embeddings = image_embed_output 105 | 106 | if only_forward_img: 107 | return image_embeddings, interm_embeddings 108 | 109 | outputs = [] 110 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 111 | if "point_coords" in image_record: 112 | points = (image_record["point_coords"], image_record["point_labels"]) 113 | else: 114 | points = None 115 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 116 | points=points, 117 | boxes=image_record.get("boxes", None), 118 | masks=image_record.get("mask_inputs", None), 119 | texts=image_record.get("texts", None) 120 | ) 121 | 122 | if not not_mask_decode : 123 | low_res_masks, iou_predictions = self.mask_decoder( 124 | image_embeddings=curr_embedding.unsqueeze(0), 125 | image_pe=self.prompt_encoder.get_dense_pe(), 126 | sparse_prompt_embeddings=sparse_embeddings, 127 | dense_prompt_embeddings=dense_embeddings, 128 | multimask_output=multimask_output 129 | ) 130 | 131 | masks = self.postprocess_masks( 132 | low_res_masks, 133 | input_size=image_record["image"].shape[-2:], 134 | original_size=image_record["original_size"], 135 | ) 136 | masks = masks > self.mask_threshold 137 | else : 138 | masks, iou_predictions, low_res_masks = [None] * 3 139 | 140 | outputs.append( 141 | { 142 | "masks": masks, 143 | "iou_predictions": iou_predictions, 144 | "low_res_logits": low_res_masks, 145 | "encoder_embedding": curr_embedding.unsqueeze(0), 146 | "image_pe": self.prompt_encoder.get_dense_pe(), 147 | "sparse_embeddings":sparse_embeddings, 148 | "dense_embeddings":dense_embeddings, 149 | } 150 | ) 151 | 152 | return outputs, interm_embeddings 153 | 154 | def postprocess_masks( 155 | self, 156 | masks: torch.Tensor, 157 | input_size: Tuple[int, ...], 158 | original_size: Tuple[int, ...], 159 | ) -> torch.Tensor: 160 | """ 161 | Remove padding and upscale masks to the original image size. 162 | 163 | Arguments: 164 | masks (torch.Tensor): Batched masks from the mask_decoder, 165 | in BxCxHxW format. 166 | input_size (tuple(int, int)): The size of the image input to the 167 | model, in (H, W) format. Used to remove padding. 168 | original_size (tuple(int, int)): The original size of the image 169 | before resizing for input to the model, in (H, W) format. 170 | 171 | Returns: 172 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 173 | is given by original_size. 174 | """ 175 | masks = F.interpolate( 176 | masks, 177 | (self.image_encoder.img_size, self.image_encoder.img_size), 178 | mode="bilinear", 179 | align_corners=False, 180 | ) 181 | masks = masks[..., : input_size[0], : input_size[1]] 182 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 183 | return masks 184 | 185 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 186 | """Normalize pixel values and pad to a square input.""" 187 | # Normalize colors 188 | x = (x - self.pixel_mean) / self.pixel_std 189 | 190 | # Pad 191 | h, w = x.shape[-2:] 192 | padh = self.image_encoder.img_size - h 193 | padw = self.image_encoder.img_size - w 194 | x = F.pad(x, (0, padw, 0, padh)) 195 | return x 196 | -------------------------------------------------------------------------------- /model/segment_anything_training/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /model/segment_anything_training/utils/coco_eval.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Grounding DINO. Midified by Shilong Liu. 3 | # url: https://github.com/IDEA-Research/GroundingDINO 4 | # Copyright (c) 2023 IDEA. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------ 7 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | """ 10 | COCO evaluator that works in distributed mode. 11 | 12 | Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py 13 | The difference is that there is less copy-pasting from pycocotools 14 | in the end of the file, as python3 can suppress prints with contextlib 15 | """ 16 | import contextlib 17 | import copy 18 | import os 19 | 20 | import numpy as np 21 | import pycocotools.mask as mask_util 22 | import torch 23 | from pycocotools.coco import COCO 24 | from pycocotools.cocoeval import COCOeval 25 | 26 | from utils.misc import all_gather 27 | 28 | 29 | class CocoGroundingEvaluator(object): 30 | def __init__(self, coco_gt, iou_types, useCats=True): 31 | assert isinstance(iou_types, (list, tuple)) 32 | coco_gt = copy.deepcopy(coco_gt) 33 | self.coco_gt = coco_gt 34 | 35 | self.iou_types = iou_types 36 | self.coco_eval = {} 37 | for iou_type in iou_types: 38 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 39 | self.coco_eval[iou_type].useCats = useCats 40 | 41 | self.img_ids = [] 42 | self.eval_imgs = {k: [] for k in iou_types} 43 | self.useCats = useCats 44 | 45 | def update(self, predictions): 46 | img_ids = list(np.unique(list(predictions.keys()))) 47 | self.img_ids.extend(img_ids) 48 | 49 | for iou_type in self.iou_types: 50 | results = self.prepare(predictions, iou_type) 51 | 52 | # suppress pycocotools prints 53 | with open(os.devnull, "w") as devnull: 54 | with contextlib.redirect_stdout(devnull): 55 | coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() 56 | 57 | coco_eval = self.coco_eval[iou_type] 58 | 59 | coco_eval.cocoDt = coco_dt 60 | coco_eval.params.imgIds = list(img_ids) 61 | coco_eval.params.useCats = self.useCats 62 | img_ids, eval_imgs = evaluate(coco_eval) 63 | 64 | self.eval_imgs[iou_type].append(eval_imgs) 65 | 66 | return results 67 | 68 | def synchronize_between_processes(self): 69 | for iou_type in self.iou_types: 70 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 71 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) 72 | 73 | def accumulate(self): 74 | for coco_eval in self.coco_eval.values(): 75 | coco_eval.accumulate() 76 | 77 | def summarize(self): 78 | for iou_type, coco_eval in self.coco_eval.items(): 79 | print("IoU metric: {}".format(iou_type)) 80 | coco_eval.summarize() 81 | 82 | def prepare(self, predictions, iou_type): 83 | if iou_type == "bbox": 84 | return self.prepare_for_coco_detection(predictions) 85 | elif iou_type == "segm": 86 | return self.prepare_for_coco_segmentation(predictions) 87 | elif iou_type == "keypoints": 88 | return self.prepare_for_coco_keypoint(predictions) 89 | else: 90 | raise ValueError("Unknown iou type {}".format(iou_type)) 91 | 92 | def prepare_for_coco_detection(self, predictions): 93 | coco_results = [] 94 | for original_id, prediction in predictions.items(): 95 | if len(prediction) == 0: 96 | continue 97 | 98 | boxes = prediction["boxes"] 99 | boxes = convert_to_xywh(boxes).tolist() 100 | scores = prediction["scores"].tolist() 101 | labels = prediction["labels"].tolist() 102 | 103 | coco_results.extend( 104 | [ 105 | { 106 | "image_id": original_id, 107 | "category_id": labels[k], 108 | "bbox": box, 109 | "score": scores[k], 110 | } 111 | for k, box in enumerate(boxes) 112 | ] 113 | ) 114 | return coco_results 115 | 116 | def prepare_for_coco_segmentation(self, predictions): 117 | coco_results = [] 118 | for original_id, prediction in predictions.items(): 119 | if len(prediction) == 0: 120 | continue 121 | 122 | scores = prediction["scores"] 123 | labels = prediction["labels"] 124 | masks = prediction["masks"] 125 | 126 | masks = masks > 0.5 127 | 128 | scores = prediction["scores"].tolist() 129 | labels = prediction["labels"].tolist() 130 | 131 | rles = [ 132 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 133 | for mask in masks 134 | ] 135 | for rle in rles: 136 | rle["counts"] = rle["counts"].decode("utf-8") 137 | 138 | coco_results.extend( 139 | [ 140 | { 141 | "image_id": original_id, 142 | "category_id": labels[k], 143 | "segmentation": rle, 144 | "score": scores[k], 145 | } 146 | for k, rle in enumerate(rles) 147 | ] 148 | ) 149 | return coco_results 150 | 151 | def prepare_for_coco_keypoint(self, predictions): 152 | coco_results = [] 153 | for original_id, prediction in predictions.items(): 154 | if len(prediction) == 0: 155 | continue 156 | 157 | boxes = prediction["boxes"] 158 | boxes = convert_to_xywh(boxes).tolist() 159 | scores = prediction["scores"].tolist() 160 | labels = prediction["labels"].tolist() 161 | keypoints = prediction["keypoints"] 162 | keypoints = keypoints.flatten(start_dim=1).tolist() 163 | 164 | coco_results.extend( 165 | [ 166 | { 167 | "image_id": original_id, 168 | "category_id": labels[k], 169 | "keypoints": keypoint, 170 | "score": scores[k], 171 | } 172 | for k, keypoint in enumerate(keypoints) 173 | ] 174 | ) 175 | return coco_results 176 | 177 | 178 | def convert_to_xywh(boxes): 179 | xmin, ymin, xmax, ymax = boxes.unbind(1) 180 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 181 | 182 | 183 | def merge(img_ids, eval_imgs): 184 | all_img_ids = all_gather(img_ids) 185 | all_eval_imgs = all_gather(eval_imgs) 186 | 187 | merged_img_ids = [] 188 | for p in all_img_ids: 189 | merged_img_ids.extend(p) 190 | 191 | merged_eval_imgs = [] 192 | for p in all_eval_imgs: 193 | merged_eval_imgs.append(p) 194 | 195 | merged_img_ids = np.array(merged_img_ids) 196 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 197 | 198 | # keep only unique (and in sorted order) images 199 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 200 | merged_eval_imgs = merged_eval_imgs[..., idx] 201 | 202 | return merged_img_ids, merged_eval_imgs 203 | 204 | 205 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 206 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 207 | img_ids = list(img_ids) 208 | eval_imgs = list(eval_imgs.flatten()) 209 | 210 | coco_eval.evalImgs = eval_imgs 211 | coco_eval.params.imgIds = img_ids 212 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 213 | 214 | 215 | ################################################################# 216 | # From pycocotools, just removed the prints and fixed 217 | # a Python3 bug about unicode not defined 218 | ################################################################# 219 | 220 | 221 | def evaluate(self): 222 | """ 223 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 224 | :return: None 225 | """ 226 | # tic = time.time() 227 | # print('Running per image evaluation...') 228 | p = self.params 229 | # add backward compatibility if useSegm is specified in params 230 | if p.useSegm is not None: 231 | p.iouType = "segm" if p.useSegm == 1 else "bbox" 232 | print("useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType)) 233 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 234 | p.imgIds = list(np.unique(p.imgIds)) 235 | if p.useCats: 236 | p.catIds = list(np.unique(p.catIds)) 237 | p.maxDets = sorted(p.maxDets) 238 | self.params = p 239 | 240 | self._prepare() 241 | # loop through images, area range, max detection number 242 | catIds = p.catIds if p.useCats else [-1] 243 | 244 | if p.iouType == "segm" or p.iouType == "bbox": 245 | computeIoU = self.computeIoU 246 | elif p.iouType == "keypoints": 247 | computeIoU = self.computeOks 248 | self.ious = { 249 | (imgId, catId): computeIoU(imgId, catId) 250 | for imgId in p.imgIds 251 | for catId in catIds} 252 | 253 | evaluateImg = self.evaluateImg 254 | maxDet = p.maxDets[-1] 255 | evalImgs = [ 256 | evaluateImg(imgId, catId, areaRng, maxDet) 257 | for catId in catIds 258 | for areaRng in p.areaRng 259 | for imgId in p.imgIds 260 | ] 261 | # this is NOT in the pycocotools code, but could be done outside 262 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 263 | self._paramsEval = copy.deepcopy(self.params) 264 | # toc = time.time() 265 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 266 | return p.imgIds, evalImgs 267 | 268 | 269 | ################################################################# 270 | # end of straight copy from pycocotools, just removing the prints 271 | ################################################################# 272 | -------------------------------------------------------------------------------- /model/segment_anything_training/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /model/segment_anything_training/utils/transforms_gdino.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Transforms and data augmentation for both image + bbox. 4 | """ 5 | import os 6 | import random 7 | import torchvision 8 | from typing import List, Optional 9 | 10 | import PIL 11 | import torch 12 | import torchvision.transforms as T 13 | import torchvision.transforms.functional as F 14 | 15 | 16 | from torch import Tensor 17 | 18 | __torchvision_need_compat_flag = float(torchvision.__version__.split(".")[1]) < 7 19 | if __torchvision_need_compat_flag: 20 | from torchvision.ops import _new_empty_tensor 21 | from torchvision.ops.misc import _output_size 22 | 23 | 24 | def box_xyxy_to_cxcywh(x): 25 | x0, y0, x1, y1 = x.unbind(-1) 26 | b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] 27 | return torch.stack(b, dim=-1) 28 | 29 | 30 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 31 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 32 | """ 33 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 34 | This will eventually be supported natively by PyTorch, and this 35 | class can go away. 36 | """ 37 | if __torchvision_need_compat_flag < 0.7: 38 | if input.numel() > 0: 39 | return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) 40 | 41 | output_shape = _output_size(2, input, size, scale_factor) 42 | output_shape = list(input.shape[:-2]) + list(output_shape) 43 | return _new_empty_tensor(input, output_shape) 44 | else: 45 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 46 | 47 | 48 | 49 | def crop(image, target, region): 50 | cropped_image = F.crop(image, *region) 51 | 52 | target = target.copy() 53 | i, j, h, w = region 54 | 55 | # should we do something wrt the original size? 56 | target["size"] = torch.tensor([h, w]) 57 | 58 | fields = ["labels", "area", "iscrowd", "positive_map"] 59 | 60 | if "boxes" in target: 61 | boxes = target["boxes"] 62 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 63 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 64 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 65 | cropped_boxes = cropped_boxes.clamp(min=0) 66 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 67 | target["boxes"] = cropped_boxes.reshape(-1, 4) 68 | target["area"] = area 69 | fields.append("boxes") 70 | 71 | if "masks" in target: 72 | # FIXME should we update the area here if there are no boxes? 73 | target["masks"] = target["masks"][:, i : i + h, j : j + w] 74 | fields.append("masks") 75 | 76 | # remove elements for which the boxes or masks that have zero area 77 | if "boxes" in target or "masks" in target: 78 | # favor boxes selection when defining which elements to keep 79 | # this is compatible with previous implementation 80 | if "boxes" in target: 81 | cropped_boxes = target["boxes"].reshape(-1, 2, 2) 82 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 83 | else: 84 | keep = target["masks"].flatten(1).any(1) 85 | 86 | for field in fields: 87 | if field in target: 88 | target[field] = target[field][keep] 89 | 90 | if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO": 91 | # for debug and visualization only. 92 | if "strings_positive" in target: 93 | target["strings_positive"] = [ 94 | _i for _i, _j in zip(target["strings_positive"], keep) if _j 95 | ] 96 | 97 | return cropped_image, target 98 | 99 | 100 | def hflip(image, target): 101 | flipped_image = F.hflip(image) 102 | 103 | w, h = image.size 104 | 105 | target = target.copy() 106 | if "boxes" in target: 107 | boxes = target["boxes"] 108 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor( 109 | [w, 0, w, 0] 110 | ) 111 | target["boxes"] = boxes 112 | 113 | if "masks" in target: 114 | target["masks"] = target["masks"].flip(-1) 115 | 116 | return flipped_image, target 117 | 118 | 119 | def resize(image, target, size, max_size=None): 120 | # size can be min_size (scalar) or (w, h) tuple 121 | 122 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 123 | w, h = image_size 124 | if max_size is not None: 125 | min_original_size = float(min((w, h))) 126 | max_original_size = float(max((w, h))) 127 | if max_original_size / min_original_size * size > max_size: 128 | size = int(round(max_size * min_original_size / max_original_size)) 129 | 130 | if (w <= h and w == size) or (h <= w and h == size): 131 | return (h, w) 132 | 133 | if w < h: 134 | ow = size 135 | oh = int(size * h / w) 136 | else: 137 | oh = size 138 | ow = int(size * w / h) 139 | 140 | return (oh, ow) 141 | 142 | def get_size(image_size, size, max_size=None): 143 | if isinstance(size, (list, tuple)): 144 | return size[::-1] 145 | else: 146 | return get_size_with_aspect_ratio(image_size, size, max_size) 147 | 148 | size = get_size(image.size, size, max_size) 149 | rescaled_image = F.resize(image, size) 150 | 151 | if target is None: 152 | return rescaled_image, None 153 | 154 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 155 | ratio_width, ratio_height = ratios 156 | 157 | target = target.copy() 158 | if "boxes" in target: 159 | boxes = target["boxes"] 160 | scaled_boxes = boxes * torch.as_tensor( 161 | [ratio_width, ratio_height, ratio_width, ratio_height] 162 | ) 163 | target["boxes"] = scaled_boxes 164 | 165 | if "area" in target: 166 | area = target["area"] 167 | scaled_area = area * (ratio_width * ratio_height) 168 | target["area"] = scaled_area 169 | 170 | h, w = size 171 | target["size"] = torch.tensor([h, w]) 172 | 173 | if "masks" in target: 174 | target["masks"] = ( 175 | interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5 176 | ) 177 | 178 | return rescaled_image, target 179 | 180 | 181 | def pad(image, target, padding): 182 | # assumes that we only pad on the bottom right corners 183 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 184 | if target is None: 185 | return padded_image, None 186 | target = target.copy() 187 | # should we do something wrt the original size? 188 | target["size"] = torch.tensor(padded_image.size[::-1]) 189 | if "masks" in target: 190 | target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1])) 191 | return padded_image, target 192 | 193 | 194 | class ResizeDebug(object): 195 | def __init__(self, size): 196 | self.size = size 197 | 198 | def __call__(self, img, target): 199 | return resize(img, target, self.size) 200 | 201 | 202 | class RandomCrop(object): 203 | def __init__(self, size): 204 | self.size = size 205 | 206 | def __call__(self, img, target): 207 | region = T.RandomCrop.get_params(img, self.size) 208 | return crop(img, target, region) 209 | 210 | 211 | class RandomSizeCrop(object): 212 | def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False): 213 | # respect_boxes: True to keep all boxes 214 | # False to tolerence box filter 215 | self.min_size = min_size 216 | self.max_size = max_size 217 | self.respect_boxes = respect_boxes 218 | 219 | def __call__(self, img: PIL.Image.Image, target: dict): 220 | init_boxes = len(target["boxes"]) 221 | max_patience = 10 222 | for i in range(max_patience): 223 | w = random.randint(self.min_size, min(img.width, self.max_size)) 224 | h = random.randint(self.min_size, min(img.height, self.max_size)) 225 | region = T.RandomCrop.get_params(img, [h, w]) 226 | result_img, result_target = crop(img, target, region) 227 | if ( 228 | not self.respect_boxes 229 | or len(result_target["boxes"]) == init_boxes 230 | or i == max_patience - 1 231 | ): 232 | return result_img, result_target 233 | return result_img, result_target 234 | 235 | 236 | class CenterCrop(object): 237 | def __init__(self, size): 238 | self.size = size 239 | 240 | def __call__(self, img, target): 241 | image_width, image_height = img.size 242 | crop_height, crop_width = self.size 243 | crop_top = int(round((image_height - crop_height) / 2.0)) 244 | crop_left = int(round((image_width - crop_width) / 2.0)) 245 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 246 | 247 | 248 | class RandomHorizontalFlip(object): 249 | def __init__(self, p=0.5): 250 | self.p = p 251 | 252 | def __call__(self, img, target): 253 | if random.random() < self.p: 254 | return hflip(img, target) 255 | return img, target 256 | 257 | 258 | class RandomResize(object): 259 | def __init__(self, sizes, max_size=None): 260 | assert isinstance(sizes, (list, tuple)) 261 | self.sizes = sizes 262 | self.max_size = max_size 263 | 264 | def __call__(self, img, target=None): 265 | size = random.choice(self.sizes) 266 | return resize(img, target, size, self.max_size) 267 | 268 | 269 | class RandomPad(object): 270 | def __init__(self, max_pad): 271 | self.max_pad = max_pad 272 | 273 | def __call__(self, img, target): 274 | pad_x = random.randint(0, self.max_pad) 275 | pad_y = random.randint(0, self.max_pad) 276 | return pad(img, target, (pad_x, pad_y)) 277 | 278 | 279 | class RandomSelect(object): 280 | """ 281 | Randomly selects between transforms1 and transforms2, 282 | with probability p for transforms1 and (1 - p) for transforms2 283 | """ 284 | 285 | def __init__(self, transforms1, transforms2, p=0.5): 286 | self.transforms1 = transforms1 287 | self.transforms2 = transforms2 288 | self.p = p 289 | 290 | def __call__(self, img, target): 291 | if random.random() < self.p: 292 | return self.transforms1(img, target) 293 | return self.transforms2(img, target) 294 | 295 | 296 | class ToTensor(object): 297 | def __call__(self, img, target): 298 | return F.to_tensor(img), target 299 | 300 | 301 | class RandomErasing(object): 302 | def __init__(self, *args, **kwargs): 303 | self.eraser = T.RandomErasing(*args, **kwargs) 304 | 305 | def __call__(self, img, target): 306 | return self.eraser(img), target 307 | 308 | 309 | class Normalize(object): 310 | def __init__(self, mean, std): 311 | self.mean = mean 312 | self.std = std 313 | 314 | def __call__(self, image, target=None): 315 | image = F.normalize(image, mean=self.mean, std=self.std) 316 | if target is None: 317 | return image, None 318 | target = target.copy() 319 | h, w = image.shape[-2:] 320 | if "boxes" in target: 321 | boxes = target["boxes"] 322 | boxes = box_xyxy_to_cxcywh(boxes) 323 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 324 | target["boxes"] = boxes 325 | return image, target 326 | 327 | 328 | class Compose(object): 329 | def __init__(self, transforms): 330 | self.transforms = transforms 331 | 332 | def __call__(self, image, target): 333 | for t in self.transforms: 334 | image, target = t(image, target) 335 | return image, target 336 | 337 | def __repr__(self): 338 | format_string = self.__class__.__name__ + "(" 339 | for t in self.transforms: 340 | format_string += "\n" 341 | format_string += " {0}".format(t) 342 | format_string += "\n)" 343 | return format_string 344 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/facebookresearch/detectron2.git 2 | opencv-python 3 | scikit-image 4 | diffusers==0.26.3 5 | xformers==0.0.16 6 | open-clip-torch==2.23.0 7 | transformers==4.32.1 -------------------------------------------------------------------------------- /scripts/segic_dist.sh: -------------------------------------------------------------------------------- 1 | NGPUS=${1:-8} 2 | encoder_model=${2:-'dinov2'} 3 | exp_name=${3:-'OUTPUT/all_exps/abs_backbone/dinov2_l'} 4 | 5 | echo $NGPUS $encoder_model $exp_name 6 | echo ${@:4} 7 | 8 | python -m torch.distributed.launch --master_port 12345 --nproc_per_node=$NGPUS train.py --output $exp_name \ 9 | --input_keys sem_corr point --eval_keys sem_corr --noised_inst --use_dual_aug --use_simm_prompt --open_ft --find_unused_params --use_dift \ 10 | --use_inst_proj --diff_text_prompt_ratio 0.75 --use_inst_train --reverse_context --learning_rate 0.0001 --use_cross_inst_prompt \ 11 | --encoder_model $encoder_model --inst_datasets coco lvis --sem_datasets coco ade20k --samples_per_epoch 80000 --auto_resume ${@:4} 12 | -------------------------------------------------------------------------------- /utils/coco80.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from torch.utils.data import Dataset 4 | 5 | class CustomConcatDataset(Dataset): 6 | def __init__(self, dataset_list, dataset_ratio=None, samples_per_epoch=160000): 7 | self.dataset_list = dataset_list 8 | if dataset_ratio is not None : 9 | assert len(dataset_ratio) == len(dataset_list) 10 | else : 11 | dataset_ratio = [1] * len(dataset_list) 12 | self.dataset_ratio = dataset_ratio 13 | self.samples_per_epoch = samples_per_epoch 14 | 15 | def __len__(self,): 16 | return self.samples_per_epoch 17 | 18 | def __getitem__(self, index): 19 | dataset_idx = random.choices(list(range(len(self.dataset_ratio))), weights=self.dataset_ratio, k=1)[0] 20 | dataset = self.dataset_list[dataset_idx] 21 | index = random.randint(0, len(dataset) - 1) 22 | return dataset[index] -------------------------------------------------------------------------------- /utils/dataset/ade20k_classes.json: -------------------------------------------------------------------------------- 1 | [ 2 | "wall", "building", "sky", "floor", "tree", "ceiling", "road", 3 | "bed", "windowpane", "grass", "cabinet", "sidewalk", 4 | "person", "earth", "door", "table", "mountain", "plant", 5 | "curtain", "chair", "car", "water", "painting", "sofa", 6 | "shelf", "house", "sea", "mirror", "rug", "field", "armchair", 7 | "seat", "fence", "desk", "rock", "wardrobe", "lamp", 8 | "bathtub", "railing", "cushion", "base", "box", "column", 9 | "signboard", "chest of drawers", "counter", "sand", "sink", 10 | "skyscraper", "fireplace", "refrigerator", "grandstand", 11 | "path", "stairs", "runway", "case", "pool table", "pillow", 12 | "screen door", "stairway", "river", "bridge", "bookcase", 13 | "blind", "coffee table", "toilet", "flower", "book", "hill", 14 | "bench", "countertop", "stove", "palm", "kitchen island", 15 | "computer", "swivel chair", "boat", "bar", "arcade machine", 16 | "hovel", "bus", "towel", "light", "truck", "tower", 17 | "chandelier", "awning", "streetlight", "booth", 18 | "television receiver", "airplane", "dirt track", "apparel", 19 | "pole", "land", "bannister", "escalator", "ottoman", "bottle", 20 | "buffet", "poster", "stage", "van", "ship", "fountain", 21 | "conveyer belt", "canopy", "washer", "plaything", 22 | "swimming pool", "stool", "barrel", "basket", "waterfall", 23 | "tent", "bag", "minibike", "cradle", "oven", "ball", "food", 24 | "step", "tank", "trade name", "microwave", "pot", "animal", 25 | "bicycle", "lake", "dishwasher", "screen", "blanket", 26 | "sculpture", "hood", "sconce", "vase", "traffic light", 27 | "tray", "ashcan", "fan", "pier", "crt screen", "plate", 28 | "monitor", "bulletin board", "shower", "radiator", "glass", 29 | "clock", "flag" 30 | ] -------------------------------------------------------------------------------- /utils/dataset/ade20k_icl.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/SEGIC/73bbcc340825e89f18e1b607cdbaaa09adc9619a/utils/dataset/ade20k_icl.pth -------------------------------------------------------------------------------- /utils/dataset/ade847_classes.json: -------------------------------------------------------------------------------- 1 | ["wall", "building, edifice", "sky", "tree", "road, route", "floor, flooring", "ceiling", "bed", "sidewalk, pavement", "earth, ground", "cabinet", "person, individual, someone, somebody, mortal, soul", "grass", "windowpane, window", "car, auto, automobile, machine, motorcar", "mountain, mount", "plant, flora, plant life", "table", "chair", "curtain, drape, drapery, mantle, pall", "door", "sofa, couch, lounge", "sea", "painting, picture", "water", "mirror", "house", "rug, carpet, carpeting", "shelf", "armchair", "fence, fencing", "field", "lamp", "rock, stone", "seat", "river", "desk", "bathtub, bathing tub, bath, tub", "railing, rail", "signboard, sign", "cushion", "path", "work surface", "stairs, steps", "column, pillar", "sink", "wardrobe, closet, press", "snow", "refrigerator, icebox", "base, pedestal, stand", "bridge, span", "blind, screen", "runway", "cliff, drop, drop-off", "sand", "fireplace, hearth, open fireplace", "pillow", "screen door, screen", "toilet, can, commode, crapper, pot, potty, stool, throne", "skyscraper", "grandstand, covered stand", "box", "pool table, billiard table, snooker table", "palm, palm tree", "double door", "coffee table, cocktail table", "counter", "countertop", "chest of drawers, chest, bureau, dresser", "kitchen island", "boat", "waterfall, falls", "stove, kitchen stove, range, kitchen range, cooking stove", "flower", "bookcase", "controls", "book", "stairway, staircase", "streetlight, street lamp", "computer, computing machine, computing device, data processor, electronic computer, information processing system", "bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle", "swivel chair", "light, light source", "bench", "case, display case, showcase, vitrine", "towel", "fountain", "embankment", "television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box", "van", "hill", "awning, sunshade, sunblind", "poster, posting, placard, notice, bill, card", "truck, motortruck", "airplane, aeroplane, plane", "pole", "tower", "court", "ball", "aircraft carrier, carrier, flattop, attack aircraft carrier", "buffet, counter, sideboard", "hovel, hut, hutch, shack, shanty", "apparel, wearing apparel, dress, clothes", "minibike, motorbike", "animal, animate being, beast, brute, creature, fauna", "chandelier, pendant, pendent", "step, stair", "booth, cubicle, stall, kiosk", "bicycle, bike, wheel, cycle", "doorframe, doorcase", "sconce", "pond", "trade name, brand name, brand, marque", "bannister, banister, balustrade, balusters, handrail", "bag", "traffic light, traffic signal, stoplight", "gazebo", "escalator, moving staircase, moving stairway", "land, ground, soil", "board, plank", "arcade machine", "eiderdown, duvet, continental quilt", "bar", "stall, stand, sales booth", "playground", "ship", "ottoman, pouf, pouffe, puff, hassock", "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", "bottle", "cradle", "pot, flowerpot", "conveyer belt, conveyor belt, conveyer, conveyor, transporter", "train, railroad train", "stool", "lake", "tank, storage tank", "ice, water ice", "basket, handbasket", "manhole", "tent, collapsible shelter", "canopy", "microwave, microwave oven", "barrel, cask", "dirt track", "beam", "dishwasher, dish washer, dishwashing machine", "plate", "screen, crt screen", "ruins", "washer, automatic washer, washing machine", "blanket, cover", "plaything, toy", "food, solid food", "screen, silver screen, projection screen", "oven", "stage", "beacon, lighthouse, beacon light, pharos", "umbrella", "sculpture", "aqueduct", "container", "scaffolding, staging", "hood, exhaust hood", "curb, curbing, kerb", "roller coaster", "horse, equus caballus", "catwalk", "glass, drinking glass", "vase", "central reservation", "carousel", "radiator", "closet", "machine", "pier, wharf, wharfage, dock", "fan", "inflatable bounce game", "pitch", "paper", "arcade, colonnade", "hot tub", "helicopter", "tray", "partition, divider", "vineyard", "bowl", "bullring", "flag", "pot", "footbridge, overcrossing, pedestrian bridge", "shower", "bag, traveling bag, travelling bag, grip, suitcase", "bulletin board, notice board", "confessional booth", "trunk, tree trunk, bole", "forest", "elevator door", "laptop, laptop computer", "instrument panel", "bucket, pail", "tapestry, tapis", "platform", "jacket", "gate", "monitor, monitoring device", "telephone booth, phone booth, call box, telephone box, telephone kiosk", "spotlight, spot", "ring", "control panel", "blackboard, chalkboard", "air conditioner, air conditioning", "chest", "clock", "sand dune", "pipe, pipage, piping", "vault", "table football", "cannon", "swimming pool, swimming bath, natatorium", "fluorescent, fluorescent fixture", "statue", "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", "exhibitor", "ladder", "carport", "dam", "pulpit", "skylight, fanlight", "water tower", "grill, grille, grillwork", "display board", "pane, pane of glass, window glass", "rubbish, trash, scrap", "ice rink", "fruit", "patio", "vending machine", "telephone, phone, telephone set", "net", "backpack, back pack, knapsack, packsack, rucksack, haversack", "jar", "track", "magazine", "shutter", "roof", "banner, streamer", "landfill", "post", "altarpiece, reredos", "hat, chapeau, lid", "arch, archway", "table game", "bag, handbag, pocketbook, purse", "document, written document, papers", "dome", "pier", "shanties", "forecourt", "crane", "dog, domestic dog, canis familiaris", "piano, pianoforte, forte-piano", "drawing", "cabin", "ad, advertisement, advertizement, advertising, advertizing, advert", "amphitheater, amphitheatre, coliseum", "monument", "henhouse", "cockpit", "heater, warmer", "windmill, aerogenerator, wind generator", "pool", "elevator, lift", "decoration, ornament, ornamentation", "labyrinth", "text, textual matter", "printer", "mezzanine, first balcony", "mattress", "straw", "stalls", "patio, terrace", "billboard, hoarding", "bus stop", "trouser, pant", "console table, console", "rack", "notebook", "shrine", "pantry", "cart", "steam shovel", "porch", "postbox, mailbox, letter box", "figurine, statuette", "recycling bin", "folding screen", "telescope", "deck chair, beach chair", "kennel", "coffee maker", "altar, communion table, lord's table", "fish", "easel", "artificial golf green", "iceberg", "candlestick, candle holder", "shower stall, shower bath", "television stand", "wall socket, wall plug, electric outlet, electrical outlet, outlet, electric receptacle", "skeleton", "grand piano, grand", "candy, confect", "grille door", "pedestal, plinth, footstall", "jersey, t-shirt, tee shirt", "shoe", "gravestone, headstone, tombstone", "shanty", "structure", "rocking chair, rocker", "bird", "place mat", "tomb", "big top", "gas pump, gasoline pump, petrol pump, island dispenser", "lockers", "cage", "finger", "bleachers", "ferris wheel", "hairdresser chair", "mat", "stands", "aquarium, fish tank, marine museum", "streetcar, tram, tramcar, trolley, trolley car", "napkin, table napkin, serviette", "dummy", "booklet, brochure, folder, leaflet, pamphlet", "sand trap", "shop, store", "table cloth", "service station", "coffin", "drawer", "cages", "slot machine, coin machine", "balcony", "volleyball court", "table tennis", "control table", "shirt", "merchandise, ware, product", "railway", "parterre", "chimney", "can, tin, tin can", "tanks", "fabric, cloth, material, textile", "alga, algae", "system", "map", "greenhouse", "mug", "barbecue", "trailer", "toilet tissue, toilet paper, bathroom tissue", "organ", "dishrag, dishcloth", "island", "keyboard", "trench", "basket, basketball hoop, hoop", "steering wheel, wheel", "pitcher, ewer", "goal", "bread, breadstuff, staff of life", "beds", "wood", "file cabinet", "newspaper, paper", "motorboat", "rope", "guitar", "rubble", "scarf", "barrels", "cap", "leaves", "control tower", "dashboard", "bandstand", "lectern", "switch, electric switch, electrical switch", "baseboard, mopboard, skirting board", "shower room", "smoke", "faucet, spigot", "bulldozer", "saucepan", "shops", "meter", "crevasse", "gear", "candelabrum, candelabra", "sofa bed", "tunnel", "pallet", "wire, conducting wire", "kettle, boiler", "bidet", "baby buggy, baby carriage, carriage, perambulator, pram, stroller, go-cart, pushchair, pusher", "music stand", "pipe, tube", "cup", "parking meter", "ice hockey rink", "shelter", "weeds", "temple", "patty, cake", "ski slope", "panel", "wallet", "wheel", "towel rack, towel horse", "roundabout", "canister, cannister, tin", "rod", "soap dispenser", "bell", "canvas", "box office, ticket office, ticket booth", "teacup", "trellis", "workbench", "valley, vale", "toaster", "knife", "podium", "ramp", "tumble dryer", "fireplug, fire hydrant, plug", "gym shoe, sneaker, tennis shoe", "lab bench", "equipment", "rocky formation", "plastic", "calendar", "caravan", "check-in-desk", "ticket counter", "brush", "mill", "covered bridge", "bowling alley", "hanger", "excavator", "trestle", "revolving door", "blast furnace", "scale, weighing machine", "projector", "soap", "locker", "tractor", "stretcher", "frame", "grating", "alembic", "candle, taper, wax light", "barrier", "cardboard", "cave", "puddle", "tarp", "price tag", "watchtower", "meters", "light bulb, lightbulb, bulb, incandescent lamp, electric light, electric-light bulb", "tracks", "hair dryer", "skirt", "viaduct", "paper towel", "coat", "sheet", "fire extinguisher, extinguisher, asphyxiator", "water wheel", "pottery, clayware", "magazine rack", "teapot", "microphone, mike", "support", "forklift", "canyon", "cash register, register", "leaf, leafage, foliage", "remote control, remote", "soap dish", "windshield, windscreen", "cat", "cue, cue stick, pool cue, pool stick", "vent, venthole, vent-hole, blowhole", "videos", "shovel", "eaves", "antenna, aerial, transmitting aerial", "shipyard", "hen, biddy", "traffic cone", "washing machines", "truck crane", "cds", "niche", "scoreboard", "briefcase", "boot", "sweater, jumper", "hay", "pack", "bottle rack", "glacier", "pergola", "building materials", "television camera", "first floor", "rifle", "tennis table", "stadium", "safety belt", "cover", "dish rack", "synthesizer", "pumpkin", "gutter", "fruit stand", "ice floe, floe", "handle, grip, handgrip, hold", "wheelchair", "mousepad, mouse mat", "diploma", "fairground ride", "radio", "hotplate", "junk", "wheelbarrow", "stream", "toll plaza", "punching bag", "trough", "throne", "chair desk", "weighbridge", "extractor fan", "hanging clothes", "dish, dish aerial, dish antenna, saucer", "alarm clock, alarm", "ski lift", "chain", "garage", "mechanical shovel", "wine rack", "tramway", "treadmill", "menu", "block", "well", "witness stand", "branch", "duck", "casserole", "frying pan", "desk organizer", "mast", "spectacles, specs, eyeglasses, glasses", "service elevator", "dollhouse", "hammock", "clothes hanging", "photocopier", "notepad", "golf cart", "footpath", "cross", "baptismal font", "boiler", "skip", "rotisserie", "tables", "water mill", "helmet", "cover curtain", "brick", "table runner", "ashtray", "street box", "stick", "hangers", "cells", "urinal", "centerpiece", "portable fridge", "dvds", "golf club", "skirting board", "water cooler", "clipboard", "camera, photographic camera", "pigeonhole", "chips", "food processor", "post box", "lid", "drum", "blender", "cave entrance", "dental chair", "obelisk", "canoe", "mobile", "monitors", "pool ball", "cue rack", "baggage carts", "shore", "fork", "paper filer", "bicycle rack", "coat rack", "garland", "sports bag", "fish tank", "towel dispenser", "carriage", "brochure", "plaque", "stringer", "iron", "spoon", "flag pole", "toilet brush", "book stand", "water faucet, water tap, tap, hydrant", "ticket office", "broom", "dvd", "ice bucket", "carapace, shell, cuticle, shield", "tureen", "folders", "chess", "root", "sewing machine", "model", "pen", "violin", "sweatshirt", "recycling materials", "mitten", "chopping board, cutting board", "mask", "log", "mouse, computer mouse", "grill", "hole", "target", "trash bag", "chalk", "sticks", "balloon", "score", "hair spray", "roll", "runner", "engine", "inflatable glove", "games", "pallets", "baskets", "coop", "dvd player", "rocking horse", "buckets", "bread rolls", "shawl", "watering can", "spotlights", "post-it", "bowls", "security camera", "runner cloth", "lock", "alarm, warning device, alarm system", "side", "roulette", "bone", "cutlery", "pool balls", "wheels", "spice rack", "plant pots", "towel ring", "bread box", "video", "funfair", "breads", "tripod", "ironing board", "skimmer", "hollow", "scratching post", "tricycle", "file box", "mountain pass", "tombstones", "cooker", "card game, cards", "golf bag", "towel paper", "chaise lounge", "sun", "toilet paper holder", "rake", "key", "umbrella stand", "dartboard", "transformer", "fireplace utensils", "sweatshirts", "cellular telephone, cellular phone, cellphone, cell, mobile phone", "tallboy", "stapler", "sauna", "test tube", "palette", "shopping carts", "tools", "push button, push, button", "star", "roof rack", "barbed wire", "spray", "ear", "sponge", "racket", "tins", "eyeglasses", "file", "scarfs", "sugar bowl", "flip flop", "headstones", "laptop bag", "leash", "climbing frame", "suit hanger", "floor spotlight", "plate rack", "sewer", "hard drive", "sprinkler", "tools box", "necklace", "bulbs", "steel industry", "club", "jack", "door bars", "control panel, instrument panel, control board, board, panel", "hairbrush", "napkin holder", "office", "smoke detector", "utensils", "apron", "scissors", "terminal", "grinder", "entry phone", "newspaper stand", "pepper shaker", "onions", "central processing unit, cpu, c p u , central processor, processor, mainframe", "tape", "bat", "coaster", "calculator", "potatoes", "luggage rack", "salt", "street number", "viewpoint", "sword", "cd", "rowing machine", "plug", "andiron, firedog, dog, dog-iron", "pepper", "tongs", "bonfire", "dog dish", "belt", "dumbbells", "videocassette recorder, vcr", "hook", "envelopes", "shower faucet", "watch", "padlock", "swimming pool ladder", "spanners", "gravy boat", "notice board", "trash bags", "fire alarm", "ladle", "stethoscope", "rocket", "funnel", "bowling pins", "valve", "thermometer", "cups", "spice jar", "night light", "soaps", "games table", "slotted spoon", "reel", "scourer", "sleeping robe", "desk mat", "dumbbell", "hammer", "tie", "typewriter", "shaker", "cheese dish", "sea star", "racquet", "butane gas cylinder", "paper weight", "shaving brush", "sunglasses", "gear shift", "towel rail", "adding machine, totalizer, totaliser"] -------------------------------------------------------------------------------- /utils/dataset/ade_icl.json: -------------------------------------------------------------------------------- 1 | { -------------------------------------------------------------------------------- /utils/dataset/ade_icl.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/SEGIC/73bbcc340825e89f18e1b607cdbaaa09adc9619a/utils/dataset/ade_icl.pth -------------------------------------------------------------------------------- /utils/dataset/pc459_classes.json: -------------------------------------------------------------------------------- 1 | ["accordion", "aeroplane", "airconditioner", "antenna", "artillery", "ashtray", "atrium", "babycarriage", "bag", "ball", "balloon", "bambooweaving", "barrel", "baseballbat", "basket", "basketballbackboard", "bathtub", "bed", "bedclothes", "beer", "bell", "bench", "bicycle", "binoculars", "bird", "birdcage", "birdfeeder", "birdnest", "blackboard", "board", "boat", "bone", "book", "bottle", "bottleopener", "bowl", "box", "bracelet", "brick", "bridge", "broom", "brush", "bucket", "building", "bus", "cabinet", "cabinetdoor", "cage", "cake", "calculator", "calendar", "camel", "camera", "cameralens", "can", "candle", "candleholder", "cap", "car", "card", "cart", "case", "casetterecorder", "cashregister", "cat", "cd", "cdplayer", "ceiling", "cellphone", "cello", "chain", "chair", "chessboard", "chicken", "chopstick", "clip", "clippers", "clock", "closet", "cloth", "clothestree", "coffee", "coffeemachine", "comb", "computer", "concrete", "cone", "container", "controlbooth", "controller", "cooker", "copyingmachine", "coral", "cork", "corkscrew", "counter", "court", "cow", "crabstick", "crane", "crate", "cross", "crutch", "cup", "curtain", "cushion", "cuttingboard", "dais", "disc", "disccase", "dishwasher", "dock", "dog", "dolphin", "door", "drainer", "dray", "drinkdispenser", "drinkingmachine", "drop", "drug", "drum", "drumkit", "duck", "dumbbell", "earphone", "earrings", "egg", "electricfan", "electriciron", "electricpot", "electricsaw", "electronickeyboard", "engine", "envelope", "equipment", "escalator", "exhibitionbooth", "extinguisher", "eyeglass", "fan", "faucet", "faxmachine", "fence", "ferriswheel", "fireextinguisher", "firehydrant", "fireplace", "fish", "fishtank", "fishbowl", "fishingnet", "fishingpole", "flag", "flagstaff", "flame", "flashlight", "floor", "flower", "fly", "foam", "food", "footbridge", "forceps", "fork", "forklift", "fountain", "fox", "frame", "fridge", "frog", "fruit", "funnel", "furnace", "gamecontroller", "gamemachine", "gascylinder", "gashood", "gasstove", "giftbox", "glass", "glassmarble", "globe", "glove", "goal", "grandstand", "grass", "gravestone", "ground", "guardrail", "guitar", "gun", "hammer", "handcart", "handle", "handrail", "hanger", "harddiskdrive", "hat", "hay", "headphone", "heater", "helicopter", "helmet", "holder", "hook", "horse", "horse-drawncarriage", "hot-airballoon", "hydrovalve", "ice", "inflatorpump", "ipod", "iron", "ironingboard", "jar", "kart", "kettle", "key", "keyboard", "kitchenrange", "kite", "knife", "knifeblock", "ladder", "laddertruck", "ladle", "laptop", "leaves", "lid", "lifebuoy", "light", "lightbulb", "lighter", "line", "lion", "lobster", "lock", "machine", "mailbox", "mannequin", "map", "mask", "mat", "matchbook", "mattress", "menu", "metal", "meterbox", "microphone", "microwave", "mirror", "missile", "model", "money", "monkey", "mop", "motorbike", "mountain", "mouse", "mousepad", "musicalinstrument", "napkin", "net", "newspaper", "oar", "ornament", "outlet", "oven", "oxygenbottle", "pack", "pan", "paper", "paperbox", "papercutter", "parachute", "parasol", "parterre", "patio", "pelage", "pen", "pencontainer", "pencil", "person", "photo", "piano", "picture", "pig", "pillar", "pillow", "pipe", "pitcher", "plant", "plastic", "plate", "platform", "player", "playground", "pliers", "plume", "poker", "pokerchip", "pole", "pooltable", "postcard", "poster", "pot", "pottedplant", "printer", "projector", "pumpkin", "rabbit", "racket", "radiator", "radio", "rail", "rake", "ramp", "rangehood", "receiver", "recorder", "recreationalmachines", "remotecontrol", "road", "robot", "rock", "rocket", "rockinghorse", "rope", "rug", "ruler", "runway", "saddle", "sand", "saw", "scale", "scanner", "scissors", "scoop", "screen", "screwdriver", "sculpture", "scythe", "sewer", "sewingmachine", "shed", "sheep", "shell", "shelves", "shoe", "shoppingcart", "shovel", "sidecar", "sidewalk", "sign", "signallight", "sink", "skateboard", "ski", "sky", "sled", "slippers", "smoke", "snail", "snake", "snow", "snowmobiles", "sofa", "spanner", "spatula", "speaker", "speedbump", "spicecontainer", "spoon", "sprayer", "squirrel", "stage", "stair", "stapler", "stick", "stickynote", "stone", "stool", "stove", "straw", "stretcher", "sun", "sunglass", "sunshade", "surveillancecamera", "swan", "sweeper", "swimring", "swimmingpool", "swing", "switch", "table", "tableware", "tank", "tap", "tape", "tarp", "telephone", "telephonebooth", "tent", "tire", "toaster", "toilet", "tong", "tool", "toothbrush", "towel", "toy", "toycar", "track", "train", "trampoline", "trashbin", "tray", "tree", "tricycle", "tripod", "trophy", "truck", "tube", "turtle", "tvmonitor", "tweezers", "typewriter", "umbrella", "unknown", "vacuumcleaner", "vendingmachine", "videocamera", "videogameconsole", "videoplayer", "videotape", "violin", "wakeboard", "wall", "wallet", "wardrobe", "washingmachine", "watch", "water", "waterdispenser", "waterpipe", "waterskateboard", "watermelon", "whale", "wharf", "wheel", "wheelchair", "window", "windowblinds", "wineglass", "wire", "wood", "wool"] 2 | 3 | -------------------------------------------------------------------------------- /utils/dataset/sd_ade847_classes.json: -------------------------------------------------------------------------------- 1 | ["wall", "building, edifice", "sky", "tree", "road, route", "floor, flooring", "ceiling", "bed", "sidewalk, pavement", "earth, ground", "cabinet", "person, individual, someone, somebody, mortal, soul", "grass", "windowpane, window", "car, auto, automobile, machine, motorcar", "mountain, mount", "plant, flora, plant life", "table", "chair", "curtain, drape, drapery, mantle, pall", "door", "sofa, couch, lounge", "sea", "painting, picture", "water", "mirror", "house", "rug, carpet, carpeting", "shelf", "armchair", "fence, fencing", "field", "lamp", "rock, stone", "seat", "river", "desk", "bathtub, bathing tub, bath, tub", "railing, rail", "signboard, sign", "cushion", "path", "work surface", "stairs, steps", "column, pillar", "sink", "wardrobe, closet, press", "snow", "refrigerator, icebox", "base, pedestal, stand", "bridge, span", "blind, screen", "runway", "cliff, drop, drop-off", "sand", "fireplace, hearth, open fireplace", "pillow", "screen door, screen", "toilet, can, commode, crapper, pot, potty, stool, throne", "skyscraper", "grandstand, covered stand", "box", "pool table, billiard table, snooker table", "palm, palm tree", "double door", "coffee table, cocktail table", "counter", "countertop", "chest of drawers, chest, bureau, dresser", "kitchen island", "boat", "waterfall, falls", "stove, kitchen stove, range, kitchen range, cooking stove", "flower", "bookcase", "controls", "book", "stairway, staircase", "streetlight, street lamp", "computer, computing machine, computing device, data processor, electronic computer, information processing system", "bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle", "swivel chair", "light, light source", "bench", "case, display case, showcase, vitrine", "towel", "fountain", "embankment", "television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box", "van", "hill", "awning, sunshade, sunblind", "poster, posting, placard, notice, bill, card", "truck, motortruck", "airplane, aeroplane, plane", "pole", "tower", "court", "ball", "aircraft carrier, carrier, flattop, attack aircraft carrier", "buffet, counter, sideboard", "hovel, hut, hutch, shack, shanty", "apparel, wearing apparel, dress, clothes", "minibike, motorbike", "animal, animate being, beast, brute, creature, fauna", "chandelier, pendant, pendent", "step, stair", "booth, cubicle, stall, kiosk", "bicycle, bike, wheel, cycle", "doorframe, doorcase", "sconce", "pond", "trade name, brand name, brand, marque", "bannister, banister, balustrade, balusters, handrail", "bag", "traffic light, traffic signal, stoplight", "gazebo", "escalator, moving staircase, moving stairway", "land, ground, soil", "board, plank", "arcade machine", "eiderdown, duvet, continental quilt", "bar", "stall, stand, sales booth", "playground", "ship", "ottoman, pouf, pouffe, puff, hassock", "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", "bottle", "cradle", "pot, flowerpot", "conveyer belt, conveyor belt, conveyer, conveyor, transporter", "train, railroad train", "stool", "lake", "tank, storage tank", "ice, water ice", "basket, handbasket", "manhole", "tent, collapsible shelter", "canopy", "microwave, microwave oven", "barrel, cask", "dirt track", "beam", "dishwasher, dish washer, dishwashing machine", "plate", "screen, crt screen", "ruins", "washer, automatic washer, washing machine", "blanket, cover", "plaything, toy", "food, solid food", "screen, silver screen, projection screen", "oven", "stage", "beacon, lighthouse, beacon light, pharos", "umbrella", "sculpture", "aqueduct", "container", "scaffolding, staging", "hood, exhaust hood", "curb, curbing, kerb", "roller coaster", "horse, equus caballus", "catwalk", "glass, drinking glass", "vase", "central reservation", "carousel", "radiator", "closet", "machine", "pier, wharf, wharfage, dock", "fan", "inflatable bounce game", "pitch", "paper", "arcade, colonnade", "hot tub", "helicopter", "tray", "partition, divider", "vineyard", "bowl", "bullring", "flag", "pot", "footbridge, overcrossing, pedestrian bridge", "shower", "bag, traveling bag, travelling bag, grip, suitcase", "bulletin board, notice board", "confessional booth", "trunk, tree trunk, bole", "forest", "elevator door", "laptop, laptop computer", "instrument panel", "bucket, pail", "tapestry, tapis", "platform", "jacket", "gate", "monitor, monitoring device", "telephone booth, phone booth, call box, telephone box, telephone kiosk", "spotlight, spot", "ring", "control panel", "blackboard, chalkboard", "air conditioner, air conditioning", "chest", "clock", "sand dune", "pipe, pipage, piping", "vault", "table football", "cannon", "swimming pool, swimming bath, natatorium", "fluorescent, fluorescent fixture", "statue", "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", "exhibitor", "ladder", "carport", "dam", "pulpit", "skylight, fanlight", "water tower", "grill, grille, grillwork", "display board", "pane, pane of glass, window glass", "rubbish, trash, scrap", "ice rink", "fruit", "patio", "vending machine", "telephone, phone, telephone set", "net", "backpack, back pack, knapsack, packsack, rucksack, haversack", "jar", "track", "magazine", "shutter", "roof", "banner, streamer", "landfill", "post", "altarpiece, reredos", "hat, chapeau, lid", "arch, archway", "table game", "bag, handbag, pocketbook, purse", "document, written document, papers", "dome", "pier", "shanties", "forecourt", "crane", "dog, domestic dog, canis familiaris", "piano, pianoforte, forte-piano", "drawing", "cabin", "ad, advertisement, advertizement, advertising, advertizing, advert", "amphitheater, amphitheatre, coliseum", "monument", "henhouse", "cockpit", "heater, warmer", "windmill, aerogenerator, wind generator", "pool", "elevator, lift", "decoration, ornament, ornamentation", "labyrinth", "text, textual matter", "printer", "mezzanine, first balcony", "mattress", "straw", "stalls", "patio, terrace", "billboard, hoarding", "bus stop", "trouser, pant", "console table, console", "rack", "notebook", "shrine", "pantry", "cart", "steam shovel", "porch", "postbox, mailbox, letter box", "figurine, statuette", "recycling bin", "folding screen", "telescope", "deck chair, beach chair", "kennel", "coffee maker", "altar, communion table, lord's table", "fish", "easel", "artificial golf green", "iceberg", "candlestick, candle holder", "shower stall, shower bath", "television stand", "wall socket, wall plug, electric outlet, electrical outlet, outlet, electric receptacle", "skeleton", "grand piano, grand", "candy, confect", "grille door", "pedestal, plinth, footstall", "jersey, t-shirt, tee shirt", "shoe", "gravestone, headstone, tombstone", "shanty", "structure", "rocking chair, rocker", "bird", "place mat", "tomb", "big top", "gas pump, gasoline pump, petrol pump, island dispenser", "lockers", "cage", "finger", "bleachers", "ferris wheel", "hairdresser chair", "mat", "stands", "aquarium, fish tank, marine museum", "streetcar, tram, tramcar, trolley, trolley car", "napkin, table napkin, serviette", "dummy", "booklet, brochure, folder, leaflet, pamphlet", "sand trap", "shop, store", "table cloth", "service station", "coffin", "drawer", "cages", "slot machine, coin machine", "balcony", "volleyball court", "table tennis", "control table", "shirt", "merchandise, ware, product", "railway", "parterre", "chimney", "can, tin, tin can", "tanks", "fabric, cloth, material, textile", "alga, algae", "system", "map", "greenhouse", "mug", "barbecue", "trailer", "toilet tissue, toilet paper, bathroom tissue", "organ", "dishrag, dishcloth", "island", "keyboard", "trench", "basket, basketball hoop, hoop", "steering wheel, wheel", "pitcher, ewer", "goal", "bread, breadstuff, staff of life", "beds", "wood", "file cabinet", "newspaper, paper", "motorboat", "rope", "guitar", "rubble", "scarf", "barrels", "cap", "leaves", "control tower", "dashboard", "bandstand", "lectern", "switch, electric switch, electrical switch", "baseboard, mopboard, skirting board", "shower room", "smoke", "faucet, spigot", "bulldozer", "saucepan", "shops", "meter", "crevasse", "gear", "candelabrum, candelabra", "sofa bed", "tunnel", "pallet", "wire, conducting wire", "kettle, boiler", "bidet", "baby buggy, baby carriage, carriage, perambulator, pram, stroller, go-cart, pushchair, pusher", "music stand", "pipe, tube", "cup", "parking meter", "ice hockey rink", "shelter", "weeds", "temple", "patty, cake", "ski slope", "panel", "wallet", "wheel", "towel rack, towel horse", "roundabout", "canister, cannister, tin", "rod", "soap dispenser", "bell", "canvas", "box office, ticket office, ticket booth", "teacup", "trellis", "workbench", "valley, vale", "toaster", "knife", "podium", "ramp", "tumble dryer", "fireplug, fire hydrant, plug", "gym shoe, sneaker, tennis shoe", "lab bench", "equipment", "rocky formation", "plastic", "calendar", "caravan", "check-in-desk", "ticket counter", "brush", "mill", "covered bridge", "bowling alley", "hanger", "excavator", "trestle", "revolving door", "blast furnace", "scale, weighing machine", "projector", "soap", "locker", "tractor", "stretcher", "frame", "grating", "alembic", "candle, taper, wax light", "barrier", "cardboard", "cave", "puddle", "tarp", "price tag", "watchtower", "meters", "light bulb, lightbulb, bulb, incandescent lamp, electric light, electric-light bulb", "tracks", "hair dryer", "skirt", "viaduct", "paper towel", "coat", "sheet", "fire extinguisher, extinguisher, asphyxiator", "water wheel", "pottery, clayware", "magazine rack", "teapot", "microphone, mike", "support", "forklift", "canyon", "cash register, register", "leaf, leafage, foliage", "remote control, remote", "soap dish", "windshield, windscreen", "cat", "cue, cue stick, pool cue, pool stick", "vent, venthole, vent-hole, blowhole", "videos", "shovel", "eaves", "antenna, aerial, transmitting aerial", "shipyard", "hen, biddy", "traffic cone", "washing machines", "truck crane", "cds", "niche", "scoreboard", "briefcase", "boot", "sweater, jumper", "hay", "pack", "bottle rack", "glacier", "pergola", "building materials", "television camera", "first floor", "rifle", "tennis table", "stadium", "safety belt", "cover", "dish rack", "synthesizer", "pumpkin", "gutter", "fruit stand", "ice floe, floe", "handle, grip, handgrip, hold", "wheelchair", "mousepad, mouse mat", "diploma", "fairground ride", "radio", "hotplate", "junk", "wheelbarrow", "stream", "toll plaza", "punching bag", "trough", "throne", "chair desk", "weighbridge", "extractor fan", "hanging clothes", "dish, dish aerial, dish antenna, saucer", "alarm clock, alarm", "ski lift", "chain", "garage", "mechanical shovel", "wine rack", "tramway", "treadmill", "menu", "block", "well", "witness stand", "branch", "duck", "casserole", "frying pan", "desk organizer", "mast", "spectacles, specs, eyeglasses, glasses", "service elevator", "dollhouse", "hammock", "clothes hanging", "photocopier", "notepad", "golf cart", "footpath", "cross", "baptismal font", "boiler", "skip", "rotisserie", "tables", "water mill", "helmet", "cover curtain", "brick", "table runner", "ashtray", "street box", "stick", "hangers", "cells", "urinal", "centerpiece", "portable fridge", "dvds", "golf club", "skirting board", "water cooler", "clipboard", "camera, photographic camera", "pigeonhole", "chips", "food processor", "post box", "lid", "drum", "blender", "cave entrance", "dental chair", "obelisk", "canoe", "mobile", "monitors", "pool ball", "cue rack", "baggage carts", "shore", "fork", "paper filer", "bicycle rack", "coat rack", "garland", "sports bag", "fish tank", "towel dispenser", "carriage", "brochure", "plaque", "stringer", "iron", "spoon", "flag pole", "toilet brush", "book stand", "water faucet, water tap, tap, hydrant", "ticket office", "broom", "dvd", "ice bucket", "carapace, shell, cuticle, shield", "tureen", "folders", "chess", "root", "sewing machine", "model", "pen", "violin", "sweatshirt", "recycling materials", "mitten", "chopping board, cutting board", "mask", "log", "mouse, computer mouse", "grill", "hole", "target", "trash bag", "chalk", "sticks", "balloon", "score", "hair spray", "roll", "runner", "engine", "inflatable glove", "games", "pallets", "baskets", "coop", "dvd player", "rocking horse", "buckets", "bread rolls", "shawl", "watering can", "spotlights", "post-it", "bowls", "security camera", "runner cloth", "lock", "alarm, warning device, alarm system", "side", "roulette", "bone", "cutlery", "pool balls", "wheels", "spice rack", "plant pots", "towel ring", "bread box", "video", "funfair", "breads", "tripod", "ironing board", "skimmer", "hollow", "scratching post", "tricycle", "file box", "mountain pass", "tombstones", "cooker", "card game, cards", "golf bag", "towel paper", "chaise lounge", "sun", "toilet paper holder", "rake", "key", "umbrella stand", "dartboard", "transformer", "fireplace utensils", "sweatshirts", "cellular telephone, cellular phone, cellphone, cell, mobile phone", "tallboy", "stapler", "sauna", "test tube", "palette", "shopping carts", "tools", "push button, push, button", "star", "roof rack", "barbed wire", "spray", "ear", "sponge", "racket", "tins", "eyeglasses", "file", "scarfs", "sugar bowl", "flip flop", "headstones", "laptop bag", "leash", "climbing frame", "suit hanger", "floor spotlight", "plate rack", "sewer", "hard drive", "sprinkler", "tools box", "necklace", "bulbs", "steel industry", "club", "jack", "door bars", "control panel, instrument panel, control board, board, panel", "hairbrush", "napkin holder", "office", "smoke detector", "utensils", "apron", "scissors", "terminal", "grinder", "entry phone", "newspaper stand", "pepper shaker", "onions", "central processing unit, cpu, c p u , central processor, processor, mainframe", "tape", "bat", "coaster", "calculator", "potatoes", "luggage rack", "salt", "street number", "viewpoint", "sword", "cd", "rowing machine", "plug", "andiron, firedog, dog, dog-iron", "pepper", "tongs", "bonfire", "dog dish", "belt", "dumbbells", "videocassette recorder, vcr", "hook", "envelopes", "shower faucet", "watch", "padlock", "swimming pool ladder", "spanners", "gravy boat", "notice board", "trash bags", "fire alarm", "ladle", "stethoscope", "rocket", "funnel", "bowling pins", "valve", "thermometer", "cups", "spice jar", "night light", "soaps", "games table", "slotted spoon", "reel", "scourer", "sleeping robe", "desk mat", "dumbbell", "hammer", "tie", "typewriter", "shaker", "cheese dish", "sea star", "racquet", "butane gas cylinder", "paper weight", "shaving brush", "sunglasses", "gear shift", "towel rail", "adding machine, totalizer, totaliser"] -------------------------------------------------------------------------------- /utils/dataset/train_ade20k_icl.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/SEGIC/73bbcc340825e89f18e1b607cdbaaa09adc9619a/utils/dataset/train_ade20k_icl.pth -------------------------------------------------------------------------------- /utils/dataset/val_ade20k_icl.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/SEGIC/73bbcc340825e89f18e1b607cdbaaa09adc9619a/utils/dataset/val_ade20k_icl.pth -------------------------------------------------------------------------------- /utils/dataset/val_ade847_icl.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/SEGIC/73bbcc340825e89f18e1b607cdbaaa09adc9619a/utils/dataset/val_ade847_icl.pth -------------------------------------------------------------------------------- /utils/dataset/val_pc459_icl.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/SEGIC/73bbcc340825e89f18e1b607cdbaaa09adc9619a/utils/dataset/val_pc459_icl.pth -------------------------------------------------------------------------------- /utils/dataset/val_sd_ade20k_icl.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/SEGIC/73bbcc340825e89f18e1b607cdbaaa09adc9619a/utils/dataset/val_sd_ade20k_icl.pth -------------------------------------------------------------------------------- /utils/dataset/val_sd_ade847_icl.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/SEGIC/73bbcc340825e89f18e1b607cdbaaa09adc9619a/utils/dataset/val_sd_ade847_icl.pth -------------------------------------------------------------------------------- /utils/fss_inst.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from torch.utils.data import Dataset 5 | import random 6 | import torch.nn.functional as F 7 | import torch 8 | import PIL.Image as Image 9 | import numpy as np 10 | from pycocotools.coco import COCO 11 | 12 | from torchvision.transforms import Compose 13 | from utils.dataloader import DualAug, LargeScaleJitter, DefaultBundle 14 | from utils.inst_aug import ColorJitter, RandomResizedCrop, RandomApply,RandomHorizontalFlip, Norm, DeNorm 15 | 16 | 17 | class InstCOCO(Dataset): 18 | def __init__(self, base_image_dir, transform, is_train=True, dataset_name='coco', max_inst=20): 19 | self.transform = transform 20 | split = 'train2017' if is_train else 'val2017' 21 | json_path = 'annotations/instances_{}.json'.format(split) 22 | self.is_lvis = False 23 | self.is_lip = False 24 | self.is_box_mask = False 25 | if dataset_name == 'lvis' : 26 | self.is_lvis = True 27 | split = 'train2017' if is_train else 'val2017' 28 | split_json = 'train' if is_train else 'val' 29 | json_path = 'lvis_v1_{}.json'.format(split_json) 30 | self.img_root = base_image_dir 31 | elif dataset_name == 'coco' : 32 | split = 'train2017' if is_train else 'val2017' 33 | json_path = 'annotations/instances_{}.json'.format(split) 34 | self.img_root = os.path.join(base_image_dir, split) 35 | elif dataset_name == 'paco_lvis' : 36 | split = 'train' if is_train else 'val' 37 | json_path = 'annotations/paco_lvis_v1_{}.json'.format(split) 38 | self.img_root = os.path.join(base_image_dir) 39 | elif dataset_name == 'o365' : 40 | self.is_box_mask = True 41 | split = 'train' if is_train else 'val' 42 | json_path = 'objects365_{}.json'.format(split) 43 | self.img_root = os.path.join(base_image_dir, split) 44 | elif dataset_name == 'lip': 45 | self.is_lip = True 46 | self.img_root = os.path.join(base_image_dir, 'train_images') 47 | self.anno_root = os.path.join(base_image_dir, 'TrainVal_parsing_annotations/train_segmentations') 48 | else : 49 | raise NotImplementedError 50 | 51 | self.ids = [] 52 | self.max_inst = max_inst 53 | if dataset_name == 'lip': 54 | for name in os.listdir(self.img_root): 55 | if name.endswith('.jpg'): 56 | self.ids.append(name) 57 | else : 58 | self.coco = COCO(os.path.join(base_image_dir, json_path)) 59 | ids = list(sorted(self.coco.imgs.keys())) 60 | for idx in ids : 61 | if len(self.coco.getAnnIds(idx)): 62 | self.ids.append(idx) 63 | 64 | def __len__(self,): 65 | return len(self.ids) 66 | 67 | def __getitem__(self, index): 68 | idx = self.ids[index] 69 | if self.is_lip : 70 | image_path, idx = idx, index 71 | image = Image.open(os.path.join(self.img_root, image_path)).convert('RGB') 72 | labels = Image.open(os.path.join(self.anno_root, image_path.replace('.jpg','.png'))).convert('L') 73 | labels = np.array(labels) 74 | masks = [] 75 | for cid in np.unique(labels)[1:]: #ignore bg 76 | masks.append(labels==cid) 77 | if len(masks) == 0: 78 | return self[index+1] 79 | masks = np.stack(masks) 80 | else : 81 | if self.is_lvis : 82 | coco_url = self.coco.loadImgs(idx)[0]["coco_url"] 83 | image_path = os.path.join(*coco_url.split('/')[-2:]) 84 | else : 85 | image_path = self.coco.loadImgs(idx)[0]["file_name"] 86 | 87 | image = Image.open(os.path.join(self.img_root, image_path)).convert('RGB') 88 | annos = self.coco.loadAnns(self.coco.getAnnIds(idx)) 89 | if self.is_box_mask : 90 | def bbox_to_mask(bbox): 91 | x,y,w,h = bbox 92 | x1,y1 = x, y 93 | x2,y2 = x+w, y 94 | x3,y3 = x+w, y+h 95 | x4,y4 = x, y+h 96 | return [[x1,y1,x2,y2,x3,y3,x4,y4]] 97 | for i, ann in enumerate(annos): 98 | annos[i]['segmentation'] = bbox_to_mask(ann['bbox']) 99 | 100 | # if len(annos) > self.max_inst : 101 | # annos = np.random.choice( 102 | # annos, size=self.max_inst, replace=False 103 | # ).tolist() 104 | masks = np.stack([self.coco.annToMask(x) for x in annos]) 105 | 106 | def to_tensor(x): 107 | return torch.tensor(np.array(x), dtype=torch.float32).permute(2,0,1) 108 | image = to_tensor(image) 109 | masks = torch.tensor(masks).float() 110 | 111 | sample = {'image': image, 112 | 'label': masks * 255, 113 | "imidx": torch.from_numpy(np.array(idx)), 114 | "shape": torch.tensor(image.shape[-2:]), 115 | "class_name": 'instance', 116 | 'is_inst': True, 117 | 'is_box_mask': self.is_box_mask 118 | } 119 | 120 | if self.transform: 121 | sample = self.transform(sample) 122 | 123 | mask_sum = sample['label_dual'].flatten(1).sum(-1) > 100 124 | mask_sum_reverse = sample['label'].flatten(1).sum(-1) > 100 125 | mask_sum = mask_sum_reverse & mask_sum 126 | non_empty_idx = mask_sum.nonzero()[:,0] 127 | max_inst = random.randint(0, self.max_inst - 1) 128 | rand_idx = torch.randperm(non_empty_idx.shape[0])[:max_inst] 129 | select_idx = non_empty_idx[rand_idx] 130 | if len(select_idx) == 0 : 131 | select_idx = torch.tensor([0]) 132 | sample['label'] = sample['label'][select_idx] 133 | sample['label_dual'] = sample['label_dual'][select_idx] 134 | 135 | if False : 136 | import cv2 137 | cv2.imwrite('tmp.jpg', image.permute(1,2,0).int().numpy()) 138 | cv2.imwrite('tmp.jpg', sample['image'].permute(1,2,0).int().numpy()) 139 | cv2.imwrite('tmp.jpg', sample['image_dual'].permute(1,2,0).int().numpy()) 140 | pass 141 | return sample 142 | 143 | 144 | def get_inst_aug(img_size): 145 | aug_list = [ 146 | Norm(), 147 | RandomResizedCrop(img_size, scale=(0.3, 1.0), interpolation=3), # 3 is bicubic 148 | RandomApply([ 149 | ColorJitter(0.4, 0.4, 0.2, 0.1) 150 | ], p=0.2), 151 | RandomHorizontalFlip(0.1), 152 | DeNorm(), 153 | DefaultBundle() 154 | ] 155 | return Compose(aug_list) 156 | 157 | 158 | # from utils.dataloader import get_im_gt_name_dict, create_dataloaders, RandomHFlip, Resize, LargeScaleJitter, DualAug, ResizeVOS 159 | # from torchvision import transforms 160 | # import cv2 161 | 162 | # # aug_list = [RandomHFlip(), LargeScaleJitter(output_size=896)] 163 | # aug_list = get_inst_aug(896) 164 | # aug_list = DualAug([aug_list]) 165 | # # dataset = InstCOCO('data/ade20k', transforms.Compose(aug_list)) 166 | # # dataset = InstCOCO('data/lip', aug_list, dataset_name='lip') 167 | # dataset = InstCOCO('data/paco', aug_list, dataset_name='paco_lvis') 168 | # def show(idx): 169 | # aa = dataset[idx] 170 | # aa['class_name'] 171 | # image, image_dual, label, label_dual = aa['image'], aa['image_dual'], aa['label'], aa['label_dual'] 172 | # xx,yy = torch.cat([image, image_dual], dim=2), torch.cat([image*label[:1]/255, image_dual*label_dual[:1]/255], dim=2) 173 | # vis = torch.cat([xx, yy], dim=1) 174 | # cv2.imwrite('tmp.jpg', vis.permute(1,2,0).numpy()[...,::-1]) 175 | 176 | # show(3) 177 | # import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /utils/inst_aug.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | from torch import Tensor 5 | import torchvision.transforms as transforms 6 | 7 | import torchvision.transforms.functional as F 8 | from torchvision.transforms.functional import _interpolation_modes_from_int, InterpolationMode 9 | from PIL import Image, ImageFilter, ImageOps 10 | 11 | 12 | class Norm(): 13 | def __call__(self, sample): 14 | sample['image'] /= 255 15 | return sample 16 | 17 | class DeNorm(): 18 | def __call__(self, sample): 19 | sample['image'] *= 255 20 | sample['ori_size'] = torch.tensor(sample['image'].shape[-2:]) 21 | return sample 22 | 23 | 24 | class RandomApply(transforms.RandomApply): 25 | """Apply randomly a list of transformations with a given probability. 26 | .. note:: 27 | In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of 28 | transforms as shown below: 29 | >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ 30 | >>> transforms.ColorJitter(), 31 | >>> ]), p=0.3) 32 | >>> scripted_transforms = torch.jit.script(transforms) 33 | Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require 34 | `lambda` functions or ``PIL.Image``. 35 | Args: 36 | transforms (sequence or torch.nn.Module): list of transformations 37 | p (float): probability 38 | """ 39 | 40 | def __init__(self, transforms, p=0.5): 41 | super().__init__(transforms, p=p) 42 | 43 | def forward(self, sample, interpolation1=None, interpolation2=None): 44 | if self.p < torch.rand(1): 45 | return sample 46 | for t in self.transforms: 47 | sample = t(sample) 48 | return sample 49 | 50 | 51 | class ColorJitter(transforms.ColorJitter): 52 | """Randomly change the brightness, contrast, saturation and hue of an image. 53 | If the image is torch Tensor, it is expected 54 | to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. 55 | If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported. 56 | Args: 57 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 58 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 59 | or the given [min, max]. Should be non negative numbers. 60 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 61 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 62 | or the given [min, max]. Should be non negative numbers. 63 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 64 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 65 | or the given [min, max]. Should be non negative numbers. 66 | hue (float or tuple of float (min, max)): How much to jitter hue. 67 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 68 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 69 | To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space; 70 | thus it does not work if you normalize your image to an interval with negative values, 71 | or use an interpolation that generates negative values before using this function. 72 | """ 73 | 74 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 75 | super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) 76 | 77 | def forward(self, sample, interpolation1=None, interpolation2=None): 78 | """ 79 | Args: 80 | img (PIL Image or Tensor): Input image. 81 | Returns: 82 | PIL Image or Tensor: Color jittered image. 83 | """ 84 | img = sample['image'] 85 | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( 86 | self.brightness, self.contrast, self.saturation, self.hue 87 | ) 88 | 89 | for fn_id in fn_idx: 90 | if fn_id == 0 and brightness_factor is not None: 91 | img = F.adjust_brightness(img, brightness_factor) 92 | elif fn_id == 1 and contrast_factor is not None: 93 | img = F.adjust_contrast(img, contrast_factor) 94 | elif fn_id == 2 and saturation_factor is not None: 95 | img = F.adjust_saturation(img, saturation_factor) 96 | elif fn_id == 3 and hue_factor is not None: 97 | img = F.adjust_hue(img, hue_factor) 98 | sample['image'] = img 99 | return sample 100 | 101 | 102 | class RandomResizedCrop(transforms.RandomResizedCrop): 103 | """Crop a random portion of image and resize it to a given size. 104 | If the image is torch Tensor, it is expected 105 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions 106 | A crop of the original image is made: the crop has a random area (H * W) 107 | and a random aspect ratio. This crop is finally resized to the given 108 | size. This is popularly used to train the Inception networks. 109 | Args: 110 | size (int or sequence): expected output size of the crop, for each edge. If size is an 111 | int instead of sequence like (h, w), a square output size ``(size, size)`` is 112 | made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). 113 | .. note:: 114 | In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. 115 | scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop, 116 | before resizing. The scale is defined with respect to the area of the original image. 117 | ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before 118 | resizing. 119 | interpolation (InterpolationMode): Desired interpolation enum defined by 120 | :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. 121 | If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and 122 | ``InterpolationMode.BICUBIC`` are supported. 123 | For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, 124 | but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. 125 | """ 126 | 127 | def __init__( 128 | self, 129 | size, 130 | scale=(0.08, 1.0), 131 | ratio=(3.0 / 4.0, 4.0 / 3.0), 132 | interpolation=InterpolationMode.BILINEAR, 133 | ): 134 | super().__init__(size, scale=scale, ratio=ratio, interpolation=interpolation) 135 | 136 | def forward(self, sample, interpolation1=None, interpolation2=None): 137 | """ 138 | Args: 139 | img (PIL Image or Tensor): Image to be cropped and resized. 140 | Returns: 141 | PIL Image or Tensor: Randomly cropped and resized image. 142 | """ 143 | img, tgt = sample['image'], sample['label'] 144 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 145 | # if interpolation1 == 'nearest': 146 | # interpolation1 = InterpolationMode.NEAREST 147 | # else: 148 | # interpolation1 = InterpolationMode.BICUBIC 149 | interpolation1 = InterpolationMode.BILINEAR 150 | interpolation2 = InterpolationMode.NEAREST 151 | 152 | img, tgt = F.resized_crop(img, i, j, h, w, self.size, interpolation1), \ 153 | F.resized_crop(tgt, i, j, h, w, self.size, interpolation2) 154 | 155 | sample.update(image=img, label=tgt) 156 | return sample 157 | 158 | 159 | 160 | class RandomHorizontalFlip(transforms.RandomHorizontalFlip): 161 | """Horizontally flip the given image randomly with a given probability. 162 | If the image is torch Tensor, it is expected 163 | to have [..., H, W] shape, where ... means an arbitrary number of leading 164 | dimensions 165 | Args: 166 | p (float): probability of the image being flipped. Default value is 0.5 167 | """ 168 | 169 | def __init__(self, p=0.5): 170 | super().__init__(p=p) 171 | 172 | def forward(self, sample, interpolation1=None, interpolation2=None): 173 | """ 174 | Args: 175 | img (PIL Image or Tensor): Image to be flipped. 176 | Returns: 177 | PIL Image or Tensor: Randomly flipped image. 178 | """ 179 | img, tgt = sample['image'], sample['label'] 180 | if torch.rand(1) < self.p: 181 | img, tgt = F.hflip(img), F.hflip(tgt) 182 | sample.update(image=img, label=tgt) 183 | return sample 184 | 185 | return sample 186 | -------------------------------------------------------------------------------- /utils/instance_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import contextlib 3 | import copy 4 | import io 5 | import itertools 6 | import json 7 | import logging 8 | import numpy as np 9 | import os 10 | import pickle 11 | from collections import OrderedDict 12 | import pycocotools.mask as mask_util 13 | import torch 14 | from pycocotools.coco import COCO 15 | from pycocotools.cocoeval import COCOeval 16 | from tabulate import tabulate 17 | 18 | import detectron2.utils.comm as comm 19 | from detectron2.config import CfgNode 20 | from detectron2.data import MetadataCatalog 21 | from detectron2.data.datasets.coco import convert_to_coco_json 22 | from detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco 23 | from detectron2.evaluation.fast_eval_api import COCOeval_opt 24 | from detectron2.structures import Boxes, BoxMode, pairwise_iou 25 | from detectron2.utils.file_io import PathManager 26 | from detectron2.utils.logger import create_small_table 27 | 28 | 29 | # modified from COCOEvaluator for instance segmetnat 30 | class InstanceSegEvaluator(COCOEvaluator): 31 | """ 32 | Evaluate AR for object proposals, AP for instance detection/segmentation, AP 33 | for keypoint detection outputs using COCO's metrics. 34 | See http://cocodataset.org/#detection-eval and 35 | http://cocodataset.org/#keypoints-eval to understand its metrics. 36 | The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means 37 | the metric cannot be computed (e.g. due to no predictions made). 38 | 39 | In addition to COCO, this evaluator is able to support any bounding box detection, 40 | instance segmentation, or keypoint detection dataset. 41 | """ 42 | 43 | def _eval_predictions(self, predictions, img_ids=None): 44 | """ 45 | Evaluate predictions. Fill self._results with the metrics of the tasks. 46 | """ 47 | self._logger.info("Preparing results for COCO format ...") 48 | coco_results = list(itertools.chain(*[x["instances"] for x in predictions])) 49 | tasks = self._tasks or self._tasks_from_predictions(coco_results) 50 | 51 | # unmap the category ids for COCO 52 | if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): 53 | dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id 54 | # all_contiguous_ids = list(dataset_id_to_contiguous_id.values()) 55 | # num_classes = len(all_contiguous_ids) 56 | # assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1 57 | 58 | reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()} 59 | for result in coco_results: 60 | category_id = result["category_id"] 61 | # assert category_id < num_classes, ( 62 | # f"A prediction has class={category_id}, " 63 | # f"but the dataset only has {num_classes} classes and " 64 | # f"predicted class id should be in [0, {num_classes - 1}]." 65 | # ) 66 | assert category_id in reverse_id_mapping, ( 67 | f"A prediction has class={category_id}, " 68 | f"but the dataset only has class ids in {dataset_id_to_contiguous_id}." 69 | ) 70 | result["category_id"] = reverse_id_mapping[category_id] 71 | 72 | if self._output_dir: 73 | file_path = os.path.join(self._output_dir, "coco_instances_results.json") 74 | self._logger.info("Saving results to {}".format(file_path)) 75 | with PathManager.open(file_path, "w") as f: 76 | f.write(json.dumps(coco_results)) 77 | f.flush() 78 | 79 | if not self._do_evaluation: 80 | self._logger.info("Annotations are not available for evaluation.") 81 | return 82 | 83 | self._logger.info( 84 | "Evaluating predictions with {} COCO API...".format( 85 | "unofficial" if self._use_fast_impl else "official" 86 | ) 87 | ) 88 | for task in sorted(tasks): 89 | assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!" 90 | coco_eval = ( 91 | _evaluate_predictions_on_coco( 92 | self._coco_api, 93 | coco_results, 94 | task, 95 | kpt_oks_sigmas=self._kpt_oks_sigmas, 96 | # use_fast_impl=self._use_fast_impl, 97 | img_ids=img_ids, 98 | max_dets_per_image=self._max_dets_per_image, 99 | ) 100 | if len(coco_results) > 0 101 | else None # cocoapi does not handle empty results very well 102 | ) 103 | 104 | res = self._derive_coco_results( 105 | coco_eval, task, class_names=self._metadata.get("thing_classes") 106 | ) 107 | self._results[task] = res -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.distributed as dist 3 | 4 | logger_initialized = {} 5 | 6 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 7 | """Initialize and get a logger by name. 8 | 9 | If the logger has not been initialized, this method will initialize the 10 | logger by adding one or two handlers, otherwise the initialized logger will 11 | be directly returned. During initialization, a StreamHandler will always be 12 | added. If `log_file` is specified and the process rank is 0, a FileHandler 13 | will also be added. 14 | 15 | Args: 16 | name (str): Logger name. 17 | log_file (str | None): The log filename. If specified, a FileHandler 18 | will be added to the logger. 19 | log_level (int): The logger level. Note that only the process of 20 | rank 0 is affected, and other processes will set the level to 21 | "Error" thus be silent most of the time. 22 | file_mode (str): The file mode used in opening log file. 23 | Defaults to 'w'. 24 | 25 | Returns: 26 | logging.Logger: The expected logger. 27 | """ 28 | logger = logging.getLogger(name) 29 | if name in logger_initialized: 30 | return logger 31 | # handle hierarchical names 32 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 33 | # initialization since it is a child of "a". 34 | for logger_name in logger_initialized: 35 | if name.startswith(logger_name): 36 | return logger 37 | 38 | # handle duplicate logs to the console 39 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) 40 | # to the root logger. As logger.propagate is True by default, this root 41 | # level handler causes logging messages from rank>0 processes to 42 | # unexpectedly show up on the console, creating much unwanted clutter. 43 | # To fix this issue, we set the root logger's StreamHandler, if any, to log 44 | # at the ERROR level. 45 | for handler in logger.root.handlers: 46 | if type(handler) is logging.StreamHandler: 47 | handler.setLevel(logging.ERROR) 48 | 49 | stream_handler = logging.StreamHandler() 50 | handlers = [stream_handler] 51 | 52 | if dist.is_available() and dist.is_initialized(): 53 | rank = dist.get_rank() 54 | else: 55 | rank = 0 56 | 57 | # only rank 0 will add a FileHandler 58 | if rank == 0 and log_file is not None: 59 | # Here, the default behaviour of the official logger is 'a'. Thus, we 60 | # provide an interface to change the file mode to the default 61 | # behaviour. 62 | file_handler = logging.FileHandler(log_file, file_mode) 63 | handlers.append(file_handler) 64 | 65 | formatter = logging.Formatter( 66 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 67 | for handler in handlers: 68 | handler.setFormatter(formatter) 69 | handler.setLevel(log_level) 70 | logger.addHandler(handler) 71 | 72 | if rank == 0: 73 | logger.setLevel(log_level) 74 | else: 75 | logger.setLevel(logging.ERROR) 76 | 77 | logger_initialized[name] = True 78 | 79 | return logger -------------------------------------------------------------------------------- /utils/loss_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from typing import List, Optional 4 | import utils.misc as misc 5 | 6 | def point_sample(input, point_coords, **kwargs): 7 | """ 8 | A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. 9 | Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside 10 | [0, 1] x [0, 1] square. 11 | Args: 12 | input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. 13 | point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains 14 | [0, 1] x [0, 1] normalized point coordinates. 15 | Returns: 16 | output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains 17 | features for points in `point_coords`. The features are obtained via bilinear 18 | interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. 19 | """ 20 | add_dim = False 21 | if point_coords.dim() == 3: 22 | add_dim = True 23 | point_coords = point_coords.unsqueeze(2) 24 | output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) 25 | if add_dim: 26 | output = output.squeeze(3) 27 | return output 28 | 29 | def cat(tensors: List[torch.Tensor], dim: int = 0): 30 | """ 31 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list 32 | """ 33 | assert isinstance(tensors, (list, tuple)) 34 | if len(tensors) == 1: 35 | return tensors[0] 36 | return torch.cat(tensors, dim) 37 | 38 | def get_uncertain_point_coords_with_randomness( 39 | coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio 40 | ): 41 | """ 42 | Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties 43 | are calculated for each point using 'uncertainty_func' function that takes point's logit 44 | prediction as input. 45 | See PointRend paper for details. 46 | Args: 47 | coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for 48 | class-specific or class-agnostic prediction. 49 | uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that 50 | contains logit predictions for P points and returns their uncertainties as a Tensor of 51 | shape (N, 1, P). 52 | num_points (int): The number of points P to sample. 53 | oversample_ratio (int): Oversampling parameter. 54 | importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. 55 | Returns: 56 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P 57 | sampled points. 58 | """ 59 | assert oversample_ratio >= 1 60 | assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 61 | num_boxes = coarse_logits.shape[0] 62 | num_sampled = int(num_points * oversample_ratio) 63 | point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) 64 | point_logits = point_sample(coarse_logits, point_coords, align_corners=False) 65 | # It is crucial to calculate uncertainty based on the sampled prediction value for the points. 66 | # Calculating uncertainties of the coarse predictions first and sampling them for points leads 67 | # to incorrect results. 68 | # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between 69 | # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. 70 | # However, if we calculate uncertainties for the coarse predictions first, 71 | # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. 72 | point_uncertainties = uncertainty_func(point_logits) 73 | num_uncertain_points = int(importance_sample_ratio * num_points) 74 | num_random_points = num_points - num_uncertain_points 75 | idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] 76 | shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) 77 | idx += shift[:, None] 78 | point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( 79 | num_boxes, num_uncertain_points, 2 80 | ) 81 | if num_random_points > 0: 82 | point_coords = cat( 83 | [ 84 | point_coords, 85 | torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), 86 | ], 87 | dim=1, 88 | ) 89 | return point_coords 90 | 91 | def dice_loss( 92 | inputs: torch.Tensor, 93 | targets: torch.Tensor, 94 | num_masks: float, 95 | mask_weight = None 96 | ): 97 | """ 98 | Compute the DICE loss, similar to generalized IOU for masks 99 | Args: 100 | inputs: A float tensor of arbitrary shape. 101 | The predictions for each example. 102 | targets: A float tensor with the same shape as inputs. Stores the binary 103 | classification label for each element in inputs 104 | (0 for the negative class and 1 for the positive class). 105 | """ 106 | inputs = inputs.sigmoid() 107 | inputs = inputs.flatten(1) 108 | numerator = 2 * (inputs * targets).sum(-1) 109 | denominator = inputs.sum(-1) + targets.sum(-1) 110 | loss = 1 - (numerator + 1) / (denominator + 1) 111 | if mask_weight is not None: 112 | loss = (loss * mask_weight).sum() / num_masks 113 | return loss.sum() / num_masks 114 | 115 | 116 | dice_loss_jit = torch.jit.script( 117 | dice_loss 118 | ) # type: torch.jit.ScriptModule 119 | 120 | 121 | def sigmoid_ce_loss( 122 | inputs: torch.Tensor, 123 | targets: torch.Tensor, 124 | num_masks: float, 125 | mask_weight = None 126 | ): 127 | """ 128 | Args: 129 | inputs: A float tensor of arbitrary shape. 130 | The predictions for each example. 131 | targets: A float tensor with the same shape as inputs. Stores the binary 132 | classification label for each element in inputs 133 | (0 for the negative class and 1 for the positive class). 134 | Returns: 135 | Loss tensor 136 | """ 137 | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none", weight=mask_weight) 138 | 139 | return loss.mean(1).sum() / num_masks 140 | 141 | 142 | sigmoid_ce_loss_jit = torch.jit.script( 143 | sigmoid_ce_loss 144 | ) # type: torch.jit.ScriptModule 145 | 146 | 147 | def calculate_uncertainty(logits): 148 | """ 149 | We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the 150 | foreground class in `classes`. 151 | Args: 152 | logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or 153 | class-agnostic, where R is the total number of predicted masks in all images and C is 154 | the number of foreground classes. The values are logits. 155 | Returns: 156 | scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with 157 | the most uncertain locations having the highest uncertainty score. 158 | """ 159 | assert logits.shape[1] == 1 160 | gt_class_logits = logits.clone() 161 | return -(torch.abs(gt_class_logits)) 162 | 163 | def loss_masks(src_masks, target_masks, num_masks, oversample_ratio=3.0, is_box_mask=False, 164 | mask_weight=None): 165 | """Compute the losses related to the masks: the focal loss and the dice loss. 166 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 167 | """ 168 | 169 | # No need to upsample predictions as we are using normalized coordinates :) 170 | 171 | if is_box_mask.sum() == 0 : 172 | is_box_mask = None 173 | 174 | if mask_weight is None : 175 | mask_weight = src_masks.new_ones(src_masks.shape[0], 1) 176 | 177 | if is_box_mask is not None: 178 | target_masks_resize = F.interpolate(target_masks, src_masks.shape[-2:]) 179 | target_masks_y, target_masks_x = target_masks_resize.max(-1)[0], target_masks_resize.max(-2)[0] 180 | src_masks_y, src_masks_x = src_masks.max(-1)[0], src_masks.max(-2)[0] 181 | 182 | target_masks_box = torch.cat([target_masks_y, target_masks_x], dim=-1).squeeze(1)[is_box_mask] 183 | src_masks_box = torch.cat([src_masks_y, src_masks_x], dim=-1).squeeze(1)[is_box_mask] 184 | 185 | loss_mask_box = sigmoid_ce_loss_jit(src_masks_box, target_masks_box, num_masks, mask_weight) 186 | loss_dice_box = dice_loss_jit(src_masks_box, target_masks_box, num_masks, mask_weight) 187 | src_masks, target_masks = src_masks[~is_box_mask], target_masks[~is_box_mask] 188 | 189 | with torch.no_grad(): 190 | # sample point_coords 191 | point_coords = get_uncertain_point_coords_with_randomness( 192 | src_masks, 193 | lambda logits: calculate_uncertainty(logits), 194 | 112 * 112, 195 | oversample_ratio, 196 | 0.75, 197 | ) 198 | # get gt labels 199 | point_labels = point_sample( 200 | target_masks, 201 | point_coords, 202 | align_corners=False, 203 | ).squeeze(1) 204 | 205 | point_logits = point_sample( 206 | src_masks, 207 | point_coords, 208 | align_corners=False, 209 | ).squeeze(1) 210 | 211 | loss_mask = sigmoid_ce_loss_jit(point_logits, point_labels, num_masks, mask_weight) 212 | loss_dice = dice_loss_jit(point_logits, point_labels, num_masks, mask_weight) 213 | 214 | if is_box_mask is not None : 215 | loss_mask = loss_mask + loss_mask_box 216 | loss_dice = loss_dice + loss_dice_box 217 | 218 | del src_masks 219 | del target_masks 220 | return loss_mask, loss_dice 221 | 222 | 223 | 224 | -------------------------------------------------------------------------------- /utils/lr_sched.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def adjust_learning_rate(optimizer, epoch, args): 4 | """Decay the learning rate with half-cycle cosine after warmup""" 5 | if epoch < args.warmup_epochs: 6 | lr = args.learning_rate * epoch / args.warmup_epochs 7 | else: 8 | lr = args.min_lr + (args.learning_rate - args.min_lr) * 0.5 * \ 9 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.max_epoch_num - args.warmup_epochs))) 10 | for param_group in optimizer.param_groups: 11 | if "lr_scale" in param_group: 12 | param_group["lr"] = lr * param_group["lr_scale"] 13 | else: 14 | param_group["lr"] = lr 15 | return lr -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import misc 3 | 4 | class AverageMeter: 5 | r""" Stores loss, evaluation results """ 6 | def __init__(self, class_ids, logger): 7 | self.logger = logger 8 | self.class_ids_interest = class_ids 9 | self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda() 10 | 11 | self.nclass = 2000 12 | 13 | self.intersection_buf = torch.zeros([self.nclass]).float().cuda() 14 | self.union_buf = torch.zeros([self.nclass]).float().cuda() 15 | self.ones = torch.ones_like(self.union_buf) 16 | self.loss_buf = [] 17 | self.valid_ids = torch.zeros([self.nclass]).float().cuda() 18 | 19 | def update(self, inter_b, union_b, class_id, loss): 20 | if isinstance(class_id, torch.Tensor): 21 | self.valid_ids[class_id.unique()] +=1 22 | elif isinstance(class_id, int): 23 | self.valid_ids[class_id] +=1 24 | 25 | self.intersection_buf[class_id] += inter_b 26 | self.union_buf[class_id] += union_b 27 | if loss is None: 28 | loss = torch.tensor(0.0) 29 | self.loss_buf.append(loss) 30 | 31 | def compute_iou(self): 32 | intersection_buf = torch.stack(misc.all_gather(self.intersection_buf.cpu())).sum(0).cuda() 33 | union_buf = torch.stack(misc.all_gather(self.union_buf.cpu())).sum(0).cuda() 34 | 35 | iou = intersection_buf.float() / union_buf.float().clip(min=1) 36 | # iou = iou[self.class_ids_interest] 37 | valid_ids = torch.stack(misc.all_gather(self.valid_ids.cpu())).sum(0).cuda() 38 | iou = iou[valid_ids > 0] 39 | print('shsape', iou.shape) 40 | miou = iou.mean() * 100 41 | 42 | return miou, None 43 | 44 | def write_result(self, split=0): 45 | iou, fb_iou = self.compute_iou() 46 | 47 | loss_buf = torch.stack(self.loss_buf) 48 | msg = '\n*** %s ' % split 49 | msg += 'Avg L: %6.5f ' % loss_buf.mean() 50 | msg += 'mIoU: %5.2f ' % iou 51 | 52 | msg += '***\n' 53 | self.logger.info(msg) 54 | 55 | def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20): 56 | if batch_idx % write_batch_idx == 0: 57 | msg = '[Epoch: %02d] ' % epoch if epoch != -1 else '' 58 | msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen) 59 | iou, fb_iou = self.compute_iou() 60 | if epoch != -1: 61 | loss_buf = torch.stack(self.loss_buf) 62 | msg += 'L: %6.5f ' % loss_buf[-1] 63 | msg += 'Avg L: %6.5f ' % loss_buf.mean() 64 | msg += 'mIoU: %5.2f | ' % iou 65 | msg += 'FB-IoU: %5.2f' % fb_iou 66 | self.logger.info(msg) -------------------------------------------------------------------------------- /utils/register_seginw_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import json 3 | import os 4 | import collections 5 | 6 | from detectron2.data import DatasetCatalog, MetadataCatalog 7 | from detectron2.data.datasets import load_sem_seg 8 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES 9 | from detectron2.utils.file_io import PathManager 10 | 11 | _CATEGORIES = ['Elephants', 'Hand-Metal', 'Watermelon', 'House-Parts', 'HouseHold-Items', 'Strawberry', 'Fruits', 'Nutterfly-Squireel', 12 | 'Hand', 'Garbage', 'Chicken', 'Rail', 'Airplane-Parts', 'Brain-Tumor', 'Poles', 'Electric-Shaver', 'Bottles', 13 | 'Toolkits', 'Trash', 'Salmon-Fillet', 'Puppies', 'Tablets', 'Phones', 'Cows', 'Ginger-Garlic'] 14 | 15 | _PREDEFINED_SPLITS_SEGINW = { 16 | "seginw_{}_val".format(cat): ( 17 | "valid", 18 | "seginw/{}".format(cat), # image_root 19 | "_annotations_min1cat.coco.json", # annot_root 20 | ) for cat in _CATEGORIES 21 | } 22 | _PREDEFINED_SPLITS_SEGINW.update({ 23 | "seginw_{}_train".format(cat): ( 24 | "train", 25 | "seginw/{}".format(cat), # image_root 26 | "_annotations_min1cat.coco.json", # annot_root 27 | ) for cat in _CATEGORIES 28 | }) 29 | 30 | _PREDEFINED_SPLITS_SEGINW.update({ 31 | "seginw_{}_train10shot".format(cat): ( 32 | "train_10shot", 33 | "seginw/{}".format(cat), # image_root 34 | "_annotations_min1cat.coco.json", # annot_root 35 | ) for cat in _CATEGORIES 36 | }) 37 | 38 | 39 | def get_metadata(): 40 | # meta = {"thing_dataset_id_to_contiguous_id": {}} 41 | meta = {} 42 | return meta 43 | 44 | 45 | def load_seginw_json(name, image_root, annot_json, metadata): 46 | """ 47 | Args: 48 | image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". 49 | gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017". 50 | json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json". 51 | Returns: 52 | list[dict]: a list of dicts in Detectron2 standard format. (See 53 | `Using Custom Datasets `_ ) 54 | """ 55 | 56 | with PathManager.open(annot_json) as f: 57 | json_info = json.load(f) 58 | 59 | # build dictionary for grounding 60 | grd_dict = collections.defaultdict(list) 61 | for grd_ann in json_info['annotations']: 62 | image_id = int(grd_ann["image_id"]) 63 | grd_dict[image_id].append(grd_ann) 64 | 65 | ret = [] 66 | for image in json_info["images"]: 67 | image_id = int(image["id"]) 68 | image_file = os.path.join(image_root, image['file_name']) 69 | grounding_anno = grd_dict[image_id] 70 | 71 | if 'train' in name and len(grounding_anno) == 0: 72 | continue 73 | 74 | ret.append( 75 | { 76 | "file_name": image_file, 77 | "image_id": image_id, 78 | "inst_info": grounding_anno, 79 | } 80 | ) 81 | 82 | assert len(ret), f"No images found in {image_root}!" 83 | assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] 84 | return ret 85 | 86 | 87 | def register_seginw( 88 | name, metadata, image_root, annot_json): 89 | DatasetCatalog.register( 90 | name, 91 | lambda: load_seginw_json(name, image_root, annot_json, metadata), 92 | ) 93 | MetadataCatalog.get(name).set( 94 | image_root=image_root, 95 | json_file=annot_json, 96 | evaluator_type="seginw", 97 | ignore_label=255, 98 | label_divisor=1000, 99 | **metadata, 100 | ) 101 | 102 | 103 | def register_all_seginw(root): 104 | for ( 105 | prefix, 106 | (split, folder_name, annot_name), 107 | ) in _PREDEFINED_SPLITS_SEGINW.items(): 108 | register_seginw( 109 | prefix, 110 | get_metadata(), 111 | os.path.join(root, folder_name, split), 112 | os.path.join(root, folder_name, split, annot_name), 113 | ) 114 | 115 | 116 | _root = os.getenv("DATASET", "datasets") 117 | register_all_seginw(_root) -------------------------------------------------------------------------------- /utils/seginw_data_mapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py 3 | import copy 4 | import random 5 | 6 | import scipy.io 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | 11 | from torchvision import transforms 12 | 13 | from pycocotools import mask 14 | from detectron2.structures import BitMasks, Boxes, Instances 15 | from detectron2.data import detection_utils as utils 16 | from detectron2.data import transforms as T 17 | 18 | # from xdecoder.utils import configurable 19 | from detectron2.config import configurable 20 | 21 | __all__ = ["SeginWDatasetMapper"] 22 | 23 | def build_transform_gen(cfg, is_train, empty=True): 24 | """ 25 | Create a list of default :class:`Augmentation` from config. 26 | Now it includes resizing and flipping. 27 | Returns: 28 | list[Augmentation] 29 | """ 30 | assert is_train, "Only support training augmentation" 31 | # cfg_input = cfg['INPUT'] 32 | # image_size = cfg_input['IMAGE_SIZE'] 33 | # min_scale = cfg_input['MIN_SCALE'] 34 | # max_scale = cfg_input['MAX_SCALE'] 35 | cfg_input = {'RANDOM_FLIP':"horizontal"} 36 | image_size = 1024 37 | min_scale = 0.8 38 | max_scale = 1.5 39 | 40 | augmentation = [] 41 | 42 | 43 | augmentation.append( 44 | T.Resize(image_size) 45 | ) 46 | 47 | # if cfg_input['RANDOM_FLIP'] != "none" and not empty: 48 | # augmentation.append( 49 | # T.RandomFlip( 50 | # horizontal=cfg_input['RANDOM_FLIP'] == "horizontal", 51 | # vertical=cfg_input['RANDOM_FLIP'] == "vertical", 52 | # ) 53 | # ) 54 | 55 | # augmentation.extend([ 56 | # T.ResizeShortestEdge([800, 1024], 1024, ), 57 | # # T.ResizeScale( 58 | # # min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size 59 | # # ), 60 | # # T.FixedSizeCrop(crop_size=(image_size, image_size)), 61 | # ]) 62 | 63 | if empty : 64 | augmentation = [] 65 | 66 | return augmentation 67 | 68 | 69 | # This is specifically designed for the COCO dataset. 70 | class SeginWDatasetMapper: 71 | """ 72 | A callable which takes a dataset dict in Detectron2 Dataset format, 73 | and map it into a format used by MaskFormer. 74 | 75 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 76 | 77 | The callable currently does the following: 78 | 79 | 1. Read the image from "file_name" 80 | 2. Applies geometric transforms to the image and annotation 81 | 3. Find and applies suitable cropping to the image and annotation 82 | 4. Prepare image and annotation to Tensors 83 | """ 84 | 85 | @configurable 86 | def __init__( 87 | self, 88 | is_train=True, 89 | tfm_gens=None, 90 | img_format=None, 91 | min_size_test=None, 92 | max_size_test=None, 93 | mean=None, 94 | std=None, 95 | ): 96 | """ 97 | NOTE: this interface is experimental. 98 | Args: 99 | is_train: for training or inference 100 | augmentations: a list of augmentations or deterministic transforms to apply 101 | tfm_gens: data augmentation 102 | image_format: an image format supported by :func:`detection_utils.read_image`. 103 | """ 104 | self.tfm_gens = tfm_gens 105 | self.img_format = img_format 106 | 107 | self.is_train = is_train 108 | self.min_size_test = min_size_test 109 | self.max_size_test = max_size_test 110 | self.pixel_mean = torch.tensor(mean)[:,None,None] 111 | self.pixel_std = torch.tensor(std)[:,None,None] 112 | 113 | t = [] 114 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC)) 115 | self.transform = transforms.Compose(t) 116 | 117 | @classmethod 118 | def from_config(cls, cfg, is_train=True): 119 | # Build augmentation 120 | if is_train: 121 | tfm_gens = build_transform_gen(cfg, is_train) 122 | else: 123 | tfm_gens = None 124 | 125 | ret = { "is_train": is_train, 126 | "tfm_gens": tfm_gens, 127 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'], 128 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'], 129 | "mean": cfg['INPUT']['PIXEL_MEAN'], 130 | "std": cfg['INPUT']['PIXEL_STD'], 131 | "img_format": cfg['INPUT']['FORMAT']} 132 | return ret 133 | 134 | def __call__(self, dataset_dict): 135 | """ 136 | Args: 137 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 138 | 139 | Returns: 140 | dict: a format that builtin models in detectron2 accept 141 | """ 142 | if self.is_train == False: 143 | file_name = dataset_dict['file_name'] 144 | image = Image.open(file_name).convert('RGB') 145 | dataset_dict['width'] = image.size[0] 146 | dataset_dict['height'] = image.size[1] 147 | image = self.transform(image) 148 | image = torch.from_numpy(np.asarray(image).copy()) 149 | image = image.permute(2,0,1) 150 | dataset_dict['image'] = image 151 | else: 152 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 153 | utils.check_image_size(dataset_dict, image) 154 | image, transforms = T.apply_transform_gens(self.tfm_gens, image) 155 | image_shape = image.shape[:2] # h, w 156 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 157 | grounding_anno = dataset_dict['inst_info'] 158 | # assert len(grounding_anno) > 0 159 | masks_grd = [] 160 | class_grd = [] 161 | for ann in grounding_anno: 162 | rle = mask.frPyObjects( 163 | ann['segmentation'], dataset_dict['height'], dataset_dict['width']) 164 | m = mask.decode(rle) 165 | # sometimes there are multiple binary map (corresponding to multiple segs) 166 | m = np.sum(m, axis=2) 167 | m = m.astype(np.uint8) # convert to np.uint8 168 | m = transforms.apply_segmentation(m[:,:,None])[:,:,0] 169 | masks_grd += [m] 170 | class_grd.append(ann['category_id']) 171 | 172 | is_things = [1 for idx in range(len(class_grd))] 173 | instances = Instances(image_shape) 174 | 175 | if len(masks_grd) == 0: 176 | # Some image does not have annotation (all ignored) 177 | instances.gt_masks = torch.zeros((0, image.shape[0], image.shape[1])) 178 | instances.gt_boxes = Boxes(torch.zeros((0, 4))) 179 | else: 180 | masks = BitMasks( 181 | torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks_grd]) 182 | ) 183 | instances.gt_masks = masks.tensor 184 | instances.gt_boxes = masks.get_bounding_boxes() 185 | 186 | instances.gt_classes = torch.tensor(class_grd, dtype=torch.int64) 187 | instances.is_things = torch.tensor(is_things, dtype=torch.int64) 188 | 189 | dataset_dict["instances"] = instances 190 | 191 | return dataset_dict -------------------------------------------------------------------------------- /utils/vos_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path, replace 3 | 4 | import torch 5 | import json 6 | import numpy as np 7 | from torch.utils.data.dataset import Dataset 8 | import torch.nn.functional as F 9 | from torchvision import transforms 10 | from torchvision.transforms import InterpolationMode 11 | from PIL import Image 12 | 13 | class YouTubeVOSTestDataset(Dataset): 14 | def __init__(self, data_root, split, transform): 15 | # self.image_dir = path.join(data_root, 'all_frames', split+'_all_frames', 'JPEGImages') 16 | self.image_dir = path.join(data_root, split, 'JPEGImages') 17 | self.mask_dir = path.join(data_root, split, 'Annotations') 18 | self.transform = transform 19 | 20 | self.vid_list = sorted(os.listdir(self.image_dir)) 21 | self.req_frame_list = {} 22 | 23 | with open(path.join(data_root, split, 'meta.json')) as f: 24 | # read meta.json to know which frame is required for evaluation 25 | meta = json.load(f)['videos'] 26 | 27 | for vid in self.vid_list: 28 | req_frames = [] 29 | objects = meta[vid]['objects'] 30 | for value in objects.values(): 31 | req_frames.extend(value['frames']) 32 | 33 | req_frames = list(set(req_frames)) 34 | self.req_frame_list[vid] = req_frames 35 | 36 | def load_video(self, vid): 37 | image_dir_this = os.path.join(self.image_dir, vid) 38 | mask_dir_this = os.path.join(self.mask_dir, vid) 39 | frames = sorted(os.listdir(image_dir_this)) 40 | first_gt_path = path.join(mask_dir_this, sorted(os.listdir(mask_dir_this))[0]) 41 | 42 | 43 | frame_list = [] 44 | for name in frames : 45 | frame_img = Image.open(os.path.join(image_dir_this, name)).convert('RGB') 46 | frame_list.append(frame_img) 47 | mask = Image.open(first_gt_path) 48 | mask = np.array(mask.convert('P'), dtype=np.uint8) 49 | 50 | return frame_list, mask, first_gt_path, frames 51 | 52 | def __getitem__(self, idx): 53 | # query_name, support_names, class_sample = self.sample_episode(idx) 54 | # query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize = self.load_frame() 55 | vid = self.vid_list[idx] 56 | frame_list, mask, mask_path, frames = self.load_video(vid) 57 | 58 | def to_tensor(x): 59 | return torch.tensor(np.array(x), dtype=torch.float32).permute(2,0,1) 60 | images = [to_tensor(x) for x in frame_list] 61 | images = torch.stack(images) 62 | mask = torch.tensor(mask) 63 | mask_ids = mask.unique()[1:] 64 | inst_masks = mask.unsqueeze(0).expand(len(mask_ids), -1, -1) == mask_ids.view(-1, 1, 1).int() 65 | 66 | sample = {'images': images, 67 | 'inst_masks': inst_masks * 255, 68 | 'inst_ids': mask_ids, 69 | 'vid': vid, 70 | 'frame_ids': frames, 71 | 'mask_path': mask_path 72 | } 73 | 74 | if self.transform: 75 | sample = self.transform(sample) 76 | 77 | sample.update({ 78 | 'ori_inst_masks':inst_masks * 255, 79 | }) 80 | 81 | return sample 82 | 83 | def __len__(self): 84 | return len(self.vid_list) 85 | 86 | 87 | 88 | class DAVISTestDataset(Dataset): 89 | def __init__(self, data_root, imset, transform): 90 | if False: 91 | self.image_dir = path.join(data_root, 'JPEGImages', 'Full-Resolution') 92 | self.mask_dir = path.join(data_root, 'Annotations', 'Full-Resolution') 93 | if not path.exists(self.image_dir): 94 | print(f'{self.image_dir} not found. Look at other options.') 95 | self.image_dir = path.join(data_root, 'JPEGImages', '1080p') 96 | self.mask_dir = path.join(data_root, 'Annotations', '1080p') 97 | assert path.exists(self.image_dir), 'path not found' 98 | else : 99 | self.image_dir = path.join(data_root, 'JPEGImages', '480p') 100 | self.mask_dir = path.join(data_root, 'Annotations', '480p') 101 | 102 | self.transform = transform 103 | 104 | 105 | with open(path.join(data_root, 'ImageSets', imset)) as f: 106 | self.vid_list = sorted([line.strip() for line in f]) 107 | # self.vid_list = sorted([line.strip().split('/')[-2] for line in f]) 108 | # self.vid_list = sorted(list(set(self.vid_list))) 109 | 110 | def load_video(self, vid): 111 | image_dir_this = os.path.join(self.image_dir, vid) 112 | mask_dir_this = os.path.join(self.mask_dir, vid) 113 | frames = sorted(os.listdir(image_dir_this)) 114 | first_gt_path = path.join(mask_dir_this, sorted(os.listdir(mask_dir_this))[0]) 115 | 116 | 117 | frame_list = [] 118 | for name in frames : 119 | frame_img = Image.open(os.path.join(image_dir_this, name)).convert('RGB') 120 | frame_list.append(frame_img) 121 | mask = Image.open(first_gt_path) 122 | mask = np.array(mask.convert('P'), dtype=np.uint8) 123 | 124 | return frame_list, mask, first_gt_path, frames 125 | 126 | def __getitem__(self, idx): 127 | # query_name, support_names, class_sample = self.sample_episode(idx) 128 | # query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize = self.load_frame() 129 | vid = self.vid_list[idx] 130 | frame_list, mask, mask_path, frames = self.load_video(vid) 131 | 132 | def to_tensor(x): 133 | return torch.tensor(np.array(x), dtype=torch.float32).permute(2,0,1) 134 | images = [to_tensor(x) for x in frame_list] 135 | images = torch.stack(images) 136 | mask = torch.tensor(mask) 137 | mask_ids = mask.unique()[1:] 138 | inst_masks = mask.unsqueeze(0).expand(len(mask_ids), -1, -1) == mask_ids.view(-1, 1, 1).int() 139 | 140 | sample = {'images': images, 141 | 'inst_masks': inst_masks * 255, 142 | 'inst_ids': mask_ids, 143 | 'vid': vid, 144 | 'frame_ids': frames, 145 | 'mask_path': mask_path 146 | } 147 | 148 | if self.transform: 149 | sample = self.transform(sample) 150 | 151 | sample.update({ 152 | 'ori_inst_masks':inst_masks * 255, 153 | }) 154 | 155 | return sample 156 | 157 | def __len__(self): 158 | return len(self.vid_list) -------------------------------------------------------------------------------- /vos_benchmark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/SEGIC/73bbcc340825e89f18e1b607cdbaaa09adc9619a/vos_benchmark/__init__.py -------------------------------------------------------------------------------- /vos_benchmark/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | import time 4 | from multiprocessing import Pool 5 | 6 | import numpy as np 7 | from PIL import Image 8 | import tqdm 9 | 10 | from .evaluator import Evaluator 11 | 12 | 13 | class VideoEvaluator: 14 | """ 15 | A processing function object. 16 | This returns metrics for a single video. 17 | """ 18 | 19 | def __init__(self, gt_root, mask_root, skip_first_and_last=True): 20 | self.gt_root = gt_root 21 | self.mask_root = mask_root 22 | self.skip_first_and_last = skip_first_and_last 23 | 24 | def __call__(self, vid_name): 25 | vid_gt_path = path.join(self.gt_root, vid_name) 26 | vid_mask_path = path.join(self.mask_root, vid_name) 27 | 28 | frames = sorted(os.listdir(vid_gt_path)) 29 | if self.skip_first_and_last: 30 | # the first and the last frames are skipped in DAVIS semi-supervised evaluation 31 | frames = frames[1:-1] 32 | evaluator = Evaluator(name=vid_name) 33 | for f in frames: 34 | try: 35 | gt_array = np.array(Image.open(path.join(vid_gt_path, f))) 36 | mask_array = np.array(Image.open(path.join(vid_mask_path, f))) 37 | assert gt_array.shape[-2:] == mask_array.shape[-2:], \ 38 | f'Dimensions mismatch: GT: {gt_array.shape}, predicted: {mask_array.shape}. '\ 39 | f'GT path: {path.join(vid_gt_path, f)}; ' \ 40 | f'predicted path: {path.join(vid_mask_path, f)}' 41 | except FileNotFoundError: 42 | print(f'{f} not found in {vid_mask_path}.') 43 | exit(1) 44 | 45 | evaluator.feed_frame(mask_array, gt_array) 46 | iou, boundary_f = evaluator.conclude() 47 | return vid_name, iou, boundary_f 48 | 49 | 50 | def benchmark(gt_roots, 51 | mask_roots, 52 | strict=True, 53 | num_processes=None, 54 | *, 55 | verbose=True, 56 | skip_first_and_last=True): 57 | """ 58 | gt_roots: a list of paths to datasets, i.e., [path_to_DatasetA, path_to_DatasetB, ...] 59 | with the below directory structure 60 | DatasetA - 61 | Video 1 - 62 | xxxx.png 63 | ... 64 | Video 2 - 65 | xxxx.png 66 | ... 67 | ... 68 | DatasetB - 69 | ... 70 | mask_roots: same as above, but the .png are masks predicted by the model 71 | strict: when True, all videos in the dataset must have corresponding predictions. 72 | Setting it to False is useful in cases where the ground-truth contains both train/val 73 | sets, but the model only predicts the val subset. 74 | Either way, if a video is predicted (i.e., the corresponding folder exists), 75 | then it must at least contain all the masks in the ground truth annotations. 76 | Masks that are in the prediction but not in the ground-truth 77 | (i.e., sparse annotations) are ignored. 78 | skip_first_and_last: whether we should skip the first and the last frame in evaluation. 79 | This is used by DAVIS 2017 in their semi-supervised evaluation. 80 | It should be disabled for unsupervised evaluation. 81 | """ 82 | 83 | assert len(gt_roots) == len(mask_roots) 84 | single_dataset = (len(gt_roots) == 1) 85 | 86 | if verbose: 87 | if skip_first_and_last: 88 | print( 89 | 'We are *SKIPPING* the evaluation of the first and the last frame (standard for semi-supervised video object segmentation).' 90 | ) 91 | else: 92 | print( 93 | 'We are *NOT SKIPPING* the evaluation of the first and the last frame (*NOT STANDARD* for semi-supervised video object segmentation).' 94 | ) 95 | 96 | pool = Pool(num_processes) 97 | start = time.time() 98 | to_wait = [] 99 | for gt_root, mask_root in zip(gt_roots, mask_roots): 100 | #Validate folders 101 | validated = True 102 | gt_videos = os.listdir(gt_root) 103 | mask_videos = os.listdir(mask_root) 104 | 105 | # if the user passed the root directory instead of Annotations 106 | if len(gt_videos) != len(mask_videos): 107 | if 'Annotations' in gt_videos: 108 | if '.png' not in os.listdir(path.join(gt_root, 'Annotations'))[0]: 109 | gt_root = path.join(gt_root, 'Annotations') 110 | gt_videos = os.listdir(gt_root) 111 | 112 | # remove non-folder items 113 | gt_videos = list(filter(lambda x: path.isdir(path.join(gt_root, x)), gt_videos)) 114 | mask_videos = list(filter(lambda x: path.isdir(path.join(mask_root, x)), mask_videos)) 115 | 116 | if not strict: 117 | videos = sorted(list(set(gt_videos) & set(mask_videos))) 118 | else: 119 | gt_extras = set(gt_videos) - set(mask_videos) 120 | mask_extras = set(mask_videos) - set(gt_videos) 121 | 122 | if len(gt_extras) > 0: 123 | print(f'Videos that are in {gt_root} but not in {mask_root}: {gt_extras}') 124 | validated = False 125 | if len(mask_extras) > 0: 126 | print(f'Videos that are in {mask_root} but not in {gt_root}: {mask_extras}') 127 | validated = False 128 | if not validated: 129 | print('Validation failed. Exiting.') 130 | exit(1) 131 | 132 | videos = sorted(gt_videos) 133 | 134 | if verbose: 135 | print(f'In dataset {gt_root}, we are evaluating on {len(videos)} videos: {videos}') 136 | 137 | if single_dataset: 138 | if verbose: 139 | results = tqdm.tqdm(pool.imap( 140 | VideoEvaluator(gt_root, mask_root, skip_first_and_last=skip_first_and_last), 141 | videos), 142 | total=len(videos)) 143 | else: 144 | results = pool.map( 145 | VideoEvaluator(gt_root, mask_root, skip_first_and_last=skip_first_and_last), 146 | videos) 147 | else: 148 | to_wait.append( 149 | pool.map_async( 150 | VideoEvaluator(gt_root, mask_root, skip_first_and_last=skip_first_and_last), 151 | videos)) 152 | 153 | pool.close() 154 | 155 | all_global_jf, all_global_j, all_global_f = [], [], [] 156 | all_object_metrics = [] 157 | for i, mask_root in enumerate(mask_roots): 158 | if not single_dataset: 159 | results = to_wait[i].get() 160 | 161 | all_iou = [] 162 | all_boundary_f = [] 163 | object_metrics = {} 164 | for name, iou, boundary_f in results: 165 | all_iou.extend(list(iou.values())) 166 | all_boundary_f.extend(list(boundary_f.values())) 167 | object_metrics[name] = (iou, boundary_f) 168 | 169 | global_j = np.array(all_iou).mean() 170 | global_f = np.array(all_boundary_f).mean() 171 | global_jf = (global_j + global_f) / 2 172 | 173 | time_taken = (time.time() - start) 174 | """ 175 | Build string for reporting results 176 | """ 177 | # find max length for padding 178 | ml = max(*[len(n) for n in object_metrics.keys()], len('Global score')) 179 | # build header 180 | out_string = f'{"sequence":<{ml}},{"obj":>3}, {"J&F":>4}, {"J":>4}, {"F":>4}\n' 181 | out_string += f'{"Global score":<{ml}},{"":>3}, {global_jf:.1f}, {global_j:.1f}, {global_f:.1f}\n' 182 | # append one line for each object 183 | for name, (iou, boundary_f) in object_metrics.items(): 184 | for object_idx in iou.keys(): 185 | j, f = iou[object_idx], boundary_f[object_idx] 186 | jf = (j + f) / 2 187 | out_string += f'{name:<{ml}},{object_idx:03}, {jf:>4.1f}, {j:>4.1f}, {f:>4.1f}\n' 188 | 189 | # print to console 190 | if verbose: 191 | print(out_string.replace(',', ' '), end='') 192 | print('\nSummary:') 193 | print(f'Global score: J&F: {global_jf:.1f} J: {global_j:.1f} F: {global_f:.1f}') 194 | print(f'Time taken: {time_taken:.2f}s') 195 | 196 | # print to file 197 | with open(path.join(mask_root, 'results.csv'), 'w') as f: 198 | f.write(out_string) 199 | 200 | all_global_jf.append(global_jf) 201 | all_global_j.append(global_j) 202 | all_global_f.append(global_f) 203 | all_object_metrics.append(object_metrics) 204 | 205 | return all_global_jf, all_global_j, all_global_f, all_object_metrics 206 | -------------------------------------------------------------------------------- /vos_benchmark/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | import cv2 4 | from skimage.morphology import disk 5 | 6 | from .utils import _seg2bmap 7 | 8 | 9 | def get_iou(intersection, pixel_sum): 10 | # handle edge cases without resorting to epsilon 11 | if intersection == pixel_sum: 12 | # both mask and gt have zero pixels in them 13 | assert intersection == 0 14 | return 1 15 | 16 | return intersection / (pixel_sum - intersection) 17 | 18 | 19 | class Evaluator: 20 | 21 | def __init__(self, boundary=0.008, name=None): 22 | # boundary: used in computing boundary F-score 23 | self.boundary = boundary 24 | self.name = name 25 | self.objects_in_gt = set() 26 | self.objects_in_masks = set() 27 | 28 | self.object_iou = defaultdict(list) 29 | self.boundary_f = defaultdict(list) 30 | 31 | def feed_frame(self, mask: np.ndarray, gt: np.ndarray): 32 | """ 33 | Compute and accumulate metrics for a single frame (mask/gt pair) 34 | """ 35 | 36 | # get all objects in the ground-truth 37 | gt_objects = np.unique(gt) 38 | gt_objects = gt_objects[gt_objects != 0].tolist() 39 | 40 | # get all objects in the predicted mask 41 | mask_objects = np.unique(mask) 42 | mask_objects = mask_objects[mask_objects != 0].tolist() 43 | 44 | self.objects_in_gt.update(set(gt_objects)) 45 | self.objects_in_masks.update(set(mask_objects)) 46 | 47 | all_objects = self.objects_in_gt.union(self.objects_in_masks) 48 | 49 | # boundary disk for boundary F-score. It is the same for all objects. 50 | bound_pix = np.ceil(self.boundary * np.linalg.norm(mask.shape)) 51 | boundary_disk = disk(bound_pix) 52 | 53 | for obj_idx in all_objects: 54 | obj_mask = (mask == obj_idx) 55 | obj_gt = (gt == obj_idx) 56 | 57 | # object iou 58 | self.object_iou[obj_idx].append( 59 | get_iou((obj_mask * obj_gt).sum(), 60 | obj_mask.sum() + obj_gt.sum())) 61 | """ 62 | # boundary f-score 63 | This part is copied from davis2017-evaluation 64 | """ 65 | mask_boundary = _seg2bmap(obj_mask) 66 | gt_boundary = _seg2bmap(obj_gt) 67 | mask_dilated = cv2.dilate(mask_boundary.astype(np.uint8), boundary_disk) 68 | gt_dilated = cv2.dilate(gt_boundary.astype(np.uint8), boundary_disk) 69 | 70 | # Get the intersection 71 | gt_match = gt_boundary * mask_dilated 72 | fg_match = mask_boundary * gt_dilated 73 | 74 | # Area of the intersection 75 | n_fg = np.sum(mask_boundary) 76 | n_gt = np.sum(gt_boundary) 77 | 78 | # Compute precision and recall 79 | if n_fg == 0 and n_gt > 0: 80 | precision = 1 81 | recall = 0 82 | elif n_fg > 0 and n_gt == 0: 83 | precision = 0 84 | recall = 1 85 | elif n_fg == 0 and n_gt == 0: 86 | precision = 1 87 | recall = 1 88 | else: 89 | precision = np.sum(fg_match) / float(n_fg) 90 | recall = np.sum(gt_match) / float(n_gt) 91 | 92 | # Compute F measure 93 | if precision + recall == 0: 94 | F = 0 95 | else: 96 | F = 2 * precision * recall / (precision + recall) 97 | self.boundary_f[obj_idx].append(F) 98 | 99 | def conclude(self): 100 | all_iou = {} 101 | all_boundary_f = {} 102 | 103 | for object_id in self.objects_in_gt: 104 | all_iou[object_id] = np.mean(self.object_iou[object_id]) * 100 105 | all_boundary_f[object_id] = np.mean(self.boundary_f[object_id]) * 100 106 | 107 | return all_iou, all_boundary_f 108 | -------------------------------------------------------------------------------- /vos_benchmark/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | 5 | def _seg2bmap(seg, width=None, height=None): 6 | """ 7 | From a segmentation, compute a binary boundary map with 1 pixel wide 8 | boundaries. The boundary pixels are offset by 1/2 pixel towards the 9 | origin from the actual segment boundary. 10 | Arguments: 11 | seg : Segments labeled from 1..k. 12 | width : Width of desired bmap <= seg.shape[1] 13 | height : Height of desired bmap <= seg.shape[0] 14 | Returns: 15 | bmap (ndarray): Binary boundary map. 16 | David Martin 17 | January 2003 18 | """ 19 | 20 | seg = seg.astype(bool) 21 | seg[seg > 0] = 1 22 | 23 | assert np.atleast_3d(seg).shape[2] == 1 24 | 25 | width = seg.shape[1] if width is None else width 26 | height = seg.shape[0] if height is None else height 27 | 28 | h, w = seg.shape[:2] 29 | 30 | ar1 = float(width) / float(height) 31 | ar2 = float(w) / float(h) 32 | 33 | assert not (width > w | height > h | abs(ar1 - ar2) > 34 | 0.01), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height) 35 | 36 | e = np.zeros_like(seg) 37 | s = np.zeros_like(seg) 38 | se = np.zeros_like(seg) 39 | 40 | e[:, :-1] = seg[:, 1:] 41 | s[:-1, :] = seg[1:, :] 42 | se[:-1, :-1] = seg[1:, 1:] 43 | 44 | b = seg ^ e | seg ^ s | seg ^ se 45 | b[-1, :] = seg[-1, :] ^ e[-1, :] 46 | b[:, -1] = seg[:, -1] ^ s[:, -1] 47 | b[-1, -1] = 0 48 | 49 | if w == width and h == height: 50 | bmap = b 51 | else: 52 | bmap = np.zeros((height, width)) 53 | for x in range(w): 54 | for y in range(h): 55 | if b[y, x]: 56 | j = 1 + math.floor((y - 1) + height / h) 57 | i = 1 + math.floor((x - 1) + width / h) 58 | bmap[j, i] = 1 59 | 60 | return bmap --------------------------------------------------------------------------------