├── .gitignore ├── README.md ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── configs ├── base_config.py ├── cfg_ade20k.py ├── cfg_city_scapes.py ├── cfg_coco_object.py ├── cfg_coco_stuff164k.py ├── cfg_context59.py ├── cfg_context60.py ├── cfg_voc20.py ├── cfg_voc21.py ├── cls_ade20k.txt ├── cls_city_scapes.txt ├── cls_coco_object.txt ├── cls_coco_stuff.txt ├── cls_context59.txt ├── cls_context60.txt ├── cls_voc20.txt └── cls_voc21.txt ├── custom_datasets.py ├── datasets └── cvt_coco_object.py ├── demo.py ├── dist_test.sh ├── eval.py ├── figs ├── demo.jpg └── scclip.jpg ├── pamr.py ├── prompts └── imagenet_template.py └── scclip_segmentor.py /.gitignore: -------------------------------------------------------------------------------- 1 | /outputs 2 | /.dist_test 3 | **__pycache__** 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Calibrated CLIP for Training-Free Open-Vocabulary Segmentation 2 | 3 | [Sule Bai*](https://sulebai.github.io/), [Yong Liu*](https://yongliu20.github.io/), [Yifei Han](https://github.com/LambdaGuard), [Haoji Zhang](https://zhang9302002.github.io/), [Yansong Tang](https://andytang15.github.io/) 4 | (* denotes equal contribution) 5 | 6 | **Official PyTorch Implementation of [Self-Calibrated CLIP for Training-Free Open-Vocabulary Segmentation](https://arxiv.org/abs/2411.15869)** 7 | 8 | 9 | 10 |
11 | 12 |
13 | 14 | ## Abstract 15 | > Recent advancements in pre-trained vision-language models like CLIP, have enabled the task of open-vocabulary segmentation. CLIP demonstrates impressive zero-shot capabilities in various downstream tasks that require holistic image understanding. However, due to its image-level pre-training, CLIP struggles to capture local details, resulting in poor performance in segmentation tasks. Our analysis reveals that anomaly tokens emerge during the forward pass, drawing excessive attention from normal patch tokens, thereby diminishing spatial awareness. To address this issue, we propose Self-Calibrated CLIP (SC-CLIP), a training-free method that calibrates CLIP to produce finer-grained representations while preserving its original generalization ability, without introducing new parameters or relying on additional backbones. Specifically, we first identify and resolve the anomaly tokens to mitigate their negative impact. Next, we enhance feature discriminability and attention correlation by leveraging the semantic consistency found in CLIP's intermediate features. Furthermore, we employ multi-level feature fusion to enrich details. Collectively, these strategies enhance CLIP's feature representation with greater granularity and coherence. Experimental results demonstrate the effectiveness of SC-CLIP, achieving state-of-the-art results across eight semantic segmentation datasets and surpassing previous methods by 9.5%. Notably, SC-CLIP boosts the performance of vanilla CLIP ViT-L/14 by 6.8 times. 16 | 17 | ## Dependencies 18 | 19 | ``` 20 | git clone https://github.com/SuleBai/SC-CLIP.git 21 | cd SC-CLIP 22 | 23 | conda create -n scclip python=3.9 24 | conda activate scclip 25 | pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html 26 | pip install openmim 27 | mim install mmcv==2.0.1 mmengine==0.8.4 mmsegmentation==1.1.1 28 | pip install ftfy regex numpy==1.26 yapf==0.40.1 29 | ``` 30 | 31 | ## Datasets 32 | We provide the dataset configurations in this repository, following [SCLIP](https://github.com/wangf3014/SCLIP). 33 | 34 | Please follow the [MMSeg data preparation document](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md) to download and pre-process the datasets. The COCO-Object dataset can be converted from COCO-Stuff164k by executing the following command: 35 | 36 | ``` 37 | python ./datasets/cvt_coco_object.py PATH_TO_COCO_STUFF164K -o PATH_TO_COCO_OBJECT 38 | ``` 39 | 40 | ## Quick Inference 41 | ``` 42 | python demo.py 43 | ``` 44 | 45 | ## Model Evaluation 46 | Single-GPU running: 47 | 48 | ``` 49 | python eval.py --config ./configs/cfg_DATASET.py --workdir YOUR_WORK_DIR 50 | ``` 51 | 52 | Multi-GPU running: 53 | ``` 54 | bash ./dist_test.sh 55 | ``` 56 | 57 | ## Acknowledgement 58 | This implementation is based on [CLIP](https://github.com/openai/CLIP), [SCLIP](https://github.com/wangf3014/SCLIP), [CLIP-DINOiser](https://github.com/wysoczanska/clip_dinoiser) and [ClearCLIP](https://github.com/mc-lan/ClearCLIP). Thanks for the awesome work. -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | from .model import * 3 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SuleBai/SC-CLIP/0417ba92851e9dd7432d608f10a0804d01a23062/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | ### CLIP source code from OpenAI: 2 | # https://github.com/openai/CLIP/blob/main/clip/clip.py 3 | 4 | import hashlib 5 | import os 6 | import urllib 7 | import warnings 8 | from typing import Any, Union, List 9 | from pkg_resources import packaging 10 | 11 | import torch 12 | from PIL import Image 13 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 14 | from tqdm import tqdm 15 | 16 | from .model import build_model 17 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 18 | 19 | try: 20 | from torchvision.transforms import InterpolationMode 21 | BICUBIC = InterpolationMode.BICUBIC 22 | except ImportError: 23 | BICUBIC = Image.BICUBIC 24 | 25 | 26 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 27 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 28 | 29 | 30 | __all__ = ["available_models", "load", "tokenize"] 31 | _tokenizer = _Tokenizer() 32 | 33 | _MODELS = { 34 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 35 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 36 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 37 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 38 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 39 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 40 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 41 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 42 | } 43 | 44 | 45 | def _download(url: str, root: str): 46 | os.makedirs(root, exist_ok=True) 47 | filename = os.path.basename(url) 48 | 49 | expected_sha256 = url.split("/")[-2] 50 | download_target = os.path.join(root, filename) 51 | 52 | if os.path.exists(download_target) and not os.path.isfile(download_target): 53 | raise RuntimeError(f"{download_target} exists and is not a regular file") 54 | 55 | if os.path.isfile(download_target): 56 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 57 | return download_target 58 | else: 59 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 60 | 61 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 62 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 63 | while True: 64 | buffer = source.read(8192) 65 | if not buffer: 66 | break 67 | 68 | output.write(buffer) 69 | loop.update(len(buffer)) 70 | 71 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 72 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 73 | 74 | return download_target 75 | 76 | 77 | def _convert_image_to_rgb(image): 78 | return image.convert("RGB") 79 | 80 | 81 | def _transform(n_px): 82 | return Compose([ 83 | Resize(n_px, interpolation=BICUBIC), 84 | CenterCrop(n_px), 85 | _convert_image_to_rgb, 86 | ToTensor(), 87 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 88 | ]) 89 | 90 | 91 | def available_models() -> List[str]: 92 | """Returns the names of available CLIP models""" 93 | return list(_MODELS.keys()) 94 | 95 | 96 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 97 | """Load a CLIP model 98 | 99 | Parameters 100 | ---------- 101 | name : str 102 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 103 | 104 | device : Union[str, torch.device] 105 | The device to put the loaded model 106 | 107 | jit : bool 108 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 109 | 110 | download_root: str 111 | path to download the model files; by default, it uses "~/.cache/clip" 112 | 113 | Returns 114 | ------- 115 | model : torch.nn.Module 116 | The CLIP model 117 | 118 | preprocess : Callable[[PIL.Image], torch.Tensor] 119 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 120 | """ 121 | if name in _MODELS: 122 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 123 | elif os.path.isfile(name): 124 | model_path = name 125 | else: 126 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 127 | 128 | try: 129 | # loading JIT archive 130 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 131 | state_dict = None 132 | except RuntimeError: 133 | # loading saved state dict 134 | if jit: 135 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 136 | jit = False 137 | state_dict = torch.load(model_path, map_location="cpu") 138 | 139 | if not jit: 140 | model = build_model(state_dict or model.state_dict()).to(device) 141 | if str(device) == "cpu": 142 | model.float() 143 | return model, _transform(model.visual.input_resolution) 144 | 145 | # patch the device names 146 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 147 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 148 | 149 | def patch_device(module): 150 | try: 151 | graphs = [module.graph] if hasattr(module, "graph") else [] 152 | except RuntimeError: 153 | graphs = [] 154 | 155 | if hasattr(module, "forward1"): 156 | graphs.append(module.forward1.graph) 157 | 158 | for graph in graphs: 159 | for node in graph.findAllNodes("prim::Constant"): 160 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 161 | node.copyAttributes(device_node) 162 | 163 | model.apply(patch_device) 164 | patch_device(model.encode_image) 165 | patch_device(model.encode_text) 166 | 167 | # patch dtype to float32 on CPU 168 | if str(device) == "cpu": 169 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 170 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 171 | float_node = float_input.node() 172 | 173 | def patch_float(module): 174 | try: 175 | graphs = [module.graph] if hasattr(module, "graph") else [] 176 | except RuntimeError: 177 | graphs = [] 178 | 179 | if hasattr(module, "forward1"): 180 | graphs.append(module.forward1.graph) 181 | 182 | for graph in graphs: 183 | for node in graph.findAllNodes("aten::to"): 184 | inputs = list(node.inputs()) 185 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 186 | if inputs[i].node()["value"] == 5: 187 | inputs[i].node().copyAttributes(float_node) 188 | 189 | model.apply(patch_float) 190 | patch_float(model.encode_image) 191 | patch_float(model.encode_text) 192 | 193 | model.float() 194 | 195 | return model, _transform(model.input_resolution.item()) 196 | 197 | 198 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 199 | """ 200 | Returns the tokenized representation of given input string(s) 201 | 202 | Parameters 203 | ---------- 204 | texts : Union[str, List[str]] 205 | An input string or a list of input strings to tokenize 206 | 207 | context_length : int 208 | The context length to use; all CLIP models use 77 as the context length 209 | 210 | truncate: bool 211 | Whether to truncate the text in case its encoding is longer than the context length 212 | 213 | Returns 214 | ------- 215 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 216 | """ 217 | if isinstance(texts, str): 218 | texts = [texts] 219 | 220 | sot_token = _tokenizer.encoder["<|startoftext|>"] 221 | eot_token = _tokenizer.encoder["<|endoftext|>"] 222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 223 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 224 | 225 | for i, tokens in enumerate(all_tokens): 226 | if len(tokens) > context_length: 227 | if truncate: 228 | tokens = tokens[:context_length] 229 | tokens[-1] = eot_token 230 | else: 231 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 232 | result[i, :len(tokens)] = torch.tensor(tokens) 233 | 234 | return result -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | ### CLIP source code from OpenAI: 2 | # https://github.com/openai/CLIP/blob/main/clip/clip.py 3 | 4 | from collections import OrderedDict 5 | from typing import Tuple, Union 6 | import math 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | import torchvision.transforms.functional as VF 12 | 13 | class Bottleneck(nn.Module): 14 | expansion = 4 15 | 16 | def __init__(self, inplanes, planes, stride=1): 17 | super().__init__() 18 | 19 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 20 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | 23 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 27 | 28 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 29 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 30 | 31 | self.relu = nn.ReLU(inplace=True) 32 | self.downsample = None 33 | self.stride = stride 34 | 35 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 36 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 37 | self.downsample = nn.Sequential(OrderedDict([ 38 | ("-1", nn.AvgPool2d(stride)), 39 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 40 | ("1", nn.BatchNorm2d(planes * self.expansion)) 41 | ])) 42 | 43 | def forward(self, x: torch.Tensor): 44 | identity = x 45 | 46 | out = self.relu(self.bn1(self.conv1(x))) 47 | out = self.relu(self.bn2(self.conv2(out))) 48 | out = self.avgpool(out) 49 | out = self.bn3(self.conv3(out)) 50 | 51 | if self.downsample is not None: 52 | identity = self.downsample(x) 53 | 54 | out += identity 55 | out = self.relu(out) 56 | return out 57 | 58 | 59 | class AttentionPool2d(nn.Module): 60 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 61 | super().__init__() 62 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 63 | self.k_proj = nn.Linear(embed_dim, embed_dim) 64 | self.q_proj = nn.Linear(embed_dim, embed_dim) 65 | self.v_proj = nn.Linear(embed_dim, embed_dim) 66 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 67 | self.num_heads = num_heads 68 | 69 | def forward(self, x, return_all_tokens=False): 70 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 71 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 72 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 73 | x, _ = F.multi_head_attention_forward( 74 | query=x, key=x, value=x, 75 | embed_dim_to_check=x.shape[-1], 76 | num_heads=self.num_heads, 77 | q_proj_weight=self.q_proj.weight, 78 | k_proj_weight=self.k_proj.weight, 79 | v_proj_weight=self.v_proj.weight, 80 | in_proj_weight=None, 81 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 82 | bias_k=None, 83 | bias_v=None, 84 | add_zero_attn=False, 85 | dropout_p=0, 86 | out_proj_weight=self.c_proj.weight, 87 | out_proj_bias=self.c_proj.bias, 88 | use_separate_proj_weight=True, 89 | training=self.training, 90 | need_weights=False 91 | ) 92 | if return_all_tokens: 93 | return x 94 | else: 95 | return x[0] 96 | 97 | 98 | class ModifiedResNet(nn.Module): 99 | """ 100 | A ResNet class that is similar to torchvision's but contains the following changes: 101 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 102 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 103 | - The final pooling layer is a QKV attention instead of an average pool 104 | """ 105 | 106 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 107 | super().__init__() 108 | self.output_dim = output_dim 109 | self.input_resolution = input_resolution 110 | 111 | # the 3-layer stem 112 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 113 | self.bn1 = nn.BatchNorm2d(width // 2) 114 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 115 | self.bn2 = nn.BatchNorm2d(width // 2) 116 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 117 | self.bn3 = nn.BatchNorm2d(width) 118 | self.avgpool = nn.AvgPool2d(2) 119 | self.relu = nn.ReLU(inplace=True) 120 | 121 | # residual layers 122 | self._inplanes = width # this is a *mutable* variable used during construction 123 | self.layer1 = self._make_layer(width, layers[0]) 124 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 125 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 126 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 127 | 128 | embed_dim = width * 32 # the ResNet feature dimension 129 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 130 | 131 | def _make_layer(self, planes, blocks, stride=1): 132 | layers = [Bottleneck(self._inplanes, planes, stride)] 133 | 134 | self._inplanes = planes * Bottleneck.expansion 135 | for _ in range(1, blocks): 136 | layers.append(Bottleneck(self._inplanes, planes)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x, return_all_tokens=False): 141 | def stem(x): 142 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 143 | x = self.relu(bn(conv(x))) 144 | x = self.avgpool(x) 145 | return x 146 | 147 | x = x.type(self.conv1.weight.dtype) 148 | x = stem(x) 149 | x = self.layer1(x) 150 | x = self.layer2(x) 151 | x = self.layer3(x) 152 | x = self.layer4(x) 153 | x = self.attnpool(x, return_all_tokens) 154 | 155 | return x 156 | 157 | 158 | class LayerNorm(nn.LayerNorm): 159 | """Subclass torch's LayerNorm to handle fp16.""" 160 | 161 | def forward(self, x: torch.Tensor): 162 | orig_type = x.dtype 163 | ret = super().forward(x.type(torch.float32)) 164 | return ret.type(orig_type) 165 | 166 | 167 | class QuickGELU(nn.Module): 168 | def forward(self, x: torch.Tensor): 169 | return x * torch.sigmoid(1.702 * x) 170 | 171 | 172 | class ResidualAttentionBlock(nn.Module): 173 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 174 | super().__init__() 175 | 176 | self.attn = nn.MultiheadAttention(d_model, n_head) 177 | self.ln_1 = LayerNorm(d_model) 178 | self.mlp = nn.Sequential(OrderedDict([ 179 | ("c_fc", nn.Linear(d_model, d_model * 4)), 180 | ("gelu", QuickGELU()), 181 | ("c_proj", nn.Linear(d_model * 4, d_model)) 182 | ])) 183 | self.ln_2 = LayerNorm(d_model) 184 | self.attn_mask = attn_mask 185 | 186 | def attention(self, x: torch.Tensor): 187 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 188 | # pdb.set_trace() 189 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 190 | 191 | def forward(self, x: torch.Tensor): 192 | x = x + self.attention(self.ln_1(x)) 193 | x = x + self.mlp(self.ln_2(x)) 194 | return x 195 | 196 | 197 | class Transformer(nn.Module): 198 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 199 | super().__init__() 200 | self.width = width 201 | self.layers = layers 202 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 203 | 204 | def forward(self, x: torch.Tensor): 205 | return self.resblocks(x) 206 | 207 | 208 | def lof_pytorch(x, n_neighbors=30, contamination=0.05): 209 | distances = torch.norm(x[:, None] - x[None, :], dim=2, p=2) ** 2 210 | 211 | knn_distances, knn_indices = torch.topk(distances, k=n_neighbors+1, largest=False) 212 | knn_distances, knn_indices = knn_distances[:, 1:], knn_indices[:, 1:] 213 | 214 | k_distances = knn_distances[:, -1].unsqueeze(1).expand_as(knn_distances) 215 | reach_distances = torch.max(knn_distances, k_distances) 216 | 217 | LRD = n_neighbors / torch.nan_to_num(reach_distances.mean(dim=1), nan=1e-6) 218 | 219 | LRD_ratios = LRD[knn_indices] / LRD.unsqueeze(1) 220 | LOF_scores = LRD_ratios.mean(dim=1) 221 | 222 | threshold = torch.quantile(LOF_scores.to(torch.float32), 1 - contamination) 223 | 224 | outlier_mask = LOF_scores > threshold 225 | outlier_indices = torch.where(outlier_mask)[0] 226 | 227 | return outlier_indices, LOF_scores 228 | 229 | 230 | class VisionTransformer(nn.Module): 231 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 232 | super().__init__() 233 | self.input_resolution = input_resolution 234 | self.patch_size = patch_size 235 | self.output_dim = output_dim 236 | self.width = width 237 | self.heads = heads 238 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 239 | scale = width ** -0.5 240 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 241 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 242 | self.ln_pre = LayerNorm(width) 243 | self.transformer = Transformer(width, layers, heads) 244 | self.ln_post = LayerNorm(width) 245 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 246 | 247 | self.beta = 0.4 248 | self.pre_adjust_idx= 8 249 | self.post_adjust_idx = 3 250 | self.multi_start_idx = 3 251 | self.multi_end_idx = 10 252 | self.res_cls = 0.3 253 | 254 | def forward(self, x: torch.Tensor, return_all=False): 255 | B, nc, w, h = x.shape 256 | x = self.conv1(x) 257 | feat_w, feat_h = x.shape[-2], x.shape[-1] 258 | x = x.reshape(x.shape[0], x.shape[1], -1) 259 | x = x.permute(0, 2, 1) 260 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) 261 | if x.shape[1] != self.positional_embedding.shape[0]: 262 | x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype) 263 | else: 264 | x = x + self.positional_embedding.to(x.dtype) 265 | x = self.ln_pre(x) 266 | 267 | x = x.permute(1, 0, 2) 268 | feats_list = [] 269 | for idx, blk in enumerate(self.transformer.resblocks[:-1], start=1): 270 | x = blk(x) 271 | feats_list.append(x) 272 | if idx == len(self.transformer.resblocks) - 1: 273 | cls_token = x[:1, ...] 274 | outlier_indices, LOF_scores = lof_pytorch(x[1:, ...].squeeze(1), n_neighbors=30, contamination=0.05) 275 | top_indices = [(torch.div(index, feat_w, rounding_mode='trunc'), index % feat_w) for index in outlier_indices] 276 | feature_map = x[1:, :, :].permute(1, 2, 0).reshape(B, self.width, feat_w, feat_h) 277 | feature_map = self.mean_interpolation(feature_map, top_indices) 278 | x = feature_map.reshape(B, self.width, feat_w * feat_h).permute(2, 0, 1) 279 | 280 | feats = feats_list[self.pre_adjust_idx][1:, ...].clone() 281 | feats = feats.permute(1, 2, 0).reshape(B, self.width, feat_w, feat_h) 282 | feats = self.mean_interpolation(feats, top_indices) 283 | feats = feats.reshape(B, self.width, feat_w * feat_h).permute(2, 0, 1) 284 | feats = feats / feats.norm(dim=2, keepdim=True) 285 | before_simi = torch.matmul(feats.permute(1, 0, 2), feats.permute(1, 2, 0)) 286 | mid_simi = before_simi.clone() 287 | before_simi[before_simi < self.beta] = 0.0 288 | x = self.adaptively_aggregate(x, before_simi) 289 | 290 | for blk in self.transformer.resblocks[-1:]: 291 | x = self.custom_attn(blk.attn, blk.ln_1(x), mid_simi=mid_simi) + self.res_cls * cls_token 292 | 293 | feats = feats_list[self.post_adjust_idx][1:, ...].clone() 294 | feats = feats / feats.norm(dim=2, keepdim=True) 295 | after_simi = torch.matmul(feats.permute(1, 0, 2), feats.permute(1, 2, 0)) 296 | after_simi[after_simi < self.beta] = 0.0 297 | x = self.adaptively_aggregate(x, after_simi) 298 | 299 | re_feats = torch.zeros_like(feats_list[0]) 300 | for i in range(self.multi_start_idx, self.multi_end_idx): 301 | re_feats += feats_list[i] 302 | cls_token = re_feats[:1, ...] 303 | blk = self.transformer.resblocks[-1] 304 | re_feats = self.custom_attn(blk.attn, blk.ln_1(re_feats[1:, ...]), mid_simi=mid_simi) + self.res_cls * cls_token 305 | re_feats = self.adaptively_aggregate(re_feats, after_simi) 306 | x += re_feats 307 | 308 | x = x.permute(1, 0, 2) 309 | if return_all: 310 | return self.ln_post(x) @ self.proj 311 | 312 | x = self.ln_post(x[:, 0, :]) 313 | if self.proj is not None: 314 | x = x @ self.proj 315 | 316 | return x 317 | 318 | def custom_attn(self, attn_layer, x, mid_simi): 319 | num_heads = attn_layer.num_heads 320 | _, bsz, embed_dim = x.size() 321 | head_dim = embed_dim // num_heads 322 | scale = head_dim ** -0.5 323 | 324 | q, k, v = F.linear(x, attn_layer.in_proj_weight, attn_layer.in_proj_bias).chunk(3, dim=-1) 325 | q = q.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 326 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 327 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 328 | 329 | mid_simi = (mid_simi - torch.mean(mid_simi)) * 3.0 330 | mid_simi[mid_simi < 0.0] = float('-inf') 331 | mid_simi = mid_simi.repeat(num_heads, 1, 1) 332 | attn_weights = F.softmax(mid_simi, dim=-1) 333 | k_attn = torch.bmm(k, k.transpose(1, 2)) * scale 334 | attn_weights += F.softmax(k_attn, dim=-1) 335 | attn_weights /= 2 336 | 337 | attn_output = torch.bmm(attn_weights, v) 338 | attn_output = attn_output.transpose(0, 1).contiguous().view(-1, bsz, embed_dim) 339 | attn_output = attn_layer.out_proj(attn_output) 340 | 341 | return attn_output 342 | 343 | def adaptively_aggregate(self, maskclip_feats: torch.Tensor, corrs: torch.Tensor): 344 | corrs_normalized = corrs / (corrs.sum(dim=-1, keepdim=True) + 1e-6) 345 | maskclip_feats_ref = torch.matmul(corrs_normalized, maskclip_feats.permute(1, 0, 2)) 346 | return maskclip_feats_ref.permute(1, 0, 2) 347 | 348 | def mean_interpolation(self, feature_map, top_indices): 349 | B, C, H, W = feature_map.shape 350 | device = feature_map.device 351 | dtype = feature_map.dtype 352 | 353 | kernel = torch.ones(C, 1, 3, 3, device=device, dtype=dtype) 354 | kernel[:, 0, 1, 1] = 0 355 | mask = torch.ones((H, W), device=device, dtype=dtype) 356 | indices = torch.tensor(top_indices, dtype=torch.long, device=device) 357 | mask[indices[:, 0], indices[:, 1]] = 0 358 | mask = mask.unsqueeze(0).unsqueeze(0) 359 | masked_feature_map = feature_map * mask 360 | padded_feature_map = F.pad(masked_feature_map, (1, 1, 1, 1), mode='constant', value=0) 361 | padded_mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) 362 | neighbor_sum = F.conv2d(padded_feature_map, kernel, groups=C) 363 | valid_neighbors = F.conv2d(padded_mask, kernel[:, :1, :, :], groups=1) 364 | valid_neighbor_mask = (valid_neighbors > 0).to(dtype) 365 | safe_valid_neighbors = valid_neighbors.clone() 366 | safe_valid_neighbors[safe_valid_neighbors == 0] = 1 367 | mean_neighbors = neighbor_sum / safe_valid_neighbors 368 | top_indices_mask = torch.zeros((H, W), device=device, dtype=dtype) 369 | top_indices_mask[indices[:, 0], indices[:, 1]] = 1 370 | top_indices_mask = top_indices_mask.unsqueeze(0).unsqueeze(0) 371 | update_mask = top_indices_mask * valid_neighbor_mask 372 | feature_map = feature_map * (1 - update_mask) + mean_neighbors * update_mask 373 | return feature_map 374 | 375 | def interpolate_pos_encoding(self, x, w, h): 376 | npatch = x.shape[1] - 1 377 | N = self.positional_embedding.shape[0] - 1 378 | if npatch == N and w == h: 379 | return self.positional_embedding 380 | class_pos_embed = self.positional_embedding[[0]] 381 | patch_pos_embed = self.positional_embedding[1:] 382 | dim = x.shape[-1] 383 | w0 = w // self.patch_size 384 | h0 = h // self.patch_size 385 | w0, h0 = w0 + 0.1, h0 + 0.1 386 | patch_pos_embed = nn.functional.interpolate( 387 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 388 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 389 | mode='bicubic', 390 | ) 391 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 392 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 393 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 394 | 395 | 396 | class CLIP(nn.Module): 397 | def __init__(self, 398 | embed_dim: int, # 512 399 | # vision 400 | image_resolution: int, # 224 401 | vision_layers: Union[Tuple[int, int, int, int], int], # 12 402 | vision_width: int, # 768 403 | vision_patch_size: int, # 16 404 | # text 405 | context_length: int, # 77 406 | vocab_size: int, # 49408 407 | transformer_width: int, # 512 408 | transformer_heads: int, # 8 409 | transformer_layers: int # 12 410 | ): 411 | super().__init__() 412 | self.context_length = context_length 413 | 414 | if isinstance(vision_layers, (tuple, list)): 415 | vision_heads = vision_width * 32 // 64 416 | self.visual = ModifiedResNet( 417 | layers=vision_layers, 418 | output_dim=embed_dim, 419 | heads=vision_heads, 420 | input_resolution=image_resolution, 421 | width=vision_width 422 | ) 423 | else: 424 | vision_heads = vision_width // 64 425 | self.visual = VisionTransformer( 426 | input_resolution=image_resolution, 427 | patch_size=vision_patch_size, 428 | width=vision_width, 429 | layers=vision_layers, 430 | heads=vision_heads, 431 | output_dim=embed_dim 432 | ) 433 | 434 | self.transformer = Transformer( 435 | width=transformer_width, 436 | layers=transformer_layers, 437 | heads=transformer_heads, 438 | attn_mask=self.build_attention_mask() 439 | ) 440 | 441 | self.vocab_size = vocab_size 442 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 443 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 444 | self.ln_final = LayerNorm(transformer_width) 445 | 446 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 447 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 448 | 449 | self.initialize_parameters() 450 | 451 | def initialize_parameters(self): 452 | nn.init.normal_(self.token_embedding.weight, std=0.02) 453 | nn.init.normal_(self.positional_embedding, std=0.01) 454 | 455 | if isinstance(self.visual, ModifiedResNet): 456 | if self.visual.attnpool is not None: 457 | std = self.visual.attnpool.c_proj.in_features ** -0.5 458 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 459 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 460 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 461 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 462 | 463 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 464 | for name, param in resnet_block.named_parameters(): 465 | if name.endswith("bn3.weight"): 466 | nn.init.zeros_(param) 467 | 468 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 469 | attn_std = self.transformer.width ** -0.5 470 | fc_std = (2 * self.transformer.width) ** -0.5 471 | for block in self.transformer.resblocks: 472 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 473 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 474 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 475 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 476 | 477 | if self.text_projection is not None: 478 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 479 | 480 | def build_attention_mask(self): 481 | # lazily create causal attention mask, with full attention between the vision tokens 482 | # pytorch uses additive attention mask; fill with -inf 483 | mask = torch.empty(self.context_length, self.context_length) 484 | mask.fill_(float("-inf")) 485 | mask.triu_(1) # zero out the lower diagonal 486 | return mask 487 | 488 | @property 489 | def dtype(self): 490 | return self.visual.conv1.weight.dtype 491 | 492 | def encode_image(self, image, return_all=False): 493 | return self.visual(image.type(self.dtype), return_all=return_all) 494 | 495 | def encode_text(self, text): 496 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 497 | 498 | x = x + self.positional_embedding.type(self.dtype) 499 | x = x.permute(1, 0, 2) # NLD -> LND 500 | x = self.transformer(x) 501 | x = x.permute(1, 0, 2) # LND -> NLD 502 | x = self.ln_final(x).type(self.dtype) 503 | 504 | return x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 505 | 506 | def forward(self, image, text): 507 | image_features = self.encode_image(image) 508 | text_features = self.encode_text(text) 509 | 510 | # normalized features 511 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 512 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 513 | 514 | # cosine similarity as logits 515 | logit_scale = self.logit_scale.exp() 516 | logits_per_image = logit_scale * image_features @ text_features.t() 517 | logits_per_text = logits_per_image.t() 518 | 519 | # shape = [global_batch_size, global_batch_size] 520 | return logits_per_image, logits_per_text 521 | 522 | def convert_weights(model: nn.Module): 523 | """Convert applicable model parameters to fp16""" 524 | 525 | def _convert_weights_to_fp16(l): 526 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 527 | l.weight.data = l.weight.data.half() 528 | if l.bias is not None: 529 | l.bias.data = l.bias.data.half() 530 | 531 | if isinstance(l, nn.MultiheadAttention): 532 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 533 | tensor = getattr(l, attr) 534 | if tensor is not None: 535 | tensor.data = tensor.data.half() 536 | 537 | for name in ["text_projection", "proj"]: 538 | if hasattr(l, name): 539 | attr = getattr(l, name) 540 | if attr is not None: 541 | attr.data = attr.data.half() 542 | 543 | model.apply(_convert_weights_to_fp16) 544 | 545 | def build_model(state_dict: dict): 546 | vit = "visual.proj" in state_dict 547 | 548 | if vit: 549 | vision_width = state_dict["visual.conv1.weight"].shape[0] 550 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 551 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 552 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 553 | image_resolution = vision_patch_size * grid_size 554 | else: 555 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 556 | vision_layers = tuple(counts) 557 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 558 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 559 | vision_patch_size = None 560 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 561 | image_resolution = output_width * 32 562 | 563 | embed_dim = state_dict["text_projection"].shape[1] 564 | context_length = state_dict["positional_embedding"].shape[0] 565 | vocab_size = state_dict["token_embedding.weight"].shape[0] 566 | transformer_width = state_dict["ln_final.weight"].shape[0] 567 | transformer_heads = transformer_width // 64 568 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 569 | 570 | model = CLIP( 571 | embed_dim, 572 | image_resolution, vision_layers, vision_width, vision_patch_size, 573 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 574 | ) 575 | 576 | for key in ["input_resolution", "context_length", "vocab_size"]: 577 | if key in state_dict: 578 | del state_dict[key] 579 | 580 | convert_weights(model) 581 | model.load_state_dict(state_dict) 582 | return model.eval() -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | ### CLIP source code from OpenAI: 2 | # https://github.com/openai/CLIP/blob/main/clip/clip.py 3 | 4 | import gzip 5 | import html 6 | import os 7 | from functools import lru_cache 8 | 9 | import ftfy 10 | import regex as re 11 | 12 | 13 | @lru_cache() 14 | def default_bpe(): 15 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 16 | 17 | 18 | @lru_cache() 19 | def bytes_to_unicode(): 20 | """ 21 | Returns list of utf-8 byte and a corresponding list of unicode strings. 22 | The reversible bpe codes work on unicode strings. 23 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 24 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 25 | This is a signficant percentage of your normal, say, 32K bpe vocab. 26 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 27 | And avoids mapping to whitespace/control characters the bpe code barfs on. 28 | """ 29 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 30 | cs = bs[:] 31 | n = 0 32 | for b in range(2**8): 33 | if b not in bs: 34 | bs.append(b) 35 | cs.append(2**8+n) 36 | n += 1 37 | cs = [chr(n) for n in cs] 38 | return dict(zip(bs, cs)) 39 | 40 | 41 | def get_pairs(word): 42 | """Return set of symbol pairs in a word. 43 | Word is represented as tuple of symbols (symbols being variable-length strings). 44 | """ 45 | pairs = set() 46 | prev_char = word[0] 47 | for char in word[1:]: 48 | pairs.add((prev_char, char)) 49 | prev_char = char 50 | return pairs 51 | 52 | 53 | def basic_clean(text): 54 | text = ftfy.fix_text(text) 55 | text = html.unescape(html.unescape(text)) 56 | return text.strip() 57 | 58 | 59 | def whitespace_clean(text): 60 | text = re.sub(r'\s+', ' ', text) 61 | text = text.strip() 62 | return text 63 | 64 | 65 | class SimpleTokenizer(object): 66 | def __init__(self, bpe_path: str = default_bpe()): 67 | self.byte_encoder = bytes_to_unicode() 68 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 69 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 70 | merges = merges[1:49152-256-2+1] 71 | merges = [tuple(merge.split()) for merge in merges] 72 | vocab = list(bytes_to_unicode().values()) 73 | vocab = vocab + [v+'' for v in vocab] 74 | for merge in merges: 75 | vocab.append(''.join(merge)) 76 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 77 | self.encoder = dict(zip(vocab, range(len(vocab)))) 78 | self.decoder = {v: k for k, v in self.encoder.items()} 79 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 80 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 81 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 82 | 83 | def bpe(self, token): 84 | if token in self.cache: 85 | return self.cache[token] 86 | word = tuple(token[:-1]) + ( token[-1] + '',) 87 | pairs = get_pairs(word) 88 | 89 | if not pairs: 90 | return token+'' 91 | 92 | while True: 93 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 94 | if bigram not in self.bpe_ranks: 95 | break 96 | first, second = bigram 97 | new_word = [] 98 | i = 0 99 | while i < len(word): 100 | try: 101 | j = word.index(first, i) 102 | new_word.extend(word[i:j]) 103 | i = j 104 | except: 105 | new_word.extend(word[i:]) 106 | break 107 | 108 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 109 | new_word.append(first+second) 110 | i += 2 111 | else: 112 | new_word.append(word[i]) 113 | i += 1 114 | new_word = tuple(new_word) 115 | word = new_word 116 | if len(word) == 1: 117 | break 118 | else: 119 | pairs = get_pairs(word) 120 | word = ' '.join(word) 121 | self.cache[token] = word 122 | return word 123 | 124 | def encode(self, text): 125 | bpe_tokens = [] 126 | text = whitespace_clean(basic_clean(text)).lower() 127 | for token in re.findall(self.pat, text): 128 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 129 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 130 | return bpe_tokens 131 | 132 | def decode(self, tokens): 133 | text = ''.join([self.decoder[token] for token in tokens]) 134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 135 | return text 136 | -------------------------------------------------------------------------------- /configs/base_config.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='SCCLIPForSegmentation', 3 | clip_path='ViT-B/16', 4 | pre_adjust_idx=8, 5 | post_adjust_idx=3, 6 | multi_start_idx=3, 7 | multi_end_idx=10, 8 | res_cls=0.3 9 | ) 10 | 11 | test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) 12 | 13 | default_scope = 'mmseg' 14 | env_cfg = dict( 15 | cudnn_benchmark=True, 16 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 17 | dist_cfg=dict(backend='nccl'), 18 | ) 19 | vis_backends = [dict(type='LocalVisBackend')] 20 | visualizer = dict( 21 | type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer') 22 | log_processor = dict(by_epoch=False) 23 | log_level = 'INFO' 24 | load_from = None 25 | resume = False 26 | 27 | test_cfg = dict(type='TestLoop') 28 | 29 | default_hooks = dict( 30 | timer=dict(type='IterTimerHook'), 31 | logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), 32 | param_scheduler=dict(type='ParamSchedulerHook'), 33 | checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000), 34 | sampler_seed=dict(type='DistSamplerSeedHook'), 35 | visualization=dict(type='SegVisualizationHook', interval=1)) -------------------------------------------------------------------------------- /configs/cfg_ade20k.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./configs/cls_ade20k.txt' 6 | ) 7 | 8 | # dataset settings 9 | dataset_type = 'ADE20KDataset' 10 | data_root = './datasets/ade/ADEChallengeData2016' 11 | 12 | test_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 15 | dict(type='LoadAnnotations', reduce_zero_label=True), 16 | dict(type='PackSegInputs') 17 | ] 18 | 19 | test_dataloader = dict( 20 | batch_size=1, 21 | num_workers=4, 22 | persistent_workers=True, 23 | sampler=dict(type='DefaultSampler', shuffle=False), 24 | dataset=dict( 25 | type=dataset_type, 26 | data_root=data_root, 27 | data_prefix=dict( 28 | img_path='images/validation', 29 | seg_map_path='annotations/validation'), 30 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /configs/cfg_city_scapes.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./configs/cls_city_scapes.txt' 6 | ) 7 | 8 | # dataset settings 9 | dataset_type = 'CityscapesDataset' 10 | data_root = './datasets/cityscapes' 11 | 12 | test_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='Resize', scale=(2048, 560), keep_ratio=True), 15 | # add loading annotation after ``Resize`` because ground truth 16 | # does not need to do resize data transform 17 | dict(type='LoadAnnotations'), 18 | dict(type='PackSegInputs') 19 | ] 20 | 21 | test_dataloader = dict( 22 | batch_size=1, 23 | num_workers=4, 24 | persistent_workers=True, 25 | sampler=dict(type='DefaultSampler', shuffle=False), 26 | dataset=dict( 27 | type=dataset_type, 28 | data_root=data_root, 29 | data_prefix=dict( 30 | img_path='leftImg8bit/val', seg_map_path='gtFine/val'), 31 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /configs/cfg_coco_object.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./configs/cls_coco_object.txt', 6 | logit_scale=55, prob_thd=0.35 7 | ) 8 | 9 | # dataset settings 10 | dataset_type = 'COCOObjectDataset' 11 | data_root = './datasets/coco_object' 12 | 13 | test_pipeline = [ 14 | dict(type='LoadImageFromFile'), 15 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 16 | # add loading annotation after ``Resize`` because ground truth 17 | # does not need to do resize data transform 18 | dict(type='LoadAnnotations'), 19 | dict(type='PackSegInputs') 20 | ] 21 | 22 | test_dataloader = dict( 23 | batch_size=1, 24 | num_workers=4, 25 | persistent_workers=True, 26 | sampler=dict(type='DefaultSampler', shuffle=False), 27 | dataset=dict( 28 | type=dataset_type, 29 | data_root=data_root, 30 | reduce_zero_label=False, 31 | data_prefix=dict( 32 | img_path='images/val2017', seg_map_path='annotations/val2017'), 33 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /configs/cfg_coco_stuff164k.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./configs/cls_coco_stuff.txt' 6 | ) 7 | 8 | # dataset settings 9 | dataset_type = 'COCOStuffDataset' 10 | data_root = './datasets/coco_stuff164k' 11 | 12 | test_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 15 | dict(type='LoadAnnotations'), 16 | dict(type='PackSegInputs') 17 | ] 18 | 19 | test_dataloader = dict( 20 | batch_size=1, 21 | num_workers=4, 22 | persistent_workers=True, 23 | sampler=dict(type='DefaultSampler', shuffle=False), 24 | dataset=dict( 25 | type=dataset_type, 26 | data_root=data_root, 27 | data_prefix=dict( 28 | img_path='images/val2017', seg_map_path='annotations/val2017'), 29 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /configs/cfg_context59.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./configs/cls_context59.txt' 6 | ) 7 | 8 | # dataset settings 9 | dataset_type = 'PascalContext59Dataset' 10 | data_root = './datasets/VOCdevkit/VOC2010' 11 | 12 | test_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 15 | dict(type='LoadAnnotations', reduce_zero_label=True), 16 | dict(type='PackSegInputs') 17 | ] 18 | 19 | test_dataloader = dict( 20 | batch_size=1, 21 | num_workers=4, 22 | persistent_workers=True, 23 | sampler=dict(type='DefaultSampler', shuffle=False), 24 | dataset=dict( 25 | type=dataset_type, 26 | data_root=data_root, 27 | data_prefix=dict( 28 | img_path='JPEGImages', seg_map_path='SegmentationClassContext'), 29 | ann_file='ImageSets/SegmentationContext/val.txt', 30 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /configs/cfg_context60.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./configs/cls_context60.txt', 6 | prob_thd=0.15 7 | ) 8 | 9 | # dataset settings 10 | dataset_type = 'PascalContext60Dataset' 11 | data_root = './datasets/VOCdevkit/VOC2010' 12 | 13 | test_pipeline = [ 14 | dict(type='LoadImageFromFile'), 15 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 16 | dict(type='LoadAnnotations'), 17 | dict(type='PackSegInputs') 18 | ] 19 | 20 | test_dataloader = dict( 21 | batch_size=1, 22 | num_workers=4, 23 | persistent_workers=True, 24 | sampler=dict(type='DefaultSampler', shuffle=False), 25 | dataset=dict( 26 | type=dataset_type, 27 | data_root=data_root, 28 | data_prefix=dict( 29 | img_path='JPEGImages', seg_map_path='SegmentationClassContext'), 30 | ann_file='ImageSets/SegmentationContext/val.txt', 31 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /configs/cfg_voc20.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./configs/cls_voc20.txt' 6 | ) 7 | 8 | # dataset settings 9 | dataset_type = 'PascalVOC20Dataset' 10 | data_root = './datasets/VOCdevkit/VOC2012' 11 | 12 | test_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 15 | dict(type='LoadAnnotations'), 16 | dict(type='PackSegInputs') 17 | ] 18 | 19 | test_dataloader = dict( 20 | batch_size=1, 21 | num_workers=4, 22 | persistent_workers=True, 23 | sampler=dict(type='DefaultSampler', shuffle=False), 24 | dataset=dict( 25 | type=dataset_type, 26 | data_root=data_root, 27 | data_prefix=dict( 28 | img_path='JPEGImages', seg_map_path='SegmentationClass'), 29 | ann_file='ImageSets/Segmentation/val.txt', 30 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /configs/cfg_voc21.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./configs/cls_voc21.txt', 6 | area_thd=0.1, 7 | logit_scale=50, prob_thd=0.15 8 | ) 9 | 10 | # dataset settings 11 | dataset_type = 'PascalVOCDataset' 12 | data_root = './datasets/VOCdevkit/VOC2012' 13 | 14 | test_pipeline = [ 15 | dict(type='LoadImageFromFile'), 16 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 17 | dict(type='LoadAnnotations'), 18 | dict(type='PackSegInputs') 19 | ] 20 | 21 | test_dataloader = dict( 22 | batch_size=1, 23 | num_workers=4, 24 | persistent_workers=True, 25 | sampler=dict(type='DefaultSampler', shuffle=False), 26 | dataset=dict( 27 | type=dataset_type, 28 | data_root=data_root, 29 | data_prefix=dict( 30 | img_path='JPEGImages', seg_map_path='SegmentationClass'), 31 | ann_file='ImageSets/Segmentation/val.txt', 32 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /configs/cls_ade20k.txt: -------------------------------------------------------------------------------- 1 | wall 2 | building 3 | sky 4 | floor 5 | tree 6 | ceiling 7 | road 8 | bed 9 | windowpane 10 | grass 11 | cabinet 12 | sidewalk 13 | person 14 | earth 15 | door 16 | table 17 | mountain 18 | plant 19 | curtain 20 | chair 21 | car 22 | water 23 | painting 24 | sofa 25 | shelf 26 | house 27 | sea 28 | mirror 29 | rug 30 | field 31 | armchair 32 | seat 33 | fence 34 | desk 35 | rock 36 | wardrobe 37 | lamp 38 | bathtub 39 | railing 40 | cushion 41 | base 42 | box 43 | column 44 | signboard 45 | chestofdrawers 46 | counter 47 | sand 48 | sink 49 | skyscraper 50 | fireplace 51 | refrigerator 52 | grandstand 53 | path 54 | stairs 55 | runway 56 | case 57 | pooltable 58 | pillow 59 | screendoor 60 | stairway 61 | river 62 | bridge 63 | bookcase 64 | blind 65 | coffeetable 66 | toilet 67 | flower 68 | book 69 | hill 70 | bench 71 | countertop 72 | stove 73 | palm 74 | kitchenisland 75 | computer 76 | swivelchair 77 | boat 78 | bar 79 | arcademachine 80 | hovel 81 | bus 82 | towel 83 | light 84 | truck 85 | tower 86 | chandelier 87 | awning 88 | streetlight 89 | booth 90 | televisionreceiver 91 | airplane 92 | dirttrack 93 | apparel 94 | pole 95 | land 96 | bannister 97 | escalator 98 | ottoman 99 | bottle 100 | buffet 101 | poster 102 | stage 103 | van 104 | ship 105 | fountain 106 | conveyerbelt 107 | canopy 108 | washer 109 | plaything 110 | swimmingpool 111 | stool 112 | barrel 113 | basket 114 | waterfall 115 | tent 116 | bag 117 | minibike 118 | cradle 119 | oven 120 | ball 121 | food 122 | step 123 | tank 124 | tradename 125 | microwave 126 | pot 127 | animal 128 | bicycle 129 | lake 130 | dishwasher 131 | screen 132 | blanket 133 | sculpture 134 | hood 135 | sconce 136 | vase 137 | trafficlight 138 | tray 139 | ashcan 140 | fan 141 | pier 142 | crtscreen 143 | plate 144 | monitor 145 | bulletinboard 146 | shower 147 | radiator 148 | glass 149 | clock 150 | flag -------------------------------------------------------------------------------- /configs/cls_city_scapes.txt: -------------------------------------------------------------------------------- 1 | road 2 | sidewalk 3 | building 4 | wall 5 | fence 6 | pole 7 | trafficlight 8 | trafficsign 9 | vegetation 10 | terrain 11 | sky 12 | person 13 | rider 14 | car 15 | truck 16 | bus 17 | train 18 | motorcycle 19 | bicycle -------------------------------------------------------------------------------- /configs/cls_coco_object.txt: -------------------------------------------------------------------------------- 1 | sky, wall, tree, wood, grass, road, sea, river, mountain, sands, desk, bed, building, cloud, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, hill, rail, fence 2 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket, body 3 | bicycle 4 | car 5 | motorcycle 6 | airplane 7 | bus 8 | train 9 | truck 10 | boat 11 | traffic light 12 | fire hydrant 13 | stop sign 14 | parking meter 15 | bench 16 | bird 17 | cat 18 | dog 19 | horse 20 | sheep 21 | cow 22 | elephant 23 | bear 24 | zebra 25 | giraffe 26 | backpack 27 | umbrella 28 | handbag 29 | tie 30 | suitcase 31 | frisbee 32 | skis 33 | snowboard 34 | sports ball 35 | kite 36 | baseball bat 37 | baseball glove 38 | skateboard 39 | surfboard 40 | tennis racket 41 | bottle 42 | wine glass 43 | cup 44 | fork 45 | knife 46 | spoon 47 | bowl 48 | banana 49 | apple 50 | sandwich 51 | orange 52 | broccoli 53 | carrot 54 | hot dog 55 | pizza 56 | donut 57 | cake 58 | chair 59 | couch 60 | potted plant 61 | bed 62 | dining table 63 | toilet 64 | tv 65 | laptop 66 | mouse 67 | remote 68 | keyboard 69 | cell phone 70 | microwave 71 | oven 72 | toaster 73 | sink 74 | refrigerator 75 | book 76 | clock 77 | vase 78 | scissors 79 | teddy bear 80 | hair drier 81 | toothbrush -------------------------------------------------------------------------------- /configs/cls_coco_stuff.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorcycle 5 | airplane 6 | bus 7 | train 8 | truck 9 | boat 10 | trafficlight 11 | firehydrant 12 | stopsign 13 | parkingmeter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sportsball 34 | kite 35 | baseballbat 36 | baseballglove 37 | skateboard 38 | surfboard 39 | tennisracket 40 | bottle 41 | wineglass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hotdog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | couch 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tv 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cellphone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddybear 79 | hairdrier 80 | toothbrush 81 | banner 82 | blanket 83 | branch 84 | bridge 85 | building-other 86 | bush 87 | cabinet 88 | cage 89 | cardboard 90 | carpet 91 | ceiling-other 92 | ceiling-tile 93 | cloth 94 | clothes 95 | clouds 96 | counter 97 | cupboard 98 | curtain 99 | desk-stuff 100 | dirt 101 | door-stuff 102 | fence 103 | floor-marble 104 | floor-other 105 | floor-stone 106 | floor-tile 107 | floor-wood 108 | flower 109 | fog 110 | food-other 111 | fruit 112 | furniture-other 113 | grass 114 | gravel 115 | ground-other 116 | hill 117 | house 118 | leaves 119 | light 120 | mat 121 | metal 122 | mirror-stuff 123 | moss 124 | mountain 125 | mud 126 | napkin 127 | net 128 | paper 129 | pavement 130 | pillow 131 | plant-other 132 | plastic 133 | platform 134 | playingfield 135 | railing 136 | railroad 137 | river 138 | road 139 | rock 140 | roof 141 | rug 142 | salad 143 | sand 144 | sea 145 | shelf 146 | sky-other 147 | skyscraper 148 | snow 149 | solid-other 150 | stairs 151 | stone 152 | straw 153 | structural-other 154 | table 155 | tent 156 | textile-other 157 | towel 158 | tree 159 | vegetable 160 | wall-brick 161 | wall-concrete 162 | wall-other 163 | wall-panel 164 | wall-stone 165 | wall-tile 166 | wall-wood 167 | water-other 168 | waterdrops 169 | window-blind 170 | window-other 171 | wood -------------------------------------------------------------------------------- /configs/cls_context59.txt: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bag 3 | bed 4 | bedclothes 5 | bench 6 | bicycle 7 | bird 8 | boat 9 | book 10 | bottle 11 | building 12 | bus 13 | cabinet 14 | car 15 | cat 16 | ceiling 17 | chair 18 | cloth 19 | computer 20 | cow 21 | cup 22 | curtain 23 | dog 24 | door 25 | fence 26 | floor 27 | flower 28 | food 29 | grass 30 | ground 31 | horse 32 | keyboard 33 | light 34 | motorbike 35 | mountain 36 | mouse 37 | person 38 | plate 39 | platform 40 | pottedplant 41 | road 42 | rock 43 | sheep 44 | shelves 45 | sidewalk 46 | sign 47 | sky 48 | snow 49 | sofa 50 | table 51 | track 52 | train 53 | tree 54 | truck 55 | tvmonitor 56 | wall 57 | water 58 | window 59 | wood -------------------------------------------------------------------------------- /configs/cls_context60.txt: -------------------------------------------------------------------------------- 1 | background 2 | aeroplane 3 | bag 4 | bed 5 | bedclothes 6 | bench 7 | bicycle 8 | bird 9 | boat 10 | book 11 | bottle 12 | building 13 | bus 14 | cabinet 15 | car 16 | cat 17 | ceiling 18 | chair 19 | cloth 20 | computer 21 | cow 22 | cup 23 | curtain 24 | dog 25 | door 26 | fence 27 | floor 28 | flower 29 | food 30 | grass 31 | ground 32 | horse 33 | keyboard 34 | light 35 | motorbike 36 | mountain 37 | mouse 38 | person 39 | plate 40 | platform 41 | pottedplant 42 | road 43 | rock 44 | sheep 45 | shelves 46 | sidewalk 47 | sign 48 | sky 49 | snow 50 | sofa 51 | table 52 | track 53 | train 54 | tree 55 | truck 56 | tvmonitor 57 | wall 58 | water 59 | window 60 | wood -------------------------------------------------------------------------------- /configs/cls_voc20.txt: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bicycle 3 | bird 4 | ship 5 | bottle 6 | bus 7 | car 8 | cat 9 | chair 10 | cow 11 | table 12 | dog 13 | horse 14 | motorbike 15 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket 16 | pottedplant 17 | sheep 18 | sofa 19 | train 20 | television monitor, tv monitor, monitor, television, screen -------------------------------------------------------------------------------- /configs/cls_voc21.txt: -------------------------------------------------------------------------------- 1 | sky, wall, tree, wood, grass, road, sea, river, mountain, sands, desk, bed, building, cloud, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, hill, rail, fence 2 | aeroplane 3 | bicycle 4 | bird 5 | ship 6 | bottle 7 | bus 8 | car 9 | cat 10 | chair 11 | cow 12 | table 13 | dog 14 | horse 15 | motorbike 16 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket 17 | pottedplant 18 | sheep 19 | sofa 20 | train 21 | television monitor, tv monitor, monitor, television, screen -------------------------------------------------------------------------------- /custom_datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import mmengine.fileio as fileio 3 | 4 | from mmseg.registry import DATASETS 5 | from mmseg.datasets import BaseSegDataset 6 | 7 | @DATASETS.register_module() 8 | class PascalVOC20Dataset(BaseSegDataset): 9 | """Pascal VOC dataset. 10 | 11 | Args: 12 | split (str): Split txt file for Pascal VOC. 13 | """ 14 | METAINFO = dict( 15 | classes=('aeroplane', 'bicycle', 'bird', 'boat', 16 | 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 17 | 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 18 | 'sofa', 'train', 'tvmonitor'), 19 | palette=[[128, 0, 0], [0, 128, 0], [0, 0, 192], 20 | [128, 128, 0], [128, 0, 128], [0, 128, 128], [192, 128, 64], 21 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 22 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 23 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 24 | [0, 64, 128]]) 25 | 26 | def __init__(self, 27 | ann_file, 28 | img_suffix='.jpg', 29 | seg_map_suffix='.png', 30 | reduce_zero_label=True, 31 | **kwargs) -> None: 32 | super().__init__( 33 | img_suffix=img_suffix, 34 | seg_map_suffix=seg_map_suffix, 35 | reduce_zero_label=reduce_zero_label, 36 | ann_file=ann_file, 37 | **kwargs) 38 | assert fileio.exists(self.data_prefix['img_path'], 39 | self.backend_args) and osp.isfile(self.ann_file) 40 | 41 | @DATASETS.register_module() 42 | class COCOObjectDataset(BaseSegDataset): 43 | """ 44 | Implementation borrowed from TCL (https://github.com/kakaobrain/tcl) and GroupViT (https://github.com/NVlabs/GroupViT) 45 | COCO-Object dataset. 46 | 1 bg class + first 80 classes from the COCO-Stuff dataset. 47 | """ 48 | 49 | METAINFO = dict( 50 | 51 | classes = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 52 | 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 53 | 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 54 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 55 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 56 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 57 | 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 58 | 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 59 | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'), 60 | 61 | palette = [[0, 0, 0], [0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], [0, 64, 64], [0, 192, 224], 62 | [0, 192, 192], [128, 192, 64], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 0, 64], 63 | [0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], 64 | [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], [64, 128, 32], [0, 160, 0], [0, 0, 0], 65 | [192, 128, 160], [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], [192, 128, 32], 66 | [128, 96, 128], [0, 0, 128], [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128], 67 | [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], [192, 0, 32], 68 | [128, 96, 0], [128, 0, 192], [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0], 69 | [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192], 70 | [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], [128, 192, 192], [0, 0, 160], 71 | [192, 160, 128], [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], [64, 160, 0], 72 | [0, 64, 0], [192, 128, 224], [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0]]) 73 | 74 | def __init__(self, **kwargs): 75 | super(COCOObjectDataset, self).__init__(img_suffix='.jpg', seg_map_suffix='_instanceTrainIds.png', **kwargs) 76 | 77 | @DATASETS.register_module() 78 | class PascalContext60Dataset(BaseSegDataset): 79 | METAINFO = dict( 80 | classes=('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 81 | 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 82 | 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 83 | 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog', 84 | 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 85 | 'horse', 'keyboard', 'light', 'motorbike', 'mountain', 86 | 'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road', 87 | 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 88 | 'sofa', 'table', 'track', 'train', 'tree', 'truck', 89 | 'tvmonitor', 'wall', 'water', 'window', 'wood'), 90 | palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 91 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 92 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 93 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 94 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 95 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 96 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 97 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 98 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 99 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 100 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 101 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 102 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 103 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 104 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]) 105 | 106 | def __init__(self, 107 | ann_file: str, 108 | img_suffix='.jpg', 109 | seg_map_suffix='.png', 110 | **kwargs) -> None: 111 | super().__init__( 112 | img_suffix=img_suffix, 113 | seg_map_suffix=seg_map_suffix, 114 | ann_file=ann_file, 115 | reduce_zero_label=False, 116 | **kwargs) 117 | 118 | 119 | @DATASETS.register_module() 120 | class PascalContext59Dataset(BaseSegDataset): 121 | METAINFO = dict( 122 | classes=('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 123 | 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 124 | 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 125 | 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', 126 | 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 127 | 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 128 | 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', 129 | 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 130 | 'table', 'track', 'train', 'tree', 'truck', 'tvmonitor', 131 | 'wall', 'water', 'window', 'wood'), 132 | palette=[[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], 133 | [120, 120, 80], [140, 140, 140], [204, 5, 255], 134 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 135 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 136 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 137 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 138 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 139 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 140 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 141 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 142 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 143 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 144 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 145 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 146 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]) 147 | 148 | def __init__(self, 149 | ann_file: str, 150 | img_suffix='.jpg', 151 | seg_map_suffix='.png', 152 | reduce_zero_label=True, 153 | **kwargs): 154 | super().__init__( 155 | img_suffix=img_suffix, 156 | seg_map_suffix=seg_map_suffix, 157 | ann_file=ann_file, 158 | reduce_zero_label=reduce_zero_label, 159 | **kwargs) -------------------------------------------------------------------------------- /datasets/cvt_coco_object.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # GroupViT (https://github.com/NVlabs/GroupViT) 3 | # Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------------ 5 | 6 | import argparse 7 | import os.path as osp 8 | import shutil 9 | from functools import partial 10 | from glob import glob 11 | 12 | import mmcv 13 | import numpy as np 14 | from PIL import Image 15 | 16 | COCO_LEN = 123287 17 | 18 | clsID_to_trID = { 19 | 0: 0, 20 | 1: 1, 21 | 2: 2, 22 | 3: 3, 23 | 4: 4, 24 | 5: 5, 25 | 6: 6, 26 | 7: 7, 27 | 8: 8, 28 | 9: 9, 29 | 10: 10, 30 | 12: 11, 31 | 13: 12, 32 | 14: 13, 33 | 15: 14, 34 | 16: 15, 35 | 17: 16, 36 | 18: 17, 37 | 19: 18, 38 | 20: 19, 39 | 21: 20, 40 | 22: 21, 41 | 23: 22, 42 | 24: 23, 43 | 26: 24, 44 | 27: 25, 45 | 30: 26, 46 | 31: 27, 47 | 32: 28, 48 | 33: 29, 49 | 34: 30, 50 | 35: 31, 51 | 36: 32, 52 | 37: 33, 53 | 38: 34, 54 | 39: 35, 55 | 40: 36, 56 | 41: 37, 57 | 42: 38, 58 | 43: 39, 59 | 45: 40, 60 | 46: 41, 61 | 47: 42, 62 | 48: 43, 63 | 49: 44, 64 | 50: 45, 65 | 51: 46, 66 | 52: 47, 67 | 53: 48, 68 | 54: 49, 69 | 55: 50, 70 | 56: 51, 71 | 57: 52, 72 | 58: 53, 73 | 59: 54, 74 | 60: 55, 75 | 61: 56, 76 | 62: 57, 77 | 63: 58, 78 | 64: 59, 79 | 66: 60, 80 | 69: 61, 81 | 71: 62, 82 | 72: 63, 83 | 73: 64, 84 | 74: 65, 85 | 75: 66, 86 | 76: 67, 87 | 77: 68, 88 | 78: 69, 89 | 79: 70, 90 | 80: 71, 91 | 81: 72, 92 | 83: 73, 93 | 84: 74, 94 | 85: 75, 95 | 86: 76, 96 | 87: 77, 97 | 88: 78, 98 | 89: 79, 99 | 91: 80, 100 | 92: 81, 101 | 93: 82, 102 | 94: 83, 103 | 95: 84, 104 | 96: 85, 105 | 97: 86, 106 | 98: 87, 107 | 99: 88, 108 | 100: 89, 109 | 101: 90, 110 | 102: 91, 111 | 103: 92, 112 | 104: 93, 113 | 105: 94, 114 | 106: 95, 115 | 107: 96, 116 | 108: 97, 117 | 109: 98, 118 | 110: 99, 119 | 111: 100, 120 | 112: 101, 121 | 113: 102, 122 | 114: 103, 123 | 115: 104, 124 | 116: 105, 125 | 117: 106, 126 | 118: 107, 127 | 119: 108, 128 | 120: 109, 129 | 121: 110, 130 | 122: 111, 131 | 123: 112, 132 | 124: 113, 133 | 125: 114, 134 | 126: 115, 135 | 127: 116, 136 | 128: 117, 137 | 129: 118, 138 | 130: 119, 139 | 131: 120, 140 | 132: 121, 141 | 133: 122, 142 | 134: 123, 143 | 135: 124, 144 | 136: 125, 145 | 137: 126, 146 | 138: 127, 147 | 139: 128, 148 | 140: 129, 149 | 141: 130, 150 | 142: 131, 151 | 143: 132, 152 | 144: 133, 153 | 145: 134, 154 | 146: 135, 155 | 147: 136, 156 | 148: 137, 157 | 149: 138, 158 | 150: 139, 159 | 151: 140, 160 | 152: 141, 161 | 153: 142, 162 | 154: 143, 163 | 155: 144, 164 | 156: 145, 165 | 157: 146, 166 | 158: 147, 167 | 159: 148, 168 | 160: 149, 169 | 161: 150, 170 | 162: 151, 171 | 163: 152, 172 | 164: 153, 173 | 165: 154, 174 | 166: 155, 175 | 167: 156, 176 | 168: 157, 177 | 169: 158, 178 | 170: 159, 179 | 171: 160, 180 | 172: 161, 181 | 173: 162, 182 | 174: 163, 183 | 175: 164, 184 | 176: 165, 185 | 177: 166, 186 | 178: 167, 187 | 179: 168, 188 | 180: 169, 189 | 181: 170, 190 | 255: 255 191 | } 192 | 193 | # set to background 194 | for k, v in clsID_to_trID.items(): 195 | clsID_to_trID[k] = v + 1 196 | if k > 90: 197 | clsID_to_trID[k] = 0 198 | 199 | 200 | def convert_to_trainID(maskpath, out_mask_dir, is_train): 201 | mask = np.array(Image.open(maskpath)) 202 | mask_copy = mask.copy() 203 | for clsID, trID in clsID_to_trID.items(): 204 | mask_copy[mask == clsID] = trID 205 | seg_filename = osp.join( 206 | out_mask_dir, 'train2017', 207 | osp.basename(maskpath).split('.')[0] + 208 | '_instanceTrainIds.png') if is_train else osp.join( 209 | out_mask_dir, 'val2017', 210 | osp.basename(maskpath).split('.')[0] + '_instanceTrainIds.png') 211 | Image.fromarray(mask_copy).save(seg_filename, 'PNG') 212 | 213 | 214 | def parse_args(): 215 | parser = argparse.ArgumentParser(description='Convert COCO Stuff 164k annotations to COCO Objects') # noqa 216 | parser.add_argument('coco_path', help='coco stuff path') 217 | parser.add_argument('-o', '--out_dir', help='output path') 218 | parser.add_argument( 219 | '--nproc', default=16, type=int, help='number of process') 220 | args = parser.parse_args() 221 | return args 222 | 223 | 224 | def main(): 225 | args = parse_args() 226 | coco_path = args.coco_path 227 | nproc = args.nproc 228 | 229 | out_dir = args.out_dir or coco_path 230 | out_img_dir = osp.join(out_dir, 'images') 231 | out_mask_dir = osp.join(out_dir, 'annotations') 232 | 233 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017')) 234 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017')) 235 | 236 | if out_dir != coco_path: 237 | shutil.copytree(osp.join(coco_path, 'images'), out_img_dir) 238 | 239 | train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png')) 240 | train_list = [file for file in train_list if 'TrainIds' not in file] 241 | test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png')) 242 | test_list = [file for file in test_list if 'TrainIds' not in file] 243 | assert (len(train_list) + len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( 244 | len(train_list), len(test_list)) 245 | 246 | if args.nproc > 1: 247 | mmcv.track_parallel_progress( 248 | partial( 249 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 250 | train_list, 251 | nproc=nproc) 252 | mmcv.track_parallel_progress( 253 | partial( 254 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 255 | test_list, 256 | nproc=nproc) 257 | else: 258 | mmcv.track_progress( 259 | partial( 260 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 261 | train_list) 262 | mmcv.track_progress( 263 | partial( 264 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 265 | test_list) 266 | 267 | print('Done!') 268 | 269 | 270 | if __name__ == '__main__': 271 | main() 272 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.pyplot as plt 3 | from torchvision import transforms 4 | from scclip_segmentor import SCCLIPForSegmentation 5 | 6 | img = Image.open('figs/demo.jpg') 7 | name_list = ['skiiing man', 'tree with snow', 'sky', 'snow'] 8 | 9 | with open('my_name.txt', 'w') as writers: 10 | for i in range(len(name_list)): 11 | if i == len(name_list)-1: 12 | writers.write(name_list[i]) 13 | else: 14 | writers.write(name_list[i] + '\n') 15 | writers.close() 16 | 17 | img_tensor = transforms.Compose([ 18 | transforms.Lambda(lambda img: img.convert('RGB')), 19 | transforms.ToTensor(), 20 | transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]), 21 | ])(img) 22 | 23 | img_tensor = img_tensor.unsqueeze(0).cuda() 24 | 25 | model = SCCLIPForSegmentation( 26 | clip_path='ViT-B/16', 27 | name_path='my_name.txt', 28 | pamr_steps=0, 29 | pamr_stride=(8, 16), 30 | slide_crop=224, 31 | slide_stride=112 32 | ) 33 | 34 | seg_pred = model.predict(img_tensor, data_samples=None) 35 | seg_pred = seg_pred.data.cpu().numpy().squeeze(0) 36 | 37 | fig, ax = plt.subplots(1, 3, figsize=(18, 6)) 38 | ax[0].imshow(img) 39 | ax[0].axis('off') 40 | ax[1].imshow(seg_pred, cmap='viridis') 41 | ax[1].axis('off') 42 | ax[2].imshow(img) 43 | ax[2].axis('off') 44 | ax[2].imshow(seg_pred, cmap='viridis', alpha=0.8) 45 | plt.tight_layout() 46 | plt.savefig('seg_ours.png', bbox_inches='tight') -------------------------------------------------------------------------------- /dist_test.sh: -------------------------------------------------------------------------------- 1 | outputs="./outputs/base" 2 | 3 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_voc20.py --work-dir $outputs --launcher pytorch 4 | 5 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_voc21.py --work-dir $outputs --launcher pytorch 6 | 7 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_ade20k.py --work-dir $outputs --launcher pytorch 8 | 9 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_city_scapes.py --work-dir $outputs --launcher pytorch 10 | 11 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_context59.py --work-dir $outputs --launcher pytorch 12 | 13 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_context60.py --work-dir $outputs --launcher pytorch 14 | 15 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_coco_object.py --work-dir $outputs --launcher pytorch 16 | 17 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_coco_stuff164k.py --work-dir $outputs --launcher pytorch 18 | 19 | 20 | cd $outputs 21 | find . -type f -name "*.log" | while read logfile 22 | do 23 | grep "data_root =" "$logfile" 24 | grep "dataset_type =" "$logfile" 25 | grep -o "mIoU: [0-9.]*" "$logfile" 26 | echo "" 27 | done -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import scclip_segmentor 4 | import custom_datasets 5 | 6 | from mmengine.config import Config 7 | from mmengine.runner import Runner 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser( 11 | description='SC-CLIP evaluation with MMSeg') 12 | parser.add_argument('--config', default='') 13 | parser.add_argument('--work-dir', default='./work_logs/') 14 | parser.add_argument( 15 | '--show', action='store_true', help='show prediction results') 16 | parser.add_argument( 17 | '--show_dir', 18 | default='', 19 | help='directory to save visualizaion images') 20 | parser.add_argument( 21 | '--launcher', 22 | choices=['none', 'pytorch', 'slurm', 'mpi'], 23 | default='none', 24 | help='job launcher') 25 | # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` 26 | # will pass the `--local-rank` parameter to `tools/train.py` instead 27 | # of `--local_rank`. 28 | parser.add_argument('--local_rank', '--local-rank', type=int, default=0) 29 | args = parser.parse_args() 30 | if 'LOCAL_RANK' not in os.environ: 31 | os.environ['LOCAL_RANK'] = str(args.local_rank) 32 | 33 | return args 34 | 35 | def trigger_visualization_hook(cfg, args): 36 | default_hooks = cfg.default_hooks 37 | if 'visualization' in default_hooks: 38 | visualization_hook = default_hooks['visualization'] 39 | # Turn on visualization 40 | visualization_hook['draw'] = True 41 | if args.show: 42 | visualization_hook['show'] = True 43 | visualization_hook['wait_time'] = args.wait_time 44 | if args.show_dir: 45 | visualizer = cfg.visualizer 46 | visualizer['save_dir'] = args.show_dir 47 | else: 48 | raise RuntimeError( 49 | 'VisualizationHook must be included in default_hooks.' 50 | 'refer to usage ' 51 | '"visualization=dict(type=\'VisualizationHook\')"') 52 | 53 | return cfg 54 | 55 | def main(): 56 | args = parse_args() 57 | 58 | cfg = Config.fromfile(args.config) 59 | cfg.launcher = args.launcher 60 | cfg.work_dir = args.work_dir 61 | 62 | runner = Runner.from_cfg(cfg) 63 | runner.test() 64 | 65 | if __name__ == '__main__': 66 | main() -------------------------------------------------------------------------------- /figs/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SuleBai/SC-CLIP/0417ba92851e9dd7432d608f10a0804d01a23062/figs/demo.jpg -------------------------------------------------------------------------------- /figs/scclip.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SuleBai/SC-CLIP/0417ba92851e9dd7432d608f10a0804d01a23062/figs/scclip.jpg -------------------------------------------------------------------------------- /pamr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 TU Darmstadt 2 | # Licnese: Apache 2.0 License. 3 | # https://github.com/visinf/1-stage-wseg/blob/master/models/mods/pamr.py 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | 8 | from functools import partial 9 | 10 | # 11 | # Helper modules 12 | # 13 | class LocalAffinity(nn.Module): 14 | 15 | def __init__(self, dilations=[1]): 16 | super(LocalAffinity, self).__init__() 17 | self.dilations = dilations 18 | weight = self._init_aff() 19 | self.register_buffer('kernel', weight) 20 | 21 | def _init_aff(self): 22 | # initialising the shift kernel 23 | weight = torch.zeros(8, 1, 3, 3) 24 | 25 | for i in range(weight.size(0)): 26 | weight[i, 0, 1, 1] = 1 27 | 28 | weight[0, 0, 0, 0] = -1 29 | weight[1, 0, 0, 1] = -1 30 | weight[2, 0, 0, 2] = -1 31 | 32 | weight[3, 0, 1, 0] = -1 33 | weight[4, 0, 1, 2] = -1 34 | 35 | weight[5, 0, 2, 0] = -1 36 | weight[6, 0, 2, 1] = -1 37 | weight[7, 0, 2, 2] = -1 38 | 39 | self.weight_check = weight.clone() 40 | 41 | return weight 42 | 43 | def forward(self, x): 44 | 45 | self.weight_check = self.weight_check.type_as(x) 46 | assert torch.all(self.weight_check.eq(self.kernel)) 47 | 48 | B,K,H,W = x.size() 49 | x = x.view(B*K,1,H,W) 50 | 51 | x_affs = [] 52 | for d in self.dilations: 53 | x_pad = F.pad(x, [d]*4, mode='replicate') 54 | x_aff = F.conv2d(x_pad, self.kernel, dilation=d) 55 | x_affs.append(x_aff) 56 | 57 | x_aff = torch.cat(x_affs, 1) 58 | return x_aff.view(B,K,-1,H,W) 59 | 60 | class LocalAffinityCopy(LocalAffinity): 61 | 62 | def _init_aff(self): 63 | # initialising the shift kernel 64 | weight = torch.zeros(8, 1, 3, 3) 65 | 66 | weight[0, 0, 0, 0] = 1 67 | weight[1, 0, 0, 1] = 1 68 | weight[2, 0, 0, 2] = 1 69 | 70 | weight[3, 0, 1, 0] = 1 71 | weight[4, 0, 1, 2] = 1 72 | 73 | weight[5, 0, 2, 0] = 1 74 | weight[6, 0, 2, 1] = 1 75 | weight[7, 0, 2, 2] = 1 76 | 77 | self.weight_check = weight.clone() 78 | return weight 79 | 80 | class LocalStDev(LocalAffinity): 81 | 82 | def _init_aff(self): 83 | weight = torch.zeros(9, 1, 3, 3) 84 | weight.zero_() 85 | 86 | weight[0, 0, 0, 0] = 1 87 | weight[1, 0, 0, 1] = 1 88 | weight[2, 0, 0, 2] = 1 89 | 90 | weight[3, 0, 1, 0] = 1 91 | weight[4, 0, 1, 1] = 1 92 | weight[5, 0, 1, 2] = 1 93 | 94 | weight[6, 0, 2, 0] = 1 95 | weight[7, 0, 2, 1] = 1 96 | weight[8, 0, 2, 2] = 1 97 | 98 | self.weight_check = weight.clone() 99 | return weight 100 | 101 | def forward(self, x): 102 | # returns (B,K,P,H,W), where P is the number 103 | # of locations 104 | x = super(LocalStDev, self).forward(x) 105 | 106 | return x.std(2, keepdim=True) 107 | 108 | class LocalAffinityAbs(LocalAffinity): 109 | 110 | def forward(self, x): 111 | x = super(LocalAffinityAbs, self).forward(x) 112 | return torch.abs(x) 113 | 114 | # 115 | # PAMR module 116 | # 117 | class PAMR(nn.Module): 118 | 119 | def __init__(self, num_iter=1, dilations=[1]): 120 | super(PAMR, self).__init__() 121 | 122 | self.num_iter = num_iter 123 | self.aff_x = LocalAffinityAbs(dilations) 124 | self.aff_m = LocalAffinityCopy(dilations) 125 | self.aff_std = LocalStDev(dilations) 126 | 127 | def forward(self, x, mask): 128 | mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True) 129 | 130 | # x: [BxKxHxW] 131 | # mask: [BxCxHxW] 132 | B,K,H,W = x.size() 133 | _,C,_,_ = mask.size() 134 | 135 | x_std = self.aff_std(x) 136 | 137 | x = -self.aff_x(x) / (1e-8 + 0.1 * x_std) 138 | x = x.mean(1, keepdim=True) 139 | x = F.softmax(x, 2) 140 | 141 | for _ in range(self.num_iter): 142 | m = self.aff_m(mask) # [BxCxPxHxW] 143 | mask = (m * x).sum(2) 144 | 145 | # xvals: [BxCxHxW] 146 | return mask -------------------------------------------------------------------------------- /prompts/imagenet_template.py: -------------------------------------------------------------------------------- 1 | 2 | imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 3 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 4 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 5 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 6 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 7 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 8 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 9 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 10 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 11 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 12 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 13 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 14 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 15 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 16 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 17 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 18 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 19 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 20 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 21 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 22 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 23 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 24 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 25 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 26 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 27 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 28 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 29 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 30 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 31 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 32 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 33 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 34 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 35 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 36 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 37 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 38 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 39 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 40 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 41 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 42 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 43 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 44 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 45 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 46 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 47 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 48 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 49 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 50 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 51 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 52 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 53 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 54 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 55 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 56 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 57 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 58 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 59 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 60 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 61 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 62 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 63 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 64 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 65 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 66 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 67 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 68 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 69 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 70 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 71 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 72 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 73 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 74 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 75 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 76 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 77 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 78 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 79 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 80 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 81 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 82 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 83 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 84 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 85 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 86 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 87 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 88 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 89 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 90 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 91 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 92 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 93 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 94 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 95 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 96 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 97 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 98 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 99 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 100 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 101 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 102 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 103 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 104 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 105 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 106 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 107 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 108 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 109 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 110 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 111 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 112 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 113 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 114 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 115 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 116 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 117 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 118 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 119 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 120 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 121 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 122 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 123 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 124 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 125 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 126 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 127 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 128 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 129 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 130 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 131 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 132 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 133 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 134 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 135 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 136 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 137 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 138 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 139 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 140 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 141 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 142 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 143 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 144 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 145 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 146 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 147 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 148 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 149 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 150 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 151 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 152 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 153 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 154 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 155 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 156 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 157 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 158 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 159 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 160 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 161 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 162 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 163 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 164 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 165 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 166 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 167 | 168 | 169 | openai_imagenet_template = [ 170 | lambda c: f'a bad photo of a {c}.', 171 | lambda c: f'a photo of many {c}.', 172 | lambda c: f'a sculpture of a {c}.', 173 | lambda c: f'a photo of the hard to see {c}.', 174 | lambda c: f'a low resolution photo of the {c}.', 175 | lambda c: f'a rendering of a {c}.', 176 | lambda c: f'graffiti of a {c}.', 177 | lambda c: f'a bad photo of the {c}.', 178 | lambda c: f'a cropped photo of the {c}.', 179 | lambda c: f'a tattoo of a {c}.', 180 | lambda c: f'the embroidered {c}.', 181 | lambda c: f'a photo of a hard to see {c}.', 182 | lambda c: f'a bright photo of a {c}.', 183 | lambda c: f'a photo of a clean {c}.', 184 | lambda c: f'a photo of a dirty {c}.', 185 | lambda c: f'a dark photo of the {c}.', 186 | lambda c: f'a drawing of a {c}.', 187 | lambda c: f'a photo of my {c}.', 188 | lambda c: f'the plastic {c}.', 189 | lambda c: f'a photo of the cool {c}.', 190 | lambda c: f'a close-up photo of a {c}.', 191 | lambda c: f'a black and white photo of the {c}.', 192 | lambda c: f'a painting of the {c}.', 193 | lambda c: f'a painting of a {c}.', 194 | lambda c: f'a pixelated photo of the {c}.', 195 | lambda c: f'a sculpture of the {c}.', 196 | lambda c: f'a bright photo of the {c}.', 197 | lambda c: f'a cropped photo of a {c}.', 198 | lambda c: f'a plastic {c}.', 199 | lambda c: f'a photo of the dirty {c}.', 200 | lambda c: f'a jpeg corrupted photo of a {c}.', 201 | lambda c: f'a blurry photo of the {c}.', 202 | lambda c: f'a photo of the {c}.', 203 | lambda c: f'a good photo of the {c}.', 204 | lambda c: f'a rendering of the {c}.', 205 | lambda c: f'a {c} in a video game.', 206 | lambda c: f'a photo of one {c}.', 207 | lambda c: f'a doodle of a {c}.', 208 | lambda c: f'a close-up photo of the {c}.', 209 | lambda c: f'a photo of a {c}.', 210 | lambda c: f'the origami {c}.', 211 | lambda c: f'the {c} in a video game.', 212 | lambda c: f'a sketch of a {c}.', 213 | lambda c: f'a doodle of the {c}.', 214 | lambda c: f'a origami {c}.', 215 | lambda c: f'a low resolution photo of a {c}.', 216 | lambda c: f'the toy {c}.', 217 | lambda c: f'a rendition of the {c}.', 218 | lambda c: f'a photo of the clean {c}.', 219 | lambda c: f'a photo of a large {c}.', 220 | lambda c: f'a rendition of a {c}.', 221 | lambda c: f'a photo of a nice {c}.', 222 | lambda c: f'a photo of a weird {c}.', 223 | lambda c: f'a blurry photo of a {c}.', 224 | lambda c: f'a cartoon {c}.', 225 | lambda c: f'art of a {c}.', 226 | lambda c: f'a sketch of the {c}.', 227 | lambda c: f'a embroidered {c}.', 228 | lambda c: f'a pixelated photo of a {c}.', 229 | lambda c: f'itap of the {c}.', 230 | lambda c: f'a jpeg corrupted photo of the {c}.', 231 | lambda c: f'a good photo of a {c}.', 232 | lambda c: f'a plushie {c}.', 233 | lambda c: f'a photo of the nice {c}.', 234 | lambda c: f'a photo of the small {c}.', 235 | lambda c: f'a photo of the weird {c}.', 236 | lambda c: f'the cartoon {c}.', 237 | lambda c: f'art of the {c}.', 238 | lambda c: f'a drawing of the {c}.', 239 | lambda c: f'a photo of the large {c}.', 240 | lambda c: f'a black and white photo of a {c}.', 241 | lambda c: f'the plushie {c}.', 242 | lambda c: f'a dark photo of a {c}.', 243 | lambda c: f'itap of a {c}.', 244 | lambda c: f'graffiti of the {c}.', 245 | lambda c: f'a toy {c}.', 246 | lambda c: f'itap of my {c}.', 247 | lambda c: f'a photo of a cool {c}.', 248 | lambda c: f'a photo of a small {c}.', 249 | lambda c: f'a tattoo of the {c}.', 250 | ] -------------------------------------------------------------------------------- /scclip_segmentor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | sys.path.append("..") 5 | 6 | import clip 7 | from prompts.imagenet_template import openai_imagenet_template 8 | 9 | from mmseg.models.segmentors import BaseSegmentor 10 | from mmseg.models.data_preprocessor import SegDataPreProcessor 11 | from mmengine.structures import PixelData 12 | 13 | from mmseg.registry import MODELS 14 | 15 | from pamr import PAMR 16 | 17 | @MODELS.register_module() 18 | class SCCLIPForSegmentation(BaseSegmentor): 19 | def __init__(self, clip_path, name_path, device=torch.device('cuda'), 20 | pamr_steps=0, pamr_stride=(8, 16), prob_thd=0.0, logit_scale=40, 21 | slide_stride=112, slide_crop=224, area_thd=None, 22 | pre_adjust_idx=8, post_adjust_idx=3, multi_start_idx=3, multi_end_idx=10, res_cls=0.3): 23 | 24 | data_preprocessor = SegDataPreProcessor( 25 | mean=[122.771, 116.746, 104.094], 26 | std=[68.501, 66.632, 70.323], 27 | rgb_to_bgr=True) 28 | super().__init__(data_preprocessor=data_preprocessor) 29 | self.net, _ = clip.load(clip_path, device=device, jit=False) 30 | 31 | self.net.visual.pre_adjust_idx = pre_adjust_idx 32 | self.net.visual.post_adjust_idx = post_adjust_idx 33 | self.net.visual.multi_start_idx = multi_start_idx 34 | self.net.visual.multi_end_idx = multi_end_idx 35 | self.net.visual.res_cls = res_cls 36 | 37 | query_words, self.query_idx = get_cls_idx(name_path) 38 | self.num_queries = len(query_words) 39 | self.num_classes = max(self.query_idx) + 1 40 | self.query_idx = torch.Tensor(self.query_idx).to(torch.int64).to(device) 41 | 42 | query_features = [] 43 | with torch.no_grad(): 44 | for qw in query_words: 45 | query = clip.tokenize([temp(qw) for temp in openai_imagenet_template]).to(device) 46 | feature = self.net.encode_text(query) 47 | feature /= feature.norm(dim=-1, keepdim=True) 48 | feature = feature.mean(dim=0) 49 | feature /= feature.norm() 50 | query_features.append(feature.unsqueeze(0)) 51 | self.query_features = torch.cat(query_features, dim=0) 52 | 53 | self.dtype = self.query_features.dtype 54 | self.logit_scale = logit_scale 55 | self.prob_thd = prob_thd 56 | self.area_thd = area_thd 57 | self.slide_stride = slide_stride 58 | self.slide_crop = slide_crop 59 | self.align_corners = False 60 | 61 | if pamr_steps > 0: 62 | self.pamr = PAMR(pamr_steps, dilations=pamr_stride).to(device) 63 | else: 64 | self.pamr = None 65 | 66 | def forward_feature(self, img, logit_size=None): 67 | if type(img) == list: 68 | img = img[0] 69 | 70 | image_features = self.net.encode_image(img, return_all=True) 71 | image_features /= image_features.norm(dim=-1, keepdim=True) 72 | logits = image_features @ self.query_features.T 73 | 74 | patch_size = self.net.visual.patch_size 75 | w, h = img[0].shape[-2] // patch_size, img[0].shape[-1] // patch_size 76 | out_dim = logits.shape[-1] 77 | logits = logits.permute(0, 2, 1).reshape(-1, out_dim, w, h) 78 | 79 | if logit_size == None: 80 | logits = nn.functional.interpolate(logits, size=img.shape[-2:], mode='bilinear', align_corners=False) 81 | else: 82 | logits = nn.functional.interpolate(logits, size=logit_size, mode='bilinear', align_corners=False) 83 | 84 | return logits 85 | 86 | def forward_slide(self, img, img_metas, stride=112, crop_size=224): 87 | """Inference by sliding-window with overlap. 88 | If h_crop > h_img or w_crop > w_img, the small patch will be used to 89 | decode without padding. 90 | """ 91 | if type(img) == list: 92 | img = img[0].unsqueeze(0) 93 | if type(stride) == int: 94 | stride = (stride, stride) 95 | if type(crop_size) == int: 96 | crop_size = (crop_size, crop_size) 97 | 98 | h_stride, w_stride = stride 99 | h_crop, w_crop = crop_size 100 | batch_size, _, h_img, w_img = img.shape 101 | out_channels = self.num_queries 102 | h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 103 | w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 104 | preds = img.new_zeros((batch_size, out_channels, h_img, w_img)) 105 | count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) 106 | for h_idx in range(h_grids): 107 | for w_idx in range(w_grids): 108 | y1 = h_idx * h_stride 109 | x1 = w_idx * w_stride 110 | y2 = min(y1 + h_crop, h_img) 111 | x2 = min(x1 + w_crop, w_img) 112 | y1 = max(y2 - h_crop, 0) 113 | x1 = max(x2 - w_crop, 0) 114 | crop_img = img[:, :, y1:y2, x1:x2] 115 | crop_seg_logit = self.forward_feature(crop_img) 116 | preds += nn.functional.pad(crop_seg_logit, 117 | (int(x1), int(preds.shape[3] - x2), int(y1), 118 | int(preds.shape[2] - y2))) 119 | 120 | count_mat[:, :, y1:y2, x1:x2] += 1 121 | assert (count_mat == 0).sum() == 0 122 | 123 | preds = preds / count_mat 124 | img_size = img_metas[0]['ori_shape'][:2] 125 | logits = nn.functional.interpolate(preds, size=img_size, mode='bilinear', align_corners=False) 126 | 127 | if self.pamr: 128 | img = nn.functional.interpolate(img, size=img_size, mode='bilinear') 129 | logits = self.pamr(img, logits.to(img.dtype)).to(self.dtype) 130 | 131 | return logits 132 | 133 | def predict(self, inputs, data_samples): 134 | if data_samples is not None: 135 | batch_img_metas = [ 136 | data_sample.metainfo for data_sample in data_samples 137 | ] 138 | else: 139 | batch_img_metas = [ 140 | dict( 141 | ori_shape=inputs.shape[2:], 142 | img_shape=inputs.shape[2:], 143 | pad_shape=inputs.shape[2:], 144 | padding_size=[0, 0, 0, 0]) 145 | ] * inputs.shape[0] 146 | 147 | if self.slide_crop > 0: 148 | seg_logits = self.forward_slide(inputs, batch_img_metas, self.slide_stride, self.slide_crop) 149 | else: 150 | seg_logits = self.forward_feature(inputs, batch_img_metas[0]['ori_shape']) 151 | 152 | return self.postprocess_result(seg_logits, data_samples) 153 | 154 | def postprocess_result(self, seg_logits, data_samples): 155 | batch_size = seg_logits.shape[0] 156 | for i in range(batch_size): 157 | seg_logits = seg_logits[i] * self.logit_scale 158 | seg_logits = seg_logits.softmax(0) # n_queries * w * h 159 | 160 | num_cls, num_queries = max(self.query_idx) + 1, len(self.query_idx) 161 | if num_cls != num_queries: 162 | seg_logits = seg_logits.unsqueeze(0) 163 | cls_index = nn.functional.one_hot(self.query_idx) 164 | cls_index = cls_index.T.view(num_cls, num_queries, 1, 1) 165 | seg_logits = (seg_logits * cls_index).max(1)[0] 166 | seg_pred = seg_logits.argmax(0, keepdim=True) 167 | 168 | if self.area_thd is not None: 169 | # Force segmentations with area < self.area_thd to 0 (background) 170 | predictions = nn.functional.one_hot(seg_logits.argmax(0), num_cls).to(seg_logits.dtype) 171 | area_pred = predictions[:, :, 1:].sum((0, 1), keepdim=True) # prone background 172 | area_pred = (area_pred > self.area_thd * area_pred.sum()).to(seg_logits.dtype) 173 | seg_logits[1:] *= area_pred.transpose(0, -1) 174 | 175 | seg_pred = seg_logits.argmax(0, keepdim=True) 176 | seg_pred[seg_logits.max(0, keepdim=True)[0] < self.prob_thd] = 0 177 | 178 | if data_samples is None: 179 | return seg_pred 180 | else: 181 | data_samples[i].set_data({ 182 | 'seg_logits': 183 | PixelData(**{'data': seg_logits}), 184 | 'pred_sem_seg': 185 | PixelData(**{'data': seg_pred}) 186 | }) 187 | 188 | return data_samples 189 | 190 | def _forward(data_samples): 191 | """ 192 | """ 193 | 194 | def inference(self, img, batch_img_metas): 195 | """ 196 | """ 197 | 198 | def encode_decode(self, inputs, batch_img_metas): 199 | """ 200 | """ 201 | 202 | def extract_feat(self, inputs): 203 | """ 204 | """ 205 | 206 | def loss(self, inputs, data_samples): 207 | """ 208 | """ 209 | 210 | def get_cls_idx(path): 211 | with open(path, 'r') as f: 212 | name_sets = f.readlines() 213 | num_cls = len(name_sets) 214 | 215 | class_names, class_indices = [], [] 216 | for idx in range(num_cls): 217 | names_i = name_sets[idx].split(', ') 218 | class_names += names_i 219 | class_indices += [idx for _ in range(len(names_i))] 220 | class_names = [item.replace('\n', '') for item in class_names] 221 | return class_names, class_indices --------------------------------------------------------------------------------