├── README.md ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── configs ├── base_config.py ├── cfg_cityscapes.py ├── cfg_coco_object.py ├── cfg_coco_stuff164k.py ├── cfg_context60.py ├── cfg_voc21.py ├── cls_cityscapes.txt ├── cls_coco_object.txt ├── cls_coco_stuff.txt ├── cls_context60.txt └── cls_voc21.txt ├── custom_datasets.py ├── datasets └── cvt_coco_object.py ├── demo.ipynb ├── demo.jpg ├── eval.py ├── figs └── overview.png ├── itaclip_segmentor.py ├── llama3_definition_generation.py ├── llama3_synonym_generation.py ├── llama_generated_texts ├── cityscapes_synonyms.txt ├── coco_object_synonyms.txt ├── coco_stuff_definitions.txt ├── context60_definitions.txt └── voc21_definitions.txt ├── pamr.py └── prompts └── imagenet_template.py /README.md: -------------------------------------------------------------------------------- 1 | # ITACLIP: Boosting Training-Free Semantic Segmentation with Image, Text, and Architectural Enhancements [CVPRW 2025] 2 | 3 | [[`paper`](https://arxiv.org/abs/2411.12044)] 4 | 5 | > **Abstract:** *Recent advances in foundational Vision Language Models (VLMs) have reshaped the evaluation paradigm in computer vision tasks. These foundational models, especially CLIP, have accelerated research in open-vocabulary computer vision tasks, including Open-Vocabulary Semantic Segmentation (OVSS). Although the initial results are promising, the dense prediction capabilities of VLMs still require further improvement. In this study, we enhance the semantic segmentation performance of CLIP by introducing new modules and modifications: 1) architectural changes in the last layer of ViT and the incorporation of attention maps from the middle layers with the last layer, 2) Image Engineering: applying data augmentations to enrich input image representations, and 3) using Large Language Models (LLMs) to generate definitions and synonyms for each class name to leverage CLIP's open-vocabulary capabilities. Our training-free method, ITACLIP, outperforms current state-of-the-art approaches on segmentation benchmarks such as COCO-Stuff, COCO-Object, Pascal Context, and Pascal VOC.* 6 | 7 |
8 | 9 |

10 | 11 | ## :tada: News 12 | **`2024/11/18` Our paper and code are publicly available.** 13 | **`2025/04/01` Our paper has been accepted to CVPRW 2025. :tada: :tada:** 14 | 15 | ## Dependencies 16 | Our code is built on top of [MMSegmentation](https://github.com/open-mmlab/mmsegmentation). Please follow the [instructions](https://mmsegmentation.readthedocs.io/en/main/get_started.html) to install MMSegmentation. We used ```Python=3.9.17```, ```torch=2.0.1```, ```mmcv=2.1.0```, and ```mmseg=1.2.2``` in our experiments. 17 | 18 | ## Datasets 19 | We support four segmentation benchmarks: COCO-Stuff, COCO-Object, Pascal Context, and Pascal VOC. For the dataset preparation, please follow the [MMSeg Dataset Preparation document](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md). The COCO-Object dataset can be derived from COCO-Stuff by running the following command 20 | 21 | ``` 22 | python datasets/cvt_coco_object.py PATH_TO_COCO_STUFF164K -o PATH_TO_COCO164K 23 | ``` 24 | 25 | Additional datasets can be seamlessly integrated following the same dataset preparation document. Please modify the dataset (```data_root```) and class name (```name_path```) paths in the config files. 26 | 27 | ## LLaMa Generated Texts 28 | For reproducibility, we provide the LLM-generated auxiliary texts. Please update the auxiliary path (```auxiliary_text_path```) in the config files. We also provide the definition and synonym generation codes (```llama3_definition_generation.py```and ```llama3_synonym_generation.py```). For the supported datasets, running these files is unnecessary, as we have already included the LLaMA-generated texts. 29 | ## Evaluation 30 | To evaluate ITACLIP on a dataset, run the following command updating the dataset_name. 31 | ``` 32 | python eval.py --config ./configs/cfg_{dataset_name}.py 33 | ``` 34 | ## Demo 35 | To evaluate ITACLIP on a single image, run the ```demo.ipynb``` Jupyter Notebook 36 | ## Results 37 | With the default configurations, you should achieve the following results (mIoU). 38 | 39 | | Dataset | mIoU | 40 | | --------------------- | ----- | 41 | | COCO-Stuff | 27.0 | 42 | | COCO-Object | 37.7 | 43 | | PASCAL VOC | 67.9 | 44 | | PASCAL Context | 37.5 | 45 | | Cityscapes | 40.2 | 46 | 47 | ## Citation 48 | If you find our project helpful, please consider citing our work. 49 | 50 | ``` 51 | @article{aydin2024itaclip, 52 | title={ITACLIP: Boosting Training-Free Semantic Segmentation with Image, Text, and Architectural Enhancements}, 53 | author={Ayd{\i}n, M Arda and {\c{C}}{\i}rpar, Efe Mert and Abdinli, Elvin and Unal, Gozde and Sahin, Yusuf H}, 54 | journal={arXiv preprint arXiv:2411.12044}, 55 | year={2024} 56 | } 57 | ``` 58 | 59 | ## Acknowledgments 60 | This implementation builds upon [CLIP](https://github.com/openai/CLIP), [SCLIP](https://github.com/wangf3014/SCLIP), and [MMSegmentation](https://github.com/open-mmlab/mmsegmentation). We gratefully acknowledge their valuable contributions. 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | from .model import * 3 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-arda-aydn/ITACLIP/030dccb8d9524ccecb8598528df583728ff61a15/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | ### CLIP source code from OpenAI: 2 | # https://github.com/openai/CLIP/blob/main/clip/clip.py 3 | 4 | import hashlib 5 | import os 6 | import urllib 7 | import warnings 8 | from typing import Any, Union, List 9 | from pkg_resources import packaging 10 | 11 | import torch 12 | from PIL import Image 13 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 14 | from tqdm import tqdm 15 | 16 | from .model import build_model 17 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 18 | 19 | try: 20 | from torchvision.transforms import InterpolationMode 21 | BICUBIC = InterpolationMode.BICUBIC 22 | except ImportError: 23 | BICUBIC = Image.BICUBIC 24 | 25 | 26 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 27 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 28 | 29 | 30 | __all__ = ["available_models", "load", "tokenize"] 31 | _tokenizer = _Tokenizer() 32 | 33 | _MODELS = { 34 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 35 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 36 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 37 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 38 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 39 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 40 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 41 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 42 | } 43 | 44 | 45 | def _download(url: str, root: str): 46 | os.makedirs(root, exist_ok=True) 47 | filename = os.path.basename(url) 48 | 49 | expected_sha256 = url.split("/")[-2] 50 | download_target = os.path.join(root, filename) 51 | 52 | if os.path.exists(download_target) and not os.path.isfile(download_target): 53 | raise RuntimeError(f"{download_target} exists and is not a regular file") 54 | 55 | if os.path.isfile(download_target): 56 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 57 | return download_target 58 | else: 59 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 60 | 61 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 62 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 63 | while True: 64 | buffer = source.read(8192) 65 | if not buffer: 66 | break 67 | 68 | output.write(buffer) 69 | loop.update(len(buffer)) 70 | 71 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 72 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 73 | 74 | return download_target 75 | 76 | 77 | def _convert_image_to_rgb(image): 78 | return image.convert("RGB") 79 | 80 | 81 | def _transform(n_px): 82 | return Compose([ 83 | Resize(n_px, interpolation=BICUBIC), 84 | CenterCrop(n_px), 85 | _convert_image_to_rgb, 86 | ToTensor(), 87 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 88 | ]) 89 | 90 | 91 | def available_models() -> List[str]: 92 | """Returns the names of available CLIP models""" 93 | return list(_MODELS.keys()) 94 | 95 | 96 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 97 | """Load a CLIP model 98 | 99 | Parameters 100 | ---------- 101 | name : str 102 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 103 | 104 | device : Union[str, torch.device] 105 | The device to put the loaded model 106 | 107 | jit : bool 108 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 109 | 110 | download_root: str 111 | path to download the model files; by default, it uses "~/.cache/clip" 112 | 113 | Returns 114 | ------- 115 | model : torch.nn.Module 116 | The CLIP model 117 | 118 | preprocess : Callable[[PIL.Image], torch.Tensor] 119 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 120 | """ 121 | if name in _MODELS: 122 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 123 | elif os.path.isfile(name): 124 | model_path = name 125 | else: 126 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 127 | 128 | try: 129 | # loading JIT archive 130 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 131 | state_dict = None 132 | except RuntimeError: 133 | # loading saved state dict 134 | if jit: 135 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 136 | jit = False 137 | state_dict = torch.load(model_path, map_location="cpu") 138 | 139 | if not jit: 140 | model = build_model(state_dict or model.state_dict()).to(device) 141 | if str(device) == "cpu": 142 | model.float() 143 | return model, _transform(model.visual.input_resolution) 144 | 145 | # patch the device names 146 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 147 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 148 | 149 | def patch_device(module): 150 | try: 151 | graphs = [module.graph] if hasattr(module, "graph") else [] 152 | except RuntimeError: 153 | graphs = [] 154 | 155 | if hasattr(module, "forward1"): 156 | graphs.append(module.forward1.graph) 157 | 158 | for graph in graphs: 159 | for node in graph.findAllNodes("prim::Constant"): 160 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 161 | node.copyAttributes(device_node) 162 | 163 | model.apply(patch_device) 164 | patch_device(model.encode_image) 165 | patch_device(model.encode_text) 166 | 167 | # patch dtype to float32 on CPU 168 | if str(device) == "cpu": 169 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 170 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 171 | float_node = float_input.node() 172 | 173 | def patch_float(module): 174 | try: 175 | graphs = [module.graph] if hasattr(module, "graph") else [] 176 | except RuntimeError: 177 | graphs = [] 178 | 179 | if hasattr(module, "forward1"): 180 | graphs.append(module.forward1.graph) 181 | 182 | for graph in graphs: 183 | for node in graph.findAllNodes("aten::to"): 184 | inputs = list(node.inputs()) 185 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 186 | if inputs[i].node()["value"] == 5: 187 | inputs[i].node().copyAttributes(float_node) 188 | 189 | model.apply(patch_float) 190 | patch_float(model.encode_image) 191 | patch_float(model.encode_text) 192 | 193 | model.float() 194 | 195 | return model, _transform(model.input_resolution.item()) 196 | 197 | 198 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 199 | """ 200 | Returns the tokenized representation of given input string(s) 201 | 202 | Parameters 203 | ---------- 204 | texts : Union[str, List[str]] 205 | An input string or a list of input strings to tokenize 206 | 207 | context_length : int 208 | The context length to use; all CLIP models use 77 as the context length 209 | 210 | truncate: bool 211 | Whether to truncate the text in case its encoding is longer than the context length 212 | 213 | Returns 214 | ------- 215 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 216 | """ 217 | if isinstance(texts, str): 218 | texts = [texts] 219 | 220 | sot_token = _tokenizer.encoder["<|startoftext|>"] 221 | eot_token = _tokenizer.encoder["<|endoftext|>"] 222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 223 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 224 | 225 | for i, tokens in enumerate(all_tokens): 226 | if len(tokens) > context_length: 227 | if truncate: 228 | tokens = tokens[:context_length] 229 | tokens[-1] = eot_token 230 | else: 231 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 232 | result[i, :len(tokens)] = torch.tensor(tokens) 233 | 234 | return result -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | ### CLIP source code from OpenAI: 2 | # https://github.com/openai/CLIP/blob/main/clip/clip.py 3 | 4 | from collections import OrderedDict 5 | from typing import Tuple, Union 6 | import math 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | 12 | import torchvision.transforms.functional as VF 13 | 14 | class Bottleneck(nn.Module): 15 | expansion = 4 16 | 17 | def __init__(self, inplanes, planes, stride=1): 18 | super().__init__() 19 | 20 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 21 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | 24 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | 27 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 28 | 29 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 30 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 31 | 32 | self.relu = nn.ReLU(inplace=True) 33 | self.downsample = None 34 | self.stride = stride 35 | 36 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 37 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 38 | self.downsample = nn.Sequential(OrderedDict([ 39 | ("-1", nn.AvgPool2d(stride)), 40 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 41 | ("1", nn.BatchNorm2d(planes * self.expansion)) 42 | ])) 43 | 44 | def forward(self, x: torch.Tensor): 45 | identity = x 46 | 47 | out = self.relu(self.bn1(self.conv1(x))) 48 | out = self.relu(self.bn2(self.conv2(out))) 49 | out = self.avgpool(out) 50 | out = self.bn3(self.conv3(out)) 51 | 52 | if self.downsample is not None: 53 | identity = self.downsample(x) 54 | 55 | out += identity 56 | out = self.relu(out) 57 | return out 58 | 59 | 60 | class AttentionPool2d(nn.Module): 61 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 62 | super().__init__() 63 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 64 | self.k_proj = nn.Linear(embed_dim, embed_dim) 65 | self.q_proj = nn.Linear(embed_dim, embed_dim) 66 | self.v_proj = nn.Linear(embed_dim, embed_dim) 67 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 68 | self.num_heads = num_heads 69 | 70 | def forward(self, x, return_all_tokens=False): 71 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 72 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 73 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 74 | x, _ = F.multi_head_attention_forward( 75 | query=x, key=x, value=x, 76 | embed_dim_to_check=x.shape[-1], 77 | num_heads=self.num_heads, 78 | q_proj_weight=self.q_proj.weight, 79 | k_proj_weight=self.k_proj.weight, 80 | v_proj_weight=self.v_proj.weight, 81 | in_proj_weight=None, 82 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 83 | bias_k=None, 84 | bias_v=None, 85 | add_zero_attn=False, 86 | dropout_p=0, 87 | out_proj_weight=self.c_proj.weight, 88 | out_proj_bias=self.c_proj.bias, 89 | use_separate_proj_weight=True, 90 | training=self.training, 91 | need_weights=False 92 | ) 93 | if return_all_tokens: 94 | return x 95 | else: 96 | return x[0] 97 | 98 | 99 | class ModifiedResNet(nn.Module): 100 | """ 101 | A ResNet class that is similar to torchvision's but contains the following changes: 102 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 103 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 104 | - The final pooling layer is a QKV attention instead of an average pool 105 | """ 106 | 107 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 108 | super().__init__() 109 | self.output_dim = output_dim 110 | self.input_resolution = input_resolution 111 | 112 | # the 3-layer stem 113 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 114 | self.bn1 = nn.BatchNorm2d(width // 2) 115 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 116 | self.bn2 = nn.BatchNorm2d(width // 2) 117 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 118 | self.bn3 = nn.BatchNorm2d(width) 119 | self.avgpool = nn.AvgPool2d(2) 120 | self.relu = nn.ReLU(inplace=True) 121 | 122 | # residual layers 123 | self._inplanes = width # this is a *mutable* variable used during construction 124 | self.layer1 = self._make_layer(width, layers[0]) 125 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 126 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 127 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 128 | 129 | embed_dim = width * 32 # the ResNet feature dimension 130 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x, return_all_tokens=False): 142 | def stem(x): 143 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 144 | x = self.relu(bn(conv(x))) 145 | x = self.avgpool(x) 146 | return x 147 | 148 | x = x.type(self.conv1.weight.dtype) 149 | x = stem(x) 150 | x = self.layer1(x) 151 | x = self.layer2(x) 152 | x = self.layer3(x) 153 | x = self.layer4(x) 154 | x = self.attnpool(x, return_all_tokens) 155 | 156 | return x 157 | 158 | 159 | class LayerNorm(nn.LayerNorm): 160 | """Subclass torch's LayerNorm to handle fp16.""" 161 | 162 | def forward(self, x: torch.Tensor): 163 | orig_type = x.dtype 164 | ret = super().forward(x.type(torch.float32)) 165 | return ret.type(orig_type) 166 | 167 | 168 | class QuickGELU(nn.Module): 169 | def forward(self, x: torch.Tensor): 170 | return x * torch.sigmoid(1.702 * x) 171 | 172 | 173 | class ResidualAttentionBlock(nn.Module): 174 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 175 | super().__init__() 176 | 177 | self.attn = nn.MultiheadAttention(d_model, n_head) 178 | self.ln_1 = LayerNorm(d_model) 179 | self.mlp = nn.Sequential(OrderedDict([ 180 | ("c_fc", nn.Linear(d_model, d_model * 4)), 181 | ("gelu", QuickGELU()), 182 | ("c_proj", nn.Linear(d_model * 4, d_model)) 183 | ])) 184 | self.ln_2 = LayerNorm(d_model) 185 | self.attn_mask = attn_mask 186 | 187 | def attention(self, x: torch.Tensor): 188 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 189 | # pdb.set_trace() 190 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 191 | 192 | def forward(self, x: torch.Tensor): 193 | x = x + self.attention(self.ln_1(x)) 194 | x = x + self.mlp(self.ln_2(x)) 195 | return x 196 | 197 | 198 | class Transformer(nn.Module): 199 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 200 | super().__init__() 201 | self.width = width 202 | self.layers = layers 203 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 204 | 205 | def forward(self, x: torch.Tensor): 206 | return self.resblocks(x) 207 | 208 | 209 | class VisionTransformer(nn.Module): 210 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 211 | super().__init__() 212 | self.input_resolution = input_resolution 213 | self.patch_size = patch_size 214 | self.output_dim = output_dim 215 | self.width = width 216 | self.heads = heads 217 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 218 | 219 | scale = width ** -0.5 220 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 221 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 222 | self.ln_pre = LayerNorm(width) 223 | 224 | self.transformer = Transformer(width, layers, heads) 225 | 226 | self.ln_post = LayerNorm(width) 227 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 228 | self.addition_cache = dict() 229 | 230 | def forward(self, x: torch.Tensor, device, return_all=False, attn_self=False, return_attn_map = False, layer='final'): 231 | x = x.to(device) 232 | _, _, w, h = x.shape 233 | if return_attn_map: 234 | return self.get_attn(x, layer=layer, attn_self=attn_self) 235 | 236 | x = self.conv1(x) # shape = [*, width, grid, grid] 237 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 238 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 239 | 240 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 241 | 242 | if x.shape[1] != self.positional_embedding.shape[0]: 243 | x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype) 244 | else: 245 | x = x + self.positional_embedding.to(x.dtype) 246 | 247 | x = self.ln_pre(x) 248 | 249 | x = x.permute(1, 0, 2) # NLD -> LND 250 | 251 | layer_count = 0 252 | attn_list = [] 253 | selected_intermediate_layers = [7,8,10] 254 | for blk in self.transformer.resblocks[:-1]: 255 | layer_count += 1 256 | if layer_count in selected_intermediate_layers: 257 | saved_attn = self.custom_attn(blk.attn, blk.ln_1(x), return_attn=True, attn_self=attn_self) 258 | attn_list.append(saved_attn) # attention maps from intermediate layers 259 | x = blk(x) 260 | else: 261 | x = blk(x) 262 | 263 | avg_attn = torch.mean(torch.stack(attn_list), dim=0) 264 | for blk in self.transformer.resblocks[-1:]: 265 | custom_attn = self.custom_attn(blk.attn, blk.ln_1(x), return_attn=True, attn_self=attn_self) 266 | avg_attn = 0.5 * custom_attn + 0.5 * avg_attn 267 | x = x + self.use_saved_attn(blk.attn, blk.ln_1(x), avg_attn) 268 | 269 | x = x.permute(1, 0, 2) # LND -> NLD 270 | 271 | if return_all: 272 | return self.ln_post(x) @ self.proj 273 | 274 | x = self.ln_post(x[:, 0, :]) 275 | if self.proj is not None: 276 | x = x @ self.proj 277 | 278 | return x 279 | 280 | def interpolate_pos_encoding(self, x, w, h): 281 | npatch = x.shape[1] - 1 282 | N = self.positional_embedding.shape[0] - 1 283 | if npatch == N and w == h: 284 | return self.positional_embedding 285 | class_pos_embed = self.positional_embedding[[0]] 286 | patch_pos_embed = self.positional_embedding[1:] 287 | dim = x.shape[-1] 288 | w0 = w // self.patch_size 289 | h0 = h // self.patch_size 290 | w0, h0 = w0 + 0.1, h0 + 0.1 291 | patch_pos_embed = nn.functional.interpolate( 292 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 293 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 294 | mode='bicubic', 295 | ) 296 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 297 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 298 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 299 | 300 | def use_saved_attn(self, attn_layer, x, saved_attn): 301 | num_heads = attn_layer.num_heads 302 | _, bsz, embed_dim = x.size() 303 | head_dim = embed_dim // num_heads 304 | _, _, v = F.linear(x, attn_layer.in_proj_weight, attn_layer.in_proj_bias).chunk(3, dim=-1) 305 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 306 | 307 | attn_output = torch.bmm(saved_attn, v) 308 | attn_output = attn_output.transpose(0, 1).contiguous().view(-1, bsz, embed_dim) 309 | attn_output = attn_layer.out_proj(attn_output) 310 | 311 | return attn_output 312 | 313 | def custom_attn(self, attn_layer, x, return_attn=False, with_attn=False, attn_self=False): 314 | 315 | num_heads = attn_layer.num_heads 316 | _, bsz, embed_dim = x.size() 317 | head_dim = embed_dim // num_heads 318 | scale = head_dim ** -0.5 319 | 320 | q, k, v = F.linear(x, attn_layer.in_proj_weight, attn_layer.in_proj_bias).chunk(3, dim=-1) 321 | q = q.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 322 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 323 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 324 | 325 | if attn_self: 326 | q_attn = torch.bmm(q, q.transpose(1, 2)) * scale 327 | k_attn = torch.bmm(k, k.transpose(1, 2)) * scale 328 | attn_weights = F.softmax(q_attn, dim=-1) + F.softmax(k_attn, dim=-1) 329 | else: 330 | attn_weights = torch.bmm(q * scale, k.transpose(1, 2)) 331 | attn_weights = F.softmax(attn_weights, dim=-1) 332 | 333 | if return_attn: 334 | return attn_weights 335 | 336 | attn_output = torch.bmm(attn_weights, v) 337 | attn_output = attn_output.transpose(0, 1).contiguous().view(-1, bsz, embed_dim) 338 | attn_output = attn_layer.out_proj(attn_output) 339 | 340 | if with_attn: 341 | return attn_output, attn_weights 342 | 343 | return attn_output 344 | 345 | def get_attn(self, x, layer='all', attn_self=False): 346 | 347 | B, nc, w, h = x.shape 348 | 349 | x = self.conv1(x.type(self.conv1.weight.dtype)) # shape = [*, width, grid, grid] 350 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 351 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 352 | 353 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 354 | 355 | if x.shape[1] != self.positional_embedding.shape[0]: 356 | x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype) 357 | else: 358 | x = x + self.positional_embedding.to(x.dtype) 359 | 360 | x = self.ln_pre(x) 361 | 362 | x = x.permute(1, 0, 2) # NLD -> LND 363 | 364 | if layer == 'final': 365 | for blk in self.transformer.resblocks[:-1]: 366 | x = blk(x) 367 | attn_map = self.custom_attn(self.transformer.resblocks[-1].attn, 368 | self.transformer.resblocks[-1].ln_1(x), 369 | attn_self=attn_self, return_attn=True) 370 | return attn_map 371 | elif layer == 'all': 372 | attn_map = [] 373 | for blk in self.transformer.resblocks[:-1]: 374 | x_i, attn_i = self.custom_attn(blk.attn, blk.ln_1(x), with_attn=True) 375 | x = x + x_i 376 | x = x + blk.mlp(blk.ln_2(x)) 377 | attn_map.append(attn_i) 378 | for blk in self.transformer.resblocks[-1:]: 379 | x_i, attn_i = self.custom_attn(blk.attn, blk.ln_1(x), with_attn=True, attn_self=attn_self) 380 | x = x + x_i 381 | x = x + blk.mlp(blk.ln_2(x)) 382 | attn_map.append(attn_i) 383 | return attn_map 384 | else: 385 | raise ValueError('layer should be final or all') 386 | 387 | 388 | class CLIP(nn.Module): 389 | def __init__(self, 390 | embed_dim: int, # 512 391 | # vision 392 | image_resolution: int, # 224 393 | vision_layers: Union[Tuple[int, int, int, int], int], # 12 394 | vision_width: int, # 768 395 | vision_patch_size: int, # 16 396 | # text 397 | context_length: int, # 77 398 | vocab_size: int, # 49408 399 | transformer_width: int, # 512 400 | transformer_heads: int, # 8 401 | transformer_layers: int # 12 402 | ): 403 | super().__init__() 404 | self.context_length = context_length 405 | 406 | if isinstance(vision_layers, (tuple, list)): 407 | vision_heads = vision_width * 32 // 64 408 | self.visual = ModifiedResNet( 409 | layers=vision_layers, 410 | output_dim=embed_dim, 411 | heads=vision_heads, 412 | input_resolution=image_resolution, 413 | width=vision_width 414 | ) 415 | else: 416 | vision_heads = vision_width // 64 417 | self.visual = VisionTransformer( 418 | input_resolution=image_resolution, 419 | patch_size=vision_patch_size, 420 | width=vision_width, 421 | layers=vision_layers, 422 | heads=vision_heads, 423 | output_dim=embed_dim 424 | ) 425 | 426 | self.transformer = Transformer( 427 | width=transformer_width, 428 | layers=transformer_layers, 429 | heads=transformer_heads, 430 | attn_mask=self.build_attention_mask() 431 | ) 432 | 433 | self.vocab_size = vocab_size 434 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 435 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 436 | self.ln_final = LayerNorm(transformer_width) 437 | 438 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 439 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 440 | 441 | self.initialize_parameters() 442 | 443 | def initialize_parameters(self): 444 | nn.init.normal_(self.token_embedding.weight, std=0.02) 445 | nn.init.normal_(self.positional_embedding, std=0.01) 446 | 447 | if isinstance(self.visual, ModifiedResNet): 448 | if self.visual.attnpool is not None: 449 | std = self.visual.attnpool.c_proj.in_features ** -0.5 450 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 451 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 452 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 453 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 454 | 455 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 456 | for name, param in resnet_block.named_parameters(): 457 | if name.endswith("bn3.weight"): 458 | nn.init.zeros_(param) 459 | 460 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 461 | attn_std = self.transformer.width ** -0.5 462 | fc_std = (2 * self.transformer.width) ** -0.5 463 | for block in self.transformer.resblocks: 464 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 465 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 466 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 467 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 468 | 469 | if self.text_projection is not None: 470 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 471 | 472 | def build_attention_mask(self): 473 | # lazily create causal attention mask, with full attention between the vision tokens 474 | # pytorch uses additive attention mask; fill with -inf 475 | mask = torch.empty(self.context_length, self.context_length) 476 | mask.fill_(float("-inf")) 477 | mask.triu_(1) # zero out the lower diagonal 478 | return mask 479 | 480 | @property 481 | def dtype(self): 482 | return self.visual.conv1.weight.dtype 483 | 484 | def encode_image(self, image, return_all=False, attn_self=False, return_attn_map = False, layer='final', device='cuda:0'): 485 | return self.visual(image.type(self.dtype), return_all=return_all, attn_self=attn_self, return_attn_map=return_attn_map, layer=layer, device=device) 486 | 487 | def encode_text(self, text): 488 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 489 | 490 | x = x + self.positional_embedding.type(self.dtype) 491 | x = x.permute(1, 0, 2) # NLD -> LND 492 | x = self.transformer(x) 493 | x = x.permute(1, 0, 2) # LND -> NLD 494 | x = self.ln_final(x).type(self.dtype) 495 | 496 | return x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 497 | 498 | def forward(self, image, text): 499 | image_features = self.encode_image(image) 500 | text_features = self.encode_text(text) 501 | 502 | # normalized features 503 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 504 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 505 | 506 | # cosine similarity as logits 507 | logit_scale = self.logit_scale.exp() 508 | logits_per_image = logit_scale * image_features @ text_features.t() 509 | logits_per_text = logits_per_image.t() 510 | 511 | # shape = [global_batch_size, global_batch_size] 512 | return logits_per_image, logits_per_text 513 | 514 | def convert_weights(model: nn.Module): 515 | """Convert applicable model parameters to fp16""" 516 | 517 | def _convert_weights_to_fp16(l): 518 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 519 | l.weight.data = l.weight.data.half() 520 | if l.bias is not None: 521 | l.bias.data = l.bias.data.half() 522 | 523 | if isinstance(l, nn.MultiheadAttention): 524 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 525 | tensor = getattr(l, attr) 526 | if tensor is not None: 527 | tensor.data = tensor.data.half() 528 | 529 | for name in ["text_projection", "proj"]: 530 | if hasattr(l, name): 531 | attr = getattr(l, name) 532 | if attr is not None: 533 | attr.data = attr.data.half() 534 | 535 | model.apply(_convert_weights_to_fp16) 536 | 537 | def build_model(state_dict: dict): 538 | vit = "visual.proj" in state_dict 539 | 540 | if vit: 541 | vision_width = state_dict["visual.conv1.weight"].shape[0] 542 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 543 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 544 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 545 | image_resolution = vision_patch_size * grid_size 546 | else: 547 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 548 | vision_layers = tuple(counts) 549 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 550 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 551 | vision_patch_size = None 552 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 553 | image_resolution = output_width * 32 554 | 555 | embed_dim = state_dict["text_projection"].shape[1] 556 | context_length = state_dict["positional_embedding"].shape[0] 557 | vocab_size = state_dict["token_embedding.weight"].shape[0] 558 | transformer_width = state_dict["ln_final.weight"].shape[0] 559 | transformer_heads = transformer_width // 64 560 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 561 | 562 | model = CLIP( 563 | embed_dim, 564 | image_resolution, vision_layers, vision_width, vision_patch_size, 565 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 566 | ) 567 | 568 | for key in ["input_resolution", "context_length", "vocab_size"]: 569 | if key in state_dict: 570 | del state_dict[key] 571 | 572 | convert_weights(model) 573 | model.load_state_dict(state_dict) 574 | return model.eval() -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | ### CLIP source code from OpenAI: 2 | # https://github.com/openai/CLIP/blob/main/clip/clip.py 3 | 4 | import gzip 5 | import html 6 | import os 7 | from functools import lru_cache 8 | 9 | import ftfy 10 | import regex as re 11 | 12 | 13 | @lru_cache() 14 | def default_bpe(): 15 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 16 | 17 | 18 | @lru_cache() 19 | def bytes_to_unicode(): 20 | """ 21 | Returns list of utf-8 byte and a corresponding list of unicode strings. 22 | The reversible bpe codes work on unicode strings. 23 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 24 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 25 | This is a signficant percentage of your normal, say, 32K bpe vocab. 26 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 27 | And avoids mapping to whitespace/control characters the bpe code barfs on. 28 | """ 29 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 30 | cs = bs[:] 31 | n = 0 32 | for b in range(2**8): 33 | if b not in bs: 34 | bs.append(b) 35 | cs.append(2**8+n) 36 | n += 1 37 | cs = [chr(n) for n in cs] 38 | return dict(zip(bs, cs)) 39 | 40 | 41 | def get_pairs(word): 42 | """Return set of symbol pairs in a word. 43 | Word is represented as tuple of symbols (symbols being variable-length strings). 44 | """ 45 | pairs = set() 46 | prev_char = word[0] 47 | for char in word[1:]: 48 | pairs.add((prev_char, char)) 49 | prev_char = char 50 | return pairs 51 | 52 | 53 | def basic_clean(text): 54 | text = ftfy.fix_text(text) 55 | text = html.unescape(html.unescape(text)) 56 | return text.strip() 57 | 58 | 59 | def whitespace_clean(text): 60 | text = re.sub(r'\s+', ' ', text) 61 | text = text.strip() 62 | return text 63 | 64 | 65 | class SimpleTokenizer(object): 66 | def __init__(self, bpe_path: str = default_bpe()): 67 | self.byte_encoder = bytes_to_unicode() 68 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 69 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 70 | merges = merges[1:49152-256-2+1] 71 | merges = [tuple(merge.split()) for merge in merges] 72 | vocab = list(bytes_to_unicode().values()) 73 | vocab = vocab + [v+'' for v in vocab] 74 | for merge in merges: 75 | vocab.append(''.join(merge)) 76 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 77 | self.encoder = dict(zip(vocab, range(len(vocab)))) 78 | self.decoder = {v: k for k, v in self.encoder.items()} 79 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 80 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 81 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 82 | 83 | def bpe(self, token): 84 | if token in self.cache: 85 | return self.cache[token] 86 | word = tuple(token[:-1]) + ( token[-1] + '',) 87 | pairs = get_pairs(word) 88 | 89 | if not pairs: 90 | return token+'' 91 | 92 | while True: 93 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 94 | if bigram not in self.bpe_ranks: 95 | break 96 | first, second = bigram 97 | new_word = [] 98 | i = 0 99 | while i < len(word): 100 | try: 101 | j = word.index(first, i) 102 | new_word.extend(word[i:j]) 103 | i = j 104 | except: 105 | new_word.extend(word[i:]) 106 | break 107 | 108 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 109 | new_word.append(first+second) 110 | i += 2 111 | else: 112 | new_word.append(word[i]) 113 | i += 1 114 | new_word = tuple(new_word) 115 | word = new_word 116 | if len(word) == 1: 117 | break 118 | else: 119 | pairs = get_pairs(word) 120 | word = ' '.join(word) 121 | self.cache[token] = word 122 | return word 123 | 124 | def encode(self, text): 125 | bpe_tokens = [] 126 | text = whitespace_clean(basic_clean(text)).lower() 127 | for token in re.findall(self.pat, text): 128 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 129 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 130 | return bpe_tokens 131 | 132 | def decode(self, tokens): 133 | text = ''.join([self.decoder[token] for token in tokens]) 134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 135 | return text 136 | -------------------------------------------------------------------------------- /configs/base_config.py: -------------------------------------------------------------------------------- 1 | # base configurations 2 | model = dict( 3 | type='ITACLIP_Segmentor', 4 | model_name='ViT-B/16' 5 | ) 6 | 7 | test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'], ignore_index = 255) 8 | 9 | default_scope = 'mmseg' 10 | env_cfg = dict( 11 | cudnn_benchmark=True, 12 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 13 | dist_cfg=dict(backend='nccl'), 14 | ) 15 | vis_backends = [dict(type='LocalVisBackend')] 16 | visualizer = dict( 17 | type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer') 18 | log_processor = dict(by_epoch=False) 19 | log_level = 'INFO' 20 | load_from = None 21 | resume = False 22 | 23 | test_cfg = dict(type='TestLoop') 24 | 25 | default_hooks = dict( 26 | timer=dict(type='IterTimerHook'), 27 | logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), 28 | param_scheduler=dict(type='ParamSchedulerHook'), 29 | checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000), 30 | sampler_seed=dict(type='DistSamplerSeedHook'), 31 | visualization=dict(type='SegVisualizationHook', interval=1)) 32 | -------------------------------------------------------------------------------- /configs/cfg_cityscapes.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | dataset_name = 'cityscapes' 4 | # model settings 5 | model = dict( 6 | type='ITACLIP_Segmentor', 7 | model_name = 'ViT-B/16', 8 | img_engineering = True, 9 | dataset_name = dataset_name, 10 | auxiliary_text_path = f'/ITACLIP/llama_generated_texts/{dataset_name}_synonyms.txt', 11 | slide_stride = 224, 12 | attn_self = True, 13 | def_coefficient = 0.05, 14 | img_eng_coefficient = 0.7, 15 | pamr_steps = 10, 16 | device = 'cuda:0', 17 | name_path=f'/ITACLIP/configs/cls_{dataset_name}.txt', 18 | logit_scale = 40, 19 | ) 20 | 21 | # dataset settings 22 | dataset_type = 'CityscapesDataset' 23 | data_root = ' ' 24 | 25 | test_pipeline = [ 26 | dict(type='LoadImageFromFile'), 27 | dict(type='Resize', scale=(2048, 560), keep_ratio=True), 28 | # add loading annotation after ``Resize`` because ground truth 29 | # does not need to do resize data transform 30 | dict(type='LoadAnnotations'), 31 | dict(type='PackSegInputs') 32 | ] 33 | 34 | test_dataloader = dict( 35 | batch_size=1, 36 | num_workers=4, 37 | persistent_workers=True, 38 | sampler=dict(type='DefaultSampler', shuffle=False), 39 | dataset=dict( 40 | type=dataset_type, 41 | data_root=data_root, 42 | data_prefix=dict( 43 | img_path='leftImg8bit/val', seg_map_path='gtFine/val'), 44 | pipeline=test_pipeline)) 45 | -------------------------------------------------------------------------------- /configs/cfg_coco_object.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | dataset_name = 'coco_object' 4 | # model settings 5 | model = dict( 6 | type='ITACLIP_Segmentor', 7 | model_name = 'ViT-B/16', 8 | img_engineering = True, 9 | dataset_name = dataset_name, 10 | auxiliary_text_path = f'/ITACLIP/llama_generated_texts/{dataset_name}_synonyms.txt', 11 | slide_stride = 28, 12 | attn_self = True, 13 | def_coefficient = 0.1, 14 | img_eng_coefficient = 0.75, 15 | pamr_steps = 10, 16 | width_chunk_size = 250, # This variable helps reduce GPU memory consumption when the text expansion technique is applied. The default values are optimized for a 24 GB GPU under this configuration. 17 | device = 'cuda:0', 18 | name_path=f'/ITACLIP/configs/cls_{dataset_name}.txt', 19 | logit_scale=50, 20 | prob_thd=0.1 21 | ) 22 | 23 | # dataset settings 24 | dataset_type = 'COCOObjectDataset' 25 | data_root = ' ' 26 | 27 | test_pipeline = [ 28 | dict(type='LoadImageFromFile'), 29 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 30 | # add loading annotation after ``Resize`` because ground truth 31 | # does not need to do resize data transform 32 | dict(type='LoadAnnotations'), 33 | dict(type='PackSegInputs') 34 | ] 35 | 36 | test_dataloader = dict( 37 | batch_size=1, 38 | num_workers=4, 39 | persistent_workers=True, 40 | sampler=dict(type='DefaultSampler', shuffle=False), 41 | dataset=dict( 42 | type=dataset_type, 43 | data_root=data_root, 44 | reduce_zero_label=False, 45 | data_prefix=dict( 46 | img_path='images/val2017', seg_map_path='annotations/val2017'), 47 | pipeline=test_pipeline)) 48 | -------------------------------------------------------------------------------- /configs/cfg_coco_stuff164k.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | dataset_name = 'coco_stuff' 4 | # model settings 5 | model = dict( 6 | type='ITACLIP_Segmentor', 7 | model_name = 'ViT-B/16', 8 | img_engineering = True, 9 | auxiliary_text_path = f'/ITACLIP/llama_generated_texts/{dataset_name}_definitions.txt', 10 | dataset_name = dataset_name, 11 | slide_stride = 28, 12 | attn_self = True, 13 | def_coefficient = 0.2, 14 | img_eng_coefficient = 0.75, 15 | width_chunk_size = 150, # This variable helps reduce GPU memory consumption when the text expansion technique is applied. The default values are optimized for a 24 GB GPU under this configuration. 16 | pamr_steps = 10, 17 | device = 'cuda:0', 18 | name_path=f'/ITACLIP/configs/cls_{dataset_name}.txt', 19 | logit_scale = 40, 20 | ) 21 | 22 | # dataset settings 23 | dataset_type = 'COCOStuffDataset' 24 | data_root = ' ' 25 | 26 | test_pipeline = [ 27 | dict(type='LoadImageFromFile'), 28 | dict(type='Resize', scale=(2048, 448), keep_ratio=True), 29 | dict(type='LoadAnnotations'), 30 | dict(type='PackSegInputs') 31 | ] 32 | 33 | test_dataloader = dict( 34 | batch_size=1, 35 | num_workers=4, 36 | persistent_workers=True, 37 | sampler=dict(type='DefaultSampler', shuffle=False), 38 | dataset=dict( 39 | type=dataset_type, 40 | data_root=data_root, 41 | data_prefix=dict( 42 | img_path='images/val2017', seg_map_path='annotations/val2017'), 43 | pipeline=test_pipeline)) 44 | 45 | default_hooks = dict( 46 | timer=dict(type='IterTimerHook'), 47 | logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), 48 | param_scheduler=dict(type='ParamSchedulerHook'), 49 | checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000), 50 | sampler_seed=dict(type='DistSamplerSeedHook'), 51 | visualization=dict(type='SegVisualizationHook', interval=1)) 52 | -------------------------------------------------------------------------------- /configs/cfg_context60.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | dataset_name = 'context60' 4 | # model settings 5 | model = dict( 6 | type='ITACLIP_Segmentor', 7 | model_name = 'ViT-B/16', 8 | img_engineering = True, 9 | dataset_name = dataset_name, 10 | auxiliary_text_path = f'/ITACLIP/llama_generated_texts/{dataset_name}_definitions.txt', 11 | slide_stride = 28, 12 | attn_self = True, 13 | def_coefficient = 0.15, 14 | img_eng_coefficient = 0.75, 15 | pamr_steps = 10, 16 | device = 'cuda:0', 17 | name_path=f'/ITACLIP/configs/cls_{dataset_name}.txt', 18 | logit_scale = 55, 19 | prob_thd = 0.1 20 | ) 21 | 22 | # dataset settings 23 | dataset_type = 'PascalContext60Dataset' 24 | data_root = ' ' 25 | 26 | test_pipeline = [ 27 | dict(type='LoadImageFromFile'), 28 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 29 | dict(type='LoadAnnotations'), 30 | dict(type='PackSegInputs') 31 | ] 32 | 33 | test_dataloader = dict( 34 | batch_size=1, 35 | num_workers=4, 36 | persistent_workers=True, 37 | sampler=dict(type='DefaultSampler', shuffle=False), 38 | dataset=dict( 39 | type=dataset_type, 40 | data_root=data_root, 41 | data_prefix=dict( 42 | img_path='JPEGImages', seg_map_path='SegmentationClassContext'), 43 | ann_file='ImageSets/SegmentationContext/val.txt', 44 | pipeline=test_pipeline)) 45 | -------------------------------------------------------------------------------- /configs/cfg_voc21.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | dataset_name = 'voc21' 4 | # model settings 5 | model = dict( 6 | type='ITACLIP_Segmentor', 7 | model_name = 'ViT-B/16', 8 | img_engineering = True, 9 | dataset_name = dataset_name, 10 | auxiliary_text_path = f'/ITACLIP/llama_generated_texts/{dataset_name}_definitions.txt', 11 | slide_stride = 28, 12 | attn_self = True, 13 | def_coefficient = 0.05, 14 | img_eng_coefficient = 0.7, 15 | pamr_steps = 10, 16 | device = 'cuda:0', 17 | name_path=f'/ITACLIP/configs/cls_{dataset_name}.txt', 18 | logit_scale = 60, 19 | prob_thd = 0.1, 20 | area_thd = 0.1 21 | ) 22 | 23 | # dataset settings 24 | dataset_type = 'PascalVOCDataset' 25 | data_root = ' ' 26 | 27 | test_pipeline = [ 28 | dict(type='LoadImageFromFile'), 29 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 30 | dict(type='LoadAnnotations'), 31 | dict(type='PackSegInputs') 32 | ] 33 | 34 | test_dataloader = dict( 35 | batch_size=1, 36 | num_workers=4, 37 | persistent_workers=True, 38 | sampler=dict(type='DefaultSampler', shuffle=False), 39 | dataset=dict( 40 | type=dataset_type, 41 | data_root=data_root, 42 | data_prefix=dict( 43 | img_path='JPEGImages', seg_map_path='SegmentationClass'), 44 | ann_file='ImageSets/Segmentation/val.txt', 45 | pipeline=test_pipeline)) 46 | -------------------------------------------------------------------------------- /configs/cls_cityscapes.txt: -------------------------------------------------------------------------------- 1 | road 2 | sidewalk 3 | building 4 | wall 5 | fence 6 | pole 7 | trafficlight 8 | trafficsign 9 | vegetation 10 | terrain 11 | sky 12 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket, person in hat 13 | rider 14 | car 15 | truck 16 | bus 17 | train 18 | motorcycle 19 | bicycle -------------------------------------------------------------------------------- /configs/cls_coco_object.txt: -------------------------------------------------------------------------------- 1 | sky, wall, tree, wood, grass, road, sea, river, mountain, sands, desk, bed, building, cloud, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, hill, rail, fence 2 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket, body, person in hat 3 | bicycle 4 | car 5 | motorcycle 6 | airplane 7 | bus 8 | train 9 | truck 10 | boat 11 | traffic light 12 | fire hydrant 13 | stop sign 14 | parking meter 15 | bench 16 | bird 17 | cat 18 | dog 19 | horse 20 | sheep 21 | cow 22 | elephant 23 | bear 24 | zebra 25 | giraffe 26 | backpack 27 | umbrella 28 | handbag 29 | tie 30 | suitcase 31 | frisbee 32 | skis 33 | snowboard 34 | sports ball 35 | kite 36 | baseball bat 37 | baseball glove 38 | skateboard 39 | surfboard 40 | tennis racket 41 | bottle 42 | wine glass 43 | cup 44 | fork 45 | knife 46 | spoon 47 | bowl 48 | banana 49 | apple 50 | sandwich 51 | orange 52 | broccoli 53 | carrot 54 | hot dog 55 | pizza 56 | donut 57 | cake 58 | chair 59 | couch 60 | potted plant 61 | bed 62 | dining table 63 | toilet 64 | tv, tvmonitor, television monitor, monitor, television, screen 65 | laptop 66 | mouse 67 | remote 68 | keyboard 69 | cell phone 70 | microwave 71 | oven 72 | toaster 73 | sink 74 | refrigerator 75 | book 76 | clock 77 | vase 78 | scissors 79 | teddy bear 80 | hair drier 81 | toothbrush -------------------------------------------------------------------------------- /configs/cls_coco_stuff.txt: -------------------------------------------------------------------------------- 1 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket, person in hat 2 | bicycle 3 | car 4 | motorcycle 5 | airplane 6 | bus 7 | train 8 | truck 9 | boat 10 | trafficlight 11 | firehydrant 12 | stopsign 13 | parkingmeter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sportsball 34 | kite 35 | baseballbat 36 | baseballglove 37 | skateboard 38 | surfboard 39 | tennisracket 40 | bottle 41 | wineglass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hotdog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | couch 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tv, tvmonitor, television monitor, monitor, television, screen 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cellphone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddybear 79 | hairdrier 80 | toothbrush 81 | banner 82 | blanket 83 | branch 84 | bridge 85 | building-other 86 | bush 87 | cabinet 88 | cage 89 | cardboard 90 | carpet 91 | ceiling-other 92 | ceiling-tile 93 | cloth 94 | clothes 95 | clouds 96 | counter 97 | cupboard 98 | curtain 99 | desk-stuff 100 | dirt 101 | door-stuff 102 | fence 103 | floor-marble 104 | floor-other 105 | floor-stone 106 | floor-tile 107 | floor-wood 108 | flower 109 | fog 110 | food-other 111 | fruit 112 | furniture-other 113 | grass 114 | gravel 115 | ground-other 116 | hill 117 | house 118 | leaves 119 | light 120 | mat 121 | metal 122 | mirror-stuff 123 | moss 124 | mountain 125 | mud 126 | napkin 127 | net 128 | paper 129 | pavement 130 | pillow 131 | plant-other 132 | plastic 133 | platform 134 | playingfield 135 | railing 136 | railroad 137 | river 138 | road 139 | rock 140 | roof 141 | rug 142 | salad 143 | sand 144 | sea 145 | shelf 146 | sky-other 147 | skyscraper 148 | snow 149 | solid-other 150 | stairs 151 | stone 152 | straw 153 | structural-other 154 | table 155 | tent 156 | textile-other 157 | towel 158 | tree 159 | vegetable 160 | wall-brick 161 | wall-concrete 162 | wall-other 163 | wall-panel 164 | wall-stone 165 | wall-tile 166 | wall-wood 167 | water-other 168 | waterdrops 169 | window-blind 170 | window-other 171 | wood -------------------------------------------------------------------------------- /configs/cls_context60.txt: -------------------------------------------------------------------------------- 1 | background 2 | aeroplane 3 | bag 4 | bed 5 | bedclothes 6 | bench 7 | bicycle 8 | bird 9 | boat 10 | book 11 | bottle 12 | building 13 | bus 14 | cabinet 15 | car 16 | cat, feline 17 | ceiling 18 | chair, seat 19 | cloth 20 | computer 21 | cow 22 | cup 23 | curtain 24 | dog 25 | door 26 | fence 27 | floor 28 | flower 29 | food 30 | grass 31 | ground 32 | horse 33 | keyboard 34 | light 35 | motorbike, motorcycle 36 | mountain 37 | mouse 38 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket, person in hat, body 39 | plate 40 | platform 41 | pottedplant 42 | road 43 | rock 44 | sheep 45 | shelves 46 | sidewalk 47 | sign 48 | sky 49 | snow 50 | sofa 51 | table 52 | track 53 | train 54 | tree 55 | truck 56 | tvmonitor, television monitor, monitor, television, screen 57 | wall 58 | water 59 | window 60 | wood -------------------------------------------------------------------------------- /configs/cls_voc21.txt: -------------------------------------------------------------------------------- 1 | sky, wall, tree, wood, grass, road, mountain, sands, desk, bed, building, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, fence, ground, water, railway, helmet, house, bridge, sign, keyboard, refrigerator, bench, sink, laptop, clock, blanket, branch, bush, cabinet, clouds, cupboard, pavement, platform, roof, sea 2 | aeroplane 3 | bicycle 4 | bird, avian 5 | ship 6 | bottle 7 | bus 8 | car 9 | cat, feline 10 | chair, seat 11 | cow 12 | table 13 | dog 14 | horse 15 | motorbike, motorcycle 16 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket, person in hat, body 17 | pottedplant 18 | sheep 19 | sofa 20 | train 21 | television monitor, tv monitor, monitor, television, screen -------------------------------------------------------------------------------- /custom_datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import mmengine.fileio as fileio 3 | 4 | from mmseg.registry import DATASETS 5 | from mmseg.datasets import BaseSegDataset 6 | 7 | @DATASETS.register_module() 8 | class PascalVOC20Dataset(BaseSegDataset): 9 | """Pascal VOC dataset. 10 | 11 | Args: 12 | split (str): Split txt file for Pascal VOC. 13 | """ 14 | METAINFO = dict( 15 | classes=('aeroplane', 'bicycle', 'bird', 'boat', 16 | 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 17 | 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 18 | 'sofa', 'train', 'tvmonitor'), 19 | palette=[[128, 0, 0], [0, 128, 0], [0, 0, 192], 20 | [128, 128, 0], [128, 0, 128], [0, 128, 128], [192, 128, 64], 21 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 22 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 23 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 24 | [0, 64, 128]]) 25 | 26 | def __init__(self, 27 | ann_file, 28 | img_suffix='.jpg', 29 | seg_map_suffix='.png', 30 | reduce_zero_label=True, 31 | **kwargs) -> None: 32 | super().__init__( 33 | img_suffix=img_suffix, 34 | seg_map_suffix=seg_map_suffix, 35 | reduce_zero_label=reduce_zero_label, 36 | ann_file=ann_file, 37 | **kwargs) 38 | assert fileio.exists(self.data_prefix['img_path'], 39 | self.backend_args) and osp.isfile(self.ann_file) 40 | 41 | @DATASETS.register_module() 42 | class COCOObjectDataset(BaseSegDataset): 43 | """ 44 | Implementation borrowed from TCL (https://github.com/kakaobrain/tcl) and GroupViT (https://github.com/NVlabs/GroupViT) 45 | COCO-Object dataset. 46 | 1 bg class + first 80 classes from the COCO-Stuff dataset. 47 | """ 48 | 49 | METAINFO = dict( 50 | 51 | classes = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 52 | 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 53 | 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 54 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 55 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 56 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 57 | 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 58 | 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 59 | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'), 60 | 61 | palette = [[0, 0, 0], [0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], [0, 64, 64], [0, 192, 224], 62 | [0, 192, 192], [128, 192, 64], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 0, 64], 63 | [0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], 64 | [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], [64, 128, 32], [0, 160, 0], [0, 0, 0], 65 | [192, 128, 160], [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], [192, 128, 32], 66 | [128, 96, 128], [0, 0, 128], [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128], 67 | [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], [192, 0, 32], 68 | [128, 96, 0], [128, 0, 192], [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0], 69 | [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192], 70 | [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], [128, 192, 192], [0, 0, 160], 71 | [192, 160, 128], [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], [64, 160, 0], 72 | [0, 64, 0], [192, 128, 224], [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0]]) 73 | 74 | def __init__(self, **kwargs): 75 | super(COCOObjectDataset, self).__init__(img_suffix='.jpg', seg_map_suffix='_instanceTrainIds.png', **kwargs) 76 | 77 | @DATASETS.register_module() 78 | class PascalContext60Dataset(BaseSegDataset): 79 | METAINFO = dict( 80 | classes=('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 81 | 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 82 | 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 83 | 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog', 84 | 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 85 | 'horse', 'keyboard', 'light', 'motorbike', 'mountain', 86 | 'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road', 87 | 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 88 | 'sofa', 'table', 'track', 'train', 'tree', 'truck', 89 | 'tvmonitor', 'wall', 'water', 'window', 'wood'), 90 | palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 91 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 92 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 93 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 94 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 95 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 96 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 97 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 98 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 99 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 100 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 101 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 102 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 103 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 104 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]) 105 | 106 | def __init__(self, 107 | ann_file: str, 108 | img_suffix='.jpg', 109 | seg_map_suffix='.png', 110 | **kwargs) -> None: 111 | super().__init__( 112 | img_suffix=img_suffix, 113 | seg_map_suffix=seg_map_suffix, 114 | ann_file=ann_file, 115 | reduce_zero_label=False, 116 | **kwargs) 117 | 118 | 119 | @DATASETS.register_module() 120 | class PascalContext59Dataset(BaseSegDataset): 121 | METAINFO = dict( 122 | classes=('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 123 | 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 124 | 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 125 | 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', 126 | 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 127 | 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 128 | 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', 129 | 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 130 | 'table', 'track', 'train', 'tree', 'truck', 'tvmonitor', 131 | 'wall', 'water', 'window', 'wood'), 132 | palette=[[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], 133 | [120, 120, 80], [140, 140, 140], [204, 5, 255], 134 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 135 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 136 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 137 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 138 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 139 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 140 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 141 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 142 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 143 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 144 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 145 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 146 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]) 147 | 148 | def __init__(self, 149 | ann_file: str, 150 | img_suffix='.jpg', 151 | seg_map_suffix='.png', 152 | reduce_zero_label=True, 153 | **kwargs): 154 | super().__init__( 155 | img_suffix=img_suffix, 156 | seg_map_suffix=seg_map_suffix, 157 | ann_file=ann_file, 158 | reduce_zero_label=reduce_zero_label, 159 | **kwargs) 160 | -------------------------------------------------------------------------------- /datasets/cvt_coco_object.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # GroupViT (https://github.com/NVlabs/GroupViT) 3 | # Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------------ 5 | 6 | import argparse 7 | import os.path as osp 8 | import shutil 9 | from functools import partial 10 | from glob import glob 11 | 12 | import mmcv 13 | import numpy as np 14 | from PIL import Image 15 | 16 | COCO_LEN = 123287 17 | 18 | clsID_to_trID = { 19 | 0: 0, 20 | 1: 1, 21 | 2: 2, 22 | 3: 3, 23 | 4: 4, 24 | 5: 5, 25 | 6: 6, 26 | 7: 7, 27 | 8: 8, 28 | 9: 9, 29 | 10: 10, 30 | 12: 11, 31 | 13: 12, 32 | 14: 13, 33 | 15: 14, 34 | 16: 15, 35 | 17: 16, 36 | 18: 17, 37 | 19: 18, 38 | 20: 19, 39 | 21: 20, 40 | 22: 21, 41 | 23: 22, 42 | 24: 23, 43 | 26: 24, 44 | 27: 25, 45 | 30: 26, 46 | 31: 27, 47 | 32: 28, 48 | 33: 29, 49 | 34: 30, 50 | 35: 31, 51 | 36: 32, 52 | 37: 33, 53 | 38: 34, 54 | 39: 35, 55 | 40: 36, 56 | 41: 37, 57 | 42: 38, 58 | 43: 39, 59 | 45: 40, 60 | 46: 41, 61 | 47: 42, 62 | 48: 43, 63 | 49: 44, 64 | 50: 45, 65 | 51: 46, 66 | 52: 47, 67 | 53: 48, 68 | 54: 49, 69 | 55: 50, 70 | 56: 51, 71 | 57: 52, 72 | 58: 53, 73 | 59: 54, 74 | 60: 55, 75 | 61: 56, 76 | 62: 57, 77 | 63: 58, 78 | 64: 59, 79 | 66: 60, 80 | 69: 61, 81 | 71: 62, 82 | 72: 63, 83 | 73: 64, 84 | 74: 65, 85 | 75: 66, 86 | 76: 67, 87 | 77: 68, 88 | 78: 69, 89 | 79: 70, 90 | 80: 71, 91 | 81: 72, 92 | 83: 73, 93 | 84: 74, 94 | 85: 75, 95 | 86: 76, 96 | 87: 77, 97 | 88: 78, 98 | 89: 79, 99 | 91: 80, 100 | 92: 81, 101 | 93: 82, 102 | 94: 83, 103 | 95: 84, 104 | 96: 85, 105 | 97: 86, 106 | 98: 87, 107 | 99: 88, 108 | 100: 89, 109 | 101: 90, 110 | 102: 91, 111 | 103: 92, 112 | 104: 93, 113 | 105: 94, 114 | 106: 95, 115 | 107: 96, 116 | 108: 97, 117 | 109: 98, 118 | 110: 99, 119 | 111: 100, 120 | 112: 101, 121 | 113: 102, 122 | 114: 103, 123 | 115: 104, 124 | 116: 105, 125 | 117: 106, 126 | 118: 107, 127 | 119: 108, 128 | 120: 109, 129 | 121: 110, 130 | 122: 111, 131 | 123: 112, 132 | 124: 113, 133 | 125: 114, 134 | 126: 115, 135 | 127: 116, 136 | 128: 117, 137 | 129: 118, 138 | 130: 119, 139 | 131: 120, 140 | 132: 121, 141 | 133: 122, 142 | 134: 123, 143 | 135: 124, 144 | 136: 125, 145 | 137: 126, 146 | 138: 127, 147 | 139: 128, 148 | 140: 129, 149 | 141: 130, 150 | 142: 131, 151 | 143: 132, 152 | 144: 133, 153 | 145: 134, 154 | 146: 135, 155 | 147: 136, 156 | 148: 137, 157 | 149: 138, 158 | 150: 139, 159 | 151: 140, 160 | 152: 141, 161 | 153: 142, 162 | 154: 143, 163 | 155: 144, 164 | 156: 145, 165 | 157: 146, 166 | 158: 147, 167 | 159: 148, 168 | 160: 149, 169 | 161: 150, 170 | 162: 151, 171 | 163: 152, 172 | 164: 153, 173 | 165: 154, 174 | 166: 155, 175 | 167: 156, 176 | 168: 157, 177 | 169: 158, 178 | 170: 159, 179 | 171: 160, 180 | 172: 161, 181 | 173: 162, 182 | 174: 163, 183 | 175: 164, 184 | 176: 165, 185 | 177: 166, 186 | 178: 167, 187 | 179: 168, 188 | 180: 169, 189 | 181: 170, 190 | 255: 255 191 | } 192 | 193 | # set to background 194 | for k, v in clsID_to_trID.items(): 195 | clsID_to_trID[k] = v + 1 196 | if k > 90: 197 | clsID_to_trID[k] = 0 198 | 199 | 200 | def convert_to_trainID(maskpath, out_mask_dir, is_train): 201 | mask = np.array(Image.open(maskpath)) 202 | mask_copy = mask.copy() 203 | for clsID, trID in clsID_to_trID.items(): 204 | mask_copy[mask == clsID] = trID 205 | seg_filename = osp.join( 206 | out_mask_dir, 'train2017', 207 | osp.basename(maskpath).split('.')[0] + 208 | '_instanceTrainIds.png') if is_train else osp.join( 209 | out_mask_dir, 'val2017', 210 | osp.basename(maskpath).split('.')[0] + '_instanceTrainIds.png') 211 | Image.fromarray(mask_copy).save(seg_filename, 'PNG') 212 | 213 | 214 | def parse_args(): 215 | parser = argparse.ArgumentParser( 216 | description=\ 217 | 'Convert COCO Stuff 164k annotations to COCO Objects') # noqa 218 | parser.add_argument('coco_path', help='coco stuff path') 219 | parser.add_argument('-o', '--out_dir', help='output path') 220 | parser.add_argument( 221 | '--nproc', default=16, type=int, help='number of process') 222 | args = parser.parse_args() 223 | return args 224 | 225 | 226 | def main(): 227 | args = parse_args() 228 | coco_path = args.coco_path 229 | nproc = args.nproc 230 | 231 | out_dir = args.out_dir or coco_path 232 | out_img_dir = osp.join(out_dir, 'images') 233 | out_mask_dir = osp.join(out_dir, 'annotations') 234 | 235 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017')) 236 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017')) 237 | 238 | if out_dir != coco_path: 239 | shutil.copytree(osp.join(coco_path, 'images'), out_img_dir) 240 | 241 | train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png')) 242 | train_list = [file for file in train_list if 'TrainIds' not in file] 243 | test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png')) 244 | test_list = [file for file in test_list if 'TrainIds' not in file] 245 | assert (len(train_list) + 246 | len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( 247 | len(train_list), len(test_list)) 248 | 249 | if args.nproc > 1: 250 | mmcv.track_parallel_progress( 251 | partial( 252 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 253 | train_list, 254 | nproc=nproc) 255 | mmcv.track_parallel_progress( 256 | partial( 257 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 258 | test_list, 259 | nproc=nproc) 260 | else: 261 | mmcv.track_progress( 262 | partial( 263 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 264 | train_list) 265 | mmcv.track_progress( 266 | partial( 267 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 268 | test_list) 269 | 270 | print('Done!') 271 | 272 | 273 | if __name__ == '__main__': 274 | main() 275 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "from mmengine.model.utils import revert_sync_batchnorm\n", 13 | "from mmseg.apis import init_model, inference_model\n", 14 | "from PIL import Image\n", 15 | "import torchmetrics\n", 16 | "import os\n", 17 | "from tqdm import tqdm\n", 18 | "import cv2" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "palette_coco_stuff=[[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],\n", 28 | " [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],\n", 29 | " [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],\n", 30 | " [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],\n", 31 | " [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],\n", 32 | " [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],\n", 33 | " [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160],\n", 34 | " [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0],\n", 35 | " [0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128],\n", 36 | " [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160],\n", 37 | " [0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128],\n", 38 | " [128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192],\n", 39 | " [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160],\n", 40 | " [64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0],\n", 41 | " [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],\n", 42 | " [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160],\n", 43 | " [64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128],\n", 44 | " [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128],\n", 45 | " [64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224],\n", 46 | " [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0],\n", 47 | " [0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128],\n", 48 | " [64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224],\n", 49 | " [64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128],\n", 50 | " [128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192],\n", 51 | " [0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224],\n", 52 | " [0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0],\n", 53 | " [64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192],\n", 54 | " [0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224],\n", 55 | " [0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128],\n", 56 | " [192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128],\n", 57 | " [64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160],\n", 58 | " [0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64],\n", 59 | " [64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128],\n", 60 | " [64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160],\n", 61 | " [0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192],\n", 62 | " [192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192],\n", 63 | " [0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160],\n", 64 | " [64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64],\n", 65 | " [64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192],\n", 66 | " [0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],\n", 67 | " [64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],\n", 68 | " [192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],\n", 69 | " [64, 192, 96], [64, 160, 64], [64, 64, 0]]\n", 70 | "classes_coco_stuff = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',\n", 71 | " 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',\n", 72 | " 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',\n", 73 | " 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',\n", 74 | " 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',\n", 75 | " 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',\n", 76 | " 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',\n", 77 | " 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',\n", 78 | " 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',\n", 79 | " 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',\n", 80 | " 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',\n", 81 | " 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',\n", 82 | " 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',\n", 83 | " 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',\n", 84 | " 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',\n", 85 | " 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',\n", 86 | " 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',\n", 87 | " 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',\n", 88 | " 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',\n", 89 | " 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',\n", 90 | " 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',\n", 91 | " 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net',\n", 92 | " 'paper', 'pavement', 'pillow', 'plant-other', 'plastic',\n", 93 | " 'platform', 'playingfield', 'railing', 'railroad', 'river', 'road',\n", 94 | " 'rock', 'roof', 'rug', 'salad', 'sand', 'sea', 'shelf',\n", 95 | " 'sky-other', 'skyscraper', 'snow', 'solid-other', 'stairs',\n", 96 | " 'stone', 'straw', 'structural-other', 'table', 'tent',\n", 97 | " 'textile-other', 'towel', 'tree', 'vegetable', 'wall-brick',\n", 98 | " 'wall-concrete', 'wall-other', 'wall-panel', 'wall-stone',\n", 99 | " 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',\n", 100 | " 'window-blind', 'window-other', 'wood']\n", 101 | "\n", 102 | "config_coco_stuff = '/mnt/disk2/arda_efe/graduation/ITACLIP/configs/cfg_coco_stuff164k.py'" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 3, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "model = init_model(config_coco_stuff, device='cuda')\n", 112 | "if not torch.cuda.is_available():\n", 113 | " model = revert_sync_batchnorm(model)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 7, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWYAAAGFCAYAAADO9lk6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAegUlEQVR4nO3d+Y8c533n8U9V33MfHJLD+6Z4i5JISopkXbYca+NDtuS1EyGbze4CQYBg/4Qgf8L+ssD+EmCRALtZIF5vDvlSJJGSLEuyRFEHb3LEa+6r77Pq2R9mNKYsihwOu7uern6/ABkWOdP9tTl8T81TVU85xhgjAIA13KAHAAB8EWEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMtEl/uBzta/aeQcANAWzMhf3/FjOGIGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMsQZgCwDGEGAMtEgx4AuC1jFDWeJKnmROTKKGJ8SVLEeDqWPaOYX9Ppzi2aivXd+iUWP1eOU5eRHGP0aOYT+XL1ds/eur0u8DnCDGvE/YruK1zVXLRb1xKrJUnDlVn9YOq4qm5Ub/XsV8ov64n5U5IkR0YdflmOpIey5xbiewv5SFJv9+zTdKxX1xNDXwpp1K9pT+GKcpEOjSTXLiO0Rg9mz6nmRBfCDNSZY4wxy/rArX/T6FnQ5rYVR/XSxC9VchOai3ap5MaViXbqcO6iJMmXo1wkpR6vsKLXz7sJXU6t0zvdeyRJhUhCs7FePTX3gR5Pf6SKE9PfrX1WNxJDt38hY7S6OicjZ+EonSNm3AUz8td3/BiOmBGYlFdSdHFZouTG9Afpj+VK6vDL6qiUv/TxrsyKoyxJnX5ZB/IjOpAfkSSd6dikN3sP6GD+shwtLHd8vmxyW46jyfjAiucA7oQwo6kc48uV0UPZczqSOau+Wk7GcXQ5uU5bS2NNnWV34ap2Fq4rooVvDqc7N2t1ZV5XEms4CkagCDOawxhtLY3rWOa0hiuz6vYKcrW4imak3cVrTR/JleQuRlmSjmbPKR3pUD6S1Gy0W+OJwabPBEiEGY1ijL43/ab6ajldTK3XpvKkNpfGlTC1oCe7rV6voB9Ova5sJKWZaI9ykZT+cehrMk5zryztqeWV8suajvXK+4qTmggvwnyXon5NscW4VNyYPLlKmKpc40tyVHTjCz8GG6OUX5HnuKq4sWCHbiLH+Oryinos/bF2F68p5Ve0uTyx8HsBz3Y3ur2iur2i0pFOxYynSpPD/MT8KR3Ondd/X/c9Tcf7mvreCB5hvhNjtKqa1vbSqK4mVuux9CfaUbwuSTrXsVGj8VU6mj2rTq8oz3H1ds8+VdyYJmL9enHqdV1KrddPhr4W8P+I5uivZvRA7oKOZs4obmpLIW6lIP++Hi+v706/qTd7D2hTeVJVJ6IPu3bKb3CoFxZ5HBnWutsSYb6NqF/TttKYnp19T6tqGeXdpDr80lJoDuZHdHDxDL8kyUjPzJ+UpMWPLStq+Y/uKzVYTau/mv3Cr20sT+nx9McBTdQYjqS9hSvaUppQp1+SL0dJv6pf9+5v6Pu+1n9Yb/fu03y0q6HvAzsR5ltI+BVtKk3o0cyn2lIaXwpxp19a9mt8/rGbSpN6buY3enngWCjO9Ef9mp6c/1B7C1c0UMve+RNCwNHv/jwXrig5qw6vpON996vqNuavUD6SUj6Sashrw36E+fdE/Zq+P3VCu4rX6/IjeJdf0r78iDzH1fXEkM6lNqp2l3+ZI8ZTxPiqOtHAfrTdk7+iTeUJ9dTy2lu40tLLE/dqoJbTH2Q+UdxU9fOBo/I5OYc6I8w3SXplHchf1tbSWF3D0+mX9UjmtKpOROPxAb3ds0+nOzYv+wj6a/On9FD2nP7n2m8GdmPD1tKYjmbPBvLeNlq4Dfy8JOkXA0e5cgJ11XZhXlWZV4f/5bvKJOnR9CfaXbzWsKPBmPG0sTyl4akTWtV3SG/2HljWSaR8JKXpWN9X7gXRMMZoXWVGUeOp+x7uuAurhWWNhTj/tvs+Tcb7A54IYdF2e2X8cPJV7S1cDXoM+XL0Wt/9+qB7l3VriRHjqb+a1RPpU9pduKZ4SE9g1lPOTerTzi36+cDRpl7z3OkV5Rpf2UiHurziwiyRlLq9wtIBRtmNqezGJUkpr6yYqSkXSTX8yhLcGntl3MInnVuVWjy5F9Gyvic1hCujp+ZP6oHcBf3v1U9rItZvzcnB4fKM/uP4z+XKb+u15LvR5Ze0q3hdr5iHVJXT+D9LYxSRrxemjmu4PKPfdu/WofwlSdKpzu16KHtuad+Pq8k1uppYrUupdfrW7LtaXZnTrwYeWtjMyZKvOXxR2x0xSwtXXRzKXdJzs+8EPYokaTLWp/+x7tvWrFO+NP5L7SiNBj1Gy/HlKB3t1Bu9B1V2Y7qeGFK6Dpe7bSqNq3vxaLjgJjSSWqc1lVn9YOqEBqoZRW+6rfx2Sk5MCVOVI6nqRPT3a76hK8m19zwf7g5HzF+h5kS0P3856DGWDFQzOpS7qA+6dwc6R28tpz+aeVsby5OBztGqXBn113L6zsyvJUk34oN6efBh+XI0E+td9h2gUb+moeq8JGlVNa3nZt9Ryq9IksbiA/qljug7M2+pv5a7q/mSprr032PG02A1w4ZNlmq7I+beWk5Pzn+o/fkRxZazxWOTjMf69WHXDr3bsyeQtb+jmTN6IHtea6pzLF/Uyc1/sf5h6CldTK1f1qWSayqz+ovRf1r695v/PG5+zXv9c5qJdutvh59T1Ym21bYBQVvOEXPbrf5/fe593Z+7aFWUJWltdU7Pzv1Wj6U/kmN8aXnfL++dMTqUu6gHs+e0lijXlXPTPz+YPqGD+Uu3/3M1Ro7xta049oXP/6rXvFcDtaz+6/V/1B9PvKJH0p9osJqWY/zFf4I7/4I2W8pIeSX11nLWxseV0dfmP9Le/BX9v1WPaawJ207uz4/o29O/XvY6JVYmZjw9M3dSG8tTC9ewL8pEO5Xwq0r4FW0vjmpn8bp6brqi4k4mY30aqs6v6GvakRQ3NW0pT2hLeeFO14oTXZrrjd4DGkkON31nPbRTmI3R9uKoNpWngp7ktqLytbY6p2OZ0/rnVY829ITgkcxZHcmeIcpN0umXdDh3celRWZKUiaQU92tfWP9drrH4gLKRlFZV0yq4cXV+xfX5y/X5CUZJGqxlNVSZ13/b8IKqhLnp2iPMxmhraUzfsuQqjOU4mL+kd3r2aCI+ULc15w6vpMfSHytifJ3v2KAHs+e0upquy2tjZXpuiuHdWluZ1VpJRTeun656TPvyn+n+xUvm6sGRUcR4qrZJJmzSFv+Px01Vz0+/ec9HFM3kSPoP47/Q2737dLz30OIvruAH1sW1wg6/pBemTizdbv5A7vzynm8Ha33+1TAT69X3pt+q+06GHX5ZD2dO6/X+w3V9XdxZW4T5WObM0l1RrcLRwuVNj89/pF2Fa/qXwUc0llh1V68xVJnT/bmL2lEcVdTUNFDLLv1ltu3kJ1aur5ZT1YkovoLlkNtxJA3WMuqp5TVQzWgm1qNstLOu74FbC3WYo35Na6pzOpy7+Lvny7WYqHytr8zo30++pn9Y/dRCnI1ZPEFkVHQTqrgxRf3a0taUjjF6JPOp9hSu3NOPymgN3V5Rs9HuhtynuS//mWajPXoifUovDxzTuz176v4e+LLQhnlTaUJHM2cWt6hszSjfrNfL64dTr+v/DD2l2Vi3/nzsZXV7BZ3p2Kx/XvWotpTG9cPJ15Y+3pWx9uoTLI/Rwt2Ev791QNmJaSLer0033QjUqL2xHRk9nj4lSQuPTzOGG1KaILRhXl+e0v7CZ0GPUTeOpP5aTn868QvVnIg6vZJcGe0tfKYt18cVkR/o3h9ojAupDbrv954gHjM1ra3MNuX9b75m+rH0xxpLDHIbdxOE9jqYz5JrlXcTQY9Rdym/om6vuLQ042phA53Pb9lFeDjSl6IsLfw0FMSOf11+SccypxX12W2w0UIb5i2l8Za6CgNoBXsKV/Xk/IdBjxF6oVzKiPk1bS5NBD0GEDqOpPWVaaW8kh7IXZBjjE517VA22hH0aKESyjCvL09pR/FG0GMAobShPKW/GP2npVvHP0uuJcx1FsqljEP5S9xmDDRIzHjqvWk/jyPZs83bdKtNhC7MvbWcNpbYTxholqFqWi4HQnUVujB3eiUN1jJBjwG0jbWVWe2x4DmaYRKqMHd6RX197v2gxwDaSsWJqhjCS1ODFKqTf1uLY9pWGgt6DKCtTMd6dTk5HPQYoRKqI2YACINwHDFzRhhAiIQizF+fe19bS2NKcacf0HSrq/PaUbyhix0bgh4lNEKxlDFQy2p9ZUYDd/k4dwD3Lm5qOpY9owh7fNdNKMIMIFjD5RmerF1HhBnAPYsaT/cVrirm1/cpKu2KMAO4Z0lT1fPTb6jTKwU9SiiEIsznUxsa8lgdAAhCKMK8rTQWisdHAYAUkjBHjcfxMoDQCEWYAQTPldGDufNBjxEKoQhzzXFZyAAs0F9tzNO6200owvxK/4Maiw8GPQbQ1owcneg7FPQYoRCKMGeiXXqve7emYr1BjwK0LaOFLUBx70IRZkk62b1Lf7fmG5qO9gQ9CgDck9CEWVo4cj7VtZ31ZiAAjgwPQa6TUIVZki6m1gc9AtCWXIkw10nowlzlMTdAIIykdLQz6DFCIXRhno716nyKfWGBZjNy9H737qDHCIXQhVmOI+NwgzbQbI6Mnpg/FfQYoRC+MGvhuuY3ew8EPQbQdmZi3UGPEAqhDHM+klKBdWag6Vw2y6+LUIZZkjzHVcmJsR0o0ER9PN6tLkJ7m84HXbtk5OrhzKcarHH/PtBovhwd55bsughtmD3H1SOZTzVAlIGmONW1QzMx7ryth9AuZRg5Otm1M+gxgLZQdON6t/s++U4k6FFCIbRhluNoIt4f9BRAW/DkapYrMuomvGGWlIl0KBtJBT0GEHodfknHMmeCHiM0Qh3m+WiX0hFuEQUazZV0JHtW/2nsX7WuPB30OC0v1GHeVJ7UhspXf5F4clRlTQyoix6vqI3lKT0z9752F65KXNO8YqEO853kI0nNsH8zUFfbS2P6zvRbSphq0KO0rLYOc49X1NrqXNBjAKETM56GKvNBj9GyQh3mshvTTLSbDY2AJoubml6cel19PJx1RUId5muJ1fr7Nc/qeO8h1pKBJuvxCoqZWtBjtKTQ3vknSduKo3ph6rgkKWq8gKcB2g871axMqI+YXRklTVVJU+ULBAjAN+Z+q23FUfWyudFdCW+YjVGHVwp6CqBtOZJ2Fm/oTyd+qf35kaDHaSmhDbMjo8fTHwc9BgBJ24ujSnnloMdoGaEN81A1rYRfCXoMAJK2lsb09PwHQY/RMkIb5u3FG+ryWcoAbOBIivvccLJcoQyzY4wSfBEAVtlUnlR/NRP0GC0hXGE2Rq7x1OUV9XDmdNDTALhJXy2nA5wEXJbQXMfsGKPH0h/pwex5uTLcpw9YZuEqjes6weOn7igUYe70inph6rg2liYVlR/0OAC+Qsyvaagyp6lYn+Rwd8FXCcVSRtR42lSalLSwlScAO62pzum/jP2L9hauBD2K1UIR5rIT0zs9e/Svgw9rLD54x483i/8AaC5HUtx4Gq7MsF/zbbR8mJNeWRvKU3qve7dOdu+Sv4wfjzy5eqtnfxOmA3ArD2bP61uz7wQ9hrVaPsxrqnN6afIV/XjyVaW8kvKR1B2Phl352l281pT5AHxZh1/Wvvxnem7mNxw530LLh3k22q3r8VXq8oraUJ5S7Q7be15IrZfnRFRy45rm6SVAYLr8kvblR7ShPBX0KNZp+TBno516s/eA3urdr/sKV3UgP3Lb03+bSxN6u2ev/mXwEV3o2MBaMxCgDr+sNTxF6EtCcbncwfzlZZ/ljZuaDucu6mDusvq8fIMnA3A72UhK1xKrgx7DOi1/xLwS3V6RKAMW6PGKem7mN4r6POnkZqEJsy9Hl5PDQY8B4C6tr0xrR/FG0GNYJTRhdmS0rjwd9BgA7lLMeHooe05xtuld0tphNkarKvNK+hU5kpLsjwG0pO2lUfXUCkGPYY3WDbMx2lG8oT8ff1lbS2NBTwPgHh3IXw56BGu07FUZCVPVd2feUgc//gAtz5G0nqXIJS15xBzzq/rD2XfVycNWgdDo9EvaWbjGZvpq0TA/M/eB7s9dlMvtIUBoDFdm9SeT/6Y9hatBjxK4lgxzL9cgA6EVNV7b75/RkmH+Rf8RnezaqVprjg/gNh5Nf6KdxetBjxGolixbPpLSmY7Nd9ywCEDrSZpq22+k35Jh7vHy+vHkK1y3DIRU3K+qq42va27JMEtSzYnwGCkgpPYWruiFqeNBjxGYlgxz3k3qtb7DOtuxKehRADSAo4UHWrSrlrzBpNMv6Rtzv23N7yoAlmVNZU7bijd0ObU+6FGajrYBsFLC1PRw5oySXjnoUZquJcO8rTjK6jLQBnYWr+vhzOm2u6655cLc6RX1QO4CYQbagCPpSPasur1i0KM0VcuFeV15RsOV2aDHANAkSb/adicCWyrMEeMpxrXLQFtx5Gt/fmRhI/02WdJomTBH/Zq+Pf1rfW/6raBHAdBErqSn5z7QSxOvBD1K07TM5XJPpE/pYP5S63wnAVA3ERnF/fb5abklOtdfzepg7nJrDAugIfpqOR3IX1aqDfZhb4nWHcxfUg9bfQJtLWmqen76DT05/2Ho15pbIsxctwxAWgjW4dxFHciPBD1KQ9kdZmM0XJ5Woo3WlgDcXtzUdCxzWhtKk6Hdgc7qMEfk60eTr2ptdS7oUQBYZENlWv95/OXQPlnb6jD7cnQptZ4n+wG4pW2lMe0uXA3dmrPVYTaOq58NHNON+KqgRwFgoZ3FG/rB1InQHTlbHWZJWl+e0qpqOugxAFgqbmp6KHsu6DHqyuowO8bXkexZHiEF4LYGqxltKY4p5ZXlhGBZw+owuzLaUJ4KegwAluvyS/qTyVf0Vzd+oi6v9a/UaJlbsgHgdmLGk2NMKO55sPqIOeWV5XJNBoBlisjX89NvaEtxrKWv1LA2zAm/ohenjqurzTbIBrByjqStpXF9c+69lj6oszPMxmhbcVQby5Oh+LEEQHOtqczp2dn35JrW3GDfvjAboz2FK/rOzK9b+jsegOC4MjqaPbvwvMAWZF+YJT2e/kgpvxL0GABamCujB7Pn9eTcSUX9WtDj3BWuygAQWoO1jJ5In1J/Lav3u3framK15Ni/QGrlETMA1Isj6VD+sl6YOq5ki/wkbmWYfzZwTGPxgaDHABAiXV5RL038SrsLV5X0ykGPc1v2hdlx5EhaW5kNehIAIeLKaENlWj+afFXfnXlLnV7R2tu37QuzpEykQ+937ZLHxXIA6syRdF/hqv7q+k+0z9InoVh58m8+1q2T3TvV6+W1s3gj6HEAhIyjhWcIPpQ9p7Ibk+e4GkkOyzh2HKtaGWZJeiT9KVEG0FBbyhPaMjkhT65e7zukU107lIl2Bj2WZWE2RnFT09HMGQ1XZoKeBkCbiMjX0/Mn9UDugk527dBbvQfkOZHA5rEqzJvLE/r+1Al1e0Xu+gPQVI6k/lpOT86f0kA1q2vJ1TrdsVnFSLLps1gV5vsKV9Ubgr1UAbQuV0b35y/pUP6S7itc1fHeQ7qeGGrqjSlWhflk1w5dSazRUHVez8yfDHocAG3M0cIzBbeUxvV/Vz2m051bm/bedpyCXDQZH9DZjk36pHOrJmO9khaelF1jYQNAQGLG00PZ8+rwSk3b49mqMH/u3838Rv21nCRpMtan93ruC3giAO1sa2lMf3njp1rXpIsSrAxzh19SzHjKRlJaVU3rkcxpbjUBEBhHC88VfGr+pFJeqeHvZ2WYX+s7rJ+selz/a/Uzyi+eETWSyo5VS+IA2szO4g39ePJVHc5ekNPATfjtK53j6ELHRklSxHhLv1x1IhpNrNLW0nhQkwGANpUnta48rc+SazQX62nIe1h5xPw5I0eXUus1klyrmPGIMgArROTrudl31NGgZ5JaG+YOr6Q/nnhFfbWc4i329AEA4fb5pXTPT7/ZkDVn+5YyFkWNpy2lcUXVmg9TBBB+O4o39KPJ1zQZ79OJ3kPKRjvq8rrWhhkAbOdoYSuJTeUJFd2EspGUJGkm1qvLqXUrfl1rwzxYzcjhthIALcCR9LX0R0v/fqpz2z2F2co15qHKvJ6fPqEIYQbQghJ+9QtXld0tK8Oc8svqadDZTgBotN3Fa3pi/tSKb+G2MsySOFYG0LIcSduKoyv+fCvDPB4f0OXkcNBjAMCKra7O66WJX61oScO+MBujx9MfabCaCXoSAFixuKlpuDKjnlrhrm/fti/MkrYXb6jPywc9BgDckw6/rL8c/akO5y7c1edZGeaR5LDSkfpcqA0AQXG0sJ/zrsL1u/o8a8McvYdLTQDAJpvKk/qzsZ/pz8Z+tqyPt/IGk6fnP1CnXw56DACoiw6/rC3liWV/vJVHzADQzqwLc49XUNKvBj0GAATGqjB31/J6cep1DdSyX/j1mlyVnFhAUwFAc1kT5g6vpB9NvqoN5akv/V7ZjStTp+30AMB21pz8S/plDVdmbvnQ1U6/pE6/8Q9ABAAbWHPEnI526ePObUGPAQCBsybMnhNRYfGJ2ADQzuxYyjBGEflyVrhFHgCEiRVh3lIa1/en31DCrwQ9CgAELvCljE6vqNXVefV4BSXMF5+GnY2k5AU/IgA0VeDV258f0bdm37nl7+XdpDwn8BEBoKmsWMowcpb+82Zrq3NBjAMAgQr8cPTDrh3627Xf0nSsN+hRAMAKgYe57MQ0WMtoqJoOehQAsELgYd5SGtcfzr4b9BgAYI3Aw7ymOqcUl8kBwJLAwwwA+KLArsroq2b1zbn31F/N3vmDAaCNBBbmhKlqd+GaXHEbNgDcLLCljKlYn34+cESj8cGgRgAAKwUWZt9xdaprh2K/dxs2ALS7QE/+OcYoyRUZAPAFgYbZc1yNxlcFOQIAWCfQMD8xf0q7iteCHAEArBNYmDeVJnQwf+mWz/gDgHYWWJifnvtAPV4xqLcHAGs19zrmmx4dxZEyANxaU8Pc6Zf00sSvFDGe+mq5Zr41ALSMhoa5q1ZQ1HhKRzu1sTylJ+ZPaU1llg06AOA2GhrmP5p5W9tLo/qga6f2Fq6omzVlALijxoXZGLkyihlPx7JnG/Y2ABA2dVtVSHllbShNqtMrSsboxanXtak0Ua+XB4C2Ubcj5s2lcf1o6jW93ntIvV5e24ujSppqvV4eANrGisLsGl9HM2cUNzWd69ioiVj/0u89nv5I7i2eeA0AWJ4Vh/mx9Mfq8ks6nLugV/of1LryjCQpwv7KAHBP7nkpo7+W04tTx+sxCwBAdxHmiPEU9xfWjKPGY6kCABpk2WH+3vSb2lX43U5wcTa4B4CGWPblcgfyI/KciK4m1+h6YrU87t8DgIZYdl1/3n9EuUhKBTehy6nhRs4EAG1t2UsZz869J0fSUHVeyrM7HAA0yrLDzMIFADQHvQUAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALAMYQYAyxBmALCMY4wxQQ8BAPgdjpgBwDKEGQAsQ5gBwDKEGQAsQ5gBwDKEGQAsQ5gBwDKEGQAsQ5gBwDL/H6pkk6l6Jml8AAAAAElFTkSuQmCC", 124 | "text/plain": [ 125 | "
" 126 | ] 127 | }, 128 | "metadata": {}, 129 | "output_type": "display_data" 130 | } 131 | ], 132 | "source": [ 133 | "IMG_PATH = 'demo.jpg'\n", 134 | "result = inference_model(model, IMG_PATH)\n", 135 | "pred_map = result.pred_sem_seg.data.cpu().numpy().squeeze().astype(np.uint8)\n", 136 | "pred_map_painted = np.array(palette_coco_stuff)[pred_map].astype(np.uint8)\n", 137 | "plt.axis('off')\n", 138 | "plt.imshow(pred_map_painted)\n", 139 | "plt.show()" 140 | ] 141 | } 142 | ], 143 | "metadata": { 144 | "kernelspec": { 145 | "display_name": "tos", 146 | "language": "python", 147 | "name": "python3" 148 | }, 149 | "language_info": { 150 | "codemirror_mode": { 151 | "name": "ipython", 152 | "version": 3 153 | }, 154 | "file_extension": ".py", 155 | "mimetype": "text/x-python", 156 | "name": "python", 157 | "nbconvert_exporter": "python", 158 | "pygments_lexer": "ipython3", 159 | "version": "3.9.17" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 2 164 | } 165 | -------------------------------------------------------------------------------- /demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-arda-aydn/ITACLIP/030dccb8d9524ccecb8598528df583728ff61a15/demo.jpg -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import itaclip_segmentor 4 | import custom_datasets 5 | 6 | from mmengine.config import Config 7 | from mmengine.runner import Runner 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser( 11 | description='ITACLIP evaluation with MMSeg') 12 | parser.add_argument('--config', default='') 13 | parser.add_argument('--work-dir', default='./work_dir/') 14 | parser.add_argument( 15 | '--show', action='store_true', help='show prediction results') 16 | parser.add_argument( 17 | '--show_dir', 18 | default='', 19 | help='directory to save visualizaion images') 20 | parser.add_argument( 21 | '--launcher', 22 | choices=['none', 'pytorch', 'slurm', 'mpi'], 23 | default='none', 24 | help='job launcher') 25 | # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` 26 | # will pass the `--local-rank` parameter to `tools/train.py` instead 27 | # of `--local_rank`. 28 | parser.add_argument('--local_rank', '--local-rank', type=int, default=0) 29 | args = parser.parse_args() 30 | if 'LOCAL_RANK' not in os.environ: 31 | os.environ['LOCAL_RANK'] = str(args.local_rank) 32 | 33 | return args 34 | 35 | def trigger_visualization_hook(cfg, args): 36 | default_hooks = cfg.default_hooks 37 | if 'visualization' in default_hooks: 38 | visualization_hook = default_hooks['visualization'] 39 | # Turn on visualization 40 | visualization_hook['draw'] = True 41 | if args.show: 42 | visualization_hook['show'] = True 43 | visualization_hook['wait_time'] = args.wait_time 44 | if args.show_dir: 45 | visualizer = cfg.visualizer 46 | visualizer['save_dir'] = args.show_dir 47 | else: 48 | raise RuntimeError( 49 | 'VisualizationHook must be included in default_hooks.' 50 | 'refer to usage ' 51 | '"visualization=dict(type=\'VisualizationHook\')"') 52 | 53 | return cfg 54 | 55 | def main(): 56 | args = parse_args() 57 | 58 | cfg = Config.fromfile(args.config) 59 | cfg.launcher = args.launcher 60 | cfg.work_dir = args.work_dir 61 | 62 | if args.show or args.show_dir: 63 | cfg = trigger_visualization_hook(cfg, args) 64 | 65 | runner = Runner.from_cfg(cfg) 66 | runner.test() 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /figs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-arda-aydn/ITACLIP/030dccb8d9524ccecb8598528df583728ff61a15/figs/overview.png -------------------------------------------------------------------------------- /itaclip_segmentor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch 3 | import torch.nn as nn 4 | import sys 5 | sys.path.append("..") 6 | import cv2 7 | from clip import clip 8 | from prompts.imagenet_template import openai_imagenet_template 9 | import torchvision.transforms as T 10 | import numpy as np 11 | from mmseg.models.segmentors import BaseSegmentor 12 | from mmseg.models.data_preprocessor import SegDataPreProcessor 13 | from mmengine.structures import PixelData 14 | 15 | from mmseg.registry import MODELS 16 | from pamr import PAMR 17 | 18 | 19 | @MODELS.register_module() 20 | class ITACLIP_Segmentor(BaseSegmentor): 21 | def __init__(self, model_name, name_path, dataset_name, device=torch.device('cuda'), pretrained = None, 22 | train_cfg = None, pamr_steps=0, pamr_stride=(8, 16), prob_thd=0.0, logit_scale=40, 23 | slide_stride=112, slide_crop=224, area_thd=None, img_engineering = False, auxiliary_text_path = None, 24 | attn_self = True, def_coefficient=0.2, img_eng_coefficient=0.75, width_chunk_size = None): 25 | 26 | assert dataset_name in ['coco_stuff','coco_object','voc21','context60'] 27 | bg = False 28 | if dataset_name in ['coco_object','voc21','context60']: 29 | bg = True # sets True when the dataset contains the "background" class. 30 | self.bg = bg 31 | data_preprocessor = SegDataPreProcessor( 32 | mean=[122.771, 116.746, 104.094], 33 | std=[68.501, 66.632, 70.323], 34 | rgb_to_bgr=True) 35 | super().__init__(data_preprocessor=data_preprocessor) 36 | self.device = device 37 | self.net, _ = clip.load(model_name, device=self.device, jit=False) 38 | query_words, self.query_idx = get_cls_idx(name_path) 39 | self.num_queries = len(query_words) 40 | self.query_idx = torch.Tensor(self.query_idx).to(torch.int64).to(device) 41 | self.img_engineering = img_engineering 42 | self.transforms = ([T.Grayscale(), 43 | T.GaussianBlur(kernel_size=11,sigma=5)]) # first-category augmentations 44 | self.flip_transforms = ([T.RandomVerticalFlip(p=1), 45 | T.RandomHorizontalFlip(p=1)]) # second-category augmentations 46 | self.attn_self = attn_self # self-self attention 47 | self.def_coefficient = def_coefficient 48 | self.img_eng_coefficient = img_eng_coefficient 49 | self.width_chunk_size = width_chunk_size # This variable is used to reduce GPU memory usage when num_cls != num_queries 50 | 51 | if auxiliary_text_path is None: 52 | self.query_features = self.text_feature(query_words) 53 | else: 54 | auxiliary_texts = self.get_aux_text(auxiliary_text_path) 55 | original_features = self.text_feature(query_words) 56 | aux_features = self.text_feature(auxiliary_texts) 57 | if self.bg: 58 | self.query_features = torch.zeros_like(original_features) 59 | num_bg_words = (self.query_idx == 0).sum().item() 60 | aux_features = aux_features[self.query_idx[num_bg_words:] - 1] 61 | self.query_features[num_bg_words:] = (1 - self.def_coefficient) * original_features[num_bg_words:] + (self.def_coefficient) * aux_features 62 | self.query_features[:num_bg_words] = original_features[:num_bg_words] 63 | else: 64 | aux_features = aux_features[self.query_idx] 65 | self.query_features = (1 - self.def_coefficient) * original_features + (self.def_coefficient) * aux_features 66 | 67 | self.logit_scale = logit_scale 68 | self.prob_thd = prob_thd 69 | self.area_thd = area_thd 70 | self.slide_stride = slide_stride 71 | self.slide_crop = slide_crop 72 | 73 | if pamr_steps > 0: 74 | self.pamr = PAMR(pamr_steps, dilations=pamr_stride).to(device) 75 | else: 76 | self.pamr = None 77 | 78 | def perform_in_chunks(self, seg_logits, query_idx, num_cls, num_queries, width_chunk_size=200): 79 | device = seg_logits.device 80 | height, width = seg_logits.shape[-2:] 81 | seg_logits = seg_logits.unsqueeze(0) 82 | output = torch.zeros((num_cls, height, width), device=device) 83 | cls_index = nn.functional.one_hot(query_idx) 84 | cls_index = cls_index.T.view(num_cls, num_queries, 1, 1) 85 | 86 | for i in range(0, width, width_chunk_size): 87 | chunk_end = min(i + width_chunk_size, width) 88 | output[:,:,i:chunk_end] = (seg_logits[:,:,:,i:chunk_end] * cls_index).max(1)[0] 89 | 90 | return output 91 | 92 | def get_aux_text(self, path): 93 | aux_text = [] 94 | with open(path,'r') as f: 95 | aux_text = f.readlines() 96 | for i,name in enumerate(aux_text): 97 | name = name.replace('\n','') 98 | aux_text[i] = name.split('>=')[1] 99 | 100 | return aux_text 101 | 102 | def get_flipped_logits(self, flip_logits, transforms, size, w, h, out_dim): 103 | logit_list = [] 104 | for i,flip_logit in enumerate(flip_logits): 105 | flip_logit = flip_logit.permute(0, 2, 1).reshape(-1, out_dim, w, h) 106 | logit = nn.functional.interpolate(flip_logit, size=size, mode='bilinear') 107 | logit = transforms[i](logit) 108 | logit_list.append(logit) 109 | logits = torch.mean(torch.stack(logit_list),dim=0) 110 | return logits 111 | 112 | def forward_feature(self, img, text_features, logit_size=None): 113 | if type(img) == list: 114 | img = img[0] 115 | 116 | img_list = [] 117 | flip_list = [] 118 | if not self.img_engineering: 119 | image_features = self.net.encode_image(img, return_all=True, attn_self=self.attn_self, device=self.device) 120 | image_features /= image_features.norm(dim=-1, keepdim=True) 121 | img_list.append(image_features) 122 | else: 123 | torch.manual_seed(0) 124 | image_features = self.net.encode_image(img, return_all=True, attn_self=self.attn_self, device=self.device) 125 | image_features /= image_features.norm(dim=-1, keepdim=True) 126 | img_list.append(image_features) 127 | for transform in self.transforms: 128 | new_img = transform(img.squeeze()) 129 | new_img = new_img.unsqueeze(0) 130 | if new_img.shape[1] == 1: 131 | new_img = new_img.expand(1,3,-1,-1) 132 | image_features = self.net.encode_image(new_img, return_all=True, attn_self=self.attn_self, device=self.device) 133 | image_features /= image_features.norm(dim=-1, keepdim=True) 134 | img_list.append(image_features) 135 | 136 | for transform in self.flip_transforms: 137 | new_img = transform(img.squeeze()) 138 | new_img = new_img.unsqueeze(0) 139 | if new_img.shape[1] == 1: 140 | new_img = new_img.expand(1,3,-1,-1) 141 | flipped_image_features = self.net.encode_image(new_img, return_all=True, attn_self=self.attn_self, device=self.device) 142 | flipped_image_features /= flipped_image_features.norm(dim=-1, keepdim=True) 143 | flip_list.append(flipped_image_features) 144 | 145 | image_features = torch.mean(torch.stack(img_list), dim=0) 146 | 147 | image_features = image_features[:, 1:] 148 | logits = image_features @ text_features.T 149 | if self.img_engineering: 150 | flip_logit_list = [] 151 | for flip_img_features in flip_list: 152 | flip_img_features = flip_img_features[:, 1:] 153 | flip_logit_list.append(flip_img_features @ text_features.T) 154 | 155 | patch_size = self.net.visual.patch_size 156 | w, h = img[0].shape[-2] // patch_size, img[0].shape[-1] // patch_size 157 | out_dim = logits.shape[-1] 158 | logits = logits.permute(0, 2, 1).reshape(-1, out_dim, w, h) 159 | 160 | if logit_size == None: 161 | logits = nn.functional.interpolate(logits, size=img.shape[-2:], mode='bilinear') 162 | if self.img_engineering: 163 | flip_logits = self.get_flipped_logits(flip_logit_list,self.flip_transforms, 164 | size=img.shape[-2:], w = w, h = h, out_dim = out_dim) 165 | logits = (self.img_eng_coefficient) * logits + (1 - self.img_eng_coefficient) * flip_logits 166 | else: 167 | logits = nn.functional.interpolate(logits, size=logit_size, mode='bilinear') 168 | if self.img_engineering: 169 | flip_logits = self.get_flipped_logits(flip_logit_list,self.flip_transforms, 170 | size=logit_size, w = w, h = h, out_dim = out_dim) 171 | logits = (self.img_eng_coefficient) * logits + (1 - self.img_eng_coefficient) * flip_logits 172 | return logits 173 | 174 | def text_feature(self, query_words, templates=openai_imagenet_template): 175 | query_features = [] 176 | with torch.no_grad(): 177 | for qw in query_words: 178 | query = clip.tokenize([temp(qw) for temp in templates]).to(self.device) 179 | feature = self.net.encode_text(query) 180 | feature /= feature.norm(dim=-1, keepdim=True) 181 | feature = feature.mean(dim=0) 182 | feature /= feature.norm() 183 | query_features.append(feature.unsqueeze(0)) 184 | 185 | return torch.cat(query_features, dim=0) 186 | 187 | def forward_slide(self, img, img_metas, text_features, query_idx, pamr=None, stride=112, crop_size=224): 188 | """Inference by sliding-window with overlap. 189 | If h_crop > h_img or w_crop > w_img, the small patch will be used to 190 | decode without padding. 191 | """ 192 | 193 | if type(img) == list: 194 | img = img[0].unsqueeze(0) 195 | if type(stride) == int: 196 | stride = (stride, stride) 197 | if type(crop_size) == int: 198 | crop_size = (crop_size, crop_size) 199 | 200 | h_stride, w_stride = stride 201 | h_crop, w_crop = crop_size 202 | batch_size, _, h_img, w_img = img.shape 203 | out_channels = len(query_idx) 204 | h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 205 | w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 206 | preds = img.new_zeros((batch_size, out_channels, h_img, w_img), device=self.device) 207 | count_mat = img.new_zeros((batch_size, 1, h_img, w_img), device=self.device) 208 | for h_idx in range(h_grids): 209 | for w_idx in range(w_grids): 210 | y1 = h_idx * h_stride 211 | x1 = w_idx * w_stride 212 | y2 = min(y1 + h_crop, h_img) 213 | x2 = min(x1 + w_crop, w_img) 214 | y1 = max(y2 - h_crop, 0) 215 | x1 = max(x2 - w_crop, 0) 216 | crop_img = img[:, :, y1:y2, x1:x2] 217 | crop_seg_logit = self.forward_feature(crop_img, text_features=text_features) 218 | preds += nn.functional.pad(crop_seg_logit, 219 | (int(x1), int(preds.shape[3] - x2), int(y1), 220 | int(preds.shape[2] - y2))) 221 | 222 | count_mat[:, :, y1:y2, x1:x2] += 1 223 | assert (count_mat == 0).sum() == 0 224 | 225 | preds = preds / count_mat 226 | img_size = img_metas[0]['ori_shape'][:2] 227 | logits = nn.functional.interpolate(preds, size=img_size, mode='bilinear') 228 | 229 | if pamr: 230 | img = nn.functional.interpolate(img, size=img_size, mode='bilinear') 231 | self.pamr = self.pamr.to(self.device) 232 | logits = self.pamr(img, logits.to(img.dtype)).to(img.dtype) 233 | 234 | return logits 235 | 236 | def predict(self, inputs, data_samples): 237 | self.net = self.net.to(self.device) 238 | inputs = inputs.to(self.device) 239 | if data_samples is not None: 240 | batch_img_metas = [ 241 | data_sample.metainfo for data_sample in data_samples 242 | ] 243 | else: 244 | batch_img_metas = [ 245 | dict( 246 | ori_shape=inputs.shape[2:], 247 | img_shape=inputs.shape[2:], 248 | pad_shape=inputs.shape[2:], 249 | padding_size=[0, 0, 0, 0]) 250 | ] * inputs.shape[0] 251 | 252 | if type(inputs) == list: 253 | inputs = inputs[0].unsqueeze(0) 254 | 255 | if self.slide_crop > 0: 256 | query_idx = self.query_idx 257 | seg_logits = self.forward_slide(inputs, batch_img_metas, self.query_features, query_idx, self.pamr, self.slide_stride, self.slide_crop) 258 | else: 259 | query_idx = self.query_idx 260 | seg_logits = self.forward_feature(inputs, self.query_features, batch_img_metas[0]['ori_shape']) 261 | 262 | return self.postprocess_result(seg_logits, data_samples, query_idx) 263 | 264 | def postprocess_result(self, seg_logits, data_samples, query_idx): 265 | batch_size = seg_logits.shape[0] 266 | for i in range(batch_size): 267 | seg_logits = seg_logits[i] * self.logit_scale 268 | seg_logits = seg_logits.softmax(0) # n_queries * w * h 269 | 270 | num_cls, num_queries = max(query_idx) + 1, len(query_idx) 271 | if num_cls != num_queries: 272 | if self.width_chunk_size is None: 273 | seg_logits = seg_logits.unsqueeze(0) 274 | cls_index = nn.functional.one_hot(query_idx) 275 | cls_index = cls_index.T.view(num_cls, num_queries, 1, 1) 276 | seg_logits = (seg_logits * cls_index).max(1)[0] 277 | else: 278 | width_chunk_size = self.width_chunk_size 279 | seg_logits = self.perform_in_chunks(seg_logits, query_idx, num_cls, num_queries, width_chunk_size=width_chunk_size) 280 | 281 | if self.area_thd is not None: 282 | # Force segmentations with area < self.area_thd to 0 (background) 283 | predictions = nn.functional.one_hot(seg_logits.argmax(0), num_cls).to(seg_logits.dtype) 284 | area_pred = predictions[:, :, 1:].sum((0, 1), keepdim=True) 285 | area_pred = (area_pred > self.area_thd * area_pred.sum()).to(seg_logits.dtype) 286 | seg_logits[1:] *= area_pred.transpose(0, -1) 287 | 288 | seg_pred = seg_logits.argmax(0, keepdim=True) 289 | seg_pred[seg_logits.max(0, keepdim=True)[0] < self.prob_thd] = 0 290 | 291 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) 292 | seg_pred = torch.from_numpy(cv2.morphologyEx(seg_pred.squeeze().cpu().numpy().astype(np.uint8), cv2.MORPH_CLOSE, kernel)).unsqueeze(0) 293 | 294 | data_samples[i].set_data({ 295 | 'seg_logits': 296 | PixelData(**{'data': seg_logits}), 297 | 'pred_sem_seg': 298 | PixelData(**{'data': seg_pred}) 299 | }) 300 | 301 | return data_samples 302 | 303 | def _forward(data_samples): 304 | """ 305 | """ 306 | 307 | def inference(self, img, batch_img_metas): 308 | """ 309 | """ 310 | 311 | def encode_decode(self, inputs, batch_img_metas): 312 | """ 313 | """ 314 | 315 | def extract_feat(self, inputs): 316 | """ 317 | """ 318 | 319 | def loss(self, inputs, data_samples): 320 | """ 321 | """ 322 | 323 | def get_cls_idx(path): 324 | with open(path, 'r') as f: 325 | name_sets = f.readlines() 326 | num_cls = len(name_sets) 327 | 328 | class_names, class_indices = [], [] 329 | for idx in range(num_cls): 330 | names_i = name_sets[idx].split(', ') 331 | class_names += names_i 332 | class_indices += [idx for _ in range(len(names_i))] 333 | class_names = [item.replace('\n', '') for item in class_names] 334 | return class_names, class_indices 335 | -------------------------------------------------------------------------------- /llama3_definition_generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import time 4 | 5 | cache_dir = YOUR_CACHE_DIR 6 | os.environ['HF_HOME'] = cache_dir 7 | print(os.getenv('HF_HOME')) 8 | import transformers 9 | 10 | dataset_name = 'coco_object' 11 | assert dataset_name in ['coco_stuff','coco_object','voc21','context60','cityscapes'] 12 | bg = False 13 | if dataset_name in ['coco_object','voc21','context60']: 14 | bg = True 15 | 16 | access_token = YOUR_HF_TOKEN 17 | path = f'/ITACLIP/configs/cls_{dataset_name}.txt' 18 | model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # different LLaMa models can be used here 19 | txt_path = f'/ITACLIP/llama_generated_texts/{dataset_name}_definitions.txt' 20 | 21 | with open(path, 'r') as f: 22 | if bg: 23 | next(f) 24 | name_sets = f.readlines() 25 | 26 | for i, name in enumerate(name_sets): 27 | name_sets[i] = name.replace('\n','') 28 | if len(name_sets[i].split(',')) > 1: 29 | name_sets[i] = name_sets[i].split(',')[0] 30 | 31 | print(name_sets) 32 | print(len(name_sets)) 33 | 34 | pipeline = transformers.pipeline( 35 | "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", token=access_token) 36 | 37 | start_time = time.time() 38 | 39 | for class_name in name_sets: 40 | 41 | messages = [ 42 | {"role": "system", "content": "Give a brief definition of prompted word like given example definitions: house >= a building that people, usually one family, live in; car >= a road vehicle with an engine, four wheels, and seats for a small number of people; (no more than 50 words, do not use extra words other than the definition of given word)"}, 43 | ] 44 | messages.append({"role": "user", "content": f"{class_name} >="}) 45 | 46 | print('class name: ', class_name) 47 | prompt = pipeline.tokenizer.apply_chat_template( 48 | messages, 49 | tokenize=False, 50 | add_generation_prompt=True 51 | ) 52 | 53 | terminators = [ 54 | pipeline.tokenizer.eos_token_id, 55 | pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") 56 | ] 57 | 58 | outputs = pipeline( 59 | prompt, 60 | max_new_tokens=256, 61 | eos_token_id=terminators, 62 | do_sample=True, 63 | temperature=0.6, 64 | top_p=0.9, 65 | ) 66 | 67 | print(outputs[0]["generated_text"][len(prompt):]) 68 | 69 | with open(txt_path, 'a') as file: 70 | file.write(f'{class_name} >=') 71 | file.write(outputs[0]["generated_text"][len(prompt):]) 72 | file.write('\n') 73 | 74 | end_time = time.time() 75 | 76 | print('total time: ', end_time - start_time) 77 | 78 | -------------------------------------------------------------------------------- /llama3_synonym_generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import time 4 | 5 | cache_dir = YOUR_CACHE_DIR 6 | os.environ['HF_HOME'] = cache_dir 7 | print(os.getenv('HF_HOME')) 8 | import transformers 9 | 10 | dataset_name = 'coco_object' 11 | assert dataset_name in ['coco_stuff','coco_object','voc21','context60','cityscapes'] 12 | bg = False 13 | if dataset_name in ['coco_object','voc21','context60']: 14 | bg = True 15 | 16 | access_token = YOUR_HF_TOKEN 17 | path = f'/ITACLIP/configs/cls_{dataset_name}.txt' 18 | model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # different LLaMa models can be used here 19 | txt_path = f"/ITACLIP/llama_generated_texts/{dataset_name}_synonyms.txt" 20 | 21 | with open(path, 'r') as f: 22 | if bg: 23 | next(f) 24 | name_sets = f.readlines() 25 | 26 | for i, name in enumerate(name_sets): 27 | name_sets[i] = name.replace('\n','') 28 | if len(name_sets[i].split(',')) > 1: 29 | name_sets[i] = name_sets[i].split(',')[0] 30 | 31 | print(name_sets) 32 | print(len(name_sets)) 33 | 34 | pipeline = transformers.pipeline( 35 | "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", token=access_token) 36 | 37 | start_time = time.time() 38 | 39 | for class_name in name_sets: 40 | messages = [ 41 | {"role": "system", "content": "Provide the synonym (thesaurus) for the prompted word. If a word does not have a synonym, give the closest meaning, as in the following example definitions: house ≥ home; car ≥ automobile. (Please provide exactly one word.)"}, 42 | ] 43 | 44 | messages.append({"role": "user", "content": f"{class_name} >="}) 45 | 46 | print('class name: ', class_name) 47 | prompt = pipeline.tokenizer.apply_chat_template( 48 | messages, 49 | tokenize=False, 50 | add_generation_prompt=True 51 | ) 52 | 53 | terminators = [ 54 | pipeline.tokenizer.eos_token_id, 55 | pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") 56 | ] 57 | 58 | outputs = pipeline( 59 | prompt, 60 | max_new_tokens=256, 61 | eos_token_id=terminators, 62 | do_sample=True, 63 | temperature=0.6, 64 | top_p=0.9, 65 | ) 66 | 67 | print(outputs[0]["generated_text"][len(prompt):]) 68 | 69 | with open(txt_path, 'a') as file: 70 | file.write(f'{class_name} >=') 71 | file.write(outputs[0]["generated_text"][len(prompt):]) 72 | file.write('\n') 73 | 74 | end_time = time.time() 75 | 76 | print('total time: ', end_time - start_time) 77 | 78 | -------------------------------------------------------------------------------- /llama_generated_texts/cityscapes_synonyms.txt: -------------------------------------------------------------------------------- 1 | road >=highway 2 | sidewalk >=path 3 | building >=structure 4 | wall >=barrier 5 | fence >=hedge 6 | pole >=rod 7 | trafficlight >=signal 8 | trafficsign >=sign 9 | vegetation >=flora 10 | terrain >=landscape 11 | sky >=heaven 12 | person >=individual 13 | rider >=passenger 14 | car >=vehicle 15 | truck >=pickup 16 | bus >=coach 17 | train >=locomotive 18 | motorcycle >=motorbike 19 | bicycle >=bike 20 | -------------------------------------------------------------------------------- /llama_generated_texts/coco_object_synonyms.txt: -------------------------------------------------------------------------------- 1 | person >=human 2 | bicycle >=bike 3 | car >=vehicle 4 | motorcycle >=bike 5 | airplane >=plane 6 | bus >=coach 7 | train >=locomotive 8 | truck >=lorry 9 | boat >=ship 10 | traffic light >=signal 11 | fire hydrant >=water valve 12 | stop sign >=Traffic signal 13 | parking meter >=paystation 14 | bench >=seat 15 | bird >=avian 16 | cat >=feline 17 | dog >=canine 18 | horse >=steed 19 | sheep >=Lamb 20 | cow >=cattle 21 | elephant >=Mammoth 22 | bear >=animal 23 | zebra >=horse 24 | giraffe >=tall 25 | backpack >=rucksack 26 | umbrella >=parasol 27 | handbag >=purse 28 | tie >=necktie 29 | suitcase >=bag 30 | frisbee >=disk 31 | skis >=snowboard 32 | snowboard >=ski 33 | sports ball >=ball 34 | kite >=balloon 35 | baseball bat >=club 36 | baseball glove >=Mitchell 37 | skateboard >=board 38 | surfboard >=board 39 | tennis racket >=Racquet 40 | bottle >=container 41 | wine glass >=Tumbler 42 | cup >=container 43 | fork >=utensil 44 | knife >=blade 45 | spoon >=utensil 46 | bowl >=dish 47 | banana >=fruit 48 | apple >=fruit 49 | sandwich >=wrap 50 | orange >=Tangerine 51 | broccoli >=cauliflower 52 | carrot >=orange 53 | hot dog >=Frank 54 | pizza >=pie 55 | donut >=treat 56 | cake >=pastry 57 | chair >=seat 58 | couch >=sofa 59 | potted plant >=Flowerpot 60 | bed >=couch 61 | dining table >=table 62 | toilet >=bathroom 63 | tv >=screen 64 | laptop >=computer 65 | mouse >=rodent 66 | remote >=distant 67 | keyboard >=typewriter 68 | cell phone >=mobile 69 | microwave >=oven 70 | oven >=stove 71 | toaster >=appliance 72 | sink >=basin 73 | refrigerator >=fridge 74 | book >=volume 75 | clock >=watch 76 | vase >=container 77 | scissors >=clippers 78 | teddy bear >=stuffed animal 79 | hair drier >=blower 80 | toothbrush >=brush 81 | -------------------------------------------------------------------------------- /llama_generated_texts/coco_stuff_definitions.txt: -------------------------------------------------------------------------------- 1 | person >=a human being, a living individual. 2 | bicycle >=a vehicle with two wheels, powered by pedaling with the legs, designed for personal transportation. 3 | car >=a road vehicle with an engine, four wheels, and seats for a small number of people 4 | motorcycle >=a two-wheeled road vehicle with an engine, designed for one or two people to ride. 5 | airplane >=a powered, fixed-wing aircraft that carries passengers or cargo through the air. 6 | bus >=a large road vehicle designed to carry many people, typically used for public transportation. 7 | train >=a self-propelled vehicle on rails, powered by electricity, diesel, or steam, used for transporting people or goods. 8 | truck >=a large road vehicle with an engine, wheels, and an open or enclosed cargo bed for carrying goods or equipment. 9 | boat >=a watercraft designed to float or move on water, typically propelled by sails, oars, or a motor. 10 | trafficlight >=a device that uses red, yellow, and green lights to control the flow of traffic at an intersection or crossing. 11 | firehydrant >=a large, usually red, outdoor device that holds water for firefighters to use in extinguishing fires. 12 | stopsign >=a traffic sign with a red octagon shape and white lettering, indicating drivers must come to a complete halt before proceeding. 13 | parkingmeter >=a device that measures and charges for the time a vehicle is parked in a specific area. 14 | bench >=a long, flat piece of furniture, usually outdoors, for sitting or resting. 15 | bird >=a warm-blooded egg-laying vertebrate with feathers, wings, and a beak. 16 | cat >=a small typically furry mammal that purrs and is often kept as a pet. 17 | dog >=a domesticated carnivorous mammal that is often kept as a pet or used for hunting or herding. 18 | horse >=a large, hoofed, herbivorous mammal often domesticated and used for riding, transportation, or work. 19 | sheep >=a domesticated mammal of the genus Ovis, typically kept for its wool, milk, or meat. 20 | cow >=a large, hooved, herbivorous mammal with a distinctive set of horns and a cow-like sound. 21 | elephant >=a large mammal with a trunk, tusks, and a memory that never forgets. 22 | bear >=a large carnivorous mammal with shaggy fur, found in forests and mountains, that walks on four legs and has a distinctive growl. 23 | zebra >=a wild or domesticated equine mammal with a distinctive black and white striped coat. 24 | giraffe >=tall, even-toed ungulate mammal with a long neck and legs, spotted or patchy coat, and a distinctive pattern of spots or patches on its back. 25 | backpack >=a bag worn on the back to carry things, often used for hiking, school, or travel. 26 | umbrella >=a portable canopy of fabric or other material that is held above one's head to protect from rain or sun. 27 | handbag >=a bag carried by hand, typically made of fabric or leather, used to carry personal items such as cosmetics, keys, and money. 28 | tie >=a piece of cloth worn around the neck to fasten a shirt or collar. 29 | suitcase >=a portable bag with a zipper or hinges, used for carrying clothes, belongings, or other items while traveling. 30 | frisbee >=a flat, circular piece of plastic or other material, typically with a hole in the center, used for throwing and catching in a recreational or competitive game. 31 | skis >=long, flat pieces of wood or plastic used for gliding over snow, typically worn by people for recreation or sport. 32 | snowboard >=a board with bindings, used for sliding down snow-covered slopes, typically with both feet attached. 33 | sportsball >=a ball used in various sports, typically made of leather or synthetic materials, designed for throwing, catching, and kicking. 34 | kite >=a light frame or framework covered with lightweight material, such as fabric, and attached to strings or lines, designed to fly in the air when lifted by the wind. 35 | baseballbat >=a wooden or metal club used in the sport of baseball, typically made of a single piece of wood or a composite material, used to hit the ball. 36 | baseballglove >=a mitt-shaped piece of equipment worn by a baseball player to catch and throw the ball. 37 | skateboard >=a flat, rectangular board with wheels, used for riding, jumping, and performing tricks on the ground, usually by standing on it with one foot and pushing with the other. 38 | surfboard >=a flat, usually rectangular piece of wood, fiberglass, or plastic, designed for riding on the surface of the water, typically used for surfing. 39 | tennisracket >=a sports equipment used to hit a ball in the game of tennis. 40 | bottle >=a container made of glass, plastic, or other materials, typically with a neck and a narrow opening, used for holding liquids or solids. 41 | wineglass >=a small, usually transparent, cup-shaped vessel for drinking wine, typically made of glass or crystal. 42 | cup >=a container, typically made of ceramic, glass, or plastic, used for drinking liquids. 43 | fork >=a utensil with tines used for piercing and lifting food. 44 | knife >=a cutting instrument with a sharp blade, used for various purposes such as food preparation, hunting, or self-defense. 45 | spoon >=a utensil used for eating or serving food, typically consisting of a long handle and a bowl-shaped or oval-shaped head. 46 | bowl >=a round, usually deep container made of ceramic, glass, or metal, used for serving or holding food, especially soup or cereal. 47 | banana >=a long, curved, yellow fruit that grows on plants and is often eaten as a snack. 48 | apple >=a sweet and juicy fruit that grows on trees. 49 | sandwich >=a food consisting of two or more slices of bread, often with fillings such as meat, cheese, vegetables, or condiments between them. 50 | orange >=a vibrant, juicy, and sweet fruit that grows on trees, typically orange in color. 51 | broccoli >=a green, edible vegetable with a large, tree-like head and thick, crunchy stalks. 52 | carrot >=a long, thin, orange vegetable that grows underground and is eaten raw or cooked. 53 | hotdog >=a type of savory food consisting of a grilled or steamed sausage served in a bun, often with various toppings such as ketchup, mustard, and relish. 54 | pizza >=a flatbread dish typically topped with tomato sauce, cheese, and various ingredients, often served hot. 55 | donut >=a sweet, ring-shaped baked food typically topped with sugar or glaze and often filled with cream or jelly. 56 | cake >=a sweet baked food often decorated and served as a treat or dessert. 57 | chair >=a piece of furniture for one person to sit on, typically having a back and legs. 58 | couch >=a piece of furniture for sitting or lying on, typically upholstered and designed for comfort, often used in a living room or family room. 59 | pottedplant >=a plant grown in a pot, often kept indoors. 60 | bed >=a piece of furniture for sleeping or resting on. 61 | diningtable >=a piece of furniture with a flat surface and legs, used for holding food, drinks, and eating. 62 | toilet >=a plumbing fixture for personal hygiene and sanitation, typically a low-level ceramic bowl with a seat and lid, used for urination and defecation. 63 | tv >=a device for receiving and displaying video and audio signals, typically used for entertainment or information. 64 | laptop >=a portable personal computer with a keyboard, screen, and processing unit. 65 | mouse >=a small rodent with a pointed snout, large ears, and a long tail, typically found in homes and farms. 66 | remote >=a location or device that is far away from a central point or a person, often requiring specialized equipment or communication to access or control. 67 | keyboard >=a device with keys that allows a person to input text or commands into a computer or other electronic device. 68 | cellphone >=a portable electronic device used for communication, entertainment, and information processing, typically held in the hand. 69 | microwave >=a kitchen appliance that uses non-ionizing radiation to heat or cook food quickly. 70 | oven >=a cooking device that uses heat to cook or bake food. 71 | toaster >=a small electric appliance used for toasting slices of bread. 72 | sink >=a plumbing fixture for washing hands, dishes, or other objects, typically mounted in a countertop or wall. 73 | refrigerator >=a large electrical appliance for keeping food and drinks cool or frozen. 74 | book >=a written or printed work consisting of pages glued or sewn together. 75 | clock >=a device that shows the time, typically with hour and minute hands and sometimes seconds, used for measuring time. 76 | vase >=a container made of glass, ceramic, or other materials, typically decorative and used to hold flowers or other ornaments. 77 | scissors >=a handheld device with two blades that are used to cut various materials such as paper, fabric, or hair. 78 | teddybear >=a soft toy bear, typically made of stuffed fabric, designed for cuddling and often given as a gift to children. 79 | hairdrier >=a device used to dry and style hair, typically with hot air or heat. 80 | toothbrush >=a small brush used to clean teeth. 81 | banner >=a strip of cloth, paper, or plastic attached to a pole or hung from a building to display a message, logo, or design. 82 | blanket >=a piece of fabric, usually made of wool, cotton, or synthetic materials, used for keeping warm, covering, or decorating a bed or chair. 83 | branch >=a part of a tree that grows out from the trunk or main stem, often with leaves, flowers, or fruit. 84 | bridge >=a structure built over a body of water, valley, or road, connecting two or more land areas. 85 | building-other >=a structure with walls, floor, and roof, used for various purposes such as housing, industry, or leisure. 86 | bush >=a small shrub or low-growing tree, typically with thorns and a small canopy. 87 | cabinet >=a piece of furniture with shelves, drawers, or compartments for storing and organizing things. 88 | cage >=a structure made of bars or wires that encloses or confines something, often used to hold or protect an animal. 89 | cardboard >=a stiff, paper-like material made from wood pulp or other plant fibers, often used for packaging, boxes, and crafts. 90 | carpet >=a floor covering made of soft material, usually woven or tufted, and often made of wool, synthetic fibers, or a combination of both. 91 | ceiling-other >=a surface above, typically made of material such as drywall, plaster, or wood, that covers and conceals the upper part of a room or building. 92 | ceiling-tile >=a flat piece of material, typically made of plastic, wood, or ceramic, installed on the inside of a building's ceiling to provide a smooth surface and sometimes to conceal wiring or insulation. 93 | cloth >=a material woven from fibers, often used to make clothing, fabric, or textile. 94 | clothes >=garments worn on the body to cover and protect it, often made of fabric, and may include items such as shirts, pants, dresses, and jackets. 95 | clouds >=visible masses of water droplets or ice crystals suspended in the air. 96 | counter >=a device used to measure or count the number of things, often in a repetitive or continuous manner. 97 | cupboard >=a piece of furniture with shelves or compartments for storing food, dishes, and other household items. 98 | curtain >=a hanging piece of fabric or other material used to cover or decorate a window, door, or room. 99 | desk-stuff >=objects, papers, and equipment placed or stored on a desk for work or study. 100 | dirt >=earth or soil that covers the ground. 101 | door-stuff >=things that are placed on or in a door. 102 | fence >=a barrier made of posts and rails, used to enclose or mark a boundary. 103 | floor-marble >=a type of flooring made from polished marble stone. 104 | floor-other >=a surface below the ceiling, usually made of materials like wood, tile, or carpet, on which people walk or stand in a building. 105 | floor-stone >=a flat, usually rectangular, piece of stone used as a surface in a building. 106 | floor-tile >=a flat piece of material, usually ceramic, stone, or wood, used to cover a floor. 107 | floor-wood >=a flat surface covered with wood, usually in a building 108 | flower >=a plant that produces colorful and often fragrant blooms. 109 | fog >=a cloud of tiny water droplets suspended in the air near the ground, reducing visibility. 110 | food-other >=edible substance for human consumption, typically obtained from plants, animals, or fungi, and prepared for eating in various ways. 111 | fruit >=a sweet and fleshy part of a plant that grows from a flower and contains seeds. 112 | furniture-other >=movable objects used to make a house or room comfortable and attractive. 113 | grass >=a type of green plant that grows in lawns, fields, and other areas. 114 | gravel >=small rounded stones or pebbles used for paving or surfacing roads, paths, and driveways. 115 | ground-other >=a region or area of land that is not covered by a body of water. 116 | hill >=a natural elevation of the earth's surface, typically rounded or conical in shape. 117 | house >=a building for human residence 118 | leaves >=a part of a plant that grows from the stem and is typically green, flat, and broad. 119 | light >=a source of illumination, such as a lamp, candle, or the sun. 120 | mat >=a flat piece of material, typically made of fabric, foam, or rubber, used as a covering or layer on a floor or surface. 121 | metal >=a naturally occurring chemical element or a material made from a combination of these elements, often hard, shiny, and used for construction, tools, and other purposes. 122 | mirror-stuff >=a reflective surface, typically made of glass with a metallic coating, used for personal grooming, decoration, or optical purposes. 123 | moss >=a small, non-vascular plant that grows close to the ground or on surfaces, often in damp or shady areas. 124 | mountain >=a natural elevation of the earth's surface, typically formed by tectonic forces, with a summit and often with slopes and valleys. 125 | mud >=a soft, wet, and sticky earth substance formed by the mixture of water and soil particles. 126 | napkin >=a small piece of cloth used for wiping the mouth or nose, especially during meals. 127 | net >=a network of threads or strings stretched between posts, used for catching or trapping something. 128 | paper >=a thin, flexible material made from processed plant fibers, used for writing, printing, and other purposes. 129 | pavement >=a surface of a road or path made of hard material, such as concrete or asphalt. 130 | pillow >=a cushion or support for the head or neck while sleeping or resting. 131 | plant-other >=a living organism that grows in the ground, typically having leaves, stems, and roots, and producing flowers, fruits, or seeds. 132 | plastic >=A synthetic or semi-synthetic organic solids usually molded or extruded into various shapes and forms. 133 | platform >=A raised surface or structure that is used as a base or support for something else, such as a stage, a computer system, or a social media site. 134 | playingfield >=a designated area, usually marked with lines or boundaries, where a sport, game, or competition is played. 135 | railing >=a barrier or partition, typically made of wood, metal, or glass, that is attached to a wall, staircase, or balcony to provide support and prevent falls. 136 | railroad >=a network of tracks for trains, often operated by a company or government. 137 | river >=a natural flowing body of water that usually empties into a larger body of water, such as an ocean, lake, or sea. 138 | road >=a path or way made for travel by vehicles, pedestrians, or animals, usually paved or surfaced, and often marked by lines or signs. 139 | rock >=a small piece of stone or mineral that has broken off from a larger rock or has been worn smooth by erosion. 140 | roof >=A covering on top of a building, usually made of materials such as tiles, shingles, or metal, designed to protect the structure from weather and provide insulation. 141 | rug >=a piece of fabric, usually made of woven or tufted fibers, used to cover and decorate floors, often with a backing to prevent slipping or wrinkling. 142 | salad >=a mixture of small pieces of food, typically including vegetables, fruits, and sometimes cheese or meat, served cold. 143 | sand >=a loose granular material composed of finely divided rock and mineral particles. 144 | sea >=A large body of saltwater that is usually surrounded by land. 145 | shelf >=a flat structure attached to a wall or standing freestanding, used for holding or displaying objects. 146 | sky-other >=the atmosphere that surrounds the Earth, visible from the ground, extending upwards and outwards to the edge of space, and appearing blue or gray due to the scattering of sunlight by gases and particles. 147 | skyscraper >=a very tall, usually multi-story, building in a city or town. 148 | snow >=small, white, delicate ice crystals that fall from the sky during winter 149 | solid-other >=a material that is dense and has a fixed shape and volume, not liquid or gas. 150 | stairs >=a set of steps, usually made of wood or metal, connecting different levels of a building or structure. 151 | stone >=a small, hard, usually rounded or irregularly shaped piece of rock or mineral. 152 | straw >=a dry, hollow stem of a grain plant, typically used for making drinking straws or for animals to eat. 153 | structural-other >=a thing that is not a part of a structure, but is connected to or near it, often providing additional support or functionality. 154 | table >=a piece of furniture with a flat surface and legs, used for holding objects, eating, or working. 155 | tent >=a portable shelter made of fabric or plastic, typically with a collapsible frame, used for temporary accommodation or outdoor recreation. 156 | textile-other >=a material made from fibers or other materials, woven, knitted, or otherwise manufactured for use in clothing, upholstery, or other applications. 157 | towel >=a piece of cloth used for drying the body after bathing or showering. 158 | tree >=a perennial plant with a single stem or trunk, supporting branches and leaves in most species. 159 | vegetable >=a plant or part of a plant that is eaten as food. 160 | wall-brick >=a vertical structure made of bricks, used to enclose or divide a space. 161 | wall-concrete >=A wall made of concrete, a hard, strong, and durable building material. 162 | wall-other >=a vertical structure, usually made of stone, brick, or wood, used to enclose, support, or divide a space or area. 163 | wall-panel >=a flat surface, usually made of wood, metal, or plastic, attached to a building's exterior or interior for decoration, insulation, or structural support. 164 | wall-stone >=a structure of stone, brick, or concrete used to enclose or separate an area, typically standing upright and vertical. 165 | wall-tile >=a flat piece of material, usually ceramic or porcelain, used to cover and decorate a wall. 166 | wall-wood >=a vertical structure made of wood, used to enclose or divide a room or building. 167 | water-other >=a clear, colorless, odorless, and tasteless liquid substance that is the main component of the oceans, lakes, and rivers, and is essential for human and animal life. 168 | waterdrops >=small drops of water that fall from the sky or from a surface. 169 | window-blind >=a decorative or functional covering for a window, typically made of fabric, plastic, or metal, used to block light, provide privacy, or add aesthetic appeal. 170 | window-other >=a transparent or translucent opening in a building's exterior or interior, typically framed by a surrounding structure, allowing natural light and air to enter or exit. 171 | wood >=a natural material that grows on trees, often used for building, furniture, and other objects. 172 | -------------------------------------------------------------------------------- /llama_generated_texts/context60_definitions.txt: -------------------------------------------------------------------------------- 1 | aeroplane >=a powered, fixed-wing aircraft that is used for flight. 2 | bag >=a flexible container made of fabric or other materials, used for carrying or storing things. 3 | bed >=a piece of furniture for sleeping or resting on, typically with a mattress and a frame, designed for one or more people. 4 | bedclothes >=fabric used to cover and decorate a bed, typically including a sheet, blankets, and a comforter. 5 | bench >=a piece of furniture, typically outdoors, with a flat surface and back, used for sitting or resting. 6 | bicycle >=a vehicle with two wheels, powered by pedaling, designed for one or two people to ride. 7 | bird >=a warm-blooded vertebrate animal with feathers, wings, and a beak, typically able to fly. 8 | boat >=a watercraft designed to float on water, typically propelled by sails, oars, or an engine, used for recreation, transportation, or fishing. 9 | book >=a written or printed work consisting of pages glued or sewn together. 10 | bottle >=a container made of glass, plastic, or other materials, typically cylindrical in shape, used for holding liquids, such as water, soda, or wine. 11 | building >=a structure with walls, floor, and roof, designed for various purposes such as residence, commerce, industry, or recreation. 12 | bus >=a large road vehicle with an engine, wheels, and seats for many people, used for transporting people from one place to another. 13 | cabinet >=a piece of furniture with shelves, drawers, or doors for storing and displaying things, often used in a home or office. 14 | car >=a road vehicle with an engine, four wheels, and seats for a small number of people. 15 | cat >=a small typically furry carnivorous mammal that purrs and is often kept as a pet. 16 | ceiling >=a surface above the floor of a room, typically horizontal and parallel to it, that forms the upper boundary of the space. 17 | chair >=a piece of furniture for one person to sit on, typically with a back and legs. 18 | cloth >=a flexible material made from fibers, used for making clothing, bedding, and other textiles. 19 | computer >=a machine that can store, process, and display information, often used for work, communication, and entertainment. 20 | cow >=a large, hooved, herbivorous mammal, often domesticated and raised for its milk and meat. 21 | cup >=A container, usually made of ceramic, glass, or plastic, used for holding and drinking liquids, such as coffee, tea, or water. 22 | curtain >=a piece of fabric hung in a window or door to control light and privacy. 23 | dog >=a domesticated carnivorous mammal that is often kept as a pet or used for hunting, herding, or guarding. 24 | door >=a movable structure used to open or close an entrance to a building, vehicle, or other enclosed space. 25 | fence >=a barrier made of posts and rails or boards, used to enclose or mark a boundary. 26 | floor >=a surface that is below the ceiling and above the walls of a room or building, usually made of wood, stone, or other materials. 27 | flower >=a plant that produces colorful and fragrant parts, often used to decorate or give as a gift. 28 | food >=a substance taken in to provide nutrition and energy for the body. 29 | grass >=a type of green plant that grows in lawns, fields, and other areas. 30 | ground >=the surface of the Earth or a similar natural or artificial surface. 31 | horse >=a large, hoofed, herbivorous mammal often domesticated and used for riding, transportation, and work. 32 | keyboard >=a device with keys for typing and inputting data into a computer or other electronic device. 33 | light >=a source of illumination, typically in the form of electromagnetic radiation with a wavelength that is visible to the human eye. 34 | motorbike >=a two-wheeled road vehicle powered by an engine, designed for one or two people to ride. 35 | mountain >=a natural elevation of the earth's surface, usually rocky and steep, with a summit and often surrounded by valleys and hills. 36 | mouse >=a small rodent-like computer input device with a cord or wireless connection, used to interact with a computer or other electronic device. 37 | person >=a human being, typically an individual, male or female, with a distinct identity, thoughts, feelings, and experiences. 38 | plate >=a flat, usually round or oval, dishware item used for serving or holding food and drinks. 39 | platform >=A raised level surface, often made of wood, concrete, or metal, used as a base for something, such as a stage, a computer, or a display. 40 | pottedplant >=a plant grown in a container, typically made of clay or plastic, and kept indoors or outdoors. 41 | road >=a path or way made for travel by vehicles, bicycles, or pedestrians. 42 | rock >=a small, naturally occurring piece of mineral material, often found on the ground or in the sea. 43 | sheep >=a domesticated mammal with a woolly coat, typically kept for its wool, milk, or meat. 44 | shelves >=flat surfaces attached to a wall or free-standing, used to hold and display objects such as books, decorative items, or storage containers. 45 | sidewalk >=a paved path or way alongside a road or street for pedestrian use. 46 | sign >=a visible indication or token of something, especially a message, warning, or direction, often displayed or displayed publicly. 47 | sky >=the atmosphere that surrounds the Earth, visible from the ground as a blue or gray expanse. 48 | snow >=a natural weather phenomenon composed of white or translucent ice crystals that fall from the sky during winter months. 49 | sofa >=a piece of furniture for sitting or lying on, typically with a back and arms, used in a living room or other room. 50 | table >=a piece of furniture with a flat surface, legs, and often drawers or shelves, used for holding objects, eating, or working. 51 | track >=a path or route made for a specific purpose, such as a railroad, bicycle path, or footpath. 52 | train >=a self-propelled vehicle on rails for carrying people or goods. 53 | tree >=a perennial plant with a single stem or trunk, supporting branches and leaves in most species. 54 | truck >=a large road vehicle with an engine, four wheels, and an open or enclosed body for carrying goods or equipment. 55 | tvmonitor >=a device for receiving and displaying video and audio signals, typically used for entertainment, information, or education. 56 | wall >=a vertical structure or barrier made of materials such as wood, brick, or stone, used to enclose, divide, or protect an area. 57 | water >=a clear, colorless, odorless, and tasteless liquid substance that is the most abundant compound on Earth, covering about 71% of its surface. 58 | window >=a transparent or translucent opening in a building, typically framed and hinged, that allows natural light and air to enter and provides a view outside. 59 | wood >=a natural material that grows on trees, used for building, furniture, and other purposes. 60 | -------------------------------------------------------------------------------- /llama_generated_texts/voc21_definitions.txt: -------------------------------------------------------------------------------- 1 | aeroplane >=a powered, fixed-wing aircraft that is used for flight. 2 | bicycle >=a vehicle with two wheels, powered by pedaling, designed for one or two people to ride. 3 | bird >=a small warm-blooded animal that has feathers, wings, and lays eggs. 4 | ship >=a large vessel that travels on water, often for transportation, trade, or recreation. 5 | bottle >=a container made of glass, plastic, or other materials, typically cylindrical or conical in shape, used for holding liquids or other substances. 6 | bus >=a large road vehicle designed to carry many people, typically used for public transportation. 7 | car >=a road vehicle with an engine, four wheels, and seats for a small number of people 8 | cat >=a small typically furry carnivorous mammal that purrs, has claws, and is often kept as a pet. 9 | chair >=a piece of furniture for one person to sit on, typically with a back and legs. 10 | cow >=a large, hooved, herbivorous mammal with a distinctive sound. 11 | table >=a piece of furniture with a flat surface and legs, used for holding objects, eating, or working. 12 | dog >=a domesticated carnivorous mammal that is often kept as a pet or used for hunting, herding, or guarding. 13 | horse >=a large, hoofed, herbivorous mammal often domesticated and used for riding, transportation, or work. 14 | motorbike >=a two-wheeled road vehicle with an engine, powered by the rider's legs and feet. 15 | person >=a human being, either male or female. 16 | pottedplant >=a small plant grown in a container, usually kept indoors or in a garden. 17 | sheep >=a domesticated mammal that grazes on grass and is often kept on farms for its wool, milk, or meat. 18 | sofa >=a piece of furniture for sitting or reclining, typically with cushions and a back, designed for comfort and often used in living rooms. 19 | train >=a self-propelled vehicle on rails, used for transporting people or goods. 20 | television monitor >=a device that shows moving images and sounds, often used for entertainment, education, or communication. 21 | -------------------------------------------------------------------------------- /pamr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 TU Darmstadt 2 | # Licnese: Apache 2.0 License. 3 | # https://github.com/visinf/1-stage-wseg/blob/master/models/mods/pamr.py 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | 8 | from functools import partial 9 | 10 | # 11 | # Helper modules 12 | # 13 | class LocalAffinity(nn.Module): 14 | 15 | def __init__(self, dilations=[1]): 16 | super(LocalAffinity, self).__init__() 17 | self.dilations = dilations 18 | weight = self._init_aff() 19 | self.register_buffer('kernel', weight) 20 | 21 | def _init_aff(self): 22 | # initialising the shift kernel 23 | weight = torch.zeros(8, 1, 3, 3) 24 | 25 | for i in range(weight.size(0)): 26 | weight[i, 0, 1, 1] = 1 27 | 28 | weight[0, 0, 0, 0] = -1 29 | weight[1, 0, 0, 1] = -1 30 | weight[2, 0, 0, 2] = -1 31 | 32 | weight[3, 0, 1, 0] = -1 33 | weight[4, 0, 1, 2] = -1 34 | 35 | weight[5, 0, 2, 0] = -1 36 | weight[6, 0, 2, 1] = -1 37 | weight[7, 0, 2, 2] = -1 38 | 39 | self.weight_check = weight.clone() 40 | 41 | return weight 42 | 43 | def forward(self, x): 44 | 45 | self.weight_check = self.weight_check.type_as(x) 46 | assert torch.all(self.weight_check.eq(self.kernel)) 47 | 48 | B,K,H,W = x.size() 49 | x = x.view(B*K,1,H,W) 50 | 51 | x_affs = [] 52 | for d in self.dilations: 53 | x_pad = F.pad(x, [d]*4, mode='replicate') 54 | x_aff = F.conv2d(x_pad, self.kernel, dilation=d) 55 | x_affs.append(x_aff) 56 | 57 | x_aff = torch.cat(x_affs, 1) 58 | return x_aff.view(B,K,-1,H,W) 59 | 60 | class LocalAffinityCopy(LocalAffinity): 61 | 62 | def _init_aff(self): 63 | # initialising the shift kernel 64 | weight = torch.zeros(8, 1, 3, 3) 65 | 66 | weight[0, 0, 0, 0] = 1 67 | weight[1, 0, 0, 1] = 1 68 | weight[2, 0, 0, 2] = 1 69 | 70 | weight[3, 0, 1, 0] = 1 71 | weight[4, 0, 1, 2] = 1 72 | 73 | weight[5, 0, 2, 0] = 1 74 | weight[6, 0, 2, 1] = 1 75 | weight[7, 0, 2, 2] = 1 76 | 77 | self.weight_check = weight.clone() 78 | return weight 79 | 80 | class LocalStDev(LocalAffinity): 81 | 82 | def _init_aff(self): 83 | weight = torch.zeros(9, 1, 3, 3) 84 | weight.zero_() 85 | 86 | weight[0, 0, 0, 0] = 1 87 | weight[1, 0, 0, 1] = 1 88 | weight[2, 0, 0, 2] = 1 89 | 90 | weight[3, 0, 1, 0] = 1 91 | weight[4, 0, 1, 1] = 1 92 | weight[5, 0, 1, 2] = 1 93 | 94 | weight[6, 0, 2, 0] = 1 95 | weight[7, 0, 2, 1] = 1 96 | weight[8, 0, 2, 2] = 1 97 | 98 | self.weight_check = weight.clone() 99 | return weight 100 | 101 | def forward(self, x): 102 | # returns (B,K,P,H,W), where P is the number 103 | # of locations 104 | x = super(LocalStDev, self).forward(x) 105 | 106 | return x.std(2, keepdim=True) 107 | 108 | class LocalAffinityAbs(LocalAffinity): 109 | 110 | def forward(self, x): 111 | x = super(LocalAffinityAbs, self).forward(x) 112 | return torch.abs(x) 113 | 114 | # 115 | # PAMR module 116 | # 117 | class PAMR(nn.Module): 118 | 119 | def __init__(self, num_iter=1, dilations=[1]): 120 | super(PAMR, self).__init__() 121 | 122 | self.num_iter = num_iter 123 | self.aff_x = LocalAffinityAbs(dilations) 124 | self.aff_m = LocalAffinityCopy(dilations) 125 | self.aff_std = LocalStDev(dilations) 126 | 127 | def forward(self, x, mask): 128 | mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True) 129 | 130 | # x: [BxKxHxW] 131 | # mask: [BxCxHxW] 132 | B,K,H,W = x.size() 133 | _,C,_,_ = mask.size() 134 | 135 | x_std = self.aff_std(x) 136 | 137 | x = -self.aff_x(x) / (1e-8 + 0.1 * x_std) 138 | x = x.mean(1, keepdim=True) 139 | x = F.softmax(x, 2) 140 | 141 | for _ in range(self.num_iter): 142 | m = self.aff_m(mask) # [BxCxPxHxW] 143 | mask = (m * x).sum(2) 144 | 145 | # xvals: [BxCxHxW] 146 | return mask -------------------------------------------------------------------------------- /prompts/imagenet_template.py: -------------------------------------------------------------------------------- 1 | 2 | imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 3 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 4 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 5 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 6 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 7 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 8 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 9 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 10 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 11 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 12 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 13 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 14 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 15 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 16 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 17 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 18 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 19 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 20 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 21 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 22 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 23 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 24 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 25 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 26 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 27 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 28 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 29 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 30 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 31 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 32 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 33 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 34 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 35 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 36 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 37 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 38 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 39 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 40 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 41 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 42 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 43 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 44 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 45 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 46 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 47 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 48 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 49 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 50 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 51 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 52 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 53 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 54 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 55 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 56 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 57 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 58 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 59 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 60 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 61 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 62 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 63 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 64 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 65 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 66 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 67 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 68 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 69 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 70 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 71 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 72 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 73 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 74 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 75 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 76 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 77 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 78 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 79 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 80 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 81 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 82 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 83 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 84 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 85 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 86 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 87 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 88 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 89 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 90 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 91 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 92 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 93 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 94 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 95 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 96 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 97 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 98 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 99 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 100 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 101 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 102 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 103 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 104 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 105 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 106 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 107 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 108 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 109 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 110 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 111 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 112 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 113 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 114 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 115 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 116 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 117 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 118 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 119 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 120 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 121 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 122 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 123 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 124 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 125 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 126 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 127 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 128 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 129 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 130 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 131 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 132 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 133 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 134 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 135 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 136 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 137 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 138 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 139 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 140 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 141 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 142 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 143 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 144 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 145 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 146 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 147 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 148 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 149 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 150 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 151 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 152 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 153 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 154 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 155 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 156 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 157 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 158 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 159 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 160 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 161 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 162 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 163 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 164 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 165 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 166 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 167 | 168 | 169 | openai_imagenet_template = [ 170 | lambda c: f'a bad photo of a {c}.', 171 | lambda c: f'a photo of many {c}.', 172 | lambda c: f'a sculpture of a {c}.', 173 | lambda c: f'a photo of the hard to see {c}.', 174 | lambda c: f'a low resolution photo of the {c}.', 175 | lambda c: f'a rendering of a {c}.', 176 | lambda c: f'graffiti of a {c}.', 177 | lambda c: f'a bad photo of the {c}.', 178 | lambda c: f'a cropped photo of the {c}.', 179 | lambda c: f'a tattoo of a {c}.', 180 | lambda c: f'the embroidered {c}.', 181 | lambda c: f'a photo of a hard to see {c}.', 182 | lambda c: f'a bright photo of a {c}.', 183 | lambda c: f'a photo of a clean {c}.', 184 | lambda c: f'a photo of a dirty {c}.', 185 | lambda c: f'a dark photo of the {c}.', 186 | lambda c: f'a drawing of a {c}.', 187 | lambda c: f'a photo of my {c}.', 188 | lambda c: f'the plastic {c}.', 189 | lambda c: f'a photo of the cool {c}.', 190 | lambda c: f'a close-up photo of a {c}.', 191 | lambda c: f'a black and white photo of the {c}.', 192 | lambda c: f'a painting of the {c}.', 193 | lambda c: f'a painting of a {c}.', 194 | lambda c: f'a pixelated photo of the {c}.', 195 | lambda c: f'a sculpture of the {c}.', 196 | lambda c: f'a bright photo of the {c}.', 197 | lambda c: f'a cropped photo of a {c}.', 198 | lambda c: f'a plastic {c}.', 199 | lambda c: f'a photo of the dirty {c}.', 200 | lambda c: f'a jpeg corrupted photo of a {c}.', 201 | lambda c: f'a blurry photo of the {c}.', 202 | lambda c: f'a photo of the {c}.', 203 | lambda c: f'a good photo of the {c}.', 204 | lambda c: f'a rendering of the {c}.', 205 | lambda c: f'a {c} in a video game.', 206 | lambda c: f'a photo of one {c}.', 207 | lambda c: f'a doodle of a {c}.', 208 | lambda c: f'a close-up photo of the {c}.', 209 | lambda c: f'a photo of a {c}.', 210 | lambda c: f'the origami {c}.', 211 | lambda c: f'the {c} in a video game.', 212 | lambda c: f'a sketch of a {c}.', 213 | lambda c: f'a doodle of the {c}.', 214 | lambda c: f'a origami {c}.', 215 | lambda c: f'a low resolution photo of a {c}.', 216 | lambda c: f'the toy {c}.', 217 | lambda c: f'a rendition of the {c}.', 218 | lambda c: f'a photo of the clean {c}.', 219 | lambda c: f'a photo of a large {c}.', 220 | lambda c: f'a rendition of a {c}.', 221 | lambda c: f'a photo of a nice {c}.', 222 | lambda c: f'a photo of a weird {c}.', 223 | lambda c: f'a blurry photo of a {c}.', 224 | lambda c: f'a cartoon {c}.', 225 | lambda c: f'art of a {c}.', 226 | lambda c: f'a sketch of the {c}.', 227 | lambda c: f'a embroidered {c}.', 228 | lambda c: f'a pixelated photo of a {c}.', 229 | lambda c: f'itap of the {c}.', 230 | lambda c: f'a jpeg corrupted photo of the {c}.', 231 | lambda c: f'a good photo of a {c}.', 232 | lambda c: f'a plushie {c}.', 233 | lambda c: f'a photo of the nice {c}.', 234 | lambda c: f'a photo of the small {c}.', 235 | lambda c: f'a photo of the weird {c}.', 236 | lambda c: f'the cartoon {c}.', 237 | lambda c: f'art of the {c}.', 238 | lambda c: f'a drawing of the {c}.', 239 | lambda c: f'a photo of the large {c}.', 240 | lambda c: f'a black and white photo of a {c}.', 241 | lambda c: f'the plushie {c}.', 242 | lambda c: f'a dark photo of a {c}.', 243 | lambda c: f'itap of a {c}.', 244 | lambda c: f'graffiti of the {c}.', 245 | lambda c: f'a toy {c}.', 246 | lambda c: f'itap of my {c}.', 247 | lambda c: f'a photo of a cool {c}.', 248 | lambda c: f'a photo of a small {c}.', 249 | lambda c: f'a tattoo of the {c}.', 250 | ] --------------------------------------------------------------------------------