├── .gitignore
├── README.md
├── clip
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── clip.py
├── model.py
└── simple_tokenizer.py
├── configs
├── base_config.py
├── cfg_ade20k.py
├── cfg_city_scapes.py
├── cfg_coco_object.py
├── cfg_coco_stuff164k.py
├── cfg_context59.py
├── cfg_context60.py
├── cfg_voc20.py
├── cfg_voc21.py
├── cls_ade20k.txt
├── cls_city_scapes.txt
├── cls_coco_object.txt
├── cls_coco_stuff.txt
├── cls_context59.txt
├── cls_context60.txt
├── cls_voc20.txt
└── cls_voc21.txt
├── custom_datasets.py
├── datasets
└── cvt_coco_object.py
├── demo.py
├── dist_test.sh
├── eval.py
├── figs
├── demo.jpg
└── scclip.jpg
├── pamr.py
├── prompts
└── imagenet_template.py
└── scclip_segmentor.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /outputs
2 | /.dist_test
3 | **__pycache__**
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Calibrated CLIP for Training-Free Open-Vocabulary Segmentation
2 |
3 | [Sule Bai*](https://sulebai.github.io/), [Yong Liu*](https://yongliu20.github.io/), [Yifei Han](https://github.com/LambdaGuard), [Haoji Zhang](https://zhang9302002.github.io/), [Yansong Tang](https://andytang15.github.io/)
4 | (* denotes equal contribution)
5 |
6 | **Official PyTorch Implementation of [Self-Calibrated CLIP for Training-Free Open-Vocabulary Segmentation](https://arxiv.org/abs/2411.15869)**
7 |
8 |
9 |
10 |
11 |

12 |
13 |
14 | ## Abstract
15 | > Recent advancements in pre-trained vision-language models like CLIP, have enabled the task of open-vocabulary segmentation. CLIP demonstrates impressive zero-shot capabilities in various downstream tasks that require holistic image understanding. However, due to its image-level pre-training, CLIP struggles to capture local details, resulting in poor performance in segmentation tasks. Our analysis reveals that anomaly tokens emerge during the forward pass, drawing excessive attention from normal patch tokens, thereby diminishing spatial awareness. To address this issue, we propose Self-Calibrated CLIP (SC-CLIP), a training-free method that calibrates CLIP to produce finer-grained representations while preserving its original generalization ability, without introducing new parameters or relying on additional backbones. Specifically, we first identify and resolve the anomaly tokens to mitigate their negative impact. Next, we enhance feature discriminability and attention correlation by leveraging the semantic consistency found in CLIP's intermediate features. Furthermore, we employ multi-level feature fusion to enrich details. Collectively, these strategies enhance CLIP's feature representation with greater granularity and coherence. Experimental results demonstrate the effectiveness of SC-CLIP, achieving state-of-the-art results across eight semantic segmentation datasets and surpassing previous methods by 9.5%. Notably, SC-CLIP boosts the performance of vanilla CLIP ViT-L/14 by 6.8 times.
16 |
17 | ## Dependencies
18 |
19 | ```
20 | git clone https://github.com/SuleBai/SC-CLIP.git
21 | cd SC-CLIP
22 |
23 | conda create -n scclip python=3.9
24 | conda activate scclip
25 | pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
26 | pip install openmim
27 | mim install mmcv==2.0.1 mmengine==0.8.4 mmsegmentation==1.1.1
28 | pip install ftfy regex numpy==1.26 yapf==0.40.1
29 | ```
30 |
31 | ## Datasets
32 | We provide the dataset configurations in this repository, following [SCLIP](https://github.com/wangf3014/SCLIP).
33 |
34 | Please follow the [MMSeg data preparation document](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md) to download and pre-process the datasets. The COCO-Object dataset can be converted from COCO-Stuff164k by executing the following command:
35 |
36 | ```
37 | python ./datasets/cvt_coco_object.py PATH_TO_COCO_STUFF164K -o PATH_TO_COCO_OBJECT
38 | ```
39 |
40 | ## Quick Inference
41 | ```
42 | python demo.py
43 | ```
44 |
45 | ## Model Evaluation
46 | Single-GPU running:
47 |
48 | ```
49 | python eval.py --config ./configs/cfg_DATASET.py --workdir YOUR_WORK_DIR
50 | ```
51 |
52 | Multi-GPU running:
53 | ```
54 | bash ./dist_test.sh
55 | ```
56 |
57 | ## Acknowledgement
58 | This implementation is based on [CLIP](https://github.com/openai/CLIP), [SCLIP](https://github.com/wangf3014/SCLIP), [CLIP-DINOiser](https://github.com/wysoczanska/clip_dinoiser) and [ClearCLIP](https://github.com/mc-lan/ClearCLIP). Thanks for the awesome work.
--------------------------------------------------------------------------------
/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 | from .model import *
3 |
--------------------------------------------------------------------------------
/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SuleBai/SC-CLIP/0417ba92851e9dd7432d608f10a0804d01a23062/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | ### CLIP source code from OpenAI:
2 | # https://github.com/openai/CLIP/blob/main/clip/clip.py
3 |
4 | import hashlib
5 | import os
6 | import urllib
7 | import warnings
8 | from typing import Any, Union, List
9 | from pkg_resources import packaging
10 |
11 | import torch
12 | from PIL import Image
13 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
14 | from tqdm import tqdm
15 |
16 | from .model import build_model
17 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
18 |
19 | try:
20 | from torchvision.transforms import InterpolationMode
21 | BICUBIC = InterpolationMode.BICUBIC
22 | except ImportError:
23 | BICUBIC = Image.BICUBIC
24 |
25 |
26 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
27 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
28 |
29 |
30 | __all__ = ["available_models", "load", "tokenize"]
31 | _tokenizer = _Tokenizer()
32 |
33 | _MODELS = {
34 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
35 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
36 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
37 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
38 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
39 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
40 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
41 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
42 | }
43 |
44 |
45 | def _download(url: str, root: str):
46 | os.makedirs(root, exist_ok=True)
47 | filename = os.path.basename(url)
48 |
49 | expected_sha256 = url.split("/")[-2]
50 | download_target = os.path.join(root, filename)
51 |
52 | if os.path.exists(download_target) and not os.path.isfile(download_target):
53 | raise RuntimeError(f"{download_target} exists and is not a regular file")
54 |
55 | if os.path.isfile(download_target):
56 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
57 | return download_target
58 | else:
59 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
60 |
61 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
62 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
63 | while True:
64 | buffer = source.read(8192)
65 | if not buffer:
66 | break
67 |
68 | output.write(buffer)
69 | loop.update(len(buffer))
70 |
71 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
72 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
73 |
74 | return download_target
75 |
76 |
77 | def _convert_image_to_rgb(image):
78 | return image.convert("RGB")
79 |
80 |
81 | def _transform(n_px):
82 | return Compose([
83 | Resize(n_px, interpolation=BICUBIC),
84 | CenterCrop(n_px),
85 | _convert_image_to_rgb,
86 | ToTensor(),
87 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
88 | ])
89 |
90 |
91 | def available_models() -> List[str]:
92 | """Returns the names of available CLIP models"""
93 | return list(_MODELS.keys())
94 |
95 |
96 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
97 | """Load a CLIP model
98 |
99 | Parameters
100 | ----------
101 | name : str
102 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
103 |
104 | device : Union[str, torch.device]
105 | The device to put the loaded model
106 |
107 | jit : bool
108 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
109 |
110 | download_root: str
111 | path to download the model files; by default, it uses "~/.cache/clip"
112 |
113 | Returns
114 | -------
115 | model : torch.nn.Module
116 | The CLIP model
117 |
118 | preprocess : Callable[[PIL.Image], torch.Tensor]
119 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
120 | """
121 | if name in _MODELS:
122 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
123 | elif os.path.isfile(name):
124 | model_path = name
125 | else:
126 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
127 |
128 | try:
129 | # loading JIT archive
130 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
131 | state_dict = None
132 | except RuntimeError:
133 | # loading saved state dict
134 | if jit:
135 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
136 | jit = False
137 | state_dict = torch.load(model_path, map_location="cpu")
138 |
139 | if not jit:
140 | model = build_model(state_dict or model.state_dict()).to(device)
141 | if str(device) == "cpu":
142 | model.float()
143 | return model, _transform(model.visual.input_resolution)
144 |
145 | # patch the device names
146 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
147 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
148 |
149 | def patch_device(module):
150 | try:
151 | graphs = [module.graph] if hasattr(module, "graph") else []
152 | except RuntimeError:
153 | graphs = []
154 |
155 | if hasattr(module, "forward1"):
156 | graphs.append(module.forward1.graph)
157 |
158 | for graph in graphs:
159 | for node in graph.findAllNodes("prim::Constant"):
160 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
161 | node.copyAttributes(device_node)
162 |
163 | model.apply(patch_device)
164 | patch_device(model.encode_image)
165 | patch_device(model.encode_text)
166 |
167 | # patch dtype to float32 on CPU
168 | if str(device) == "cpu":
169 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
170 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
171 | float_node = float_input.node()
172 |
173 | def patch_float(module):
174 | try:
175 | graphs = [module.graph] if hasattr(module, "graph") else []
176 | except RuntimeError:
177 | graphs = []
178 |
179 | if hasattr(module, "forward1"):
180 | graphs.append(module.forward1.graph)
181 |
182 | for graph in graphs:
183 | for node in graph.findAllNodes("aten::to"):
184 | inputs = list(node.inputs())
185 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
186 | if inputs[i].node()["value"] == 5:
187 | inputs[i].node().copyAttributes(float_node)
188 |
189 | model.apply(patch_float)
190 | patch_float(model.encode_image)
191 | patch_float(model.encode_text)
192 |
193 | model.float()
194 |
195 | return model, _transform(model.input_resolution.item())
196 |
197 |
198 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
199 | """
200 | Returns the tokenized representation of given input string(s)
201 |
202 | Parameters
203 | ----------
204 | texts : Union[str, List[str]]
205 | An input string or a list of input strings to tokenize
206 |
207 | context_length : int
208 | The context length to use; all CLIP models use 77 as the context length
209 |
210 | truncate: bool
211 | Whether to truncate the text in case its encoding is longer than the context length
212 |
213 | Returns
214 | -------
215 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
216 | """
217 | if isinstance(texts, str):
218 | texts = [texts]
219 |
220 | sot_token = _tokenizer.encoder["<|startoftext|>"]
221 | eot_token = _tokenizer.encoder["<|endoftext|>"]
222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
223 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
224 |
225 | for i, tokens in enumerate(all_tokens):
226 | if len(tokens) > context_length:
227 | if truncate:
228 | tokens = tokens[:context_length]
229 | tokens[-1] = eot_token
230 | else:
231 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
232 | result[i, :len(tokens)] = torch.tensor(tokens)
233 |
234 | return result
--------------------------------------------------------------------------------
/clip/model.py:
--------------------------------------------------------------------------------
1 | ### CLIP source code from OpenAI:
2 | # https://github.com/openai/CLIP/blob/main/clip/clip.py
3 |
4 | from collections import OrderedDict
5 | from typing import Tuple, Union
6 | import math
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from torch import nn
11 | import torchvision.transforms.functional as VF
12 |
13 | class Bottleneck(nn.Module):
14 | expansion = 4
15 |
16 | def __init__(self, inplanes, planes, stride=1):
17 | super().__init__()
18 |
19 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
20 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
21 | self.bn1 = nn.BatchNorm2d(planes)
22 |
23 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 |
26 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
27 |
28 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
29 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
30 |
31 | self.relu = nn.ReLU(inplace=True)
32 | self.downsample = None
33 | self.stride = stride
34 |
35 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
36 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
37 | self.downsample = nn.Sequential(OrderedDict([
38 | ("-1", nn.AvgPool2d(stride)),
39 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
40 | ("1", nn.BatchNorm2d(planes * self.expansion))
41 | ]))
42 |
43 | def forward(self, x: torch.Tensor):
44 | identity = x
45 |
46 | out = self.relu(self.bn1(self.conv1(x)))
47 | out = self.relu(self.bn2(self.conv2(out)))
48 | out = self.avgpool(out)
49 | out = self.bn3(self.conv3(out))
50 |
51 | if self.downsample is not None:
52 | identity = self.downsample(x)
53 |
54 | out += identity
55 | out = self.relu(out)
56 | return out
57 |
58 |
59 | class AttentionPool2d(nn.Module):
60 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
61 | super().__init__()
62 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
63 | self.k_proj = nn.Linear(embed_dim, embed_dim)
64 | self.q_proj = nn.Linear(embed_dim, embed_dim)
65 | self.v_proj = nn.Linear(embed_dim, embed_dim)
66 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
67 | self.num_heads = num_heads
68 |
69 | def forward(self, x, return_all_tokens=False):
70 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
71 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
72 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
73 | x, _ = F.multi_head_attention_forward(
74 | query=x, key=x, value=x,
75 | embed_dim_to_check=x.shape[-1],
76 | num_heads=self.num_heads,
77 | q_proj_weight=self.q_proj.weight,
78 | k_proj_weight=self.k_proj.weight,
79 | v_proj_weight=self.v_proj.weight,
80 | in_proj_weight=None,
81 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
82 | bias_k=None,
83 | bias_v=None,
84 | add_zero_attn=False,
85 | dropout_p=0,
86 | out_proj_weight=self.c_proj.weight,
87 | out_proj_bias=self.c_proj.bias,
88 | use_separate_proj_weight=True,
89 | training=self.training,
90 | need_weights=False
91 | )
92 | if return_all_tokens:
93 | return x
94 | else:
95 | return x[0]
96 |
97 |
98 | class ModifiedResNet(nn.Module):
99 | """
100 | A ResNet class that is similar to torchvision's but contains the following changes:
101 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
102 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
103 | - The final pooling layer is a QKV attention instead of an average pool
104 | """
105 |
106 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
107 | super().__init__()
108 | self.output_dim = output_dim
109 | self.input_resolution = input_resolution
110 |
111 | # the 3-layer stem
112 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
113 | self.bn1 = nn.BatchNorm2d(width // 2)
114 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
115 | self.bn2 = nn.BatchNorm2d(width // 2)
116 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
117 | self.bn3 = nn.BatchNorm2d(width)
118 | self.avgpool = nn.AvgPool2d(2)
119 | self.relu = nn.ReLU(inplace=True)
120 |
121 | # residual layers
122 | self._inplanes = width # this is a *mutable* variable used during construction
123 | self.layer1 = self._make_layer(width, layers[0])
124 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
125 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
126 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
127 |
128 | embed_dim = width * 32 # the ResNet feature dimension
129 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
130 |
131 | def _make_layer(self, planes, blocks, stride=1):
132 | layers = [Bottleneck(self._inplanes, planes, stride)]
133 |
134 | self._inplanes = planes * Bottleneck.expansion
135 | for _ in range(1, blocks):
136 | layers.append(Bottleneck(self._inplanes, planes))
137 |
138 | return nn.Sequential(*layers)
139 |
140 | def forward(self, x, return_all_tokens=False):
141 | def stem(x):
142 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
143 | x = self.relu(bn(conv(x)))
144 | x = self.avgpool(x)
145 | return x
146 |
147 | x = x.type(self.conv1.weight.dtype)
148 | x = stem(x)
149 | x = self.layer1(x)
150 | x = self.layer2(x)
151 | x = self.layer3(x)
152 | x = self.layer4(x)
153 | x = self.attnpool(x, return_all_tokens)
154 |
155 | return x
156 |
157 |
158 | class LayerNorm(nn.LayerNorm):
159 | """Subclass torch's LayerNorm to handle fp16."""
160 |
161 | def forward(self, x: torch.Tensor):
162 | orig_type = x.dtype
163 | ret = super().forward(x.type(torch.float32))
164 | return ret.type(orig_type)
165 |
166 |
167 | class QuickGELU(nn.Module):
168 | def forward(self, x: torch.Tensor):
169 | return x * torch.sigmoid(1.702 * x)
170 |
171 |
172 | class ResidualAttentionBlock(nn.Module):
173 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
174 | super().__init__()
175 |
176 | self.attn = nn.MultiheadAttention(d_model, n_head)
177 | self.ln_1 = LayerNorm(d_model)
178 | self.mlp = nn.Sequential(OrderedDict([
179 | ("c_fc", nn.Linear(d_model, d_model * 4)),
180 | ("gelu", QuickGELU()),
181 | ("c_proj", nn.Linear(d_model * 4, d_model))
182 | ]))
183 | self.ln_2 = LayerNorm(d_model)
184 | self.attn_mask = attn_mask
185 |
186 | def attention(self, x: torch.Tensor):
187 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
188 | # pdb.set_trace()
189 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
190 |
191 | def forward(self, x: torch.Tensor):
192 | x = x + self.attention(self.ln_1(x))
193 | x = x + self.mlp(self.ln_2(x))
194 | return x
195 |
196 |
197 | class Transformer(nn.Module):
198 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
199 | super().__init__()
200 | self.width = width
201 | self.layers = layers
202 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
203 |
204 | def forward(self, x: torch.Tensor):
205 | return self.resblocks(x)
206 |
207 |
208 | def lof_pytorch(x, n_neighbors=30, contamination=0.05):
209 | distances = torch.norm(x[:, None] - x[None, :], dim=2, p=2) ** 2
210 |
211 | knn_distances, knn_indices = torch.topk(distances, k=n_neighbors+1, largest=False)
212 | knn_distances, knn_indices = knn_distances[:, 1:], knn_indices[:, 1:]
213 |
214 | k_distances = knn_distances[:, -1].unsqueeze(1).expand_as(knn_distances)
215 | reach_distances = torch.max(knn_distances, k_distances)
216 |
217 | LRD = n_neighbors / torch.nan_to_num(reach_distances.mean(dim=1), nan=1e-6)
218 |
219 | LRD_ratios = LRD[knn_indices] / LRD.unsqueeze(1)
220 | LOF_scores = LRD_ratios.mean(dim=1)
221 |
222 | threshold = torch.quantile(LOF_scores.to(torch.float32), 1 - contamination)
223 |
224 | outlier_mask = LOF_scores > threshold
225 | outlier_indices = torch.where(outlier_mask)[0]
226 |
227 | return outlier_indices, LOF_scores
228 |
229 |
230 | class VisionTransformer(nn.Module):
231 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
232 | super().__init__()
233 | self.input_resolution = input_resolution
234 | self.patch_size = patch_size
235 | self.output_dim = output_dim
236 | self.width = width
237 | self.heads = heads
238 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
239 | scale = width ** -0.5
240 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
241 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
242 | self.ln_pre = LayerNorm(width)
243 | self.transformer = Transformer(width, layers, heads)
244 | self.ln_post = LayerNorm(width)
245 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
246 |
247 | self.beta = 0.4
248 | self.pre_adjust_idx= 8
249 | self.post_adjust_idx = 3
250 | self.multi_start_idx = 3
251 | self.multi_end_idx = 10
252 | self.res_cls = 0.3
253 |
254 | def forward(self, x: torch.Tensor, return_all=False):
255 | B, nc, w, h = x.shape
256 | x = self.conv1(x)
257 | feat_w, feat_h = x.shape[-2], x.shape[-1]
258 | x = x.reshape(x.shape[0], x.shape[1], -1)
259 | x = x.permute(0, 2, 1)
260 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
261 | if x.shape[1] != self.positional_embedding.shape[0]:
262 | x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype)
263 | else:
264 | x = x + self.positional_embedding.to(x.dtype)
265 | x = self.ln_pre(x)
266 |
267 | x = x.permute(1, 0, 2)
268 | feats_list = []
269 | for idx, blk in enumerate(self.transformer.resblocks[:-1], start=1):
270 | x = blk(x)
271 | feats_list.append(x)
272 | if idx == len(self.transformer.resblocks) - 1:
273 | cls_token = x[:1, ...]
274 | outlier_indices, LOF_scores = lof_pytorch(x[1:, ...].squeeze(1), n_neighbors=30, contamination=0.05)
275 | top_indices = [(torch.div(index, feat_w, rounding_mode='trunc'), index % feat_w) for index in outlier_indices]
276 | feature_map = x[1:, :, :].permute(1, 2, 0).reshape(B, self.width, feat_w, feat_h)
277 | feature_map = self.mean_interpolation(feature_map, top_indices)
278 | x = feature_map.reshape(B, self.width, feat_w * feat_h).permute(2, 0, 1)
279 |
280 | feats = feats_list[self.pre_adjust_idx][1:, ...].clone()
281 | feats = feats.permute(1, 2, 0).reshape(B, self.width, feat_w, feat_h)
282 | feats = self.mean_interpolation(feats, top_indices)
283 | feats = feats.reshape(B, self.width, feat_w * feat_h).permute(2, 0, 1)
284 | feats = feats / feats.norm(dim=2, keepdim=True)
285 | before_simi = torch.matmul(feats.permute(1, 0, 2), feats.permute(1, 2, 0))
286 | mid_simi = before_simi.clone()
287 | before_simi[before_simi < self.beta] = 0.0
288 | x = self.adaptively_aggregate(x, before_simi)
289 |
290 | for blk in self.transformer.resblocks[-1:]:
291 | x = self.custom_attn(blk.attn, blk.ln_1(x), mid_simi=mid_simi) + self.res_cls * cls_token
292 |
293 | feats = feats_list[self.post_adjust_idx][1:, ...].clone()
294 | feats = feats / feats.norm(dim=2, keepdim=True)
295 | after_simi = torch.matmul(feats.permute(1, 0, 2), feats.permute(1, 2, 0))
296 | after_simi[after_simi < self.beta] = 0.0
297 | x = self.adaptively_aggregate(x, after_simi)
298 |
299 | re_feats = torch.zeros_like(feats_list[0])
300 | for i in range(self.multi_start_idx, self.multi_end_idx):
301 | re_feats += feats_list[i]
302 | cls_token = re_feats[:1, ...]
303 | blk = self.transformer.resblocks[-1]
304 | re_feats = self.custom_attn(blk.attn, blk.ln_1(re_feats[1:, ...]), mid_simi=mid_simi) + self.res_cls * cls_token
305 | re_feats = self.adaptively_aggregate(re_feats, after_simi)
306 | x += re_feats
307 |
308 | x = x.permute(1, 0, 2)
309 | if return_all:
310 | return self.ln_post(x) @ self.proj
311 |
312 | x = self.ln_post(x[:, 0, :])
313 | if self.proj is not None:
314 | x = x @ self.proj
315 |
316 | return x
317 |
318 | def custom_attn(self, attn_layer, x, mid_simi):
319 | num_heads = attn_layer.num_heads
320 | _, bsz, embed_dim = x.size()
321 | head_dim = embed_dim // num_heads
322 | scale = head_dim ** -0.5
323 |
324 | q, k, v = F.linear(x, attn_layer.in_proj_weight, attn_layer.in_proj_bias).chunk(3, dim=-1)
325 | q = q.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
326 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
327 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
328 |
329 | mid_simi = (mid_simi - torch.mean(mid_simi)) * 3.0
330 | mid_simi[mid_simi < 0.0] = float('-inf')
331 | mid_simi = mid_simi.repeat(num_heads, 1, 1)
332 | attn_weights = F.softmax(mid_simi, dim=-1)
333 | k_attn = torch.bmm(k, k.transpose(1, 2)) * scale
334 | attn_weights += F.softmax(k_attn, dim=-1)
335 | attn_weights /= 2
336 |
337 | attn_output = torch.bmm(attn_weights, v)
338 | attn_output = attn_output.transpose(0, 1).contiguous().view(-1, bsz, embed_dim)
339 | attn_output = attn_layer.out_proj(attn_output)
340 |
341 | return attn_output
342 |
343 | def adaptively_aggregate(self, maskclip_feats: torch.Tensor, corrs: torch.Tensor):
344 | corrs_normalized = corrs / (corrs.sum(dim=-1, keepdim=True) + 1e-6)
345 | maskclip_feats_ref = torch.matmul(corrs_normalized, maskclip_feats.permute(1, 0, 2))
346 | return maskclip_feats_ref.permute(1, 0, 2)
347 |
348 | def mean_interpolation(self, feature_map, top_indices):
349 | B, C, H, W = feature_map.shape
350 | device = feature_map.device
351 | dtype = feature_map.dtype
352 |
353 | kernel = torch.ones(C, 1, 3, 3, device=device, dtype=dtype)
354 | kernel[:, 0, 1, 1] = 0
355 | mask = torch.ones((H, W), device=device, dtype=dtype)
356 | indices = torch.tensor(top_indices, dtype=torch.long, device=device)
357 | mask[indices[:, 0], indices[:, 1]] = 0
358 | mask = mask.unsqueeze(0).unsqueeze(0)
359 | masked_feature_map = feature_map * mask
360 | padded_feature_map = F.pad(masked_feature_map, (1, 1, 1, 1), mode='constant', value=0)
361 | padded_mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0)
362 | neighbor_sum = F.conv2d(padded_feature_map, kernel, groups=C)
363 | valid_neighbors = F.conv2d(padded_mask, kernel[:, :1, :, :], groups=1)
364 | valid_neighbor_mask = (valid_neighbors > 0).to(dtype)
365 | safe_valid_neighbors = valid_neighbors.clone()
366 | safe_valid_neighbors[safe_valid_neighbors == 0] = 1
367 | mean_neighbors = neighbor_sum / safe_valid_neighbors
368 | top_indices_mask = torch.zeros((H, W), device=device, dtype=dtype)
369 | top_indices_mask[indices[:, 0], indices[:, 1]] = 1
370 | top_indices_mask = top_indices_mask.unsqueeze(0).unsqueeze(0)
371 | update_mask = top_indices_mask * valid_neighbor_mask
372 | feature_map = feature_map * (1 - update_mask) + mean_neighbors * update_mask
373 | return feature_map
374 |
375 | def interpolate_pos_encoding(self, x, w, h):
376 | npatch = x.shape[1] - 1
377 | N = self.positional_embedding.shape[0] - 1
378 | if npatch == N and w == h:
379 | return self.positional_embedding
380 | class_pos_embed = self.positional_embedding[[0]]
381 | patch_pos_embed = self.positional_embedding[1:]
382 | dim = x.shape[-1]
383 | w0 = w // self.patch_size
384 | h0 = h // self.patch_size
385 | w0, h0 = w0 + 0.1, h0 + 0.1
386 | patch_pos_embed = nn.functional.interpolate(
387 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
388 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
389 | mode='bicubic',
390 | )
391 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
392 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
393 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
394 |
395 |
396 | class CLIP(nn.Module):
397 | def __init__(self,
398 | embed_dim: int, # 512
399 | # vision
400 | image_resolution: int, # 224
401 | vision_layers: Union[Tuple[int, int, int, int], int], # 12
402 | vision_width: int, # 768
403 | vision_patch_size: int, # 16
404 | # text
405 | context_length: int, # 77
406 | vocab_size: int, # 49408
407 | transformer_width: int, # 512
408 | transformer_heads: int, # 8
409 | transformer_layers: int # 12
410 | ):
411 | super().__init__()
412 | self.context_length = context_length
413 |
414 | if isinstance(vision_layers, (tuple, list)):
415 | vision_heads = vision_width * 32 // 64
416 | self.visual = ModifiedResNet(
417 | layers=vision_layers,
418 | output_dim=embed_dim,
419 | heads=vision_heads,
420 | input_resolution=image_resolution,
421 | width=vision_width
422 | )
423 | else:
424 | vision_heads = vision_width // 64
425 | self.visual = VisionTransformer(
426 | input_resolution=image_resolution,
427 | patch_size=vision_patch_size,
428 | width=vision_width,
429 | layers=vision_layers,
430 | heads=vision_heads,
431 | output_dim=embed_dim
432 | )
433 |
434 | self.transformer = Transformer(
435 | width=transformer_width,
436 | layers=transformer_layers,
437 | heads=transformer_heads,
438 | attn_mask=self.build_attention_mask()
439 | )
440 |
441 | self.vocab_size = vocab_size
442 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
443 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
444 | self.ln_final = LayerNorm(transformer_width)
445 |
446 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
447 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
448 |
449 | self.initialize_parameters()
450 |
451 | def initialize_parameters(self):
452 | nn.init.normal_(self.token_embedding.weight, std=0.02)
453 | nn.init.normal_(self.positional_embedding, std=0.01)
454 |
455 | if isinstance(self.visual, ModifiedResNet):
456 | if self.visual.attnpool is not None:
457 | std = self.visual.attnpool.c_proj.in_features ** -0.5
458 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
459 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
460 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
461 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
462 |
463 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
464 | for name, param in resnet_block.named_parameters():
465 | if name.endswith("bn3.weight"):
466 | nn.init.zeros_(param)
467 |
468 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
469 | attn_std = self.transformer.width ** -0.5
470 | fc_std = (2 * self.transformer.width) ** -0.5
471 | for block in self.transformer.resblocks:
472 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
473 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
474 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
475 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
476 |
477 | if self.text_projection is not None:
478 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
479 |
480 | def build_attention_mask(self):
481 | # lazily create causal attention mask, with full attention between the vision tokens
482 | # pytorch uses additive attention mask; fill with -inf
483 | mask = torch.empty(self.context_length, self.context_length)
484 | mask.fill_(float("-inf"))
485 | mask.triu_(1) # zero out the lower diagonal
486 | return mask
487 |
488 | @property
489 | def dtype(self):
490 | return self.visual.conv1.weight.dtype
491 |
492 | def encode_image(self, image, return_all=False):
493 | return self.visual(image.type(self.dtype), return_all=return_all)
494 |
495 | def encode_text(self, text):
496 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
497 |
498 | x = x + self.positional_embedding.type(self.dtype)
499 | x = x.permute(1, 0, 2) # NLD -> LND
500 | x = self.transformer(x)
501 | x = x.permute(1, 0, 2) # LND -> NLD
502 | x = self.ln_final(x).type(self.dtype)
503 |
504 | return x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
505 |
506 | def forward(self, image, text):
507 | image_features = self.encode_image(image)
508 | text_features = self.encode_text(text)
509 |
510 | # normalized features
511 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
512 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
513 |
514 | # cosine similarity as logits
515 | logit_scale = self.logit_scale.exp()
516 | logits_per_image = logit_scale * image_features @ text_features.t()
517 | logits_per_text = logits_per_image.t()
518 |
519 | # shape = [global_batch_size, global_batch_size]
520 | return logits_per_image, logits_per_text
521 |
522 | def convert_weights(model: nn.Module):
523 | """Convert applicable model parameters to fp16"""
524 |
525 | def _convert_weights_to_fp16(l):
526 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
527 | l.weight.data = l.weight.data.half()
528 | if l.bias is not None:
529 | l.bias.data = l.bias.data.half()
530 |
531 | if isinstance(l, nn.MultiheadAttention):
532 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
533 | tensor = getattr(l, attr)
534 | if tensor is not None:
535 | tensor.data = tensor.data.half()
536 |
537 | for name in ["text_projection", "proj"]:
538 | if hasattr(l, name):
539 | attr = getattr(l, name)
540 | if attr is not None:
541 | attr.data = attr.data.half()
542 |
543 | model.apply(_convert_weights_to_fp16)
544 |
545 | def build_model(state_dict: dict):
546 | vit = "visual.proj" in state_dict
547 |
548 | if vit:
549 | vision_width = state_dict["visual.conv1.weight"].shape[0]
550 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
551 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
552 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
553 | image_resolution = vision_patch_size * grid_size
554 | else:
555 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
556 | vision_layers = tuple(counts)
557 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
558 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
559 | vision_patch_size = None
560 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
561 | image_resolution = output_width * 32
562 |
563 | embed_dim = state_dict["text_projection"].shape[1]
564 | context_length = state_dict["positional_embedding"].shape[0]
565 | vocab_size = state_dict["token_embedding.weight"].shape[0]
566 | transformer_width = state_dict["ln_final.weight"].shape[0]
567 | transformer_heads = transformer_width // 64
568 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
569 |
570 | model = CLIP(
571 | embed_dim,
572 | image_resolution, vision_layers, vision_width, vision_patch_size,
573 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
574 | )
575 |
576 | for key in ["input_resolution", "context_length", "vocab_size"]:
577 | if key in state_dict:
578 | del state_dict[key]
579 |
580 | convert_weights(model)
581 | model.load_state_dict(state_dict)
582 | return model.eval()
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | ### CLIP source code from OpenAI:
2 | # https://github.com/openai/CLIP/blob/main/clip/clip.py
3 |
4 | import gzip
5 | import html
6 | import os
7 | from functools import lru_cache
8 |
9 | import ftfy
10 | import regex as re
11 |
12 |
13 | @lru_cache()
14 | def default_bpe():
15 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
16 |
17 |
18 | @lru_cache()
19 | def bytes_to_unicode():
20 | """
21 | Returns list of utf-8 byte and a corresponding list of unicode strings.
22 | The reversible bpe codes work on unicode strings.
23 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
24 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
25 | This is a signficant percentage of your normal, say, 32K bpe vocab.
26 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
27 | And avoids mapping to whitespace/control characters the bpe code barfs on.
28 | """
29 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
30 | cs = bs[:]
31 | n = 0
32 | for b in range(2**8):
33 | if b not in bs:
34 | bs.append(b)
35 | cs.append(2**8+n)
36 | n += 1
37 | cs = [chr(n) for n in cs]
38 | return dict(zip(bs, cs))
39 |
40 |
41 | def get_pairs(word):
42 | """Return set of symbol pairs in a word.
43 | Word is represented as tuple of symbols (symbols being variable-length strings).
44 | """
45 | pairs = set()
46 | prev_char = word[0]
47 | for char in word[1:]:
48 | pairs.add((prev_char, char))
49 | prev_char = char
50 | return pairs
51 |
52 |
53 | def basic_clean(text):
54 | text = ftfy.fix_text(text)
55 | text = html.unescape(html.unescape(text))
56 | return text.strip()
57 |
58 |
59 | def whitespace_clean(text):
60 | text = re.sub(r'\s+', ' ', text)
61 | text = text.strip()
62 | return text
63 |
64 |
65 | class SimpleTokenizer(object):
66 | def __init__(self, bpe_path: str = default_bpe()):
67 | self.byte_encoder = bytes_to_unicode()
68 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
69 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
70 | merges = merges[1:49152-256-2+1]
71 | merges = [tuple(merge.split()) for merge in merges]
72 | vocab = list(bytes_to_unicode().values())
73 | vocab = vocab + [v+'' for v in vocab]
74 | for merge in merges:
75 | vocab.append(''.join(merge))
76 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
77 | self.encoder = dict(zip(vocab, range(len(vocab))))
78 | self.decoder = {v: k for k, v in self.encoder.items()}
79 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
80 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
81 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
82 |
83 | def bpe(self, token):
84 | if token in self.cache:
85 | return self.cache[token]
86 | word = tuple(token[:-1]) + ( token[-1] + '',)
87 | pairs = get_pairs(word)
88 |
89 | if not pairs:
90 | return token+''
91 |
92 | while True:
93 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
94 | if bigram not in self.bpe_ranks:
95 | break
96 | first, second = bigram
97 | new_word = []
98 | i = 0
99 | while i < len(word):
100 | try:
101 | j = word.index(first, i)
102 | new_word.extend(word[i:j])
103 | i = j
104 | except:
105 | new_word.extend(word[i:])
106 | break
107 |
108 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
109 | new_word.append(first+second)
110 | i += 2
111 | else:
112 | new_word.append(word[i])
113 | i += 1
114 | new_word = tuple(new_word)
115 | word = new_word
116 | if len(word) == 1:
117 | break
118 | else:
119 | pairs = get_pairs(word)
120 | word = ' '.join(word)
121 | self.cache[token] = word
122 | return word
123 |
124 | def encode(self, text):
125 | bpe_tokens = []
126 | text = whitespace_clean(basic_clean(text)).lower()
127 | for token in re.findall(self.pat, text):
128 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
129 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
130 | return bpe_tokens
131 |
132 | def decode(self, tokens):
133 | text = ''.join([self.decoder[token] for token in tokens])
134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
135 | return text
136 |
--------------------------------------------------------------------------------
/configs/base_config.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='SCCLIPForSegmentation',
3 | clip_path='ViT-B/16',
4 | pre_adjust_idx=8,
5 | post_adjust_idx=3,
6 | multi_start_idx=3,
7 | multi_end_idx=10,
8 | res_cls=0.3
9 | )
10 |
11 | test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
12 |
13 | default_scope = 'mmseg'
14 | env_cfg = dict(
15 | cudnn_benchmark=True,
16 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
17 | dist_cfg=dict(backend='nccl'),
18 | )
19 | vis_backends = [dict(type='LocalVisBackend')]
20 | visualizer = dict(
21 | type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
22 | log_processor = dict(by_epoch=False)
23 | log_level = 'INFO'
24 | load_from = None
25 | resume = False
26 |
27 | test_cfg = dict(type='TestLoop')
28 |
29 | default_hooks = dict(
30 | timer=dict(type='IterTimerHook'),
31 | logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
32 | param_scheduler=dict(type='ParamSchedulerHook'),
33 | checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
34 | sampler_seed=dict(type='DistSamplerSeedHook'),
35 | visualization=dict(type='SegVisualizationHook', interval=1))
--------------------------------------------------------------------------------
/configs/cfg_ade20k.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./configs/cls_ade20k.txt'
6 | )
7 |
8 | # dataset settings
9 | dataset_type = 'ADE20KDataset'
10 | data_root = './datasets/ade/ADEChallengeData2016'
11 |
12 | test_pipeline = [
13 | dict(type='LoadImageFromFile'),
14 | dict(type='Resize', scale=(2048, 336), keep_ratio=True),
15 | dict(type='LoadAnnotations', reduce_zero_label=True),
16 | dict(type='PackSegInputs')
17 | ]
18 |
19 | test_dataloader = dict(
20 | batch_size=1,
21 | num_workers=4,
22 | persistent_workers=True,
23 | sampler=dict(type='DefaultSampler', shuffle=False),
24 | dataset=dict(
25 | type=dataset_type,
26 | data_root=data_root,
27 | data_prefix=dict(
28 | img_path='images/validation',
29 | seg_map_path='annotations/validation'),
30 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/configs/cfg_city_scapes.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./configs/cls_city_scapes.txt'
6 | )
7 |
8 | # dataset settings
9 | dataset_type = 'CityscapesDataset'
10 | data_root = './datasets/cityscapes'
11 |
12 | test_pipeline = [
13 | dict(type='LoadImageFromFile'),
14 | dict(type='Resize', scale=(2048, 560), keep_ratio=True),
15 | # add loading annotation after ``Resize`` because ground truth
16 | # does not need to do resize data transform
17 | dict(type='LoadAnnotations'),
18 | dict(type='PackSegInputs')
19 | ]
20 |
21 | test_dataloader = dict(
22 | batch_size=1,
23 | num_workers=4,
24 | persistent_workers=True,
25 | sampler=dict(type='DefaultSampler', shuffle=False),
26 | dataset=dict(
27 | type=dataset_type,
28 | data_root=data_root,
29 | data_prefix=dict(
30 | img_path='leftImg8bit/val', seg_map_path='gtFine/val'),
31 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/configs/cfg_coco_object.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./configs/cls_coco_object.txt',
6 | logit_scale=55, prob_thd=0.35
7 | )
8 |
9 | # dataset settings
10 | dataset_type = 'COCOObjectDataset'
11 | data_root = './datasets/coco_object'
12 |
13 | test_pipeline = [
14 | dict(type='LoadImageFromFile'),
15 | dict(type='Resize', scale=(2048, 336), keep_ratio=True),
16 | # add loading annotation after ``Resize`` because ground truth
17 | # does not need to do resize data transform
18 | dict(type='LoadAnnotations'),
19 | dict(type='PackSegInputs')
20 | ]
21 |
22 | test_dataloader = dict(
23 | batch_size=1,
24 | num_workers=4,
25 | persistent_workers=True,
26 | sampler=dict(type='DefaultSampler', shuffle=False),
27 | dataset=dict(
28 | type=dataset_type,
29 | data_root=data_root,
30 | reduce_zero_label=False,
31 | data_prefix=dict(
32 | img_path='images/val2017', seg_map_path='annotations/val2017'),
33 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/configs/cfg_coco_stuff164k.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./configs/cls_coco_stuff.txt'
6 | )
7 |
8 | # dataset settings
9 | dataset_type = 'COCOStuffDataset'
10 | data_root = './datasets/coco_stuff164k'
11 |
12 | test_pipeline = [
13 | dict(type='LoadImageFromFile'),
14 | dict(type='Resize', scale=(2048, 336), keep_ratio=True),
15 | dict(type='LoadAnnotations'),
16 | dict(type='PackSegInputs')
17 | ]
18 |
19 | test_dataloader = dict(
20 | batch_size=1,
21 | num_workers=4,
22 | persistent_workers=True,
23 | sampler=dict(type='DefaultSampler', shuffle=False),
24 | dataset=dict(
25 | type=dataset_type,
26 | data_root=data_root,
27 | data_prefix=dict(
28 | img_path='images/val2017', seg_map_path='annotations/val2017'),
29 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/configs/cfg_context59.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./configs/cls_context59.txt'
6 | )
7 |
8 | # dataset settings
9 | dataset_type = 'PascalContext59Dataset'
10 | data_root = './datasets/VOCdevkit/VOC2010'
11 |
12 | test_pipeline = [
13 | dict(type='LoadImageFromFile'),
14 | dict(type='Resize', scale=(2048, 336), keep_ratio=True),
15 | dict(type='LoadAnnotations', reduce_zero_label=True),
16 | dict(type='PackSegInputs')
17 | ]
18 |
19 | test_dataloader = dict(
20 | batch_size=1,
21 | num_workers=4,
22 | persistent_workers=True,
23 | sampler=dict(type='DefaultSampler', shuffle=False),
24 | dataset=dict(
25 | type=dataset_type,
26 | data_root=data_root,
27 | data_prefix=dict(
28 | img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
29 | ann_file='ImageSets/SegmentationContext/val.txt',
30 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/configs/cfg_context60.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./configs/cls_context60.txt',
6 | prob_thd=0.15
7 | )
8 |
9 | # dataset settings
10 | dataset_type = 'PascalContext60Dataset'
11 | data_root = './datasets/VOCdevkit/VOC2010'
12 |
13 | test_pipeline = [
14 | dict(type='LoadImageFromFile'),
15 | dict(type='Resize', scale=(2048, 336), keep_ratio=True),
16 | dict(type='LoadAnnotations'),
17 | dict(type='PackSegInputs')
18 | ]
19 |
20 | test_dataloader = dict(
21 | batch_size=1,
22 | num_workers=4,
23 | persistent_workers=True,
24 | sampler=dict(type='DefaultSampler', shuffle=False),
25 | dataset=dict(
26 | type=dataset_type,
27 | data_root=data_root,
28 | data_prefix=dict(
29 | img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
30 | ann_file='ImageSets/SegmentationContext/val.txt',
31 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/configs/cfg_voc20.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./configs/cls_voc20.txt'
6 | )
7 |
8 | # dataset settings
9 | dataset_type = 'PascalVOC20Dataset'
10 | data_root = './datasets/VOCdevkit/VOC2012'
11 |
12 | test_pipeline = [
13 | dict(type='LoadImageFromFile'),
14 | dict(type='Resize', scale=(2048, 336), keep_ratio=True),
15 | dict(type='LoadAnnotations'),
16 | dict(type='PackSegInputs')
17 | ]
18 |
19 | test_dataloader = dict(
20 | batch_size=1,
21 | num_workers=4,
22 | persistent_workers=True,
23 | sampler=dict(type='DefaultSampler', shuffle=False),
24 | dataset=dict(
25 | type=dataset_type,
26 | data_root=data_root,
27 | data_prefix=dict(
28 | img_path='JPEGImages', seg_map_path='SegmentationClass'),
29 | ann_file='ImageSets/Segmentation/val.txt',
30 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/configs/cfg_voc21.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./configs/cls_voc21.txt',
6 | area_thd=0.1,
7 | logit_scale=50, prob_thd=0.15
8 | )
9 |
10 | # dataset settings
11 | dataset_type = 'PascalVOCDataset'
12 | data_root = './datasets/VOCdevkit/VOC2012'
13 |
14 | test_pipeline = [
15 | dict(type='LoadImageFromFile'),
16 | dict(type='Resize', scale=(2048, 336), keep_ratio=True),
17 | dict(type='LoadAnnotations'),
18 | dict(type='PackSegInputs')
19 | ]
20 |
21 | test_dataloader = dict(
22 | batch_size=1,
23 | num_workers=4,
24 | persistent_workers=True,
25 | sampler=dict(type='DefaultSampler', shuffle=False),
26 | dataset=dict(
27 | type=dataset_type,
28 | data_root=data_root,
29 | data_prefix=dict(
30 | img_path='JPEGImages', seg_map_path='SegmentationClass'),
31 | ann_file='ImageSets/Segmentation/val.txt',
32 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/configs/cls_ade20k.txt:
--------------------------------------------------------------------------------
1 | wall
2 | building
3 | sky
4 | floor
5 | tree
6 | ceiling
7 | road
8 | bed
9 | windowpane
10 | grass
11 | cabinet
12 | sidewalk
13 | person
14 | earth
15 | door
16 | table
17 | mountain
18 | plant
19 | curtain
20 | chair
21 | car
22 | water
23 | painting
24 | sofa
25 | shelf
26 | house
27 | sea
28 | mirror
29 | rug
30 | field
31 | armchair
32 | seat
33 | fence
34 | desk
35 | rock
36 | wardrobe
37 | lamp
38 | bathtub
39 | railing
40 | cushion
41 | base
42 | box
43 | column
44 | signboard
45 | chestofdrawers
46 | counter
47 | sand
48 | sink
49 | skyscraper
50 | fireplace
51 | refrigerator
52 | grandstand
53 | path
54 | stairs
55 | runway
56 | case
57 | pooltable
58 | pillow
59 | screendoor
60 | stairway
61 | river
62 | bridge
63 | bookcase
64 | blind
65 | coffeetable
66 | toilet
67 | flower
68 | book
69 | hill
70 | bench
71 | countertop
72 | stove
73 | palm
74 | kitchenisland
75 | computer
76 | swivelchair
77 | boat
78 | bar
79 | arcademachine
80 | hovel
81 | bus
82 | towel
83 | light
84 | truck
85 | tower
86 | chandelier
87 | awning
88 | streetlight
89 | booth
90 | televisionreceiver
91 | airplane
92 | dirttrack
93 | apparel
94 | pole
95 | land
96 | bannister
97 | escalator
98 | ottoman
99 | bottle
100 | buffet
101 | poster
102 | stage
103 | van
104 | ship
105 | fountain
106 | conveyerbelt
107 | canopy
108 | washer
109 | plaything
110 | swimmingpool
111 | stool
112 | barrel
113 | basket
114 | waterfall
115 | tent
116 | bag
117 | minibike
118 | cradle
119 | oven
120 | ball
121 | food
122 | step
123 | tank
124 | tradename
125 | microwave
126 | pot
127 | animal
128 | bicycle
129 | lake
130 | dishwasher
131 | screen
132 | blanket
133 | sculpture
134 | hood
135 | sconce
136 | vase
137 | trafficlight
138 | tray
139 | ashcan
140 | fan
141 | pier
142 | crtscreen
143 | plate
144 | monitor
145 | bulletinboard
146 | shower
147 | radiator
148 | glass
149 | clock
150 | flag
--------------------------------------------------------------------------------
/configs/cls_city_scapes.txt:
--------------------------------------------------------------------------------
1 | road
2 | sidewalk
3 | building
4 | wall
5 | fence
6 | pole
7 | trafficlight
8 | trafficsign
9 | vegetation
10 | terrain
11 | sky
12 | person
13 | rider
14 | car
15 | truck
16 | bus
17 | train
18 | motorcycle
19 | bicycle
--------------------------------------------------------------------------------
/configs/cls_coco_object.txt:
--------------------------------------------------------------------------------
1 | sky, wall, tree, wood, grass, road, sea, river, mountain, sands, desk, bed, building, cloud, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, hill, rail, fence
2 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket, body
3 | bicycle
4 | car
5 | motorcycle
6 | airplane
7 | bus
8 | train
9 | truck
10 | boat
11 | traffic light
12 | fire hydrant
13 | stop sign
14 | parking meter
15 | bench
16 | bird
17 | cat
18 | dog
19 | horse
20 | sheep
21 | cow
22 | elephant
23 | bear
24 | zebra
25 | giraffe
26 | backpack
27 | umbrella
28 | handbag
29 | tie
30 | suitcase
31 | frisbee
32 | skis
33 | snowboard
34 | sports ball
35 | kite
36 | baseball bat
37 | baseball glove
38 | skateboard
39 | surfboard
40 | tennis racket
41 | bottle
42 | wine glass
43 | cup
44 | fork
45 | knife
46 | spoon
47 | bowl
48 | banana
49 | apple
50 | sandwich
51 | orange
52 | broccoli
53 | carrot
54 | hot dog
55 | pizza
56 | donut
57 | cake
58 | chair
59 | couch
60 | potted plant
61 | bed
62 | dining table
63 | toilet
64 | tv
65 | laptop
66 | mouse
67 | remote
68 | keyboard
69 | cell phone
70 | microwave
71 | oven
72 | toaster
73 | sink
74 | refrigerator
75 | book
76 | clock
77 | vase
78 | scissors
79 | teddy bear
80 | hair drier
81 | toothbrush
--------------------------------------------------------------------------------
/configs/cls_coco_stuff.txt:
--------------------------------------------------------------------------------
1 | person
2 | bicycle
3 | car
4 | motorcycle
5 | airplane
6 | bus
7 | train
8 | truck
9 | boat
10 | trafficlight
11 | firehydrant
12 | stopsign
13 | parkingmeter
14 | bench
15 | bird
16 | cat
17 | dog
18 | horse
19 | sheep
20 | cow
21 | elephant
22 | bear
23 | zebra
24 | giraffe
25 | backpack
26 | umbrella
27 | handbag
28 | tie
29 | suitcase
30 | frisbee
31 | skis
32 | snowboard
33 | sportsball
34 | kite
35 | baseballbat
36 | baseballglove
37 | skateboard
38 | surfboard
39 | tennisracket
40 | bottle
41 | wineglass
42 | cup
43 | fork
44 | knife
45 | spoon
46 | bowl
47 | banana
48 | apple
49 | sandwich
50 | orange
51 | broccoli
52 | carrot
53 | hotdog
54 | pizza
55 | donut
56 | cake
57 | chair
58 | couch
59 | pottedplant
60 | bed
61 | diningtable
62 | toilet
63 | tv
64 | laptop
65 | mouse
66 | remote
67 | keyboard
68 | cellphone
69 | microwave
70 | oven
71 | toaster
72 | sink
73 | refrigerator
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddybear
79 | hairdrier
80 | toothbrush
81 | banner
82 | blanket
83 | branch
84 | bridge
85 | building-other
86 | bush
87 | cabinet
88 | cage
89 | cardboard
90 | carpet
91 | ceiling-other
92 | ceiling-tile
93 | cloth
94 | clothes
95 | clouds
96 | counter
97 | cupboard
98 | curtain
99 | desk-stuff
100 | dirt
101 | door-stuff
102 | fence
103 | floor-marble
104 | floor-other
105 | floor-stone
106 | floor-tile
107 | floor-wood
108 | flower
109 | fog
110 | food-other
111 | fruit
112 | furniture-other
113 | grass
114 | gravel
115 | ground-other
116 | hill
117 | house
118 | leaves
119 | light
120 | mat
121 | metal
122 | mirror-stuff
123 | moss
124 | mountain
125 | mud
126 | napkin
127 | net
128 | paper
129 | pavement
130 | pillow
131 | plant-other
132 | plastic
133 | platform
134 | playingfield
135 | railing
136 | railroad
137 | river
138 | road
139 | rock
140 | roof
141 | rug
142 | salad
143 | sand
144 | sea
145 | shelf
146 | sky-other
147 | skyscraper
148 | snow
149 | solid-other
150 | stairs
151 | stone
152 | straw
153 | structural-other
154 | table
155 | tent
156 | textile-other
157 | towel
158 | tree
159 | vegetable
160 | wall-brick
161 | wall-concrete
162 | wall-other
163 | wall-panel
164 | wall-stone
165 | wall-tile
166 | wall-wood
167 | water-other
168 | waterdrops
169 | window-blind
170 | window-other
171 | wood
--------------------------------------------------------------------------------
/configs/cls_context59.txt:
--------------------------------------------------------------------------------
1 | aeroplane
2 | bag
3 | bed
4 | bedclothes
5 | bench
6 | bicycle
7 | bird
8 | boat
9 | book
10 | bottle
11 | building
12 | bus
13 | cabinet
14 | car
15 | cat
16 | ceiling
17 | chair
18 | cloth
19 | computer
20 | cow
21 | cup
22 | curtain
23 | dog
24 | door
25 | fence
26 | floor
27 | flower
28 | food
29 | grass
30 | ground
31 | horse
32 | keyboard
33 | light
34 | motorbike
35 | mountain
36 | mouse
37 | person
38 | plate
39 | platform
40 | pottedplant
41 | road
42 | rock
43 | sheep
44 | shelves
45 | sidewalk
46 | sign
47 | sky
48 | snow
49 | sofa
50 | table
51 | track
52 | train
53 | tree
54 | truck
55 | tvmonitor
56 | wall
57 | water
58 | window
59 | wood
--------------------------------------------------------------------------------
/configs/cls_context60.txt:
--------------------------------------------------------------------------------
1 | background
2 | aeroplane
3 | bag
4 | bed
5 | bedclothes
6 | bench
7 | bicycle
8 | bird
9 | boat
10 | book
11 | bottle
12 | building
13 | bus
14 | cabinet
15 | car
16 | cat
17 | ceiling
18 | chair
19 | cloth
20 | computer
21 | cow
22 | cup
23 | curtain
24 | dog
25 | door
26 | fence
27 | floor
28 | flower
29 | food
30 | grass
31 | ground
32 | horse
33 | keyboard
34 | light
35 | motorbike
36 | mountain
37 | mouse
38 | person
39 | plate
40 | platform
41 | pottedplant
42 | road
43 | rock
44 | sheep
45 | shelves
46 | sidewalk
47 | sign
48 | sky
49 | snow
50 | sofa
51 | table
52 | track
53 | train
54 | tree
55 | truck
56 | tvmonitor
57 | wall
58 | water
59 | window
60 | wood
--------------------------------------------------------------------------------
/configs/cls_voc20.txt:
--------------------------------------------------------------------------------
1 | aeroplane
2 | bicycle
3 | bird
4 | ship
5 | bottle
6 | bus
7 | car
8 | cat
9 | chair
10 | cow
11 | table
12 | dog
13 | horse
14 | motorbike
15 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket
16 | pottedplant
17 | sheep
18 | sofa
19 | train
20 | television monitor, tv monitor, monitor, television, screen
--------------------------------------------------------------------------------
/configs/cls_voc21.txt:
--------------------------------------------------------------------------------
1 | sky, wall, tree, wood, grass, road, sea, river, mountain, sands, desk, bed, building, cloud, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, hill, rail, fence
2 | aeroplane
3 | bicycle
4 | bird
5 | ship
6 | bottle
7 | bus
8 | car
9 | cat
10 | chair
11 | cow
12 | table
13 | dog
14 | horse
15 | motorbike
16 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket
17 | pottedplant
18 | sheep
19 | sofa
20 | train
21 | television monitor, tv monitor, monitor, television, screen
--------------------------------------------------------------------------------
/custom_datasets.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import mmengine.fileio as fileio
3 |
4 | from mmseg.registry import DATASETS
5 | from mmseg.datasets import BaseSegDataset
6 |
7 | @DATASETS.register_module()
8 | class PascalVOC20Dataset(BaseSegDataset):
9 | """Pascal VOC dataset.
10 |
11 | Args:
12 | split (str): Split txt file for Pascal VOC.
13 | """
14 | METAINFO = dict(
15 | classes=('aeroplane', 'bicycle', 'bird', 'boat',
16 | 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
17 | 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
18 | 'sofa', 'train', 'tvmonitor'),
19 | palette=[[128, 0, 0], [0, 128, 0], [0, 0, 192],
20 | [128, 128, 0], [128, 0, 128], [0, 128, 128], [192, 128, 64],
21 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
22 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
23 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
24 | [0, 64, 128]])
25 |
26 | def __init__(self,
27 | ann_file,
28 | img_suffix='.jpg',
29 | seg_map_suffix='.png',
30 | reduce_zero_label=True,
31 | **kwargs) -> None:
32 | super().__init__(
33 | img_suffix=img_suffix,
34 | seg_map_suffix=seg_map_suffix,
35 | reduce_zero_label=reduce_zero_label,
36 | ann_file=ann_file,
37 | **kwargs)
38 | assert fileio.exists(self.data_prefix['img_path'],
39 | self.backend_args) and osp.isfile(self.ann_file)
40 |
41 | @DATASETS.register_module()
42 | class COCOObjectDataset(BaseSegDataset):
43 | """
44 | Implementation borrowed from TCL (https://github.com/kakaobrain/tcl) and GroupViT (https://github.com/NVlabs/GroupViT)
45 | COCO-Object dataset.
46 | 1 bg class + first 80 classes from the COCO-Stuff dataset.
47 | """
48 |
49 | METAINFO = dict(
50 |
51 | classes = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
52 | 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
53 | 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
54 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
55 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
56 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
57 | 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
58 | 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
59 | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'),
60 |
61 | palette = [[0, 0, 0], [0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], [0, 64, 64], [0, 192, 224],
62 | [0, 192, 192], [128, 192, 64], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 0, 64],
63 | [0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
64 | [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], [64, 128, 32], [0, 160, 0], [0, 0, 0],
65 | [192, 128, 160], [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], [192, 128, 32],
66 | [128, 96, 128], [0, 0, 128], [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],
67 | [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], [192, 0, 32],
68 | [128, 96, 0], [128, 0, 192], [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],
69 | [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
70 | [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], [128, 192, 192], [0, 0, 160],
71 | [192, 160, 128], [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], [64, 160, 0],
72 | [0, 64, 0], [192, 128, 224], [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0]])
73 |
74 | def __init__(self, **kwargs):
75 | super(COCOObjectDataset, self).__init__(img_suffix='.jpg', seg_map_suffix='_instanceTrainIds.png', **kwargs)
76 |
77 | @DATASETS.register_module()
78 | class PascalContext60Dataset(BaseSegDataset):
79 | METAINFO = dict(
80 | classes=('background', 'aeroplane', 'bag', 'bed', 'bedclothes',
81 | 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle',
82 | 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling',
83 | 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog',
84 | 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground',
85 | 'horse', 'keyboard', 'light', 'motorbike', 'mountain',
86 | 'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road',
87 | 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow',
88 | 'sofa', 'table', 'track', 'train', 'tree', 'truck',
89 | 'tvmonitor', 'wall', 'water', 'window', 'wood'),
90 | palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
91 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
92 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
93 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
94 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
95 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
96 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
97 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
98 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
99 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
100 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
101 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
102 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
103 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
104 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
105 |
106 | def __init__(self,
107 | ann_file: str,
108 | img_suffix='.jpg',
109 | seg_map_suffix='.png',
110 | **kwargs) -> None:
111 | super().__init__(
112 | img_suffix=img_suffix,
113 | seg_map_suffix=seg_map_suffix,
114 | ann_file=ann_file,
115 | reduce_zero_label=False,
116 | **kwargs)
117 |
118 |
119 | @DATASETS.register_module()
120 | class PascalContext59Dataset(BaseSegDataset):
121 | METAINFO = dict(
122 | classes=('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
123 | 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
124 | 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
125 | 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
126 | 'floor', 'flower', 'food', 'grass', 'ground', 'horse',
127 | 'keyboard', 'light', 'motorbike', 'mountain', 'mouse',
128 | 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock',
129 | 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa',
130 | 'table', 'track', 'train', 'tree', 'truck', 'tvmonitor',
131 | 'wall', 'water', 'window', 'wood'),
132 | palette=[[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
133 | [120, 120, 80], [140, 140, 140], [204, 5, 255],
134 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
135 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
136 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
137 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
138 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
139 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
140 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
141 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
142 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
143 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
144 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
145 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
146 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
147 |
148 | def __init__(self,
149 | ann_file: str,
150 | img_suffix='.jpg',
151 | seg_map_suffix='.png',
152 | reduce_zero_label=True,
153 | **kwargs):
154 | super().__init__(
155 | img_suffix=img_suffix,
156 | seg_map_suffix=seg_map_suffix,
157 | ann_file=ann_file,
158 | reduce_zero_label=reduce_zero_label,
159 | **kwargs)
--------------------------------------------------------------------------------
/datasets/cvt_coco_object.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # GroupViT (https://github.com/NVlabs/GroupViT)
3 | # Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved.
4 | # ------------------------------------------------------------------------------
5 |
6 | import argparse
7 | import os.path as osp
8 | import shutil
9 | from functools import partial
10 | from glob import glob
11 |
12 | import mmcv
13 | import numpy as np
14 | from PIL import Image
15 |
16 | COCO_LEN = 123287
17 |
18 | clsID_to_trID = {
19 | 0: 0,
20 | 1: 1,
21 | 2: 2,
22 | 3: 3,
23 | 4: 4,
24 | 5: 5,
25 | 6: 6,
26 | 7: 7,
27 | 8: 8,
28 | 9: 9,
29 | 10: 10,
30 | 12: 11,
31 | 13: 12,
32 | 14: 13,
33 | 15: 14,
34 | 16: 15,
35 | 17: 16,
36 | 18: 17,
37 | 19: 18,
38 | 20: 19,
39 | 21: 20,
40 | 22: 21,
41 | 23: 22,
42 | 24: 23,
43 | 26: 24,
44 | 27: 25,
45 | 30: 26,
46 | 31: 27,
47 | 32: 28,
48 | 33: 29,
49 | 34: 30,
50 | 35: 31,
51 | 36: 32,
52 | 37: 33,
53 | 38: 34,
54 | 39: 35,
55 | 40: 36,
56 | 41: 37,
57 | 42: 38,
58 | 43: 39,
59 | 45: 40,
60 | 46: 41,
61 | 47: 42,
62 | 48: 43,
63 | 49: 44,
64 | 50: 45,
65 | 51: 46,
66 | 52: 47,
67 | 53: 48,
68 | 54: 49,
69 | 55: 50,
70 | 56: 51,
71 | 57: 52,
72 | 58: 53,
73 | 59: 54,
74 | 60: 55,
75 | 61: 56,
76 | 62: 57,
77 | 63: 58,
78 | 64: 59,
79 | 66: 60,
80 | 69: 61,
81 | 71: 62,
82 | 72: 63,
83 | 73: 64,
84 | 74: 65,
85 | 75: 66,
86 | 76: 67,
87 | 77: 68,
88 | 78: 69,
89 | 79: 70,
90 | 80: 71,
91 | 81: 72,
92 | 83: 73,
93 | 84: 74,
94 | 85: 75,
95 | 86: 76,
96 | 87: 77,
97 | 88: 78,
98 | 89: 79,
99 | 91: 80,
100 | 92: 81,
101 | 93: 82,
102 | 94: 83,
103 | 95: 84,
104 | 96: 85,
105 | 97: 86,
106 | 98: 87,
107 | 99: 88,
108 | 100: 89,
109 | 101: 90,
110 | 102: 91,
111 | 103: 92,
112 | 104: 93,
113 | 105: 94,
114 | 106: 95,
115 | 107: 96,
116 | 108: 97,
117 | 109: 98,
118 | 110: 99,
119 | 111: 100,
120 | 112: 101,
121 | 113: 102,
122 | 114: 103,
123 | 115: 104,
124 | 116: 105,
125 | 117: 106,
126 | 118: 107,
127 | 119: 108,
128 | 120: 109,
129 | 121: 110,
130 | 122: 111,
131 | 123: 112,
132 | 124: 113,
133 | 125: 114,
134 | 126: 115,
135 | 127: 116,
136 | 128: 117,
137 | 129: 118,
138 | 130: 119,
139 | 131: 120,
140 | 132: 121,
141 | 133: 122,
142 | 134: 123,
143 | 135: 124,
144 | 136: 125,
145 | 137: 126,
146 | 138: 127,
147 | 139: 128,
148 | 140: 129,
149 | 141: 130,
150 | 142: 131,
151 | 143: 132,
152 | 144: 133,
153 | 145: 134,
154 | 146: 135,
155 | 147: 136,
156 | 148: 137,
157 | 149: 138,
158 | 150: 139,
159 | 151: 140,
160 | 152: 141,
161 | 153: 142,
162 | 154: 143,
163 | 155: 144,
164 | 156: 145,
165 | 157: 146,
166 | 158: 147,
167 | 159: 148,
168 | 160: 149,
169 | 161: 150,
170 | 162: 151,
171 | 163: 152,
172 | 164: 153,
173 | 165: 154,
174 | 166: 155,
175 | 167: 156,
176 | 168: 157,
177 | 169: 158,
178 | 170: 159,
179 | 171: 160,
180 | 172: 161,
181 | 173: 162,
182 | 174: 163,
183 | 175: 164,
184 | 176: 165,
185 | 177: 166,
186 | 178: 167,
187 | 179: 168,
188 | 180: 169,
189 | 181: 170,
190 | 255: 255
191 | }
192 |
193 | # set to background
194 | for k, v in clsID_to_trID.items():
195 | clsID_to_trID[k] = v + 1
196 | if k > 90:
197 | clsID_to_trID[k] = 0
198 |
199 |
200 | def convert_to_trainID(maskpath, out_mask_dir, is_train):
201 | mask = np.array(Image.open(maskpath))
202 | mask_copy = mask.copy()
203 | for clsID, trID in clsID_to_trID.items():
204 | mask_copy[mask == clsID] = trID
205 | seg_filename = osp.join(
206 | out_mask_dir, 'train2017',
207 | osp.basename(maskpath).split('.')[0] +
208 | '_instanceTrainIds.png') if is_train else osp.join(
209 | out_mask_dir, 'val2017',
210 | osp.basename(maskpath).split('.')[0] + '_instanceTrainIds.png')
211 | Image.fromarray(mask_copy).save(seg_filename, 'PNG')
212 |
213 |
214 | def parse_args():
215 | parser = argparse.ArgumentParser(description='Convert COCO Stuff 164k annotations to COCO Objects') # noqa
216 | parser.add_argument('coco_path', help='coco stuff path')
217 | parser.add_argument('-o', '--out_dir', help='output path')
218 | parser.add_argument(
219 | '--nproc', default=16, type=int, help='number of process')
220 | args = parser.parse_args()
221 | return args
222 |
223 |
224 | def main():
225 | args = parse_args()
226 | coco_path = args.coco_path
227 | nproc = args.nproc
228 |
229 | out_dir = args.out_dir or coco_path
230 | out_img_dir = osp.join(out_dir, 'images')
231 | out_mask_dir = osp.join(out_dir, 'annotations')
232 |
233 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017'))
234 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017'))
235 |
236 | if out_dir != coco_path:
237 | shutil.copytree(osp.join(coco_path, 'images'), out_img_dir)
238 |
239 | train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png'))
240 | train_list = [file for file in train_list if 'TrainIds' not in file]
241 | test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png'))
242 | test_list = [file for file in test_list if 'TrainIds' not in file]
243 | assert (len(train_list) + len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
244 | len(train_list), len(test_list))
245 |
246 | if args.nproc > 1:
247 | mmcv.track_parallel_progress(
248 | partial(
249 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
250 | train_list,
251 | nproc=nproc)
252 | mmcv.track_parallel_progress(
253 | partial(
254 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
255 | test_list,
256 | nproc=nproc)
257 | else:
258 | mmcv.track_progress(
259 | partial(
260 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
261 | train_list)
262 | mmcv.track_progress(
263 | partial(
264 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
265 | test_list)
266 |
267 | print('Done!')
268 |
269 |
270 | if __name__ == '__main__':
271 | main()
272 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import matplotlib.pyplot as plt
3 | from torchvision import transforms
4 | from scclip_segmentor import SCCLIPForSegmentation
5 |
6 | img = Image.open('figs/demo.jpg')
7 | name_list = ['skiiing man', 'tree with snow', 'sky', 'snow']
8 |
9 | with open('my_name.txt', 'w') as writers:
10 | for i in range(len(name_list)):
11 | if i == len(name_list)-1:
12 | writers.write(name_list[i])
13 | else:
14 | writers.write(name_list[i] + '\n')
15 | writers.close()
16 |
17 | img_tensor = transforms.Compose([
18 | transforms.Lambda(lambda img: img.convert('RGB')),
19 | transforms.ToTensor(),
20 | transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]),
21 | ])(img)
22 |
23 | img_tensor = img_tensor.unsqueeze(0).cuda()
24 |
25 | model = SCCLIPForSegmentation(
26 | clip_path='ViT-B/16',
27 | name_path='my_name.txt',
28 | pamr_steps=0,
29 | pamr_stride=(8, 16),
30 | slide_crop=224,
31 | slide_stride=112
32 | )
33 |
34 | seg_pred = model.predict(img_tensor, data_samples=None)
35 | seg_pred = seg_pred.data.cpu().numpy().squeeze(0)
36 |
37 | fig, ax = plt.subplots(1, 3, figsize=(18, 6))
38 | ax[0].imshow(img)
39 | ax[0].axis('off')
40 | ax[1].imshow(seg_pred, cmap='viridis')
41 | ax[1].axis('off')
42 | ax[2].imshow(img)
43 | ax[2].axis('off')
44 | ax[2].imshow(seg_pred, cmap='viridis', alpha=0.8)
45 | plt.tight_layout()
46 | plt.savefig('seg_ours.png', bbox_inches='tight')
--------------------------------------------------------------------------------
/dist_test.sh:
--------------------------------------------------------------------------------
1 | outputs="./outputs/base"
2 |
3 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_voc20.py --work-dir $outputs --launcher pytorch
4 |
5 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_voc21.py --work-dir $outputs --launcher pytorch
6 |
7 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_ade20k.py --work-dir $outputs --launcher pytorch
8 |
9 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_city_scapes.py --work-dir $outputs --launcher pytorch
10 |
11 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_context59.py --work-dir $outputs --launcher pytorch
12 |
13 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_context60.py --work-dir $outputs --launcher pytorch
14 |
15 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_coco_object.py --work-dir $outputs --launcher pytorch
16 |
17 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 eval.py --config configs/cfg_coco_stuff164k.py --work-dir $outputs --launcher pytorch
18 |
19 |
20 | cd $outputs
21 | find . -type f -name "*.log" | while read logfile
22 | do
23 | grep "data_root =" "$logfile"
24 | grep "dataset_type =" "$logfile"
25 | grep -o "mIoU: [0-9.]*" "$logfile"
26 | echo ""
27 | done
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import scclip_segmentor
4 | import custom_datasets
5 |
6 | from mmengine.config import Config
7 | from mmengine.runner import Runner
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(
11 | description='SC-CLIP evaluation with MMSeg')
12 | parser.add_argument('--config', default='')
13 | parser.add_argument('--work-dir', default='./work_logs/')
14 | parser.add_argument(
15 | '--show', action='store_true', help='show prediction results')
16 | parser.add_argument(
17 | '--show_dir',
18 | default='',
19 | help='directory to save visualizaion images')
20 | parser.add_argument(
21 | '--launcher',
22 | choices=['none', 'pytorch', 'slurm', 'mpi'],
23 | default='none',
24 | help='job launcher')
25 | # When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
26 | # will pass the `--local-rank` parameter to `tools/train.py` instead
27 | # of `--local_rank`.
28 | parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
29 | args = parser.parse_args()
30 | if 'LOCAL_RANK' not in os.environ:
31 | os.environ['LOCAL_RANK'] = str(args.local_rank)
32 |
33 | return args
34 |
35 | def trigger_visualization_hook(cfg, args):
36 | default_hooks = cfg.default_hooks
37 | if 'visualization' in default_hooks:
38 | visualization_hook = default_hooks['visualization']
39 | # Turn on visualization
40 | visualization_hook['draw'] = True
41 | if args.show:
42 | visualization_hook['show'] = True
43 | visualization_hook['wait_time'] = args.wait_time
44 | if args.show_dir:
45 | visualizer = cfg.visualizer
46 | visualizer['save_dir'] = args.show_dir
47 | else:
48 | raise RuntimeError(
49 | 'VisualizationHook must be included in default_hooks.'
50 | 'refer to usage '
51 | '"visualization=dict(type=\'VisualizationHook\')"')
52 |
53 | return cfg
54 |
55 | def main():
56 | args = parse_args()
57 |
58 | cfg = Config.fromfile(args.config)
59 | cfg.launcher = args.launcher
60 | cfg.work_dir = args.work_dir
61 |
62 | runner = Runner.from_cfg(cfg)
63 | runner.test()
64 |
65 | if __name__ == '__main__':
66 | main()
--------------------------------------------------------------------------------
/figs/demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SuleBai/SC-CLIP/0417ba92851e9dd7432d608f10a0804d01a23062/figs/demo.jpg
--------------------------------------------------------------------------------
/figs/scclip.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SuleBai/SC-CLIP/0417ba92851e9dd7432d608f10a0804d01a23062/figs/scclip.jpg
--------------------------------------------------------------------------------
/pamr.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 TU Darmstadt
2 | # Licnese: Apache 2.0 License.
3 | # https://github.com/visinf/1-stage-wseg/blob/master/models/mods/pamr.py
4 | import torch
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 |
8 | from functools import partial
9 |
10 | #
11 | # Helper modules
12 | #
13 | class LocalAffinity(nn.Module):
14 |
15 | def __init__(self, dilations=[1]):
16 | super(LocalAffinity, self).__init__()
17 | self.dilations = dilations
18 | weight = self._init_aff()
19 | self.register_buffer('kernel', weight)
20 |
21 | def _init_aff(self):
22 | # initialising the shift kernel
23 | weight = torch.zeros(8, 1, 3, 3)
24 |
25 | for i in range(weight.size(0)):
26 | weight[i, 0, 1, 1] = 1
27 |
28 | weight[0, 0, 0, 0] = -1
29 | weight[1, 0, 0, 1] = -1
30 | weight[2, 0, 0, 2] = -1
31 |
32 | weight[3, 0, 1, 0] = -1
33 | weight[4, 0, 1, 2] = -1
34 |
35 | weight[5, 0, 2, 0] = -1
36 | weight[6, 0, 2, 1] = -1
37 | weight[7, 0, 2, 2] = -1
38 |
39 | self.weight_check = weight.clone()
40 |
41 | return weight
42 |
43 | def forward(self, x):
44 |
45 | self.weight_check = self.weight_check.type_as(x)
46 | assert torch.all(self.weight_check.eq(self.kernel))
47 |
48 | B,K,H,W = x.size()
49 | x = x.view(B*K,1,H,W)
50 |
51 | x_affs = []
52 | for d in self.dilations:
53 | x_pad = F.pad(x, [d]*4, mode='replicate')
54 | x_aff = F.conv2d(x_pad, self.kernel, dilation=d)
55 | x_affs.append(x_aff)
56 |
57 | x_aff = torch.cat(x_affs, 1)
58 | return x_aff.view(B,K,-1,H,W)
59 |
60 | class LocalAffinityCopy(LocalAffinity):
61 |
62 | def _init_aff(self):
63 | # initialising the shift kernel
64 | weight = torch.zeros(8, 1, 3, 3)
65 |
66 | weight[0, 0, 0, 0] = 1
67 | weight[1, 0, 0, 1] = 1
68 | weight[2, 0, 0, 2] = 1
69 |
70 | weight[3, 0, 1, 0] = 1
71 | weight[4, 0, 1, 2] = 1
72 |
73 | weight[5, 0, 2, 0] = 1
74 | weight[6, 0, 2, 1] = 1
75 | weight[7, 0, 2, 2] = 1
76 |
77 | self.weight_check = weight.clone()
78 | return weight
79 |
80 | class LocalStDev(LocalAffinity):
81 |
82 | def _init_aff(self):
83 | weight = torch.zeros(9, 1, 3, 3)
84 | weight.zero_()
85 |
86 | weight[0, 0, 0, 0] = 1
87 | weight[1, 0, 0, 1] = 1
88 | weight[2, 0, 0, 2] = 1
89 |
90 | weight[3, 0, 1, 0] = 1
91 | weight[4, 0, 1, 1] = 1
92 | weight[5, 0, 1, 2] = 1
93 |
94 | weight[6, 0, 2, 0] = 1
95 | weight[7, 0, 2, 1] = 1
96 | weight[8, 0, 2, 2] = 1
97 |
98 | self.weight_check = weight.clone()
99 | return weight
100 |
101 | def forward(self, x):
102 | # returns (B,K,P,H,W), where P is the number
103 | # of locations
104 | x = super(LocalStDev, self).forward(x)
105 |
106 | return x.std(2, keepdim=True)
107 |
108 | class LocalAffinityAbs(LocalAffinity):
109 |
110 | def forward(self, x):
111 | x = super(LocalAffinityAbs, self).forward(x)
112 | return torch.abs(x)
113 |
114 | #
115 | # PAMR module
116 | #
117 | class PAMR(nn.Module):
118 |
119 | def __init__(self, num_iter=1, dilations=[1]):
120 | super(PAMR, self).__init__()
121 |
122 | self.num_iter = num_iter
123 | self.aff_x = LocalAffinityAbs(dilations)
124 | self.aff_m = LocalAffinityCopy(dilations)
125 | self.aff_std = LocalStDev(dilations)
126 |
127 | def forward(self, x, mask):
128 | mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True)
129 |
130 | # x: [BxKxHxW]
131 | # mask: [BxCxHxW]
132 | B,K,H,W = x.size()
133 | _,C,_,_ = mask.size()
134 |
135 | x_std = self.aff_std(x)
136 |
137 | x = -self.aff_x(x) / (1e-8 + 0.1 * x_std)
138 | x = x.mean(1, keepdim=True)
139 | x = F.softmax(x, 2)
140 |
141 | for _ in range(self.num_iter):
142 | m = self.aff_m(mask) # [BxCxPxHxW]
143 | mask = (m * x).sum(2)
144 |
145 | # xvals: [BxCxHxW]
146 | return mask
--------------------------------------------------------------------------------
/prompts/imagenet_template.py:
--------------------------------------------------------------------------------
1 |
2 | imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
3 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
4 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
5 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
6 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
7 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
8 | "box turtle", "banded gecko", "green iguana", "Carolina anole",
9 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
10 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
11 | "American alligator", "triceratops", "worm snake", "ring-necked snake",
12 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
13 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
14 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
15 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
16 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
17 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
18 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
19 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
20 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
21 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
22 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
23 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
24 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
25 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
26 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
27 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
28 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
29 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
30 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
31 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
32 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
33 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
34 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
35 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
36 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
37 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
38 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
39 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
40 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
41 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
42 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
43 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
44 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
45 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
46 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
47 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
48 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
49 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
50 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
51 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
52 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
53 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
54 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
55 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
56 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
57 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
58 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
59 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
60 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
61 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
62 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
63 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
64 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
65 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
66 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
67 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
68 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
69 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
70 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
71 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
72 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
73 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
74 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
75 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
76 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
77 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
78 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
79 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
80 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
81 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
82 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
83 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
84 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
85 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
86 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
87 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
88 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
89 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
90 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
91 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
92 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
93 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
94 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
95 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
96 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
97 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
98 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
99 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
100 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
101 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck",
102 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
103 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
104 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
105 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
106 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
107 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
108 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
109 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
110 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
111 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
112 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
113 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
114 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
115 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
116 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
117 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
118 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
119 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
120 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
121 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
122 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
123 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
124 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
125 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
126 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
127 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
128 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
129 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
130 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
131 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
132 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
133 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
134 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
135 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
136 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
137 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
138 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
139 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
140 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
141 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
142 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
143 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
144 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
145 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
146 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
147 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
148 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
149 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
150 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
151 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
152 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
153 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
154 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
155 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
156 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
157 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
158 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
159 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
160 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
161 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
162 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
163 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
164 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
165 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
166 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
167 |
168 |
169 | openai_imagenet_template = [
170 | lambda c: f'a bad photo of a {c}.',
171 | lambda c: f'a photo of many {c}.',
172 | lambda c: f'a sculpture of a {c}.',
173 | lambda c: f'a photo of the hard to see {c}.',
174 | lambda c: f'a low resolution photo of the {c}.',
175 | lambda c: f'a rendering of a {c}.',
176 | lambda c: f'graffiti of a {c}.',
177 | lambda c: f'a bad photo of the {c}.',
178 | lambda c: f'a cropped photo of the {c}.',
179 | lambda c: f'a tattoo of a {c}.',
180 | lambda c: f'the embroidered {c}.',
181 | lambda c: f'a photo of a hard to see {c}.',
182 | lambda c: f'a bright photo of a {c}.',
183 | lambda c: f'a photo of a clean {c}.',
184 | lambda c: f'a photo of a dirty {c}.',
185 | lambda c: f'a dark photo of the {c}.',
186 | lambda c: f'a drawing of a {c}.',
187 | lambda c: f'a photo of my {c}.',
188 | lambda c: f'the plastic {c}.',
189 | lambda c: f'a photo of the cool {c}.',
190 | lambda c: f'a close-up photo of a {c}.',
191 | lambda c: f'a black and white photo of the {c}.',
192 | lambda c: f'a painting of the {c}.',
193 | lambda c: f'a painting of a {c}.',
194 | lambda c: f'a pixelated photo of the {c}.',
195 | lambda c: f'a sculpture of the {c}.',
196 | lambda c: f'a bright photo of the {c}.',
197 | lambda c: f'a cropped photo of a {c}.',
198 | lambda c: f'a plastic {c}.',
199 | lambda c: f'a photo of the dirty {c}.',
200 | lambda c: f'a jpeg corrupted photo of a {c}.',
201 | lambda c: f'a blurry photo of the {c}.',
202 | lambda c: f'a photo of the {c}.',
203 | lambda c: f'a good photo of the {c}.',
204 | lambda c: f'a rendering of the {c}.',
205 | lambda c: f'a {c} in a video game.',
206 | lambda c: f'a photo of one {c}.',
207 | lambda c: f'a doodle of a {c}.',
208 | lambda c: f'a close-up photo of the {c}.',
209 | lambda c: f'a photo of a {c}.',
210 | lambda c: f'the origami {c}.',
211 | lambda c: f'the {c} in a video game.',
212 | lambda c: f'a sketch of a {c}.',
213 | lambda c: f'a doodle of the {c}.',
214 | lambda c: f'a origami {c}.',
215 | lambda c: f'a low resolution photo of a {c}.',
216 | lambda c: f'the toy {c}.',
217 | lambda c: f'a rendition of the {c}.',
218 | lambda c: f'a photo of the clean {c}.',
219 | lambda c: f'a photo of a large {c}.',
220 | lambda c: f'a rendition of a {c}.',
221 | lambda c: f'a photo of a nice {c}.',
222 | lambda c: f'a photo of a weird {c}.',
223 | lambda c: f'a blurry photo of a {c}.',
224 | lambda c: f'a cartoon {c}.',
225 | lambda c: f'art of a {c}.',
226 | lambda c: f'a sketch of the {c}.',
227 | lambda c: f'a embroidered {c}.',
228 | lambda c: f'a pixelated photo of a {c}.',
229 | lambda c: f'itap of the {c}.',
230 | lambda c: f'a jpeg corrupted photo of the {c}.',
231 | lambda c: f'a good photo of a {c}.',
232 | lambda c: f'a plushie {c}.',
233 | lambda c: f'a photo of the nice {c}.',
234 | lambda c: f'a photo of the small {c}.',
235 | lambda c: f'a photo of the weird {c}.',
236 | lambda c: f'the cartoon {c}.',
237 | lambda c: f'art of the {c}.',
238 | lambda c: f'a drawing of the {c}.',
239 | lambda c: f'a photo of the large {c}.',
240 | lambda c: f'a black and white photo of a {c}.',
241 | lambda c: f'the plushie {c}.',
242 | lambda c: f'a dark photo of a {c}.',
243 | lambda c: f'itap of a {c}.',
244 | lambda c: f'graffiti of the {c}.',
245 | lambda c: f'a toy {c}.',
246 | lambda c: f'itap of my {c}.',
247 | lambda c: f'a photo of a cool {c}.',
248 | lambda c: f'a photo of a small {c}.',
249 | lambda c: f'a tattoo of the {c}.',
250 | ]
--------------------------------------------------------------------------------
/scclip_segmentor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import sys
4 | sys.path.append("..")
5 |
6 | import clip
7 | from prompts.imagenet_template import openai_imagenet_template
8 |
9 | from mmseg.models.segmentors import BaseSegmentor
10 | from mmseg.models.data_preprocessor import SegDataPreProcessor
11 | from mmengine.structures import PixelData
12 |
13 | from mmseg.registry import MODELS
14 |
15 | from pamr import PAMR
16 |
17 | @MODELS.register_module()
18 | class SCCLIPForSegmentation(BaseSegmentor):
19 | def __init__(self, clip_path, name_path, device=torch.device('cuda'),
20 | pamr_steps=0, pamr_stride=(8, 16), prob_thd=0.0, logit_scale=40,
21 | slide_stride=112, slide_crop=224, area_thd=None,
22 | pre_adjust_idx=8, post_adjust_idx=3, multi_start_idx=3, multi_end_idx=10, res_cls=0.3):
23 |
24 | data_preprocessor = SegDataPreProcessor(
25 | mean=[122.771, 116.746, 104.094],
26 | std=[68.501, 66.632, 70.323],
27 | rgb_to_bgr=True)
28 | super().__init__(data_preprocessor=data_preprocessor)
29 | self.net, _ = clip.load(clip_path, device=device, jit=False)
30 |
31 | self.net.visual.pre_adjust_idx = pre_adjust_idx
32 | self.net.visual.post_adjust_idx = post_adjust_idx
33 | self.net.visual.multi_start_idx = multi_start_idx
34 | self.net.visual.multi_end_idx = multi_end_idx
35 | self.net.visual.res_cls = res_cls
36 |
37 | query_words, self.query_idx = get_cls_idx(name_path)
38 | self.num_queries = len(query_words)
39 | self.num_classes = max(self.query_idx) + 1
40 | self.query_idx = torch.Tensor(self.query_idx).to(torch.int64).to(device)
41 |
42 | query_features = []
43 | with torch.no_grad():
44 | for qw in query_words:
45 | query = clip.tokenize([temp(qw) for temp in openai_imagenet_template]).to(device)
46 | feature = self.net.encode_text(query)
47 | feature /= feature.norm(dim=-1, keepdim=True)
48 | feature = feature.mean(dim=0)
49 | feature /= feature.norm()
50 | query_features.append(feature.unsqueeze(0))
51 | self.query_features = torch.cat(query_features, dim=0)
52 |
53 | self.dtype = self.query_features.dtype
54 | self.logit_scale = logit_scale
55 | self.prob_thd = prob_thd
56 | self.area_thd = area_thd
57 | self.slide_stride = slide_stride
58 | self.slide_crop = slide_crop
59 | self.align_corners = False
60 |
61 | if pamr_steps > 0:
62 | self.pamr = PAMR(pamr_steps, dilations=pamr_stride).to(device)
63 | else:
64 | self.pamr = None
65 |
66 | def forward_feature(self, img, logit_size=None):
67 | if type(img) == list:
68 | img = img[0]
69 |
70 | image_features = self.net.encode_image(img, return_all=True)
71 | image_features /= image_features.norm(dim=-1, keepdim=True)
72 | logits = image_features @ self.query_features.T
73 |
74 | patch_size = self.net.visual.patch_size
75 | w, h = img[0].shape[-2] // patch_size, img[0].shape[-1] // patch_size
76 | out_dim = logits.shape[-1]
77 | logits = logits.permute(0, 2, 1).reshape(-1, out_dim, w, h)
78 |
79 | if logit_size == None:
80 | logits = nn.functional.interpolate(logits, size=img.shape[-2:], mode='bilinear', align_corners=False)
81 | else:
82 | logits = nn.functional.interpolate(logits, size=logit_size, mode='bilinear', align_corners=False)
83 |
84 | return logits
85 |
86 | def forward_slide(self, img, img_metas, stride=112, crop_size=224):
87 | """Inference by sliding-window with overlap.
88 | If h_crop > h_img or w_crop > w_img, the small patch will be used to
89 | decode without padding.
90 | """
91 | if type(img) == list:
92 | img = img[0].unsqueeze(0)
93 | if type(stride) == int:
94 | stride = (stride, stride)
95 | if type(crop_size) == int:
96 | crop_size = (crop_size, crop_size)
97 |
98 | h_stride, w_stride = stride
99 | h_crop, w_crop = crop_size
100 | batch_size, _, h_img, w_img = img.shape
101 | out_channels = self.num_queries
102 | h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
103 | w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
104 | preds = img.new_zeros((batch_size, out_channels, h_img, w_img))
105 | count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
106 | for h_idx in range(h_grids):
107 | for w_idx in range(w_grids):
108 | y1 = h_idx * h_stride
109 | x1 = w_idx * w_stride
110 | y2 = min(y1 + h_crop, h_img)
111 | x2 = min(x1 + w_crop, w_img)
112 | y1 = max(y2 - h_crop, 0)
113 | x1 = max(x2 - w_crop, 0)
114 | crop_img = img[:, :, y1:y2, x1:x2]
115 | crop_seg_logit = self.forward_feature(crop_img)
116 | preds += nn.functional.pad(crop_seg_logit,
117 | (int(x1), int(preds.shape[3] - x2), int(y1),
118 | int(preds.shape[2] - y2)))
119 |
120 | count_mat[:, :, y1:y2, x1:x2] += 1
121 | assert (count_mat == 0).sum() == 0
122 |
123 | preds = preds / count_mat
124 | img_size = img_metas[0]['ori_shape'][:2]
125 | logits = nn.functional.interpolate(preds, size=img_size, mode='bilinear', align_corners=False)
126 |
127 | if self.pamr:
128 | img = nn.functional.interpolate(img, size=img_size, mode='bilinear')
129 | logits = self.pamr(img, logits.to(img.dtype)).to(self.dtype)
130 |
131 | return logits
132 |
133 | def predict(self, inputs, data_samples):
134 | if data_samples is not None:
135 | batch_img_metas = [
136 | data_sample.metainfo for data_sample in data_samples
137 | ]
138 | else:
139 | batch_img_metas = [
140 | dict(
141 | ori_shape=inputs.shape[2:],
142 | img_shape=inputs.shape[2:],
143 | pad_shape=inputs.shape[2:],
144 | padding_size=[0, 0, 0, 0])
145 | ] * inputs.shape[0]
146 |
147 | if self.slide_crop > 0:
148 | seg_logits = self.forward_slide(inputs, batch_img_metas, self.slide_stride, self.slide_crop)
149 | else:
150 | seg_logits = self.forward_feature(inputs, batch_img_metas[0]['ori_shape'])
151 |
152 | return self.postprocess_result(seg_logits, data_samples)
153 |
154 | def postprocess_result(self, seg_logits, data_samples):
155 | batch_size = seg_logits.shape[0]
156 | for i in range(batch_size):
157 | seg_logits = seg_logits[i] * self.logit_scale
158 | seg_logits = seg_logits.softmax(0) # n_queries * w * h
159 |
160 | num_cls, num_queries = max(self.query_idx) + 1, len(self.query_idx)
161 | if num_cls != num_queries:
162 | seg_logits = seg_logits.unsqueeze(0)
163 | cls_index = nn.functional.one_hot(self.query_idx)
164 | cls_index = cls_index.T.view(num_cls, num_queries, 1, 1)
165 | seg_logits = (seg_logits * cls_index).max(1)[0]
166 | seg_pred = seg_logits.argmax(0, keepdim=True)
167 |
168 | if self.area_thd is not None:
169 | # Force segmentations with area < self.area_thd to 0 (background)
170 | predictions = nn.functional.one_hot(seg_logits.argmax(0), num_cls).to(seg_logits.dtype)
171 | area_pred = predictions[:, :, 1:].sum((0, 1), keepdim=True) # prone background
172 | area_pred = (area_pred > self.area_thd * area_pred.sum()).to(seg_logits.dtype)
173 | seg_logits[1:] *= area_pred.transpose(0, -1)
174 |
175 | seg_pred = seg_logits.argmax(0, keepdim=True)
176 | seg_pred[seg_logits.max(0, keepdim=True)[0] < self.prob_thd] = 0
177 |
178 | if data_samples is None:
179 | return seg_pred
180 | else:
181 | data_samples[i].set_data({
182 | 'seg_logits':
183 | PixelData(**{'data': seg_logits}),
184 | 'pred_sem_seg':
185 | PixelData(**{'data': seg_pred})
186 | })
187 |
188 | return data_samples
189 |
190 | def _forward(data_samples):
191 | """
192 | """
193 |
194 | def inference(self, img, batch_img_metas):
195 | """
196 | """
197 |
198 | def encode_decode(self, inputs, batch_img_metas):
199 | """
200 | """
201 |
202 | def extract_feat(self, inputs):
203 | """
204 | """
205 |
206 | def loss(self, inputs, data_samples):
207 | """
208 | """
209 |
210 | def get_cls_idx(path):
211 | with open(path, 'r') as f:
212 | name_sets = f.readlines()
213 | num_cls = len(name_sets)
214 |
215 | class_names, class_indices = [], []
216 | for idx in range(num_cls):
217 | names_i = name_sets[idx].split(', ')
218 | class_names += names_i
219 | class_indices += [idx for _ in range(len(names_i))]
220 | class_names = [item.replace('\n', '') for item in class_names]
221 | return class_names, class_indices
--------------------------------------------------------------------------------