├── README.md ├── assets └── motivation.png ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── configs ├── datasets │ ├── caltech101.yaml │ ├── dtd.yaml │ ├── eurosat.yaml │ ├── fgvc_aircraft.yaml │ ├── food101.yaml │ ├── imagenet.yaml │ ├── imagenet_a.yaml │ ├── imagenet_r.yaml │ ├── imagenet_sketch.yaml │ ├── imagenetv2.yaml │ ├── oxford_flowers.yaml │ ├── oxford_pets.yaml │ ├── stanford_cars.yaml │ ├── sun397.yaml │ └── ucf101.yaml └── trainers │ ├── CoCoOp │ ├── vit_b16_c16_ep10_batch1.yaml │ ├── vit_b16_c4_ep10_batch1.yaml │ ├── vit_b16_c4_ep10_batch1_ctxv1.yaml │ └── vit_b16_c8_ep10_batch1.yaml │ ├── CoOp │ ├── rn101.yaml │ ├── rn101_ep50.yaml │ ├── rn50.yaml │ ├── rn50_ctxv1.yaml │ ├── rn50_ep100.yaml │ ├── rn50_ep50.yaml │ ├── rn50_ep50_ctxv1.yaml │ ├── rn50_val.yaml │ ├── vit_b16_ep10.yaml │ ├── vit_b16_ep100.yaml │ ├── vit_b16_ep50.yaml │ ├── vit_b32.yaml │ └── vit_b32_ep50.yaml │ ├── CoOp_testtime │ ├── vit_b16_ep10.yaml │ ├── vit_b16_ep200.yaml │ └── vit_b16_ep50.yaml │ ├── Unified │ ├── vit_b16_ep100.yaml │ ├── vit_b16_ep200.yaml │ └── vit_b16_ep50.yaml │ ├── VPT │ ├── vit_b16_ep10.yaml │ ├── vit_b16_ep100.yaml │ ├── vit_b16_ep200.yaml │ └── vit_b16_ep50.yaml │ ├── VPT_deep │ ├── vit_b16_ep10.yaml │ ├── vit_b16_ep200.yaml │ └── vit_b16_ep50.yaml │ ├── VPT_shallow │ ├── vit_b16_ep10.yaml │ ├── vit_b16_ep200.yaml │ └── vit_b16_ep50.yaml │ └── VPT_testtime │ ├── vit_b16_ep10.yaml │ ├── vit_b16_ep200.yaml │ └── vit_b16_ep50.yaml ├── scripts ├── README.md ├── cocoop │ ├── base2new_test.sh │ ├── xd_test.sh │ └── xd_train.sh ├── eval.sh ├── imagenet_coop.sh ├── imagenet_zero.sh ├── main.sh ├── unified │ ├── base2new_test_coop.sh │ ├── base2new_test_new.sh │ ├── base2new_test_new2.sh │ ├── base2new_test_new3.sh │ ├── base2new_train_caltech-101.sh │ ├── base2new_train_dtd.sh │ ├── base2new_train_eurosat.sh │ ├── base2new_train_fgvc_aircraft.sh │ ├── base2new_train_food101.sh │ ├── base2new_train_imagenet.sh │ ├── base2new_train_oxford_flowers.sh │ ├── base2new_train_oxford_pets.sh │ ├── base2new_train_stanford_cars.sh │ ├── base2new_train_sun397.sh │ └── base2new_train_ucf101.sh └── zeroshot.sh ├── train.py └── trainers ├── .flake8 ├── __init__.py ├── cocoop.py ├── coop.py ├── coop_testtime.py ├── imagenet_templates.py ├── linter.sh ├── losses.py ├── unified.py ├── utils.py ├── vpt.py ├── vpt_deep.py ├── vpt_shallow.py └── zsclip.py /README.md: -------------------------------------------------------------------------------- 1 |

Unified Vision and Language Prompt Learning

2 | 3 |

4 | arXiv | 5 | Code 6 |

7 | 8 | > **Unified Vision and Language Prompt Learning**
9 | > Yuhang Zang, Wei Li, Kaiyang Zhou, Chen Huang, Chen Change Loy
10 | > arXiv, 2022
11 | 12 |

13 | 14 |

15 | 16 | ## How to run 17 | 18 | This code is based on [CoOp](https://github.com/KaiyangZhou/CoOp), you may refer to the [install instruction](https://github.com/KaiyangZhou/CoOp?tab=readme-ov-file#how-to-install) 19 | 20 | The training scripts to re-produce the results are provided [here](scripts/unified). 21 | 22 | The model structure is defined [here](trainers/unified.py). 23 | -------------------------------------------------------------------------------- /assets/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuhangzang/UPT/3d1640fcfd2532fd651041bc955fc5baff51c71f/assets/motivation.png -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuhangzang/UPT/3d1640fcfd2532fd651041bc955fc5baff51c71f/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px, interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | """Returns the names of available CLIP models""" 83 | return list(_MODELS.keys()) 84 | 85 | 86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 87 | """Load a CLIP model 88 | 89 | Parameters 90 | ---------- 91 | name : str 92 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 93 | 94 | device : Union[str, torch.device] 95 | The device to put the loaded model 96 | 97 | jit : bool 98 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 99 | 100 | Returns 101 | ------- 102 | model : torch.nn.Module 103 | The CLIP model 104 | 105 | preprocess : Callable[[PIL.Image], torch.Tensor] 106 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 107 | """ 108 | if name in _MODELS: 109 | model_path = _download(_MODELS[name]) 110 | elif os.path.isfile(name): 111 | model_path = name 112 | else: 113 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 114 | 115 | try: 116 | # loading JIT archive 117 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 118 | state_dict = None 119 | except RuntimeError: 120 | # loading saved state dict 121 | if jit: 122 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 123 | jit = False 124 | state_dict = torch.load(model_path, map_location="cpu") 125 | 126 | if not jit: 127 | model = build_model(state_dict or model.state_dict()).to(device) 128 | if str(device) == "cpu": 129 | model.float() 130 | return model, _transform(model.visual.input_resolution) 131 | 132 | # patch the device names 133 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 134 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 135 | 136 | def patch_device(module): 137 | try: 138 | graphs = [module.graph] if hasattr(module, "graph") else [] 139 | except RuntimeError: 140 | graphs = [] 141 | 142 | if hasattr(module, "forward1"): 143 | graphs.append(module.forward1.graph) 144 | 145 | for graph in graphs: 146 | for node in graph.findAllNodes("prim::Constant"): 147 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 148 | node.copyAttributes(device_node) 149 | 150 | model.apply(patch_device) 151 | patch_device(model.encode_image) 152 | patch_device(model.encode_text) 153 | 154 | # patch dtype to float32 on CPU 155 | if str(device) == "cpu": 156 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 157 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 158 | float_node = float_input.node() 159 | 160 | def patch_float(module): 161 | try: 162 | graphs = [module.graph] if hasattr(module, "graph") else [] 163 | except RuntimeError: 164 | graphs = [] 165 | 166 | if hasattr(module, "forward1"): 167 | graphs.append(module.forward1.graph) 168 | 169 | for graph in graphs: 170 | for node in graph.findAllNodes("aten::to"): 171 | inputs = list(node.inputs()) 172 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 173 | if inputs[i].node()["value"] == 5: 174 | inputs[i].node().copyAttributes(float_node) 175 | 176 | model.apply(patch_float) 177 | patch_float(model.encode_image) 178 | patch_float(model.encode_text) 179 | 180 | model.float() 181 | 182 | return model, _transform(model.input_resolution.item()) 183 | 184 | 185 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 186 | """ 187 | Returns the tokenized representation of given input string(s) 188 | 189 | Parameters 190 | ---------- 191 | texts : Union[str, List[str]] 192 | An input string or a list of input strings to tokenize 193 | 194 | context_length : int 195 | The context length to use; all CLIP models use 77 as the context length 196 | 197 | truncate: bool 198 | Whether to truncate the text in case its encoding is longer than the context length 199 | 200 | Returns 201 | ------- 202 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 203 | """ 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | 207 | sot_token = _tokenizer.encoder["<|startoftext|>"] 208 | eot_token = _tokenizer.encoder["<|endoftext|>"] 209 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 210 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 211 | 212 | for i, tokens in enumerate(all_tokens): 213 | if len(tokens) > context_length: 214 | if truncate: 215 | tokens = tokens[:context_length] 216 | tokens[-1] = eot_token 217 | else: 218 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 219 | result[i, :len(tokens)] = torch.tensor(tokens) 220 | 221 | return result 222 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | def forward_prompt(self, x, prompt): 202 | l_length = len(self.resblocks) 203 | p_length = len(prompt) // l_length 204 | 205 | prompt = prompt.reshape(p_length, l_length, -1) 206 | prompt = prompt.permute(1, 0, 2) 207 | prompt = prompt.unsqueeze(2).repeat(1, 1, x.shape[1], 1) 208 | # prompt = prompt.unsqueeze(1).repeat(1, x.shape[1], 1) 209 | 210 | for ind, block in enumerate(self.resblocks): 211 | cls = x[0:1, :, :] 212 | if ind == 0: 213 | spatial = x[1:, :, :] 214 | else: 215 | spatial = x[1+p_length:, :, :] 216 | x = torch.cat([cls, prompt[ind, :], spatial], 0) 217 | # x = torch.cat([cls, prompt, spatial], 0) 218 | x = block(x) 219 | return x 220 | 221 | 222 | class VisionTransformer(nn.Module): 223 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 224 | super().__init__() 225 | self.input_resolution = input_resolution 226 | self.output_dim = output_dim 227 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 228 | 229 | scale = width ** -0.5 230 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 231 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 232 | self.ln_pre = LayerNorm(width) 233 | 234 | self.transformer = Transformer(width, layers, heads) 235 | 236 | self.ln_post = LayerNorm(width) 237 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 238 | 239 | def forward(self, x: torch.Tensor): 240 | x = self.conv1(x) # shape = [*, width, grid, grid] 241 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 242 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 243 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 244 | x = x + self.positional_embedding.to(x.dtype) 245 | x = self.ln_pre(x) 246 | 247 | x = x.permute(1, 0, 2) # NLD -> LND 248 | x = self.transformer(x) 249 | x = x.permute(1, 0, 2) # LND -> NLD 250 | 251 | x = self.ln_post(x[:, 0, :]) 252 | 253 | if self.proj is not None: 254 | x = x @ self.proj 255 | 256 | return x 257 | 258 | def forward_prompt(self, x: torch.Tensor, prompt): 259 | x = self.conv1(x) # shape = [*, width, grid, grid] 260 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 261 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 262 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 263 | x = x + self.positional_embedding.to(x.dtype) 264 | 265 | # 266 | cls = x[:, 0:1, :] 267 | spatial = x[:, 1:, :] 268 | prompt = prompt.unsqueeze(0).repeat(len(x), 1, 1) 269 | x = torch.cat([cls, prompt, spatial], 1) 270 | # 271 | 272 | x = self.ln_pre(x) 273 | 274 | x = x.permute(1, 0, 2) # NLD -> LND 275 | x = self.transformer(x) 276 | x = x.permute(1, 0, 2) # LND -> NLD 277 | 278 | x = self.ln_post(x[:, 0, :]) 279 | 280 | if self.proj is not None: 281 | out = x @ self.proj 282 | 283 | return out, x 284 | 285 | def forward_prompt_deep(self, x: torch.Tensor, prompt): 286 | x = self.conv1(x) # shape = [*, width, grid, grid] 287 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 288 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 289 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 290 | x = x + self.positional_embedding.to(x.dtype) 291 | 292 | x = self.ln_pre(x) 293 | 294 | x = x.permute(1, 0, 2) # NLD -> LND 295 | x = self.transformer.forward_prompt(x, prompt) 296 | x = x.permute(1, 0, 2) # LND -> NLD 297 | 298 | x = self.ln_post(x[:, 0, :]) 299 | 300 | if self.proj is not None: 301 | out = x @ self.proj 302 | 303 | return out, x 304 | 305 | 306 | class CLIP(nn.Module): 307 | def __init__(self, 308 | embed_dim: int, 309 | # vision 310 | image_resolution: int, 311 | vision_layers: Union[Tuple[int, int, int, int], int], 312 | vision_width: int, 313 | vision_patch_size: int, 314 | # text 315 | context_length: int, 316 | vocab_size: int, 317 | transformer_width: int, 318 | transformer_heads: int, 319 | transformer_layers: int 320 | ): 321 | super().__init__() 322 | 323 | self.context_length = context_length 324 | 325 | if isinstance(vision_layers, (tuple, list)): 326 | vision_heads = vision_width * 32 // 64 327 | self.visual = ModifiedResNet( 328 | layers=vision_layers, 329 | output_dim=embed_dim, 330 | heads=vision_heads, 331 | input_resolution=image_resolution, 332 | width=vision_width 333 | ) 334 | else: 335 | vision_heads = vision_width // 64 336 | self.visual = VisionTransformer( 337 | input_resolution=image_resolution, 338 | patch_size=vision_patch_size, 339 | width=vision_width, 340 | layers=vision_layers, 341 | heads=vision_heads, 342 | output_dim=embed_dim 343 | ) 344 | 345 | self.transformer = Transformer( 346 | width=transformer_width, 347 | layers=transformer_layers, 348 | heads=transformer_heads, 349 | attn_mask=self.build_attention_mask() 350 | ) 351 | 352 | self.vocab_size = vocab_size 353 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 354 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 355 | self.ln_final = LayerNorm(transformer_width) 356 | 357 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 358 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 359 | 360 | self.initialize_parameters() 361 | 362 | def initialize_parameters(self): 363 | nn.init.normal_(self.token_embedding.weight, std=0.02) 364 | nn.init.normal_(self.positional_embedding, std=0.01) 365 | 366 | if isinstance(self.visual, ModifiedResNet): 367 | if self.visual.attnpool is not None: 368 | std = self.visual.attnpool.c_proj.in_features ** -0.5 369 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 370 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 371 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 372 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 373 | 374 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 375 | for name, param in resnet_block.named_parameters(): 376 | if name.endswith("bn3.weight"): 377 | nn.init.zeros_(param) 378 | 379 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 380 | attn_std = self.transformer.width ** -0.5 381 | fc_std = (2 * self.transformer.width) ** -0.5 382 | for block in self.transformer.resblocks: 383 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 384 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 385 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 386 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 387 | 388 | if self.text_projection is not None: 389 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 390 | 391 | def build_attention_mask(self): 392 | # lazily create causal attention mask, with full attention between the vision tokens 393 | # pytorch uses additive attention mask; fill with -inf 394 | mask = torch.empty(self.context_length, self.context_length) 395 | mask.fill_(float("-inf")) 396 | mask.triu_(1) # zero out the lower diagonal 397 | return mask 398 | 399 | @property 400 | def dtype(self): 401 | return self.visual.conv1.weight.dtype 402 | 403 | def encode_image(self, image): 404 | return self.visual(image.type(self.dtype)) 405 | 406 | def encode_image_prompt(self, image, prompt): 407 | return self.visual.forward_prompt(image.type(self.dtype), prompt) 408 | 409 | def encode_image_prompt_deep(self, image, prompt): 410 | return self.visual.forward_prompt_deep(image.type(self.dtype), prompt) 411 | 412 | def encode_text(self, text): 413 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 414 | 415 | x = x + self.positional_embedding.type(self.dtype) 416 | x = x.permute(1, 0, 2) # NLD -> LND 417 | x = self.transformer(x) 418 | x = x.permute(1, 0, 2) # LND -> NLD 419 | x = self.ln_final(x).type(self.dtype) 420 | 421 | # x.shape = [batch_size, n_ctx, transformer.width] 422 | # take features from the eot embedding (eot_token is the highest number in each sequence) 423 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 424 | 425 | return x 426 | 427 | def forward(self, image, text): 428 | image_features = self.encode_image(image) 429 | text_features = self.encode_text(text) 430 | 431 | # normalized features 432 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 433 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 434 | 435 | # cosine similarity as logits 436 | logit_scale = self.logit_scale.exp() 437 | logits_per_image = logit_scale * image_features @ text_features.t() 438 | logits_per_text = logit_scale * text_features @ image_features.t() 439 | 440 | # shape = [global_batch_size, global_batch_size] 441 | return logits_per_image, logits_per_text 442 | 443 | 444 | def convert_weights(model: nn.Module): 445 | """Convert applicable model parameters to fp16""" 446 | 447 | def _convert_weights_to_fp16(l): 448 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 449 | l.weight.data = l.weight.data.half() 450 | if l.bias is not None: 451 | l.bias.data = l.bias.data.half() 452 | 453 | if isinstance(l, nn.MultiheadAttention): 454 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 455 | tensor = getattr(l, attr) 456 | if tensor is not None: 457 | tensor.data = tensor.data.half() 458 | 459 | for name in ["text_projection", "proj"]: 460 | if hasattr(l, name): 461 | attr = getattr(l, name) 462 | if attr is not None: 463 | attr.data = attr.data.half() 464 | 465 | model.apply(_convert_weights_to_fp16) 466 | 467 | 468 | def build_model(state_dict: dict): 469 | vit = "visual.proj" in state_dict 470 | 471 | if vit: 472 | vision_width = state_dict["visual.conv1.weight"].shape[0] 473 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 474 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 475 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 476 | image_resolution = vision_patch_size * grid_size 477 | else: 478 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 479 | vision_layers = tuple(counts) 480 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 481 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 482 | vision_patch_size = None 483 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 484 | image_resolution = output_width * 32 485 | 486 | embed_dim = state_dict["text_projection"].shape[1] 487 | context_length = state_dict["positional_embedding"].shape[0] 488 | vocab_size = state_dict["token_embedding.weight"].shape[0] 489 | transformer_width = state_dict["ln_final.weight"].shape[0] 490 | transformer_heads = transformer_width // 64 491 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 492 | 493 | model = CLIP( 494 | embed_dim, 495 | image_resolution, vision_layers, vision_width, vision_patch_size, 496 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 497 | ) 498 | 499 | for key in ["input_resolution", "context_length", "vocab_size"]: 500 | if key in state_dict: 501 | del state_dict[key] 502 | 503 | convert_weights(model) 504 | model.load_state_dict(state_dict) 505 | return model.eval() 506 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /configs/datasets/caltech101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Caltech101" 3 | -------------------------------------------------------------------------------- /configs/datasets/dtd.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "DescribableTextures" 3 | -------------------------------------------------------------------------------- /configs/datasets/eurosat.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "EuroSAT" 3 | -------------------------------------------------------------------------------- /configs/datasets/fgvc_aircraft.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "FGVCAircraft" 3 | -------------------------------------------------------------------------------- /configs/datasets/food101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Food101" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNet" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetA" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetR" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_sketch.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetSketch" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenetv2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetV2" 3 | -------------------------------------------------------------------------------- /configs/datasets/oxford_flowers.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordFlowers" -------------------------------------------------------------------------------- /configs/datasets/oxford_pets.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordPets" -------------------------------------------------------------------------------- /configs/datasets/stanford_cars.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "StanfordCars" 3 | -------------------------------------------------------------------------------- /configs/datasets/sun397.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SUN397" 3 | -------------------------------------------------------------------------------- /configs/datasets/ucf101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "UCF101" 3 | -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c16_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 16 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 1 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c8_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 8 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn101.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn101_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_val.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 200 4 | TEST: 5 | BATCH_SIZE: 200 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | MODEL: 16 | BACKBONE: 17 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep10.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b32.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b32_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" -------------------------------------------------------------------------------- /configs/trainers/CoOp_testtime/vit_b16_ep10.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoOp_testtime/vit_b16_ep200.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoOp_testtime/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/Unified/vit_b16_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" 36 | -------------------------------------------------------------------------------- /configs/trainers/Unified/vit_b16_ep200.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" 36 | -------------------------------------------------------------------------------- /configs/trainers/Unified/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" 36 | -------------------------------------------------------------------------------- /configs/trainers/VPT/vit_b16_ep10.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT/vit_b16_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT/vit_b16_ep200.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT_deep/vit_b16_ep10.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 1 # 100 6 | NUM_WORKERS: 0 # 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT_deep/vit_b16_ep200.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT_deep/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT_shallow/vit_b16_ep10.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 1 # 100 6 | NUM_WORKERS: 0 # 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT_shallow/vit_b16_ep200.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT_shallow/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT_testtime/vit_b16_ep10.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT_testtime/vit_b16_ep200.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT_testtime/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | coop_testtime: 包含了 coop 的结果: (82.69, 63.29, 71.70) 2 | 3 | VPT: 包含了 vpt-shallow 的结果 (prompt length = 4): (80.17, 70.06, 74.78) 4 | VPT_testtime: 包含了 vpt-mix 的结果 (prompt length = 4): (82.08, 69.10, 75.03) 5 | VPT_deep: 包含了 vpt-deep 的结果 (prompt length = 4): (83.64, 67.31, 74.59) 6 | 7 | Unified: 包含了 vpt-deep + coop, seperate 的结果 8 | Unified_v2: 包含了 vpt-mix + coop, seperate 的结果 9 | Unified_v3: 包含了 vpt-deep + coop, shared 的结果 (10e: 80.40, 74.24, 77.20; 200e: 84.35, 64.86, 73.33) 10 | Unified_v4: 包含了 vpt-deep + coop, DETR encoder 的结果 (200e: 74.56) 11 | Unified_v5: 包含了 vpt-deep + coop self-attn 的结果 (10e: 76.99; 200e: 74.20) -------------------------------------------------------------------------------- /scripts/cocoop/base2new_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=CoOp_testtime 8 | 9 | DATASET=stanford_cars 10 | SEED=1 11 | 12 | CFG=vit_b16_ep50 13 | SHOTS=16 14 | LOADEP=50 15 | SUB=new 16 | 17 | 18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 21 | if [ -d "$DIR" ]; then 22 | echo "Results are available in ${DIR}. Skip this job" 23 | else 24 | echo "Run this job and save the output to ${DIR}" 25 | 26 | python train.py \ 27 | --root ${DATA} \ 28 | --seed ${SEED} \ 29 | --trainer ${TRAINER} \ 30 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 31 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 32 | --output-dir ${DIR} \ 33 | --model-dir ${MODEL_DIR} \ 34 | --load-epoch ${LOADEP} \ 35 | --eval-only \ 36 | DATASET.NUM_SHOTS ${SHOTS} \ 37 | DATASET.SUBSAMPLE_CLASSES ${SUB} 38 | fi -------------------------------------------------------------------------------- /scripts/cocoop/xd_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=CoCoOp 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c4_ep10_batch1_ctxv1 13 | SHOTS=16 14 | 15 | 16 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}. Skip this job" 19 | else 20 | echo "Run this job and save the output to ${DIR}" 21 | 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \ 30 | --load-epoch 10 \ 31 | --eval-only 32 | fi -------------------------------------------------------------------------------- /scripts/cocoop/xd_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=CoCoOp 8 | 9 | DATASET=caltech101 10 | SEED=1 11 | 12 | CFG=vit_b16_c4_ep10_batch1_ctxv1 13 | SHOTS=16 14 | 15 | 16 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}. Skip this job" 19 | else 20 | echo "Run this job and save the output to ${DIR}" 21 | 22 | CUDA_VISIBLE_DEVICES=1 python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | DATASET.NUM_SHOTS ${SHOTS} 30 | fi -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=CoOp 8 | 9 | DATASET=$1 10 | CFG=$2 11 | CTP=$3 12 | NCTX=$4 13 | SHOTS=$5 14 | CSC=$6 15 | 16 | # for SEED in 1 2 3 17 | for SEED in 1 18 | do 19 | CUDA_VISIBLE_DEVICES=1; python train.py \ 20 | --root ${DATA} \ 21 | --seed ${SEED} \ 22 | --trainer ${TRAINER} \ 23 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 25 | --output-dir output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/${DATASET}/seed${SEED} \ 26 | --model-dir output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} \ 27 | --eval-only \ 28 | TRAINER.COOP.N_CTX ${NCTX} \ 29 | TRAINER.COOP.CSC ${CSC} \ 30 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} 31 | done 32 | -------------------------------------------------------------------------------- /scripts/imagenet_coop.sh: -------------------------------------------------------------------------------- 1 | bash ./main.sh imagenet vit_b16_ep50 end 16 16 False 2 | -------------------------------------------------------------------------------- /scripts/imagenet_zero.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=5 bash zeroshot.sh imagenet vit_b16 2 | -------------------------------------------------------------------------------- /scripts/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=CoOp 8 | 9 | DATASET=$1 10 | CFG=$2 # config file 11 | CTP=$3 # class token position (end or middle) 12 | NCTX=$4 # number of context tokens 13 | SHOTS=$5 # number of shots (1, 2, 4, 8, 16) 14 | CSC=$6 # class-specific context (False or True) 15 | 16 | for SEED in 1 2 3 17 | do 18 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} 19 | if [ -d "$DIR" ]; then 20 | echo "Results are available in ${DIR}. Skip this job" 21 | else 22 | echo "Run this job and save the output to ${DIR}" 23 | python train.py \ 24 | --root ${DATA} \ 25 | --seed ${SEED} \ 26 | --trainer ${TRAINER} \ 27 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 28 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 29 | --output-dir ${DIR} \ 30 | TRAINER.COOP.N_CTX ${NCTX} \ 31 | TRAINER.COOP.CSC ${CSC} \ 32 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 33 | DATASET.NUM_SHOTS ${SHOTS} 34 | fi 35 | done -------------------------------------------------------------------------------- /scripts/unified/base2new_test_coop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=CoOp 8 | 9 | DATASET=imagenet 10 | 11 | CFG=vit_b16_ep50 12 | SHOTS=16 13 | LOADEP=50 14 | SUB=base 15 | 16 | for SEED in 1 17 | do 18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 21 | CUDA_VISIBLE_DEVICES=0 python train.py \ 22 | --root ${DATA} \ 23 | --seed ${SEED} \ 24 | --trainer ${TRAINER} \ 25 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 26 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 27 | --output-dir ${DIR} \ 28 | --model-dir ${MODEL_DIR} \ 29 | --load-epoch ${LOADEP} \ 30 | --eval-only \ 31 | DATASET.NUM_SHOTS ${SHOTS} \ 32 | DATASET.SUBSAMPLE_CLASSES ${SUB} 33 | done 34 | -------------------------------------------------------------------------------- /scripts/unified/base2new_test_new.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=ucf101 10 | 11 | CFG=vit_b16_ep200 12 | SHOTS=16 13 | LOADEP=200 14 | SUB=new 15 | 16 | for SEED in 1 17 | do 18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 21 | CUDA_VISIBLE_DEVICES=5 python train.py \ 22 | --root ${DATA} \ 23 | --seed ${SEED} \ 24 | --trainer ${TRAINER} \ 25 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 26 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 27 | --output-dir ${DIR} \ 28 | --model-dir ${MODEL_DIR} \ 29 | --load-epoch ${LOADEP} \ 30 | --eval-only \ 31 | DATASET.NUM_SHOTS ${SHOTS} \ 32 | DATASET.SUBSAMPLE_CLASSES ${SUB} 33 | done 34 | -------------------------------------------------------------------------------- /scripts/unified/base2new_test_new2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=ucf101 10 | 11 | CFG=vit_b16_ep200 12 | SHOTS=16 13 | LOADEP=200 14 | SUB=new 15 | 16 | for SEED in 2 17 | do 18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 21 | CUDA_VISIBLE_DEVICES=6 python train.py \ 22 | --root ${DATA} \ 23 | --seed ${SEED} \ 24 | --trainer ${TRAINER} \ 25 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 26 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 27 | --output-dir ${DIR} \ 28 | --model-dir ${MODEL_DIR} \ 29 | --load-epoch ${LOADEP} \ 30 | --eval-only \ 31 | DATASET.NUM_SHOTS ${SHOTS} \ 32 | DATASET.SUBSAMPLE_CLASSES ${SUB} 33 | done 34 | -------------------------------------------------------------------------------- /scripts/unified/base2new_test_new3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=ucf101 10 | 11 | CFG=vit_b16_ep200 12 | SHOTS=16 13 | LOADEP=200 14 | SUB=new 15 | 16 | for SEED in 3 17 | do 18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 21 | CUDA_VISIBLE_DEVICES=7 python train.py \ 22 | --root ${DATA} \ 23 | --seed ${SEED} \ 24 | --trainer ${TRAINER} \ 25 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 26 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 27 | --output-dir ${DIR} \ 28 | --model-dir ${MODEL_DIR} \ 29 | --load-epoch ${LOADEP} \ 30 | --eval-only \ 31 | DATASET.NUM_SHOTS ${SHOTS} \ 32 | DATASET.SUBSAMPLE_CLASSES ${SUB} 33 | done 34 | -------------------------------------------------------------------------------- /scripts/unified/base2new_train_caltech-101.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=caltech101 10 | 11 | CFG=vit_b16_ep10 12 | SHOTS=16 13 | 14 | for SEED in 3 15 | do 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | CUDA_VISIBLE_DEVICES=7 python train.py \ 18 | --root ${DATA} \ 19 | --seed ${SEED} \ 20 | --trainer ${TRAINER} \ 21 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 23 | --output-dir ${DIR} \ 24 | DATASET.NUM_SHOTS ${SHOTS} \ 25 | DATASET.SUBSAMPLE_CLASSES base 26 | done 27 | -------------------------------------------------------------------------------- /scripts/unified/base2new_train_dtd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=dtd 10 | 11 | CFG=vit_b16_ep200 12 | SHOTS=16 13 | 14 | for SEED in 3 15 | do 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | CUDA_VISIBLE_DEVICES=7 python train.py \ 18 | --root ${DATA} \ 19 | --seed ${SEED} \ 20 | --trainer ${TRAINER} \ 21 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 23 | --output-dir ${DIR} \ 24 | DATASET.NUM_SHOTS ${SHOTS} \ 25 | DATASET.SUBSAMPLE_CLASSES base 26 | done -------------------------------------------------------------------------------- /scripts/unified/base2new_train_eurosat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=eurosat 10 | 11 | CFG=vit_b16_ep200 12 | SHOTS=16 13 | 14 | for SEED in 3 15 | do 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | CUDA_VISIBLE_DEVICES=7 python train.py \ 18 | --root ${DATA} \ 19 | --seed ${SEED} \ 20 | --trainer ${TRAINER} \ 21 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 23 | --output-dir ${DIR} \ 24 | DATASET.NUM_SHOTS ${SHOTS} \ 25 | DATASET.SUBSAMPLE_CLASSES base 26 | done 27 | -------------------------------------------------------------------------------- /scripts/unified/base2new_train_fgvc_aircraft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=fgvc_aircraft 10 | 11 | CFG=vit_b16_ep200 12 | SHOTS=16 13 | 14 | for SEED in 1 15 | do 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | CUDA_VISIBLE_DEVICES=5 python train.py \ 18 | --root ${DATA} \ 19 | --seed ${SEED} \ 20 | --trainer ${TRAINER} \ 21 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 23 | --output-dir ${DIR} \ 24 | DATASET.NUM_SHOTS ${SHOTS} \ 25 | DATASET.SUBSAMPLE_CLASSES base 26 | done -------------------------------------------------------------------------------- /scripts/unified/base2new_train_food101.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=food101 10 | 11 | CFG=vit_b16_ep200 12 | SHOTS=16 13 | 14 | for SEED in 3 15 | do 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | CUDA_VISIBLE_DEVICES=7 python train.py \ 18 | --root ${DATA} \ 19 | --seed ${SEED} \ 20 | --trainer ${TRAINER} \ 21 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 23 | --output-dir ${DIR} \ 24 | DATASET.NUM_SHOTS ${SHOTS} \ 25 | DATASET.SUBSAMPLE_CLASSES base 26 | done -------------------------------------------------------------------------------- /scripts/unified/base2new_train_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=imagenet 10 | 11 | CFG=vit_b16_ep50 12 | SHOTS=16 13 | 14 | for SEED in 3 15 | do 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | CUDA_VISIBLE_DEVICES=2 python train.py \ 18 | --root ${DATA} \ 19 | --seed ${SEED} \ 20 | --trainer ${TRAINER} \ 21 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 23 | --output-dir ${DIR} \ 24 | DATASET.NUM_SHOTS ${SHOTS} \ 25 | DATASET.SUBSAMPLE_CLASSES base 26 | done -------------------------------------------------------------------------------- /scripts/unified/base2new_train_oxford_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=oxford_flowers 10 | 11 | CFG=vit_b16_ep200 12 | SHOTS=16 13 | 14 | for SEED in 1 15 | do 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | CUDA_VISIBLE_DEVICES=5 python train.py \ 18 | --root ${DATA} \ 19 | --seed ${SEED} \ 20 | --trainer ${TRAINER} \ 21 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 23 | --output-dir ${DIR} \ 24 | DATASET.NUM_SHOTS ${SHOTS} \ 25 | DATASET.SUBSAMPLE_CLASSES base 26 | done -------------------------------------------------------------------------------- /scripts/unified/base2new_train_oxford_pets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=oxford_pets 10 | 11 | CFG=vit_b16_ep200 12 | SHOTS=16 13 | 14 | for SEED in 3 15 | do 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | CUDA_VISIBLE_DEVICES=7 python train.py \ 18 | --root ${DATA} \ 19 | --seed ${SEED} \ 20 | --trainer ${TRAINER} \ 21 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 23 | --output-dir ${DIR} \ 24 | DATASET.NUM_SHOTS ${SHOTS} \ 25 | DATASET.SUBSAMPLE_CLASSES base 26 | done -------------------------------------------------------------------------------- /scripts/unified/base2new_train_stanford_cars.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=stanford_cars 10 | 11 | CFG=vit_b16_ep200 12 | SHOTS=16 13 | 14 | for SEED in 1 15 | do 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | CUDA_VISIBLE_DEVICES=5 python train.py \ 18 | --root ${DATA} \ 19 | --seed ${SEED} \ 20 | --trainer ${TRAINER} \ 21 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 23 | --output-dir ${DIR} \ 24 | DATASET.NUM_SHOTS ${SHOTS} \ 25 | DATASET.SUBSAMPLE_CLASSES base 26 | done -------------------------------------------------------------------------------- /scripts/unified/base2new_train_sun397.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=sun397 10 | 11 | CFG=vit_b16_ep200 12 | SHOTS=16 13 | 14 | for SEED in 3 15 | do 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | CUDA_VISIBLE_DEVICES=7 python train.py \ 18 | --root ${DATA} \ 19 | --seed ${SEED} \ 20 | --trainer ${TRAINER} \ 21 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 23 | --output-dir ${DIR} \ 24 | DATASET.NUM_SHOTS ${SHOTS} \ 25 | DATASET.SUBSAMPLE_CLASSES base 26 | done -------------------------------------------------------------------------------- /scripts/unified/base2new_train_ucf101.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=Unified 8 | 9 | DATASET=ucf101 10 | 11 | CFG=vit_b16_ep200 12 | SHOTS=16 13 | 14 | for SEED in 3 15 | do 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | CUDA_VISIBLE_DEVICES=7 python train.py \ 18 | --root ${DATA} \ 19 | --seed ${SEED} \ 20 | --trainer ${TRAINER} \ 21 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 23 | --output-dir ${DIR} \ 24 | DATASET.NUM_SHOTS ${SHOTS} \ 25 | DATASET.SUBSAMPLE_CLASSES base 26 | done -------------------------------------------------------------------------------- /scripts/zeroshot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=ZeroshotCLIP 8 | DATASET=$1 9 | CFG=$2 # rn50, rn101, vit_b32 or vit_b16 10 | 11 | python train.py \ 12 | --root ${DATA} \ 13 | --trainer ${TRAINER} \ 14 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 15 | --config-file configs/trainers/CoOp/${CFG}.yaml \ 16 | --output-dir output/${TRAINER}/${CFG}/${DATASET} \ 17 | --eval-only 18 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from dassl.utils import setup_logger, set_random_seed, collect_env_info 5 | from dassl.config import get_cfg_default 6 | from dassl.engine import build_trainer 7 | 8 | # custom 9 | import datasets.oxford_pets 10 | import datasets.oxford_flowers 11 | import datasets.fgvc_aircraft 12 | import datasets.dtd 13 | import datasets.eurosat 14 | import datasets.stanford_cars 15 | import datasets.food101 16 | import datasets.sun397 17 | import datasets.caltech101 18 | import datasets.ucf101 19 | import datasets.imagenet 20 | 21 | import datasets.imagenet_sketch 22 | import datasets.imagenetv2 23 | import datasets.imagenet_a 24 | import datasets.imagenet_r 25 | 26 | import trainers.coop 27 | import trainers.cocoop 28 | import trainers.zsclip 29 | import trainers.coop_testtime 30 | import trainers.vpt 31 | import trainers.vpt_testtime 32 | import trainers.vpt_deep 33 | import trainers.vpt_shallow 34 | import trainers.unified 35 | import trainers.unified_v2 36 | import trainers.unified_v3 37 | import trainers.unified_v4 38 | import trainers.unified_v5 39 | import trainers.unified_v6 40 | 41 | 42 | def print_args(args, cfg): 43 | print("***************") 44 | print("** Arguments **") 45 | print("***************") 46 | optkeys = list(args.__dict__.keys()) 47 | optkeys.sort() 48 | for key in optkeys: 49 | print("{}: {}".format(key, args.__dict__[key])) 50 | print("************") 51 | print("** Config **") 52 | print("************") 53 | print(cfg) 54 | 55 | 56 | def reset_cfg(cfg, args): 57 | if args.root: 58 | cfg.DATASET.ROOT = args.root 59 | 60 | if args.output_dir: 61 | cfg.OUTPUT_DIR = args.output_dir 62 | 63 | if args.resume: 64 | cfg.RESUME = args.resume 65 | 66 | if args.seed: 67 | cfg.SEED = args.seed 68 | 69 | if args.source_domains: 70 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains 71 | 72 | if args.target_domains: 73 | cfg.DATASET.TARGET_DOMAINS = args.target_domains 74 | 75 | if args.transforms: 76 | cfg.INPUT.TRANSFORMS = args.transforms 77 | 78 | if args.trainer: 79 | cfg.TRAINER.NAME = args.trainer 80 | 81 | if args.backbone: 82 | cfg.MODEL.BACKBONE.NAME = args.backbone 83 | 84 | if args.head: 85 | cfg.MODEL.HEAD.NAME = args.head 86 | 87 | 88 | def extend_cfg(cfg): 89 | """ 90 | Add new config variables. 91 | 92 | E.g. 93 | from yacs.config import CfgNode as CN 94 | cfg.TRAINER.MY_MODEL = CN() 95 | cfg.TRAINER.MY_MODEL.PARAM_A = 1. 96 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 97 | cfg.TRAINER.MY_MODEL.PARAM_C = False 98 | """ 99 | from yacs.config import CfgNode as CN 100 | 101 | cfg.TRAINER.COOP = CN() 102 | cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors 103 | cfg.TRAINER.COOP.CSC = False # class-specific context 104 | cfg.TRAINER.COOP.CTX_INIT = "" # initialization words 105 | cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp 106 | cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front' 107 | 108 | cfg.TRAINER.COCOOP = CN() 109 | cfg.TRAINER.COCOOP.N_CTX = 16 # number of context vectors 110 | cfg.TRAINER.COCOOP.CTX_INIT = "" # initialization words 111 | cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp 112 | 113 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new 114 | 115 | 116 | def setup_cfg(args): 117 | cfg = get_cfg_default() 118 | extend_cfg(cfg) 119 | 120 | # 1. From the dataset config file 121 | if args.dataset_config_file: 122 | cfg.merge_from_file(args.dataset_config_file) 123 | 124 | # 2. From the method config file 125 | if args.config_file: 126 | cfg.merge_from_file(args.config_file) 127 | 128 | # 3. From input arguments 129 | reset_cfg(cfg, args) 130 | 131 | # 4. From optional input arguments 132 | cfg.merge_from_list(args.opts) 133 | 134 | cfg.freeze() 135 | 136 | return cfg 137 | 138 | 139 | def main(args): 140 | cfg = setup_cfg(args) 141 | if cfg.SEED >= 0: 142 | print("Setting fixed seed: {}".format(cfg.SEED)) 143 | set_random_seed(cfg.SEED) 144 | setup_logger(cfg.OUTPUT_DIR) 145 | 146 | if torch.cuda.is_available() and cfg.USE_CUDA: 147 | torch.backends.cudnn.benchmark = True 148 | 149 | print_args(args, cfg) 150 | print("Collecting env info ...") 151 | print("** System info **\n{}\n".format(collect_env_info())) 152 | 153 | trainer = build_trainer(cfg) 154 | 155 | if args.eval_only: 156 | trainer.load_model(args.model_dir, epoch=args.load_epoch) 157 | trainer.test() 158 | return 159 | 160 | if not args.no_train: 161 | trainer.train() 162 | 163 | 164 | if __name__ == "__main__": 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument("--root", type=str, default="", help="path to dataset") 167 | parser.add_argument("--output-dir", type=str, default="", help="output directory") 168 | parser.add_argument( 169 | "--resume", 170 | type=str, 171 | default="", 172 | help="checkpoint directory (from which the training resumes)", 173 | ) 174 | parser.add_argument( 175 | "--seed", type=int, default=-1, help="only positive value enables a fixed seed" 176 | ) 177 | parser.add_argument( 178 | "--source-domains", type=str, nargs="+", help="source domains for DA/DG" 179 | ) 180 | parser.add_argument( 181 | "--target-domains", type=str, nargs="+", help="target domains for DA/DG" 182 | ) 183 | parser.add_argument( 184 | "--transforms", type=str, nargs="+", help="data augmentation methods" 185 | ) 186 | parser.add_argument( 187 | "--config-file", type=str, default="", help="path to config file" 188 | ) 189 | parser.add_argument( 190 | "--dataset-config-file", 191 | type=str, 192 | default="", 193 | help="path to config file for dataset setup", 194 | ) 195 | parser.add_argument("--trainer", type=str, default="", help="name of trainer") 196 | parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone") 197 | parser.add_argument("--head", type=str, default="", help="name of head") 198 | parser.add_argument("--eval-only", action="store_true", help="evaluation only") 199 | parser.add_argument( 200 | "--model-dir", 201 | type=str, 202 | default="", 203 | help="load model from this directory for eval-only mode", 204 | ) 205 | parser.add_argument( 206 | "--load-epoch", type=int, help="load model weights at this epoch for evaluation" 207 | ) 208 | parser.add_argument( 209 | "--no-train", action="store_true", help="do not call trainer.train()" 210 | ) 211 | parser.add_argument( 212 | "opts", 213 | default=None, 214 | nargs=argparse.REMAINDER, 215 | help="modify config options using the command-line", 216 | ) 217 | args = parser.parse_args() 218 | main(args) 219 | -------------------------------------------------------------------------------- /trainers/.flake8: -------------------------------------------------------------------------------- 1 | # This is an example .flake8 config, used when developing *Black* itself. 2 | # Keep in sync with setup.cfg which is used for source packages. 3 | 4 | [flake8] 5 | ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811 6 | max-line-length = 100 7 | max-complexity = 18 8 | select = B,C,E,F,W,T4,B9 9 | exclude = build 10 | per-file-ignores = 11 | **/__init__.py:F401,F403,E402 12 | **/configs/**.py:F401,E402 13 | configs/**.py:F401,E402 14 | **/tests/config/**.py:F401,E402 15 | tests/config/**.py:F401,E402 16 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuhangzang/UPT/3d1640fcfd2532fd651041bc955fc5baff51c71f/trainers/__init__.py -------------------------------------------------------------------------------- /trainers/cocoop.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.cuda.amp import GradScaler, autocast 7 | from torch.nn import functional as F 8 | 9 | from clip import clip 10 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 11 | from dassl.engine import TRAINER_REGISTRY, TrainerX 12 | from dassl.optim import build_lr_scheduler, build_optimizer 13 | from dassl.utils import load_checkpoint, load_pretrained_weights 14 | 15 | _tokenizer = _Tokenizer() 16 | 17 | 18 | def load_clip_to_cpu(cfg): 19 | backbone_name = cfg.MODEL.BACKBONE.NAME 20 | url = clip._MODELS[backbone_name] 21 | model_path = clip._download(url) 22 | 23 | try: 24 | # loading JIT archive 25 | model = torch.jit.load(model_path, map_location="cpu").eval() 26 | state_dict = None 27 | 28 | except RuntimeError: 29 | state_dict = torch.load(model_path, map_location="cpu") 30 | 31 | model = clip.build_model(state_dict or model.state_dict()) 32 | 33 | return model 34 | 35 | 36 | class TextEncoder(nn.Module): 37 | def __init__(self, clip_model): 38 | super().__init__() 39 | self.transformer = clip_model.transformer 40 | self.positional_embedding = clip_model.positional_embedding 41 | self.ln_final = clip_model.ln_final 42 | self.text_projection = clip_model.text_projection 43 | self.dtype = clip_model.dtype 44 | 45 | def forward(self, prompts, tokenized_prompts): 46 | x = prompts + self.positional_embedding.type(self.dtype) 47 | x = x.permute(1, 0, 2) # NLD -> LND 48 | x = self.transformer(x) 49 | x = x.permute(1, 0, 2) # LND -> NLD 50 | x = self.ln_final(x).type(self.dtype) 51 | 52 | # x.shape = [batch_size, n_ctx, transformer.width] 53 | # take features from the eot embedding (eot_token is the highest number in each sequence) 54 | x = x[torch.arange(x.shape[0]), 55 | tokenized_prompts.argmax(dim=-1)] @ self.text_projection 56 | 57 | return x 58 | 59 | 60 | class PromptLearner(nn.Module): 61 | def __init__(self, cfg, classnames, clip_model): 62 | super().__init__() 63 | n_cls = len(classnames) 64 | n_ctx = cfg.TRAINER.COCOOP.N_CTX 65 | ctx_init = cfg.TRAINER.COCOOP.CTX_INIT 66 | dtype = clip_model.dtype 67 | ctx_dim = clip_model.ln_final.weight.shape[0] 68 | vis_dim = clip_model.visual.output_dim 69 | clip_imsize = clip_model.visual.input_resolution 70 | cfg_imsize = cfg.INPUT.SIZE[0] 71 | assert cfg_imsize == clip_imsize 72 | # f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 73 | 74 | if ctx_init: 75 | # use given words to initialize context vectors 76 | ctx_init = ctx_init.replace("_", " ") 77 | n_ctx = len(ctx_init.split(" ")) 78 | prompt = clip.tokenize(ctx_init) 79 | with torch.no_grad(): 80 | embedding = clip_model.token_embedding(prompt).type(dtype) 81 | ctx_vectors = embedding[0, 1:1 + n_ctx, :] 82 | prompt_prefix = ctx_init 83 | else: 84 | # random initialization 85 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 86 | nn.init.normal_(ctx_vectors, std=0.02) 87 | prompt_prefix = " ".join(["X"] * n_ctx) 88 | 89 | print(f'Initial context: "{prompt_prefix}"') 90 | print(f"Number of context words (tokens): {n_ctx}") 91 | 92 | self.ctx = nn.Parameter(ctx_vectors) 93 | 94 | self.meta_net = nn.Sequential( 95 | OrderedDict([("linear1", nn.Linear(vis_dim, vis_dim // 16)), 96 | ("relu", nn.ReLU(inplace=True)), 97 | ("linear2", nn.Linear(vis_dim // 16, ctx_dim))])) 98 | 99 | if cfg.TRAINER.COCOOP.PREC == "fp16": 100 | self.meta_net.half() 101 | 102 | classnames = [name.replace("_", " ") for name in classnames] 103 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 104 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 105 | 106 | tokenized_prompts = torch.cat([clip.tokenize(p) 107 | for p in prompts]) # (n_cls, n_tkn) 108 | with torch.no_grad(): 109 | embedding = clip_model.token_embedding(tokenized_prompts).type( 110 | dtype) 111 | 112 | # These token vectors will be saved when in save_model(), 113 | # but they should be ignored in load_model() as we want to use 114 | # those computed using the current class names 115 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 116 | self.register_buffer("token_suffix", 117 | embedding[:, 1 + n_ctx:, :]) # CLS, EOS 118 | 119 | self.n_cls = n_cls 120 | self.n_ctx = n_ctx 121 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 122 | self.name_lens = name_lens 123 | 124 | def construct_prompts(self, ctx, prefix, suffix, label=None): 125 | # dim0 is either batch_size (during training) or n_cls (during testing) 126 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim) 127 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim) 128 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) 129 | 130 | if label is not None: 131 | prefix = prefix[label] 132 | suffix = suffix[label] 133 | 134 | prompts = torch.cat( 135 | [ 136 | prefix, # (dim0, 1, dim) 137 | ctx, # (dim0, n_ctx, dim) 138 | suffix, # (dim0, *, dim) 139 | ], 140 | dim=1, 141 | ) 142 | 143 | return prompts 144 | 145 | def forward(self, im_features): 146 | prefix = self.token_prefix 147 | suffix = self.token_suffix 148 | ctx = self.ctx # (n_ctx, ctx_dim) 149 | bias = self.meta_net(im_features) # (batch, ctx_dim) 150 | bias = bias.unsqueeze(1) # (batch, 1, ctx_dim) 151 | ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim) 152 | ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim) 153 | 154 | # Use instance-conditioned context tokens for all classes 155 | prompts = [] 156 | for ctx_shifted_i in ctx_shifted: 157 | ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1) 158 | pts_i = self.construct_prompts(ctx_i, prefix, 159 | suffix) # (n_cls, n_tkn, ctx_dim) 160 | prompts.append(pts_i) 161 | prompts = torch.stack(prompts) 162 | 163 | return prompts 164 | 165 | 166 | class CustomCLIP(nn.Module): 167 | def __init__(self, cfg, classnames, clip_model): 168 | super().__init__() 169 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) 170 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 171 | self.image_encoder = clip_model.visual 172 | self.text_encoder = TextEncoder(clip_model) 173 | self.logit_scale = clip_model.logit_scale 174 | self.dtype = clip_model.dtype 175 | 176 | def forward(self, image, label=None): 177 | tokenized_prompts = self.tokenized_prompts 178 | logit_scale = self.logit_scale.exp() 179 | 180 | image_features = self.image_encoder(image.type(self.dtype)) 181 | image_features = image_features / image_features.norm(dim=-1, 182 | keepdim=True) 183 | 184 | prompts = self.prompt_learner(image_features) 185 | 186 | logits = [] 187 | for pts_i, imf_i in zip(prompts, image_features): 188 | text_features = self.text_encoder(pts_i, tokenized_prompts) 189 | text_features = text_features / text_features.norm(dim=-1, 190 | keepdim=True) 191 | l_i = logit_scale * imf_i @ text_features.t() 192 | logits.append(l_i) 193 | logits = torch.stack(logits) 194 | 195 | if self.prompt_learner.training: 196 | return F.cross_entropy(logits, label) 197 | 198 | return logits 199 | 200 | 201 | @TRAINER_REGISTRY.register() 202 | class CoCoOp(TrainerX): 203 | def check_cfg(self, cfg): 204 | assert cfg.TRAINER.COCOOP.PREC in ["fp16", "fp32", "amp"] 205 | 206 | def build_model(self): 207 | cfg = self.cfg 208 | classnames = self.dm.dataset.classnames 209 | 210 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 211 | clip_model = load_clip_to_cpu(cfg) 212 | 213 | if cfg.TRAINER.COCOOP.PREC == "fp32" or cfg.TRAINER.COCOOP.PREC == "amp": 214 | # CLIP's default precision is fp16 215 | clip_model.float() 216 | 217 | print("Building custom CLIP") 218 | self.model = CustomCLIP(cfg, classnames, clip_model) 219 | 220 | print("Turning off gradients in both the image and the text encoder") 221 | name_to_update = "prompt_learner" 222 | 223 | for name, param in self.model.named_parameters(): 224 | if name_to_update not in name: 225 | param.requires_grad_(False) 226 | 227 | # Double check 228 | enabled = set() 229 | for name, param in self.model.named_parameters(): 230 | if param.requires_grad: 231 | enabled.add(name) 232 | print(f"Parameters to be updated: {enabled}") 233 | 234 | if cfg.MODEL.INIT_WEIGHTS: 235 | load_pretrained_weights(self.model.prompt_learner, 236 | cfg.MODEL.INIT_WEIGHTS) 237 | 238 | self.model.to(self.device) 239 | # NOTE: only give prompt_learner to the optimizer 240 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 241 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 242 | self.register_model("prompt_learner", self.model.prompt_learner, 243 | self.optim, self.sched) 244 | 245 | self.scaler = GradScaler( 246 | ) if cfg.TRAINER.COCOOP.PREC == "amp" else None 247 | 248 | # Note that multi-gpu training could be slow because CLIP's size is 249 | # big, which slows down the copy operation in DataParallel 250 | device_count = torch.cuda.device_count() 251 | if device_count > 1: 252 | print( 253 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!" 254 | ) 255 | self.model = nn.DataParallel(self.model) 256 | 257 | def forward_backward(self, batch): 258 | image, label = self.parse_batch_train(batch) 259 | 260 | model = self.model 261 | optim = self.optim 262 | scaler = self.scaler 263 | 264 | prec = self.cfg.TRAINER.COCOOP.PREC 265 | if prec == "amp": 266 | with autocast(): 267 | loss = model(image, label) 268 | optim.zero_grad() 269 | scaler.scale(loss).backward() 270 | scaler.step(optim) 271 | scaler.update() 272 | else: 273 | loss = model(image, label) 274 | optim.zero_grad() 275 | loss.backward() 276 | optim.step() 277 | 278 | loss_summary = {"loss": loss.item()} 279 | 280 | if (self.batch_idx + 1) == self.num_batches: 281 | self.update_lr() 282 | 283 | return loss_summary 284 | 285 | def parse_batch_train(self, batch): 286 | input = batch["img"] 287 | label = batch["label"] 288 | input = input.to(self.device) 289 | label = label.to(self.device) 290 | return input, label 291 | 292 | def load_model(self, directory, epoch=None): 293 | if not directory: 294 | print( 295 | "Note that load_model() is skipped as no pretrained model is given" 296 | ) 297 | return 298 | 299 | names = self.get_model_names() 300 | 301 | # By default, the best model is loaded 302 | model_file = "model-best.pth.tar" 303 | 304 | if epoch is not None: 305 | model_file = "model.pth.tar-" + str(epoch) 306 | 307 | for name in names: 308 | model_path = osp.join(directory, name, model_file) 309 | 310 | if not osp.exists(model_path): 311 | raise FileNotFoundError( 312 | 'Model not found at "{}"'.format(model_path)) 313 | 314 | checkpoint = load_checkpoint(model_path) 315 | state_dict = checkpoint["state_dict"] 316 | epoch = checkpoint["epoch"] 317 | 318 | # Ignore fixed token vectors 319 | if "token_prefix" in state_dict: 320 | del state_dict["token_prefix"] 321 | 322 | if "token_suffix" in state_dict: 323 | del state_dict["token_suffix"] 324 | 325 | print("Loading weights to {} " 326 | 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 327 | # set strict=False 328 | self._models[name].load_state_dict(state_dict, strict=False) 329 | -------------------------------------------------------------------------------- /trainers/coop.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.cuda.amp import GradScaler, autocast 6 | from torch.nn import functional as F 7 | 8 | from clip import clip 9 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 10 | from dassl.engine import TRAINER_REGISTRY, TrainerX 11 | from dassl.metrics import compute_accuracy 12 | from dassl.optim import build_lr_scheduler, build_optimizer 13 | from dassl.utils import load_checkpoint, load_pretrained_weights 14 | 15 | _tokenizer = _Tokenizer() 16 | 17 | 18 | def load_clip_to_cpu(cfg): 19 | backbone_name = cfg.MODEL.BACKBONE.NAME 20 | url = clip._MODELS[backbone_name] 21 | model_path = clip._download(url) 22 | 23 | try: 24 | # loading JIT archive 25 | model = torch.jit.load(model_path, map_location="cpu").eval() 26 | state_dict = None 27 | 28 | except RuntimeError: 29 | state_dict = torch.load(model_path, map_location="cpu") 30 | 31 | model = clip.build_model(state_dict or model.state_dict()) 32 | 33 | return model 34 | 35 | 36 | class TextEncoder(nn.Module): 37 | def __init__(self, clip_model): 38 | super().__init__() 39 | self.transformer = clip_model.transformer 40 | self.positional_embedding = clip_model.positional_embedding 41 | self.ln_final = clip_model.ln_final 42 | self.text_projection = clip_model.text_projection 43 | self.dtype = clip_model.dtype 44 | 45 | def forward(self, prompts, tokenized_prompts): 46 | x = prompts + self.positional_embedding.type(self.dtype) 47 | x = x.permute(1, 0, 2) # NLD -> LND 48 | x = self.transformer(x) 49 | x = x.permute(1, 0, 2) # LND -> NLD 50 | x = self.ln_final(x).type(self.dtype) 51 | 52 | # x.shape = [batch_size, n_ctx, transformer.width] 53 | # take features from the eot embedding (eot_token is the highest number in each sequence) 54 | x = x[torch.arange(x.shape[0]), 55 | tokenized_prompts.argmax(dim=-1)] @ self.text_projection 56 | 57 | return x 58 | 59 | 60 | class PromptLearner(nn.Module): 61 | def __init__(self, cfg, classnames, clip_model): 62 | super().__init__() 63 | n_cls = len(classnames) 64 | n_ctx = cfg.TRAINER.COOP.N_CTX 65 | ctx_init = cfg.TRAINER.COOP.CTX_INIT 66 | dtype = clip_model.dtype 67 | ctx_dim = clip_model.ln_final.weight.shape[0] 68 | clip_imsize = clip_model.visual.input_resolution 69 | cfg_imsize = cfg.INPUT.SIZE[0] 70 | assert cfg_imsize == clip_imsize 71 | # f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 72 | 73 | if ctx_init: 74 | # use given words to initialize context vectors 75 | ctx_init = ctx_init.replace("_", " ") 76 | n_ctx = len(ctx_init.split(" ")) 77 | prompt = clip.tokenize(ctx_init) 78 | with torch.no_grad(): 79 | embedding = clip_model.token_embedding(prompt).type(dtype) 80 | ctx_vectors = embedding[0, 1:1 + n_ctx, :] 81 | prompt_prefix = ctx_init 82 | 83 | else: 84 | # random initialization 85 | if cfg.TRAINER.COOP.CSC: 86 | print("Initializing class-specific contexts") 87 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) 88 | else: 89 | print("Initializing a generic context") 90 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 91 | nn.init.normal_(ctx_vectors, std=0.02) 92 | prompt_prefix = " ".join(["X"] * n_ctx) 93 | 94 | print(f'Initial context: "{prompt_prefix}"') 95 | print(f"Number of context words (tokens): {n_ctx}") 96 | 97 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized 98 | 99 | classnames = [name.replace("_", " ") for name in classnames] 100 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 101 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 102 | 103 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) 104 | with torch.no_grad(): 105 | embedding = clip_model.token_embedding(tokenized_prompts).type( 106 | dtype) 107 | 108 | # These token vectors will be saved when in save_model(), 109 | # but they should be ignored in load_model() as we want to use 110 | # those computed using the current class names 111 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 112 | self.register_buffer("token_suffix", 113 | embedding[:, 1 + n_ctx:, :]) # CLS, EOS 114 | 115 | self.n_cls = n_cls 116 | self.n_ctx = n_ctx 117 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 118 | self.name_lens = name_lens 119 | self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION 120 | 121 | def forward(self): 122 | ctx = self.ctx 123 | if ctx.dim() == 2: 124 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 125 | 126 | prefix = self.token_prefix 127 | suffix = self.token_suffix 128 | 129 | if self.class_token_position == "end": 130 | prompts = torch.cat( 131 | [ 132 | prefix, # (n_cls, 1, dim) 133 | ctx, # (n_cls, n_ctx, dim) 134 | suffix, # (n_cls, *, dim) 135 | ], 136 | dim=1, 137 | ) 138 | 139 | elif self.class_token_position == "middle": 140 | half_n_ctx = self.n_ctx // 2 141 | prompts = [] 142 | for i in range(self.n_cls): 143 | name_len = self.name_lens[i] 144 | prefix_i = prefix[i:i + 1, :, :] 145 | class_i = suffix[i:i + 1, :name_len, :] 146 | suffix_i = suffix[i:i + 1, name_len:, :] 147 | ctx_i_half1 = ctx[i:i + 1, :half_n_ctx, :] 148 | ctx_i_half2 = ctx[i:i + 1, half_n_ctx:, :] 149 | prompt = torch.cat( 150 | [ 151 | prefix_i, # (1, 1, dim) 152 | ctx_i_half1, # (1, n_ctx//2, dim) 153 | class_i, # (1, name_len, dim) 154 | ctx_i_half2, # (1, n_ctx//2, dim) 155 | suffix_i, # (1, *, dim) 156 | ], 157 | dim=1, 158 | ) 159 | prompts.append(prompt) 160 | prompts = torch.cat(prompts, dim=0) 161 | 162 | elif self.class_token_position == "front": 163 | prompts = [] 164 | for i in range(self.n_cls): 165 | name_len = self.name_lens[i] 166 | prefix_i = prefix[i:i + 1, :, :] 167 | class_i = suffix[i:i + 1, :name_len, :] 168 | suffix_i = suffix[i:i + 1, name_len:, :] 169 | ctx_i = ctx[i:i + 1, :, :] 170 | prompt = torch.cat( 171 | [ 172 | prefix_i, # (1, 1, dim) 173 | class_i, # (1, name_len, dim) 174 | ctx_i, # (1, n_ctx, dim) 175 | suffix_i, # (1, *, dim) 176 | ], 177 | dim=1, 178 | ) 179 | prompts.append(prompt) 180 | prompts = torch.cat(prompts, dim=0) 181 | 182 | else: 183 | raise ValueError 184 | 185 | return prompts 186 | 187 | 188 | class CustomCLIP(nn.Module): 189 | def __init__(self, cfg, classnames, clip_model): 190 | super().__init__() 191 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) 192 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 193 | self.image_encoder = clip_model.visual 194 | self.text_encoder = TextEncoder(clip_model) 195 | self.logit_scale = clip_model.logit_scale 196 | self.dtype = clip_model.dtype 197 | 198 | def forward(self, image): 199 | image_features = self.image_encoder(image.type(self.dtype)) 200 | 201 | prompts = self.prompt_learner() 202 | tokenized_prompts = self.tokenized_prompts 203 | text_features = self.text_encoder(prompts, tokenized_prompts) 204 | 205 | image_features = image_features / image_features.norm(dim=-1, 206 | keepdim=True) 207 | text_features = text_features / text_features.norm(dim=-1, 208 | keepdim=True) 209 | 210 | logit_scale = self.logit_scale.exp() 211 | logits = logit_scale * image_features @ text_features.t() 212 | 213 | return logits 214 | 215 | 216 | @TRAINER_REGISTRY.register() 217 | class CoOp(TrainerX): 218 | """Context Optimization (CoOp). 219 | 220 | Learning to Prompt for Vision-Language Models 221 | https://arxiv.org/abs/2109.01134 222 | """ 223 | def check_cfg(self, cfg): 224 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"] 225 | 226 | def build_model(self): 227 | cfg = self.cfg 228 | classnames = self.dm.dataset.classnames 229 | 230 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 231 | clip_model = load_clip_to_cpu(cfg) 232 | 233 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": 234 | # CLIP's default precision is fp16 235 | clip_model.float() 236 | 237 | print("Building custom CLIP") 238 | self.model = CustomCLIP(cfg, classnames, clip_model) 239 | 240 | print("Turning off gradients in both the image and the text encoder") 241 | for name, param in self.model.named_parameters(): 242 | if "prompt_learner" not in name: 243 | param.requires_grad_(False) 244 | 245 | if cfg.MODEL.INIT_WEIGHTS: 246 | load_pretrained_weights(self.model.prompt_learner, 247 | cfg.MODEL.INIT_WEIGHTS) 248 | 249 | self.model.to(self.device) 250 | # NOTE: only give prompt_learner to the optimizer 251 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 252 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 253 | self.register_model("prompt_learner", self.model.prompt_learner, 254 | self.optim, self.sched) 255 | 256 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None 257 | 258 | # Note that multi-gpu training could be slow because CLIP's size is 259 | # big, which slows down the copy operation in DataParallel 260 | device_count = 1 # torch.cuda.device_count() 261 | if device_count > 1: 262 | print( 263 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!" 264 | ) 265 | self.model = nn.DataParallel(self.model) 266 | 267 | def forward_backward(self, batch): 268 | image, label = self.parse_batch_train(batch) 269 | 270 | prec = self.cfg.TRAINER.COOP.PREC 271 | if prec == "amp": 272 | with autocast(): 273 | output = self.model(image) 274 | loss = F.cross_entropy(output, label) 275 | self.optim.zero_grad() 276 | self.scaler.scale(loss).backward() 277 | self.scaler.step(self.optim) 278 | self.scaler.update() 279 | else: 280 | output = self.model(image) 281 | loss = F.cross_entropy(output, label) 282 | self.model_backward_and_update(loss) 283 | 284 | loss_summary = { 285 | "loss": loss.item(), 286 | "acc": compute_accuracy(output, label)[0].item(), 287 | } 288 | 289 | if (self.batch_idx + 1) == self.num_batches: 290 | self.update_lr() 291 | 292 | return loss_summary 293 | 294 | def parse_batch_train(self, batch): 295 | input = batch["img"] 296 | label = batch["label"] 297 | input = input.to(self.device) 298 | label = label.to(self.device) 299 | return input, label 300 | 301 | def load_model(self, directory, epoch=None): 302 | if not directory: 303 | print( 304 | "Note that load_model() is skipped as no pretrained model is given" 305 | ) 306 | return 307 | 308 | names = self.get_model_names() 309 | 310 | # By default, the best model is loaded 311 | model_file = "model-best.pth.tar" 312 | 313 | if epoch is not None: 314 | model_file = "model.pth.tar-" + str(epoch) 315 | 316 | for name in names: 317 | model_path = osp.join(directory, name, model_file) 318 | 319 | if not osp.exists(model_path): 320 | raise FileNotFoundError( 321 | 'Model not found at "{}"'.format(model_path)) 322 | 323 | checkpoint = load_checkpoint(model_path) 324 | state_dict = checkpoint["state_dict"] 325 | epoch = checkpoint["epoch"] 326 | 327 | # Ignore fixed token vectors 328 | if "token_prefix" in state_dict: 329 | del state_dict["token_prefix"] 330 | 331 | if "token_suffix" in state_dict: 332 | del state_dict["token_suffix"] 333 | 334 | print("Loading weights to {} " 335 | 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 336 | # set strict=False 337 | self._models[name].load_state_dict(state_dict, strict=False) 338 | -------------------------------------------------------------------------------- /trainers/coop_testtime.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from PIL import ImageFilter 5 | from torchvision import transforms 6 | from tqdm import tqdm 7 | 8 | from dassl.data.data_manager import (DataManager, DatasetWrapper, 9 | build_data_loader) 10 | from dassl.data.datasets import build_dataset 11 | from dassl.data.samplers import build_sampler 12 | from dassl.data.transforms import build_transform 13 | from dassl.engine import TRAINER_REGISTRY 14 | from dassl.utils import read_image 15 | 16 | from .coop import CoOp 17 | 18 | # from copy import deepcopy 19 | 20 | 21 | class GaussianBlur(object): 22 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 23 | def __init__(self, sigma=[.1, 2.]): 24 | self.sigma = sigma 25 | 26 | def __call__(self, x): 27 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 28 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 29 | return x 30 | 31 | 32 | class DatasetWrapper_aug(DatasetWrapper): 33 | def __init__(self, cfg, data_source, transform=None, is_train=False): 34 | super().__init__(cfg, data_source, transform, is_train) 35 | 36 | normalize = transforms.Normalize( 37 | mean=[0.48145466, 0.4578275, 0.40821073], 38 | std=[0.26862954, 0.26130258, 0.27577711]) 39 | augment = transforms.Compose([ 40 | transforms.RandomResizedCrop(224, scale=(0.08, 1.)), 41 | transforms.RandomApply( 42 | [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 43 | transforms.RandomGrayscale(p=0.2), 44 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | normalize, 48 | ]) 49 | self.augment = augment 50 | 51 | def __getitem__(self, idx): 52 | item = self.data_source[idx] 53 | 54 | output = { 55 | "label": item.label, 56 | "domain": item.domain, 57 | "impath": item.impath 58 | } 59 | 60 | img0 = read_image(item.impath) 61 | 62 | if self.transform is not None: 63 | if isinstance(self.transform, (list, tuple)): 64 | raise NotImplementedError 65 | for i, tfm in enumerate(self.transform): 66 | img = self._transform_image(tfm, img0) 67 | output["img"] = img 68 | else: 69 | img = self._transform_image(self.transform, img0) 70 | output["img"] = img 71 | for name in ["img_aug", "img_aug2", "img_aug3"]: 72 | output[name] = self._transform_image(self.augment, img0) 73 | 74 | if self.return_img0: 75 | output["img0"] = self.to_tensor(img0) 76 | 77 | return output 78 | 79 | 80 | def build_data_loader(cfg, 81 | sampler_type='SequentialSampler', 82 | data_source=None, 83 | batch_size=64, 84 | n_domain=0, 85 | n_ins=2, 86 | tfm=None, 87 | is_train=True, 88 | dataset_wrapper=None): 89 | # Build sampler 90 | if not is_train: 91 | random.shuffle(data_source) 92 | sampler = build_sampler(sampler_type, 93 | cfg=cfg, 94 | data_source=data_source, 95 | batch_size=batch_size, 96 | n_domain=n_domain, 97 | n_ins=n_ins) 98 | 99 | if dataset_wrapper is None: 100 | dataset_wrapper = DatasetWrapper 101 | 102 | # Build data loader 103 | data_loader = torch.utils.data.DataLoader( 104 | dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train), 105 | batch_size=batch_size, 106 | sampler=sampler, 107 | num_workers=cfg.DATALOADER.NUM_WORKERS, 108 | drop_last=is_train and len(data_source) >= batch_size, 109 | pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA), 110 | ) 111 | assert len(data_loader) > 0 112 | 113 | return data_loader 114 | 115 | 116 | class DataManager_aug(DataManager): 117 | def __init__(self, 118 | cfg, 119 | custom_tfm_train=None, 120 | custom_tfm_test=None, 121 | dataset_wrapper=None): 122 | super().__init__(cfg, custom_tfm_train, custom_tfm_test, 123 | dataset_wrapper) 124 | # Build test_loader 125 | dataset = build_dataset(cfg) 126 | 127 | if custom_tfm_test is None: 128 | tfm_test = build_transform(cfg, is_train=False) 129 | else: 130 | print("* Using custom transform for testing") 131 | tfm_test = custom_tfm_test 132 | 133 | test_loader = build_data_loader( 134 | cfg, 135 | sampler_type=cfg.DATALOADER.TEST.SAMPLER, 136 | data_source=dataset.test, 137 | batch_size=cfg.DATALOADER.TEST.BATCH_SIZE, 138 | tfm=tfm_test, 139 | is_train=False, 140 | dataset_wrapper=DatasetWrapper_aug, 141 | ) 142 | self.test_loader = test_loader 143 | 144 | 145 | def softmax_entropy(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 146 | return -(y.softmax(1) * x.log_softmax(1)).sum(1).mean(0) 147 | 148 | 149 | @TRAINER_REGISTRY.register() 150 | class CoOp_testtime(CoOp): 151 | def build_data_loader(self): 152 | dm = DataManager_aug(self.cfg) 153 | 154 | self.train_loader_x = dm.train_loader_x 155 | self.train_loader_u = dm.train_loader_u # optional, can be None 156 | self.val_loader = dm.val_loader # optional, can be None 157 | self.test_loader = dm.test_loader 158 | self.num_classes = dm.num_classes 159 | self.num_source_domains = dm.num_source_domains 160 | self.lab2cname = dm.lab2cname # dict {label: classname} 161 | 162 | self.dm = dm 163 | 164 | def test(self, split=None): 165 | """A generic testing pipeline.""" 166 | # self.set_model_mode("eval") 167 | self.model.to(self.device) 168 | 169 | self._optims['prompt_learner'].param_groups[0]['lr'] = 2e-6 170 | self.evaluator.reset() 171 | 172 | if split is None: 173 | split = self.cfg.TEST.SPLIT 174 | 175 | if split == "val" and self.val_loader is not None: 176 | data_loader = self.val_loader 177 | print("Do evaluation on {} set".format(split)) 178 | else: 179 | data_loader = self.test_loader 180 | print("Do evaluation on test set") 181 | 182 | # model_state = deepcopy(self.model.state_dict()) 183 | for batch_idx, batch in enumerate(tqdm(data_loader)): 184 | # self.model.load_state_dict(model_state, strict=True) 185 | input, input_aug, input_aug2, label = self.parse_batch_test(batch) 186 | 187 | for _ in range(0): 188 | output = self.model_inference(input) 189 | output_aug = self.model_inference(input_aug) 190 | loss = softmax_entropy(output_aug, output) 191 | self.model_backward_and_update(loss) 192 | 193 | for _ in range(0): 194 | output = self.model_inference(input) 195 | output_aug = self.model_inference(input_aug2) 196 | loss = softmax_entropy(output_aug, output) 197 | self.model_backward_and_update(loss) 198 | 199 | with torch.no_grad(): 200 | output = self.model_inference(input) 201 | self.evaluator.process(output, label) 202 | 203 | results = self.evaluator.evaluate() 204 | 205 | for k, v in results.items(): 206 | tag = "{}/{}".format(split, k) 207 | self.write_scalar(tag, v, self.epoch) 208 | 209 | return list(results.values())[0] 210 | 211 | def parse_batch_test(self, batch): 212 | input = batch["img"] 213 | input_aug = batch["img_aug"] 214 | input_aug2 = batch["img_aug2"] 215 | label = batch["label"] 216 | 217 | input = input.to(self.device) 218 | input_aug = input_aug.to(self.device) 219 | input_aug2 = input_aug2.to(self.device) 220 | label = label.to(self.device) 221 | 222 | return input, input_aug, input_aug2, label 223 | -------------------------------------------------------------------------------- /trainers/imagenet_templates.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb 2 | 3 | IMAGENET_TEMPLATES = [ 4 | "a bad photo of a {}.", 5 | "a photo of many {}.", 6 | "a sculpture of a {}.", 7 | "a photo of the hard to see {}.", 8 | "a low resolution photo of the {}.", 9 | "a rendering of a {}.", 10 | "graffiti of a {}.", 11 | "a bad photo of the {}.", 12 | "a cropped photo of the {}.", 13 | "a tattoo of a {}.", 14 | "the embroidered {}.", 15 | "a photo of a hard to see {}.", 16 | "a bright photo of a {}.", 17 | "a photo of a clean {}.", 18 | "a photo of a dirty {}.", 19 | "a dark photo of the {}.", 20 | "a drawing of a {}.", 21 | "a photo of my {}.", 22 | "the plastic {}.", 23 | "a photo of the cool {}.", 24 | "a close-up photo of a {}.", 25 | "a black and white photo of the {}.", 26 | "a painting of the {}.", 27 | "a painting of a {}.", 28 | "a pixelated photo of the {}.", 29 | "a sculpture of the {}.", 30 | "a bright photo of the {}.", 31 | "a cropped photo of a {}.", 32 | "a plastic {}.", 33 | "a photo of the dirty {}.", 34 | "a jpeg corrupted photo of a {}.", 35 | "a blurry photo of the {}.", 36 | "a photo of the {}.", 37 | "a good photo of the {}.", 38 | "a rendering of the {}.", 39 | "a {} in a video game.", 40 | "a photo of one {}.", 41 | "a doodle of a {}.", 42 | "a close-up photo of the {}.", 43 | "a photo of a {}.", 44 | "the origami {}.", 45 | "the {} in a video game.", 46 | "a sketch of a {}.", 47 | "a doodle of the {}.", 48 | "a origami {}.", 49 | "a low resolution photo of a {}.", 50 | "the toy {}.", 51 | "a rendition of the {}.", 52 | "a photo of the clean {}.", 53 | "a photo of a large {}.", 54 | "a rendition of a {}.", 55 | "a photo of a nice {}.", 56 | "a photo of a weird {}.", 57 | "a blurry photo of a {}.", 58 | "a cartoon {}.", 59 | "art of a {}.", 60 | "a sketch of the {}.", 61 | "a embroidered {}.", 62 | "a pixelated photo of a {}.", 63 | "itap of the {}.", 64 | "a jpeg corrupted photo of the {}.", 65 | "a good photo of a {}.", 66 | "a plushie {}.", 67 | "a photo of the nice {}.", 68 | "a photo of the small {}.", 69 | "a photo of the weird {}.", 70 | "the cartoon {}.", 71 | "art of the {}.", 72 | "a drawing of the {}.", 73 | "a photo of the large {}.", 74 | "a black and white photo of a {}.", 75 | "the plushie {}.", 76 | "a dark photo of a {}.", 77 | "itap of a {}.", 78 | "graffiti of the {}.", 79 | "a toy {}.", 80 | "itap of my {}.", 81 | "a photo of a cool {}.", 82 | "a photo of a small {}.", 83 | "a tattoo of the {}.", 84 | ] 85 | 86 | IMAGENET_TEMPLATES_SELECT = [ 87 | "itap of a {}.", 88 | "a bad photo of the {}.", 89 | "a origami {}.", 90 | "a photo of the large {}.", 91 | "a {} in a video game.", 92 | "art of the {}.", 93 | "a photo of the small {}.", 94 | ] 95 | -------------------------------------------------------------------------------- /trainers/linter.sh: -------------------------------------------------------------------------------- 1 | rm -rf `find -type d -name .ipynb_checkpoints` 2 | yapf -r -i ./*.py 3 | isort -rc ./*.py 4 | flake8 ./*.py 5 | -------------------------------------------------------------------------------- /trainers/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .utils import get_rank, get_world_size 6 | 7 | # from .utils import all_gather_batch_with_grad 8 | 9 | 10 | class SIMCLRLoss(nn.Module): 11 | """ 12 | This is the SimCLR loss in https://arxiv.org/abs/2002.05709 13 | The embedding vectors are assumed to have size (2 x batch_size, embedding_dim) and 14 | the memory layout that can be reshaped into shape (2, batch_size, embedding_dim). 15 | This memory layout is consistent with the SimCLR collator in 16 | https://github.com/facebookresearch/vissl/blob/master/vissl/data/collators/simclr_collator.py 17 | Config params: 18 | temperature (float): the temperature to be applied on the logits 19 | """ 20 | def __init__(self, temperature=0.1): 21 | super().__init__() 22 | self.tau = temperature 23 | self.labels = None 24 | self.masks = None 25 | self.last_local_batch_size = None 26 | 27 | def forward(self, q_a, q_b): 28 | q_a = F.normalize(q_a, dim=-1, p=2) 29 | q_b = F.normalize(q_b, dim=-1, p=2) 30 | 31 | local_batch_size = q_a.size(0) 32 | 33 | # k_a, k_b = all_gather_batch_with_grad([q_a, q_b]) 34 | k_a, k_b = q_a, q_b 35 | 36 | if local_batch_size != self.last_local_batch_size: 37 | self.labels = local_batch_size * get_rank() + torch.arange( 38 | local_batch_size, device=q_a.device) 39 | total_batch_size = local_batch_size * get_world_size() 40 | self.masks = F.one_hot(self.labels, total_batch_size) * 1e9 41 | self.last_local_batch_size = local_batch_size 42 | 43 | logits_aa = torch.matmul(q_a, k_a.transpose(0, 1)) / self.tau 44 | logits_aa = logits_aa - self.masks 45 | logits_bb = torch.matmul(q_b, k_b.transpose(0, 1)) / self.tau 46 | logits_bb = logits_bb - self.masks 47 | logits_ab = torch.matmul(q_a, k_b.transpose(0, 1)) / self.tau 48 | logits_ba = torch.matmul(q_b, k_a.transpose(0, 1)) / self.tau 49 | 50 | loss_a = F.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), 51 | self.labels) 52 | loss_b = F.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), 53 | self.labels) 54 | loss = (loss_a + loss_b) / 2 # divide by 2 to average over all samples 55 | 56 | # compute accuracy 57 | with torch.no_grad(): 58 | pred = torch.argmax(torch.cat([logits_ab, logits_aa], dim=1), 59 | dim=-1) 60 | correct = pred.eq(self.labels).sum() 61 | acc = 100 * correct / local_batch_size 62 | 63 | return loss, acc 64 | -------------------------------------------------------------------------------- /trainers/unified.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.cuda.amp import GradScaler, autocast 6 | from torch.nn import functional as F 7 | 8 | from clip import clip 9 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 10 | from dassl.engine import TRAINER_REGISTRY, TrainerX 11 | from dassl.metrics import compute_accuracy 12 | from dassl.optim import build_lr_scheduler, build_optimizer 13 | from dassl.utils import load_checkpoint, load_pretrained_weights 14 | 15 | from .coop import TextEncoder, load_clip_to_cpu 16 | 17 | _tokenizer = _Tokenizer() 18 | 19 | 20 | class PromptLearner(nn.Module): 21 | def __init__(self, cfg, classnames, clip_model): 22 | super().__init__() 23 | n_cls = len(classnames) 24 | n_ctx = cfg.TRAINER.COOP.N_CTX 25 | ctx_init = cfg.TRAINER.COOP.CTX_INIT 26 | dtype = clip_model.dtype 27 | ctx_dim = clip_model.ln_final.weight.shape[0] 28 | clip_imsize = clip_model.visual.input_resolution 29 | cfg_imsize = cfg.INPUT.SIZE[0] 30 | assert cfg_imsize == clip_imsize 31 | # f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 32 | 33 | if ctx_init: 34 | # use given words to initialize context vectors 35 | ctx_init = ctx_init.replace("_", " ") 36 | n_ctx = len(ctx_init.split(" ")) 37 | prompt = clip.tokenize(ctx_init) 38 | with torch.no_grad(): 39 | embedding = clip_model.token_embedding(prompt).type(dtype) 40 | ctx_vectors = embedding[0, 1:1 + n_ctx, :] 41 | prompt_prefix = ctx_init 42 | 43 | else: 44 | # random initialization 45 | if cfg.TRAINER.COOP.CSC: 46 | print("Initializing class-specific contexts") 47 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) 48 | else: 49 | print("Initializing a generic context") 50 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 51 | nn.init.normal_(ctx_vectors, std=0.02) 52 | prompt_prefix = " ".join(["X"] * n_ctx) 53 | 54 | print(f'Initial context: "{prompt_prefix}"') 55 | print(f"Number of context words (tokens): {n_ctx}") 56 | 57 | # self.ctx = nn.Parameter(ctx_vectors) # to be optimized 58 | ctx = ctx_vectors.unsqueeze(0) 59 | 60 | classnames = [name.replace("_", " ") for name in classnames] 61 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 62 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 63 | 64 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) 65 | with torch.no_grad(): 66 | embedding = clip_model.token_embedding(tokenized_prompts).type( 67 | dtype) 68 | 69 | # These token vectors will be saved when in save_model(), 70 | # but they should be ignored in load_model() as we want to use 71 | # those computed using the current class names 72 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 73 | self.register_buffer("token_suffix", 74 | embedding[:, 1 + n_ctx:, :]) # CLS, EOS 75 | 76 | self.n_cls = n_cls 77 | self.n_ctx = n_ctx 78 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 79 | self.name_lens = name_lens 80 | self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION 81 | 82 | vis_dim = clip_model.visual.positional_embedding.shape[-1] 83 | visual_vectors = torch.empty(1, n_ctx * 2, ctx_dim, dtype=dtype) 84 | nn.init.normal_(visual_vectors, std=0.02) 85 | # self.visual_ctx = nn.Parameter(visual_vectors) 86 | self.uni_ctx = nn.Parameter(visual_vectors) 87 | # uni_ctx = torch.cat([ctx, visual_vectors], 1) 88 | # self.uni_ctx = nn.Parameter(uni_ctx) 89 | 90 | num_heads = 1 91 | dropout = 0.1 92 | self.self_attn = nn.MultiheadAttention(ctx_dim, 93 | num_heads, 94 | dropout=dropout) 95 | self.self_attn.type(dtype) 96 | self.dropout = nn.Dropout(dropout) 97 | self.dropout.type(dtype) 98 | self.dropout1 = nn.Dropout(dropout) 99 | self.dropout1.type(dtype) 100 | self.dropout2 = nn.Dropout(dropout) 101 | self.dropout2.type(dtype) 102 | 103 | self.norm1 = nn.LayerNorm(ctx_dim) 104 | self.norm1.type(dtype) 105 | self.norm2 = nn.LayerNorm(ctx_dim) 106 | self.norm2.type(dtype) 107 | 108 | self.linear1 = nn.Linear(ctx_dim, ctx_dim) 109 | self.linear1.type(dtype) 110 | self.linear2 = nn.Linear(ctx_dim, ctx_dim) 111 | self.linear2.type(dtype) 112 | self.activation = F.relu 113 | 114 | # self.mlp_text = nn.Linear(ctx_dim, ctx_dim) 115 | # self.mlp_text.type(dtype) 116 | self.mlp = nn.Linear(ctx_dim, vis_dim) 117 | self.mlp.type(dtype) 118 | 119 | def get_text_prompt(self): 120 | src = self.uni_ctx 121 | src2 = self.self_attn(src, src, src)[0] 122 | src = src + self.dropout1(src2) 123 | src = self.norm1(src) 124 | src2 = self.norm2(src) 125 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 126 | src = src + self.dropout2(src2) 127 | return src[0, :4, :] 128 | 129 | def get_visual_prompt(self): 130 | src = self.uni_ctx 131 | src2 = self.self_attn(src, src, src)[0] 132 | src = src + self.dropout1(src2) 133 | src = self.norm1(src) 134 | src2 = self.norm2(src) 135 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 136 | src = src + self.dropout2(src2) 137 | return self.mlp(src[0, 4:, :]) 138 | 139 | def forward(self): 140 | ctx = self.get_text_prompt() 141 | if ctx.dim() == 2: 142 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 143 | 144 | prefix = self.token_prefix 145 | suffix = self.token_suffix 146 | 147 | if self.class_token_position == "end": 148 | prompts = torch.cat( 149 | [ 150 | prefix, # (n_cls, 1, dim) 151 | ctx, # (n_cls, n_ctx, dim) 152 | suffix, # (n_cls, *, dim) 153 | ], 154 | dim=1, 155 | ) 156 | 157 | elif self.class_token_position == "middle": 158 | half_n_ctx = self.n_ctx // 2 159 | prompts = [] 160 | for i in range(self.n_cls): 161 | name_len = self.name_lens[i] 162 | prefix_i = prefix[i:i + 1, :, :] 163 | class_i = suffix[i:i + 1, :name_len, :] 164 | suffix_i = suffix[i:i + 1, name_len:, :] 165 | ctx_i_half1 = ctx[i:i + 1, :half_n_ctx, :] 166 | ctx_i_half2 = ctx[i:i + 1, half_n_ctx:, :] 167 | prompt = torch.cat( 168 | [ 169 | prefix_i, # (1, 1, dim) 170 | ctx_i_half1, # (1, n_ctx//2, dim) 171 | class_i, # (1, name_len, dim) 172 | ctx_i_half2, # (1, n_ctx//2, dim) 173 | suffix_i, # (1, *, dim) 174 | ], 175 | dim=1, 176 | ) 177 | prompts.append(prompt) 178 | prompts = torch.cat(prompts, dim=0) 179 | 180 | elif self.class_token_position == "front": 181 | prompts = [] 182 | for i in range(self.n_cls): 183 | name_len = self.name_lens[i] 184 | prefix_i = prefix[i:i + 1, :, :] 185 | class_i = suffix[i:i + 1, :name_len, :] 186 | suffix_i = suffix[i:i + 1, name_len:, :] 187 | ctx_i = ctx[i:i + 1, :, :] 188 | prompt = torch.cat( 189 | [ 190 | prefix_i, # (1, 1, dim) 191 | class_i, # (1, name_len, dim) 192 | ctx_i, # (1, n_ctx, dim) 193 | suffix_i, # (1, *, dim) 194 | ], 195 | dim=1, 196 | ) 197 | prompts.append(prompt) 198 | prompts = torch.cat(prompts, dim=0) 199 | 200 | else: 201 | raise ValueError 202 | 203 | return prompts 204 | 205 | 206 | class CustomCLIP(nn.Module): 207 | def __init__(self, cfg, classnames, clip_model): 208 | super().__init__() 209 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) 210 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 211 | self.image_encoder = clip_model.visual 212 | self.text_encoder = TextEncoder(clip_model) 213 | self.logit_scale = clip_model.logit_scale 214 | self.dtype = clip_model.dtype 215 | 216 | def forward(self, image): 217 | # import time 218 | # torch.save(image.data.cpu(), f'img_{int(time.time())}.pt') 219 | 220 | # image_features = self.image_encoder(image.type(self.dtype)) 221 | visual_ctx = self.prompt_learner.get_visual_prompt() 222 | image_features, _ = self.image_encoder.forward_prompt_shallow( 223 | image.type(self.dtype), visual_ctx) 224 | 225 | prompts = self.prompt_learner() 226 | tokenized_prompts = self.tokenized_prompts 227 | text_features = self.text_encoder(prompts, tokenized_prompts) 228 | 229 | image_features = image_features / image_features.norm(dim=-1, 230 | keepdim=True) 231 | text_features = text_features / text_features.norm(dim=-1, 232 | keepdim=True) 233 | 234 | logit_scale = self.logit_scale.exp() 235 | logits = logit_scale * image_features @ text_features.t() 236 | 237 | return logits 238 | 239 | 240 | @TRAINER_REGISTRY.register() 241 | class Unified_v6(TrainerX): 242 | def check_cfg(self, cfg): 243 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"] 244 | 245 | def build_model(self): 246 | cfg = self.cfg 247 | classnames = self.dm.dataset.classnames 248 | 249 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 250 | clip_model = load_clip_to_cpu(cfg) 251 | 252 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": 253 | # CLIP's default precision is fp16 254 | clip_model.float() 255 | 256 | print("Building custom CLIP") 257 | self.model = CustomCLIP(cfg, classnames, clip_model) 258 | 259 | print("Turning off gradients in both the image and the text encoder") 260 | for name, param in self.model.named_parameters(): 261 | if "prompt_learner" not in name: 262 | param.requires_grad_(False) 263 | 264 | if cfg.MODEL.INIT_WEIGHTS: 265 | load_pretrained_weights(self.model.prompt_learner, 266 | cfg.MODEL.INIT_WEIGHTS) 267 | 268 | self.model.to(self.device) 269 | # NOTE: only give prompt_learner to the optimizer 270 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 271 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 272 | self.register_model("prompt_learner", self.model.prompt_learner, 273 | self.optim, self.sched) 274 | 275 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None 276 | 277 | # Note that multi-gpu training could be slow because CLIP's size is 278 | # big, which slows down the copy operation in DataParallel 279 | device_count = 1 # torch.cuda.device_count() 280 | if device_count > 1: 281 | print( 282 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!" 283 | ) 284 | self.model = nn.DataParallel(self.model) 285 | 286 | def forward_backward(self, batch): 287 | image, label = self.parse_batch_train(batch) 288 | 289 | prec = self.cfg.TRAINER.COOP.PREC 290 | if prec == "amp": 291 | with autocast(): 292 | output = self.model(image) 293 | loss = F.cross_entropy(output, label) 294 | self.optim.zero_grad() 295 | self.scaler.scale(loss).backward() 296 | self.scaler.step(self.optim) 297 | self.scaler.update() 298 | else: 299 | output = self.model(image) 300 | loss = F.cross_entropy(output, label) 301 | self.model_backward_and_update(loss) 302 | 303 | loss_summary = { 304 | "loss": loss.item(), 305 | "acc": compute_accuracy(output, label)[0].item(), 306 | } 307 | 308 | if (self.batch_idx + 1) == self.num_batches: 309 | self.update_lr() 310 | 311 | return loss_summary 312 | 313 | def parse_batch_train(self, batch): 314 | input = batch["img"] 315 | label = batch["label"] 316 | input = input.to(self.device) 317 | label = label.to(self.device) 318 | return input, label 319 | 320 | def load_model(self, directory, epoch=None): 321 | if not directory: 322 | print( 323 | "Note that load_model() is skipped as no pretrained model is given" 324 | ) 325 | return 326 | 327 | names = self.get_model_names() 328 | 329 | # By default, the best model is loaded 330 | model_file = "model-best.pth.tar" 331 | 332 | if epoch is not None: 333 | model_file = "model.pth.tar-" + str(epoch) 334 | 335 | for name in names: 336 | model_path = osp.join(directory, name, model_file) 337 | 338 | if not osp.exists(model_path): 339 | raise FileNotFoundError( 340 | 'Model not found at "{}"'.format(model_path)) 341 | 342 | checkpoint = load_checkpoint(model_path) 343 | state_dict = checkpoint["state_dict"] 344 | epoch = checkpoint["epoch"] 345 | 346 | # Ignore fixed token vectors 347 | if "token_prefix" in state_dict: 348 | del state_dict["token_prefix"] 349 | 350 | if "token_suffix" in state_dict: 351 | del state_dict["token_suffix"] 352 | 353 | print("Loading weights to {} " 354 | 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 355 | # set strict=False 356 | self._models[name].load_state_dict(state_dict, strict=False) 357 | -------------------------------------------------------------------------------- /trainers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | import os 5 | import random 6 | import shutil 7 | 8 | # This source code is licensed under the license found in the 9 | # LICENSE file in the root directory of this source tree. 10 | import numpy as np 11 | import torch 12 | import torch.autograd as autograd 13 | import torch.distributed as dist 14 | from PIL import ImageFilter 15 | from torch.nn import DataParallel 16 | from torch.nn.parallel import DistributedDataParallel 17 | 18 | 19 | def get_model(model): 20 | if isinstance(model, DataParallel): 21 | return model.module 22 | elif isinstance(model, DistributedDataParallel): 23 | return model.module 24 | else: 25 | return model 26 | 27 | 28 | def setup_for_distributed(is_master): 29 | """ 30 | This function disables printing when not in master process 31 | """ 32 | import builtins as __builtin__ 33 | builtin_print = __builtin__.print 34 | 35 | def print(*args, **kwargs): 36 | force = kwargs.pop('force', False) 37 | if is_master or force: 38 | builtin_print(*args, **kwargs) 39 | 40 | __builtin__.print = print 41 | 42 | 43 | def is_dist_avail_and_initialized(): 44 | if not dist.is_available(): 45 | return False 46 | if not dist.is_initialized(): 47 | return False 48 | return True 49 | 50 | 51 | def get_world_size(): 52 | if not is_dist_avail_and_initialized(): 53 | return 1 54 | return dist.get_world_size() 55 | 56 | 57 | def get_rank(): 58 | if not is_dist_avail_and_initialized(): 59 | return 0 60 | return dist.get_rank() 61 | 62 | 63 | def is_main_process(): 64 | return get_rank() == 0 65 | 66 | 67 | def save_on_master(state, is_best, output_dir): 68 | if is_main_process(): 69 | ckpt_path = f'{output_dir}/checkpoint.pt' 70 | best_path = f'{output_dir}/checkpoint_best.pt' 71 | torch.save(state, ckpt_path) 72 | if is_best: 73 | shutil.copyfile(ckpt_path, best_path) 74 | 75 | 76 | def init_distributed_mode(args): 77 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 78 | args.rank = int(os.environ["RANK"]) 79 | args.world_size = int(os.environ['WORLD_SIZE']) 80 | args.gpu = int(os.environ['LOCAL_RANK']) 81 | elif 'SLURM_PROCID' in os.environ: 82 | args.rank = int(os.environ['SLURM_PROCID']) 83 | args.gpu = args.rank % torch.cuda.device_count() 84 | else: 85 | print('Not using distributed mode') 86 | args.distributed = False 87 | return 88 | 89 | args.distributed = True 90 | 91 | torch.cuda.set_device(args.gpu) 92 | args.dist_backend = 'nccl' 93 | print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), 94 | flush=True) 95 | torch.distributed.init_process_group(backend=args.dist_backend, 96 | init_method=args.dist_url, 97 | world_size=args.world_size, 98 | rank=args.rank) 99 | torch.distributed.barrier() 100 | setup_for_distributed(args.rank == 0) 101 | 102 | 103 | def scaled_all_reduce(tensors, is_scale=True): 104 | """Performs the scaled all_reduce operation on the provided tensors. 105 | The input tensors are modified in-place. Currently supports only the sum 106 | reduction operator. The reduced values are scaled by the inverse size of the 107 | world size. 108 | """ 109 | world_size = get_world_size() 110 | # There is no need for reduction in the single-proc case 111 | if world_size == 1: 112 | return tensors 113 | # Queue the reductions 114 | reductions = [] 115 | for tensor in tensors: 116 | reduction = dist.all_reduce(tensor, async_op=True) 117 | reductions.append(reduction) 118 | # Wait for reductions to finish 119 | for reduction in reductions: 120 | reduction.wait() 121 | # Scale the results 122 | if is_scale: 123 | for tensor in tensors: 124 | tensor.mul_(1.0 / world_size) 125 | return tensors 126 | 127 | 128 | def all_gather_batch(tensors): 129 | """ 130 | Performs all_gather operation on the provided tensors. 131 | """ 132 | # Queue the gathered tensors 133 | world_size = get_world_size() 134 | # There is no need for reduction in the single-proc case 135 | if world_size == 1: 136 | return tensors 137 | tensor_list = [] 138 | output_tensor = [] 139 | for tensor in tensors: 140 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] 141 | dist.all_gather( 142 | tensor_all, 143 | tensor, 144 | async_op=False # performance opt 145 | ) 146 | 147 | tensor_list.append(tensor_all) 148 | 149 | for tensor_all in tensor_list: 150 | output_tensor.append(torch.cat(tensor_all, dim=0)) 151 | return output_tensor 152 | 153 | 154 | class GatherLayer(autograd.Function): 155 | """ 156 | Gather tensors from all workers with support for backward propagation: 157 | This implementation does not cut the gradients as torch.distributed.all_gather does. 158 | """ 159 | @staticmethod 160 | def forward(ctx, x): 161 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 162 | dist.all_gather(output, x) 163 | return tuple(output) 164 | 165 | @staticmethod 166 | def backward(ctx, *grads): 167 | all_gradients = torch.stack(grads) 168 | dist.all_reduce(all_gradients) 169 | return all_gradients[dist.get_rank()] 170 | 171 | 172 | def all_gather_batch_with_grad(tensors): 173 | """ 174 | Performs all_gather operation on the provided tensors. 175 | Graph remains connected for backward grad computation. 176 | """ 177 | # Queue the gathered tensors 178 | world_size = get_world_size() 179 | # There is no need for reduction in the single-proc case 180 | if world_size == 1: 181 | return tensors 182 | tensor_list = [] 183 | output_tensor = [] 184 | 185 | for tensor in tensors: 186 | tensor_all = GatherLayer.apply(tensor) 187 | tensor_list.append(tensor_all) 188 | 189 | for tensor_all in tensor_list: 190 | output_tensor.append(torch.cat(tensor_all, dim=0)) 191 | return output_tensor 192 | 193 | 194 | def cosine_scheduler(base_value, 195 | final_value, 196 | epochs, 197 | niter_per_ep, 198 | warmup_epochs=0, 199 | start_warmup_value=0): 200 | warmup_schedule = np.array([]) 201 | warmup_iters = warmup_epochs * niter_per_ep 202 | if warmup_epochs > 0: 203 | warmup_schedule = np.linspace(start_warmup_value, base_value, 204 | warmup_iters) 205 | 206 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 207 | schedule = final_value + 0.5 * (base_value - final_value) * ( 208 | 1 + np.cos(np.pi * iters / len(iters))) 209 | 210 | schedule = np.concatenate((warmup_schedule, schedule)) 211 | assert len(schedule) == epochs * niter_per_ep 212 | return schedule 213 | 214 | 215 | class GaussianBlur(object): 216 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 217 | def __init__(self, sigma=[.1, 2.]): 218 | self.sigma = sigma 219 | 220 | def __call__(self, x): 221 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 222 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 223 | return x 224 | -------------------------------------------------------------------------------- /trainers/vpt.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from collections import OrderedDict 3 | from copy import deepcopy 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.cuda.amp import GradScaler, autocast 8 | from torch.nn import functional as F 9 | from tqdm import tqdm 10 | 11 | from clip import clip 12 | from dassl.data.data_manager import DataManager 13 | from dassl.data.datasets import build_dataset 14 | from dassl.data.transforms import build_transform 15 | from dassl.engine import TRAINER_REGISTRY, TrainerX 16 | from dassl.metrics import compute_accuracy 17 | from dassl.optim import build_lr_scheduler, build_optimizer 18 | from dassl.utils import load_checkpoint, load_pretrained_weights 19 | 20 | from .coop import load_clip_to_cpu 21 | from .coop_testtime import DatasetWrapper_aug, build_data_loader 22 | from .losses import SIMCLRLoss 23 | from .zsclip import CUSTOM_TEMPLATES 24 | 25 | 26 | class DataManager_aug(DataManager): 27 | def __init__(self, 28 | cfg, 29 | custom_tfm_train=None, 30 | custom_tfm_test=None, 31 | dataset_wrapper=None): 32 | # Load dataset 33 | dataset = build_dataset(cfg) 34 | 35 | # Build transform 36 | if custom_tfm_train is None: 37 | tfm_train = build_transform(cfg, is_train=True) 38 | else: 39 | print("* Using custom transform for training") 40 | tfm_train = custom_tfm_train 41 | 42 | if custom_tfm_test is None: 43 | tfm_test = build_transform(cfg, is_train=False) 44 | else: 45 | print("* Using custom transform for testing") 46 | tfm_test = custom_tfm_test 47 | 48 | # Build train_loader_x 49 | train_loader_x = build_data_loader( 50 | cfg, 51 | sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER, 52 | data_source=dataset.train_x, 53 | batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, 54 | n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN, 55 | n_ins=cfg.DATALOADER.TRAIN_X.N_INS, 56 | tfm=tfm_train, 57 | is_train=True, 58 | dataset_wrapper=DatasetWrapper_aug, 59 | ) 60 | 61 | # Build train_loader_u 62 | train_loader_u = None 63 | if dataset.train_u: 64 | sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER 65 | batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE 66 | n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN 67 | n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS 68 | 69 | if cfg.DATALOADER.TRAIN_U.SAME_AS_X: 70 | sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER 71 | batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE 72 | n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN 73 | n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS 74 | 75 | train_loader_u = build_data_loader( 76 | cfg, 77 | sampler_type=sampler_type_, 78 | data_source=dataset.train_u, 79 | batch_size=batch_size_, 80 | n_domain=n_domain_, 81 | n_ins=n_ins_, 82 | tfm=tfm_train, 83 | is_train=True, 84 | dataset_wrapper=dataset_wrapper, 85 | ) 86 | 87 | # Build val_loader 88 | val_loader = None 89 | if dataset.val: 90 | val_loader = build_data_loader( 91 | cfg, 92 | sampler_type=cfg.DATALOADER.TEST.SAMPLER, 93 | data_source=dataset.val, 94 | batch_size=cfg.DATALOADER.TEST.BATCH_SIZE, 95 | tfm=tfm_test, 96 | is_train=False, 97 | dataset_wrapper=dataset_wrapper, 98 | ) 99 | 100 | # Build test_loader 101 | test_loader = build_data_loader( 102 | cfg, 103 | sampler_type=cfg.DATALOADER.TEST.SAMPLER, 104 | data_source=dataset.test, 105 | batch_size=cfg.DATALOADER.TEST.BATCH_SIZE, 106 | tfm=tfm_test, 107 | is_train=False, 108 | dataset_wrapper=DatasetWrapper_aug, 109 | ) 110 | 111 | # Attributes 112 | self._num_classes = dataset.num_classes 113 | self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS) 114 | self._lab2cname = dataset.lab2cname 115 | 116 | # Dataset and data-loaders 117 | self.dataset = dataset 118 | self.train_loader_x = train_loader_x 119 | self.train_loader_u = train_loader_u 120 | self.val_loader = val_loader 121 | self.test_loader = test_loader 122 | 123 | if cfg.VERBOSE: 124 | self.show_dataset_summary(cfg) 125 | 126 | 127 | class CustomVPT(nn.Module): 128 | def __init__(self, cfg, classnames, clip_model, device): 129 | super().__init__() 130 | 131 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] 132 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 133 | print(f"Prompts: {prompts}") 134 | prompts = torch.cat([clip.tokenize(p) for p in prompts]) 135 | prompts = prompts.to(device) 136 | clip_model = clip_model.to(device) 137 | 138 | with torch.no_grad(): 139 | text_features = clip_model.encode_text(prompts) 140 | text_features = text_features / text_features.norm(dim=-1, 141 | keepdim=True) 142 | 143 | self.text_features = text_features 144 | self.clip_model = clip_model 145 | 146 | n_ctx = cfg.TRAINER.COOP.N_CTX 147 | dtype = clip_model.dtype 148 | ctx_dim = self.clip_model.visual.positional_embedding.shape[-1] 149 | 150 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype, device=device) 151 | nn.init.normal_(ctx_vectors, std=0.02) 152 | self.visual_ctx = nn.Parameter(ctx_vectors) 153 | # self.mlp = self._build_mlp(768, 4096, 256, dtype=dtype) 154 | self.mlp = self._build_mlp(768, 512, 256, dtype=dtype) 155 | self.mlp.to(device) 156 | 157 | def _build_mlp(self, in_dim, mlp_dim, out_dim, dtype): 158 | return nn.Sequential( 159 | OrderedDict([ 160 | ("layer1", nn.Linear(in_dim, mlp_dim).to(dtype)), 161 | # ("bn1", nn.SyncBatchNorm(mlp_dim)), 162 | ("relu1", nn.ReLU(inplace=True)), 163 | # ("layer2", nn.Linear(mlp_dim, mlp_dim).to(dtype)), 164 | # ("bn2", nn.SyncBatchNorm(mlp_dim)), 165 | # ("relu2", nn.ReLU(inplace=True)), 166 | ("layer3", nn.Linear(mlp_dim, out_dim).to(dtype)), 167 | ])) 168 | 169 | def forward(self, image, aug1, aug2, training=True): 170 | if training: 171 | image_features, _ = self.clip_model.encode_image_prompt( 172 | image, self.visual_ctx) 173 | image_features = image_features / image_features.norm(dim=-1, 174 | keepdim=True) 175 | logit_scale = self.clip_model.logit_scale.exp() 176 | logits = logit_scale * image_features @ self.text_features.t() 177 | else: 178 | logits = None 179 | 180 | _, h2 = self.clip_model.encode_image_prompt(aug1, self.visual_ctx) 181 | _, h3 = self.clip_model.encode_image_prompt(aug2, self.visual_ctx) 182 | aug1_embed = self.mlp(h2) 183 | aug2_embed = self.mlp(h3) 184 | 185 | return logits, aug1_embed, aug2_embed 186 | 187 | 188 | @TRAINER_REGISTRY.register() 189 | class VPT(TrainerX): 190 | def check_cfg(self, cfg): 191 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"] 192 | 193 | def build_model(self): 194 | cfg = self.cfg 195 | classnames = self.dm.dataset.classnames 196 | 197 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 198 | clip_model = load_clip_to_cpu(cfg) 199 | 200 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": 201 | # CLIP's default precision is fp16 202 | clip_model.float() 203 | 204 | print("Building custom VPT") 205 | self.model = CustomVPT(cfg, classnames, clip_model, self.device) 206 | 207 | print("Turning off gradients in both the image and the text encoder") 208 | for name, param in self.model.named_parameters(): 209 | # if "ctx" not in name: 210 | if "clip" in name: 211 | param.requires_grad_(False) 212 | 213 | if cfg.MODEL.INIT_WEIGHTS: 214 | # load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) 215 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS) 216 | 217 | self.model.to(self.device) 218 | # NOTE: only give prompt_learner to the optimizer 219 | # self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 220 | self.optim = build_optimizer(self.model, cfg.OPTIM) 221 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 222 | 223 | # self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) 224 | self.register_model("model", self.model, self.optim, self.sched) 225 | 226 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None 227 | 228 | # Note that multi-gpu training could be slow because CLIP's size is 229 | # big, which slows down the copy operation in DataParallel 230 | device_count = 1 # torch.cuda.device_count() 231 | if device_count > 1: 232 | print( 233 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!" 234 | ) 235 | self.model = nn.DataParallel(self.model) 236 | 237 | self.ssl_loss = SIMCLRLoss() 238 | 239 | def forward_backward(self, batch): 240 | image, image_aug, image_aug2, label = self.parse_batch_train(batch) 241 | 242 | prec = self.cfg.TRAINER.COOP.PREC 243 | if prec == "amp": 244 | with autocast(): 245 | output, aug1_embed, aug2_embed = self.model( 246 | image, image_aug, image_aug2) 247 | loss_ce = F.cross_entropy(output, label) 248 | loss_ssl, acc_ssl = self.ssl_loss(aug1_embed, aug2_embed) 249 | loss = loss_ce + loss_ssl * 0.0 250 | self.optim.zero_grad() 251 | self.scaler.scale(loss).backward() 252 | self.scaler.step(self.optim) 253 | self.scaler.update() 254 | else: 255 | output, aug1_embed, aug2_embed = self.model( 256 | image, image_aug, image_aug2) 257 | loss_ce = F.cross_entropy(output, label) 258 | loss_ssl, acc_ssl = self.ssl_loss(aug1_embed, aug2_embed) 259 | loss = loss_ce + loss_ssl * 0.0 260 | self.model_backward_and_update(loss) 261 | 262 | loss_summary = { 263 | "loss": loss.item(), 264 | "loss_ce": loss_ce.item(), 265 | "loss_ssl": loss_ssl.item(), 266 | "acc": compute_accuracy(output, label)[0].item(), 267 | "acc_ssl": acc_ssl.item() 268 | } 269 | 270 | if (self.batch_idx + 1) == self.num_batches: 271 | self.update_lr() 272 | 273 | return loss_summary 274 | 275 | def load_model(self, directory, epoch=None): 276 | if not directory: 277 | print( 278 | "Note that load_model() is skipped as no pretrained model is given" 279 | ) 280 | return 281 | 282 | names = self.get_model_names() 283 | 284 | # By default, the best model is loaded 285 | model_file = "model-best.pth.tar" 286 | 287 | if epoch is not None: 288 | model_file = "model.pth.tar-" + str(epoch) 289 | 290 | for name in names: 291 | model_path = osp.join(directory, name, model_file) 292 | 293 | if not osp.exists(model_path): 294 | raise FileNotFoundError( 295 | 'Model not found at "{}"'.format(model_path)) 296 | 297 | checkpoint = load_checkpoint(model_path) 298 | state_dict = checkpoint["state_dict"] 299 | epoch = checkpoint["epoch"] 300 | 301 | # Ignore fixed token vectors 302 | if "token_prefix" in state_dict: 303 | del state_dict["token_prefix"] 304 | 305 | if "token_suffix" in state_dict: 306 | del state_dict["token_suffix"] 307 | 308 | print("Loading weights to {} " 309 | 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 310 | # set strict=False 311 | self._models[name].load_state_dict(state_dict, strict=False) 312 | 313 | def parse_batch_train(self, batch): 314 | input = batch["img"] 315 | input_aug = batch["img_aug"] 316 | input_aug2 = batch["img_aug2"] 317 | label = batch["label"] 318 | input = input.to(self.device) 319 | input_aug = input_aug.to(self.device) 320 | input_aug2 = input_aug2.to(self.device) 321 | label = label.to(self.device) 322 | return input, input_aug, input_aug2, label 323 | 324 | # test-time training 325 | 326 | def build_data_loader(self): 327 | dm = DataManager_aug(self.cfg) 328 | 329 | self.train_loader_x = dm.train_loader_x 330 | self.train_loader_u = dm.train_loader_u # optional, can be None 331 | self.val_loader = dm.val_loader # optional, can be None 332 | self.test_loader = dm.test_loader 333 | self.num_classes = dm.num_classes 334 | self.num_source_domains = dm.num_source_domains 335 | self.lab2cname = dm.lab2cname # dict {label: classname} 336 | 337 | self.dm = dm 338 | 339 | def test(self, split=None): 340 | """A generic testing pipeline.""" 341 | # self.set_model_mode("eval") 342 | self.model.to(self.device) 343 | 344 | self._optims['model'].param_groups[0]['lr'] = 2e-4 345 | self.evaluator.reset() 346 | 347 | if split is None: 348 | split = self.cfg.TEST.SPLIT 349 | 350 | if split == "val" and self.val_loader is not None: 351 | data_loader = self.val_loader 352 | print("Do evaluation on {} set".format(split)) 353 | else: 354 | data_loader = self.test_loader 355 | print("Do evaluation on test set") 356 | 357 | model_state = deepcopy(self.model.state_dict()) 358 | for batch_idx, batch in enumerate(tqdm(data_loader)): 359 | self.model.load_state_dict(model_state, strict=True) 360 | input, input_aug, input_aug2, input_aug3, label = self.parse_batch_test( 361 | batch) 362 | 363 | for _ in range(0): 364 | logits, aug1_embed, aug2_embed = self.model_inference( 365 | input, input_aug, input_aug2, training=False) 366 | loss_ssl, _ = self.ssl_loss(aug1_embed, aug2_embed) 367 | self.model_backward_and_update(loss_ssl) 368 | 369 | for _ in range(0): 370 | logits, aug1_embed, aug2_embed = self.model_inference( 371 | input, input_aug2, input_aug3, training=False) 372 | loss_ssl, _ = self.ssl_loss(aug1_embed, aug2_embed) 373 | self.model_backward_and_update(loss_ssl) 374 | 375 | with torch.no_grad(): 376 | output, _, _ = self.model_inference(input, 377 | input_aug, 378 | input_aug2, 379 | training=True) 380 | self.evaluator.process(output, label) 381 | results = self.evaluator.evaluate() 382 | 383 | for k, v in results.items(): 384 | tag = "{}/{}".format(split, k) 385 | self.write_scalar(tag, v, self.epoch) 386 | 387 | return list(results.values())[0] 388 | 389 | def model_inference(self, input, input_aug, input_aug2, training=True): 390 | return self.model(input, input_aug, input_aug2, training=training) 391 | 392 | def parse_batch_test(self, batch): 393 | input = batch["img"] 394 | input_aug = batch["img_aug"] 395 | input_aug2 = batch["img_aug2"] 396 | input_aug3 = batch["img_aug3"] 397 | label = batch["label"] 398 | input = input.to(self.device) 399 | input_aug = input_aug.to(self.device) 400 | input_aug2 = input_aug2.to(self.device) 401 | input_aug3 = input_aug3.to(self.device) 402 | label = label.to(self.device) 403 | return input, input_aug, input_aug2, input_aug3, label 404 | -------------------------------------------------------------------------------- /trainers/vpt_deep.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.cuda.amp import GradScaler, autocast 6 | from torch.nn import functional as F 7 | 8 | from clip import clip 9 | from dassl.engine import TRAINER_REGISTRY, TrainerX 10 | from dassl.metrics import compute_accuracy 11 | from dassl.optim import build_lr_scheduler, build_optimizer 12 | from dassl.utils import load_checkpoint, load_pretrained_weights 13 | 14 | from .coop import load_clip_to_cpu 15 | from .zsclip import CUSTOM_TEMPLATES 16 | 17 | 18 | class CustomVPT(nn.Module): 19 | def __init__(self, cfg, classnames, clip_model, device): 20 | super().__init__() 21 | 22 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] 23 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 24 | print(f"Prompts: {prompts}") 25 | prompts = torch.cat([clip.tokenize(p) for p in prompts]) 26 | prompts = prompts.to(device) 27 | clip_model = clip_model.to(device) 28 | 29 | with torch.no_grad(): 30 | text_features = clip_model.encode_text(prompts) 31 | text_features = text_features / text_features.norm(dim=-1, 32 | keepdim=True) 33 | 34 | self.text_features = text_features 35 | self.clip_model = clip_model 36 | 37 | n_ctx = cfg.TRAINER.COOP.N_CTX 38 | dtype = clip_model.dtype 39 | ctx_dim = self.clip_model.visual.positional_embedding.shape[-1] 40 | 41 | ctx_vectors = torch.empty(n_ctx * 12, 42 | ctx_dim, 43 | dtype=dtype, 44 | device=device) 45 | # ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype, device=device) 46 | nn.init.normal_(ctx_vectors, std=0.02) 47 | self.visual_ctx = nn.Parameter(ctx_vectors) 48 | 49 | def forward(self, image): 50 | image_features, _ = self.clip_model.encode_image_prompt_deep( 51 | image, self.visual_ctx) 52 | image_features = image_features / image_features.norm(dim=-1, 53 | keepdim=True) 54 | logit_scale = self.clip_model.logit_scale.exp() 55 | logits = logit_scale * image_features @ self.text_features.t() 56 | 57 | return logits 58 | 59 | 60 | @TRAINER_REGISTRY.register() 61 | class VPT_deep(TrainerX): 62 | def check_cfg(self, cfg): 63 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"] 64 | 65 | def build_model(self): 66 | cfg = self.cfg 67 | classnames = self.dm.dataset.classnames 68 | 69 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 70 | clip_model = load_clip_to_cpu(cfg) 71 | 72 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": 73 | # CLIP's default precision is fp16 74 | clip_model.float() 75 | 76 | print("Building custom VPT") 77 | self.model = CustomVPT(cfg, classnames, clip_model, self.device) 78 | 79 | print("Turning off gradients in both the image and the text encoder") 80 | for name, param in self.model.named_parameters(): 81 | if "clip" in name: 82 | param.requires_grad_(False) 83 | 84 | if cfg.MODEL.INIT_WEIGHTS: 85 | # load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) 86 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS) 87 | 88 | self.model.to(self.device) 89 | # NOTE: only give prompt_learner to the optimizer 90 | # self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 91 | self.optim = build_optimizer(self.model, cfg.OPTIM) 92 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 93 | 94 | # self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) 95 | self.register_model("model", self.model, self.optim, self.sched) 96 | 97 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None 98 | 99 | # Note that multi-gpu training could be slow because CLIP's size is 100 | # big, which slows down the copy operation in DataParallel 101 | device_count = 1 # torch.cuda.device_count() 102 | if device_count > 1: 103 | print( 104 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!" 105 | ) 106 | self.model = nn.DataParallel(self.model) 107 | 108 | def forward_backward(self, batch): 109 | image, label = self.parse_batch_train(batch) 110 | 111 | prec = self.cfg.TRAINER.COOP.PREC 112 | if prec == "amp": 113 | with autocast(): 114 | output = self.model(image) 115 | loss = F.cross_entropy(output, label) 116 | self.optim.zero_grad() 117 | self.scaler.scale(loss).backward() 118 | self.scaler.step(self.optim) 119 | self.scaler.update() 120 | else: 121 | output = self.model(image) 122 | loss = F.cross_entropy(output, label) 123 | self.model_backward_and_update(loss) 124 | 125 | loss_summary = { 126 | "loss": loss.item(), 127 | "acc": compute_accuracy(output, label)[0].item(), 128 | } 129 | 130 | if (self.batch_idx + 1) == self.num_batches: 131 | self.update_lr() 132 | 133 | return loss_summary 134 | 135 | def parse_batch_train(self, batch): 136 | input = batch["img"] 137 | label = batch["label"] 138 | input = input.to(self.device) 139 | label = label.to(self.device) 140 | return input, label 141 | 142 | def load_model(self, directory, epoch=None): 143 | if not directory: 144 | print( 145 | "Note that load_model() is skipped as no pretrained model is given" 146 | ) 147 | return 148 | 149 | names = self.get_model_names() 150 | 151 | # By default, the best model is loaded 152 | model_file = "model-best.pth.tar" 153 | 154 | if epoch is not None: 155 | model_file = "model.pth.tar-" + str(epoch) 156 | 157 | for name in names: 158 | model_path = osp.join(directory, name, model_file) 159 | 160 | if not osp.exists(model_path): 161 | raise FileNotFoundError( 162 | 'Model not found at "{}"'.format(model_path)) 163 | 164 | checkpoint = load_checkpoint(model_path) 165 | state_dict = checkpoint["state_dict"] 166 | epoch = checkpoint["epoch"] 167 | 168 | # Ignore fixed token vectors 169 | if "token_prefix" in state_dict: 170 | del state_dict["token_prefix"] 171 | 172 | if "token_suffix" in state_dict: 173 | del state_dict["token_suffix"] 174 | 175 | print("Loading weights to {} " 176 | 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 177 | # set strict=False 178 | self._models[name].load_state_dict(state_dict, strict=False) 179 | -------------------------------------------------------------------------------- /trainers/vpt_shallow.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.cuda.amp import GradScaler, autocast 6 | from torch.nn import functional as F 7 | 8 | from clip import clip 9 | from dassl.engine import TRAINER_REGISTRY, TrainerX 10 | from dassl.metrics import compute_accuracy 11 | from dassl.optim import build_lr_scheduler, build_optimizer 12 | from dassl.utils import load_checkpoint, load_pretrained_weights 13 | 14 | from .coop import load_clip_to_cpu 15 | from .zsclip import CUSTOM_TEMPLATES 16 | 17 | 18 | class CustomVPT(nn.Module): 19 | def __init__(self, cfg, classnames, clip_model, device): 20 | super().__init__() 21 | 22 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] 23 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 24 | print(f"Prompts: {prompts}") 25 | prompts = torch.cat([clip.tokenize(p) for p in prompts]) 26 | prompts = prompts.to(device) 27 | clip_model = clip_model.to(device) 28 | 29 | with torch.no_grad(): 30 | text_features = clip_model.encode_text(prompts) 31 | text_features = text_features / text_features.norm(dim=-1, 32 | keepdim=True) 33 | 34 | self.text_features = text_features 35 | self.clip_model = clip_model 36 | 37 | n_ctx = cfg.TRAINER.COOP.N_CTX 38 | dtype = clip_model.dtype 39 | ctx_dim = self.clip_model.visual.positional_embedding.shape[-1] 40 | 41 | ctx_vectors = torch.empty(n_ctx * 1, 42 | ctx_dim, 43 | dtype=dtype, 44 | device=device) 45 | # ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype, device=device) 46 | nn.init.normal_(ctx_vectors, std=0.02) 47 | self.visual_ctx = nn.Parameter(ctx_vectors) 48 | 49 | def forward(self, image): 50 | # import time 51 | # torch.save(image.data.cpu(), f'img_{int(time.time())}.pt') 52 | 53 | image_features, _ = self.clip_model.encode_image_prompt_shallow( 54 | image, self.visual_ctx) 55 | image_features = image_features / image_features.norm(dim=-1, 56 | keepdim=True) 57 | logit_scale = self.clip_model.logit_scale.exp() 58 | logits = logit_scale * image_features @ self.text_features.t() 59 | 60 | return logits 61 | 62 | 63 | @TRAINER_REGISTRY.register() 64 | class VPT_shallow(TrainerX): 65 | def check_cfg(self, cfg): 66 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"] 67 | 68 | def build_model(self): 69 | cfg = self.cfg 70 | classnames = self.dm.dataset.classnames 71 | 72 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 73 | clip_model = load_clip_to_cpu(cfg) 74 | 75 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": 76 | # CLIP's default precision is fp16 77 | clip_model.float() 78 | 79 | print("Building custom VPT") 80 | self.model = CustomVPT(cfg, classnames, clip_model, self.device) 81 | 82 | print("Turning off gradients in both the image and the text encoder") 83 | for name, param in self.model.named_parameters(): 84 | if "clip" in name: 85 | param.requires_grad_(False) 86 | 87 | if cfg.MODEL.INIT_WEIGHTS: 88 | # load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) 89 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS) 90 | 91 | self.model.to(self.device) 92 | # NOTE: only give prompt_learner to the optimizer 93 | # self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 94 | self.optim = build_optimizer(self.model, cfg.OPTIM) 95 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 96 | 97 | # self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) 98 | self.register_model("model", self.model, self.optim, self.sched) 99 | 100 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None 101 | 102 | # Note that multi-gpu training could be slow because CLIP's size is 103 | # big, which slows down the copy operation in DataParallel 104 | device_count = 1 # torch.cuda.device_count() 105 | if device_count > 1: 106 | print( 107 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!" 108 | ) 109 | self.model = nn.DataParallel(self.model) 110 | 111 | def forward_backward(self, batch): 112 | image, label = self.parse_batch_train(batch) 113 | 114 | prec = self.cfg.TRAINER.COOP.PREC 115 | if prec == "amp": 116 | with autocast(): 117 | output = self.model(image) 118 | loss = F.cross_entropy(output, label) 119 | self.optim.zero_grad() 120 | self.scaler.scale(loss).backward() 121 | self.scaler.step(self.optim) 122 | self.scaler.update() 123 | else: 124 | output = self.model(image) 125 | loss = F.cross_entropy(output, label) 126 | self.model_backward_and_update(loss) 127 | 128 | loss_summary = { 129 | "loss": loss.item(), 130 | "acc": compute_accuracy(output, label)[0].item(), 131 | } 132 | 133 | if (self.batch_idx + 1) == self.num_batches: 134 | self.update_lr() 135 | 136 | return loss_summary 137 | 138 | def parse_batch_train(self, batch): 139 | input = batch["img"] 140 | label = batch["label"] 141 | input = input.to(self.device) 142 | label = label.to(self.device) 143 | return input, label 144 | 145 | def load_model(self, directory, epoch=None): 146 | if not directory: 147 | print( 148 | "Note that load_model() is skipped as no pretrained model is given" 149 | ) 150 | return 151 | 152 | names = self.get_model_names() 153 | 154 | # By default, the best model is loaded 155 | model_file = "model-best.pth.tar" 156 | 157 | if epoch is not None: 158 | model_file = "model.pth.tar-" + str(epoch) 159 | 160 | for name in names: 161 | model_path = osp.join(directory, name, model_file) 162 | 163 | if not osp.exists(model_path): 164 | raise FileNotFoundError( 165 | 'Model not found at "{}"'.format(model_path)) 166 | 167 | checkpoint = load_checkpoint(model_path) 168 | state_dict = checkpoint["state_dict"] 169 | epoch = checkpoint["epoch"] 170 | 171 | # Ignore fixed token vectors 172 | if "token_prefix" in state_dict: 173 | del state_dict["token_prefix"] 174 | 175 | if "token_suffix" in state_dict: 176 | del state_dict["token_suffix"] 177 | 178 | print("Loading weights to {} " 179 | 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 180 | # set strict=False 181 | self._models[name].load_state_dict(state_dict, strict=False) 182 | -------------------------------------------------------------------------------- /trainers/zsclip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from clip import clip 4 | from dassl.engine import TRAINER_REGISTRY, TrainerX 5 | 6 | from .coop import load_clip_to_cpu 7 | from .imagenet_templates import IMAGENET_TEMPLATES_SELECT 8 | 9 | CUSTOM_TEMPLATES = { 10 | "OxfordPets": "a photo of a {}, a type of pet.", 11 | "OxfordFlowers": "a photo of a {}, a type of flower.", 12 | "FGVCAircraft": "a photo of a {}, a type of aircraft.", 13 | "DescribableTextures": "{} texture.", 14 | "EuroSAT": "a centered satellite photo of {}.", 15 | "StanfordCars": "a photo of a {}.", 16 | "Food101": "a photo of {}, a type of food.", 17 | "SUN397": "a photo of a {}.", 18 | "Caltech101": "a photo of a {}.", 19 | "UCF101": "a photo of a person doing {}.", 20 | "ImageNet": "a photo of a {}.", 21 | "ImageNetSketch": "a photo of a {}.", 22 | "ImageNetV2": "a photo of a {}.", 23 | "ImageNetA": "a photo of a {}.", 24 | "ImageNetR": "a photo of a {}.", 25 | } 26 | 27 | 28 | @TRAINER_REGISTRY.register() 29 | class ZeroshotCLIP(TrainerX): 30 | def build_model(self): 31 | cfg = self.cfg 32 | classnames = self.dm.dataset.classnames 33 | 34 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 35 | clip_model = load_clip_to_cpu(cfg) 36 | clip_model.to(self.device) 37 | 38 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] 39 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 40 | print(f"Prompts: {prompts}") 41 | prompts = torch.cat([clip.tokenize(p) for p in prompts]) 42 | prompts = prompts.to(self.device) 43 | 44 | with torch.no_grad(): 45 | text_features = clip_model.encode_text(prompts) 46 | text_features = text_features / text_features.norm(dim=-1, 47 | keepdim=True) 48 | 49 | self.text_features = text_features 50 | self.clip_model = clip_model 51 | 52 | def model_inference(self, image): 53 | image_features = self.clip_model.encode_image(image) 54 | image_features = image_features / image_features.norm(dim=-1, 55 | keepdim=True) 56 | logit_scale = self.clip_model.logit_scale.exp() 57 | logits = logit_scale * image_features @ self.text_features.t() 58 | return logits 59 | 60 | 61 | @TRAINER_REGISTRY.register() 62 | class ZeroshotCLIP2(ZeroshotCLIP): 63 | """Prompt ensembling.""" 64 | 65 | # templates = IMAGENET_TEMPLATES 66 | templates = IMAGENET_TEMPLATES_SELECT 67 | 68 | def build_model(self): 69 | cfg = self.cfg 70 | classnames = self.dm.dataset.classnames 71 | 72 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 73 | clip_model = load_clip_to_cpu(cfg) 74 | clip_model.to(self.device) 75 | 76 | for params in clip_model.parameters(): 77 | params.requires_grad_(False) 78 | 79 | # add custom-made prompt 80 | if cfg.DATASET.NAME != "ImageNet": 81 | self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]] 82 | 83 | num_temp = len(self.templates) 84 | print(f"Prompt ensembling (n={num_temp})") 85 | 86 | mean_text_features = 0 87 | for i, temp in enumerate(self.templates): 88 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 89 | prompts = torch.cat([clip.tokenize(p) 90 | for p in prompts]).to(self.device) 91 | text_features = clip_model.encode_text(prompts) 92 | text_features = text_features / text_features.norm(dim=-1, 93 | keepdim=True) 94 | mean_text_features = mean_text_features + text_features 95 | mean_text_features = mean_text_features / num_temp 96 | mean_text_features = mean_text_features / mean_text_features.norm( 97 | dim=-1, keepdim=True) 98 | 99 | self.text_features = mean_text_features 100 | self.clip_model = clip_model 101 | --------------------------------------------------------------------------------