├── README.md
├── assets
└── motivation.png
├── clip
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── clip.py
├── model.py
└── simple_tokenizer.py
├── configs
├── datasets
│ ├── caltech101.yaml
│ ├── dtd.yaml
│ ├── eurosat.yaml
│ ├── fgvc_aircraft.yaml
│ ├── food101.yaml
│ ├── imagenet.yaml
│ ├── imagenet_a.yaml
│ ├── imagenet_r.yaml
│ ├── imagenet_sketch.yaml
│ ├── imagenetv2.yaml
│ ├── oxford_flowers.yaml
│ ├── oxford_pets.yaml
│ ├── stanford_cars.yaml
│ ├── sun397.yaml
│ └── ucf101.yaml
└── trainers
│ ├── CoCoOp
│ ├── vit_b16_c16_ep10_batch1.yaml
│ ├── vit_b16_c4_ep10_batch1.yaml
│ ├── vit_b16_c4_ep10_batch1_ctxv1.yaml
│ └── vit_b16_c8_ep10_batch1.yaml
│ ├── CoOp
│ ├── rn101.yaml
│ ├── rn101_ep50.yaml
│ ├── rn50.yaml
│ ├── rn50_ctxv1.yaml
│ ├── rn50_ep100.yaml
│ ├── rn50_ep50.yaml
│ ├── rn50_ep50_ctxv1.yaml
│ ├── rn50_val.yaml
│ ├── vit_b16_ep10.yaml
│ ├── vit_b16_ep100.yaml
│ ├── vit_b16_ep50.yaml
│ ├── vit_b32.yaml
│ └── vit_b32_ep50.yaml
│ ├── CoOp_testtime
│ ├── vit_b16_ep10.yaml
│ ├── vit_b16_ep200.yaml
│ └── vit_b16_ep50.yaml
│ ├── Unified
│ ├── vit_b16_ep100.yaml
│ ├── vit_b16_ep200.yaml
│ └── vit_b16_ep50.yaml
│ ├── VPT
│ ├── vit_b16_ep10.yaml
│ ├── vit_b16_ep100.yaml
│ ├── vit_b16_ep200.yaml
│ └── vit_b16_ep50.yaml
│ ├── VPT_deep
│ ├── vit_b16_ep10.yaml
│ ├── vit_b16_ep200.yaml
│ └── vit_b16_ep50.yaml
│ ├── VPT_shallow
│ ├── vit_b16_ep10.yaml
│ ├── vit_b16_ep200.yaml
│ └── vit_b16_ep50.yaml
│ └── VPT_testtime
│ ├── vit_b16_ep10.yaml
│ ├── vit_b16_ep200.yaml
│ └── vit_b16_ep50.yaml
├── scripts
├── README.md
├── cocoop
│ ├── base2new_test.sh
│ ├── xd_test.sh
│ └── xd_train.sh
├── eval.sh
├── imagenet_coop.sh
├── imagenet_zero.sh
├── main.sh
├── unified
│ ├── base2new_test_coop.sh
│ ├── base2new_test_new.sh
│ ├── base2new_test_new2.sh
│ ├── base2new_test_new3.sh
│ ├── base2new_train_caltech-101.sh
│ ├── base2new_train_dtd.sh
│ ├── base2new_train_eurosat.sh
│ ├── base2new_train_fgvc_aircraft.sh
│ ├── base2new_train_food101.sh
│ ├── base2new_train_imagenet.sh
│ ├── base2new_train_oxford_flowers.sh
│ ├── base2new_train_oxford_pets.sh
│ ├── base2new_train_stanford_cars.sh
│ ├── base2new_train_sun397.sh
│ └── base2new_train_ucf101.sh
└── zeroshot.sh
├── train.py
└── trainers
├── .flake8
├── __init__.py
├── cocoop.py
├── coop.py
├── coop_testtime.py
├── imagenet_templates.py
├── linter.sh
├── losses.py
├── unified.py
├── utils.py
├── vpt.py
├── vpt_deep.py
├── vpt_shallow.py
└── zsclip.py
/README.md:
--------------------------------------------------------------------------------
1 |
Unified Vision and Language Prompt Learning
2 |
3 |
7 |
8 | > **Unified Vision and Language Prompt Learning**
9 | > Yuhang Zang, Wei Li, Kaiyang Zhou, Chen Huang, Chen Change Loy
10 | > arXiv, 2022
11 |
12 |
13 |
14 |
15 |
16 | ## How to run
17 |
18 | This code is based on [CoOp](https://github.com/KaiyangZhou/CoOp), you may refer to the [install instruction](https://github.com/KaiyangZhou/CoOp?tab=readme-ov-file#how-to-install)
19 |
20 | The training scripts to re-produce the results are provided [here](scripts/unified).
21 |
22 | The model structure is defined [here](trainers/unified.py).
23 |
--------------------------------------------------------------------------------
/assets/motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuhangzang/UPT/3d1640fcfd2532fd651041bc955fc5baff51c71f/assets/motivation.png
--------------------------------------------------------------------------------
/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuhangzang/UPT/3d1640fcfd2532fd651041bc955fc5baff51c71f/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Union, List
6 |
7 | import torch
8 | from PIL import Image
9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10 | from tqdm import tqdm
11 |
12 | from .model import build_model
13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14 |
15 | try:
16 | from torchvision.transforms import InterpolationMode
17 | BICUBIC = InterpolationMode.BICUBIC
18 | except ImportError:
19 | BICUBIC = Image.BICUBIC
20 |
21 |
22 | if torch.__version__.split(".") < ["1", "7", "1"]:
23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
24 |
25 |
26 | __all__ = ["available_models", "load", "tokenize"]
27 | _tokenizer = _Tokenizer()
28 |
29 | _MODELS = {
30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
36 | }
37 |
38 |
39 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
40 | os.makedirs(root, exist_ok=True)
41 | filename = os.path.basename(url)
42 |
43 | expected_sha256 = url.split("/")[-2]
44 | download_target = os.path.join(root, filename)
45 |
46 | if os.path.exists(download_target) and not os.path.isfile(download_target):
47 | raise RuntimeError(f"{download_target} exists and is not a regular file")
48 |
49 | if os.path.isfile(download_target):
50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
51 | return download_target
52 | else:
53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
54 |
55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
57 | while True:
58 | buffer = source.read(8192)
59 | if not buffer:
60 | break
61 |
62 | output.write(buffer)
63 | loop.update(len(buffer))
64 |
65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
67 |
68 | return download_target
69 |
70 |
71 | def _transform(n_px):
72 | return Compose([
73 | Resize(n_px, interpolation=BICUBIC),
74 | CenterCrop(n_px),
75 | lambda image: image.convert("RGB"),
76 | ToTensor(),
77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
78 | ])
79 |
80 |
81 | def available_models() -> List[str]:
82 | """Returns the names of available CLIP models"""
83 | return list(_MODELS.keys())
84 |
85 |
86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
87 | """Load a CLIP model
88 |
89 | Parameters
90 | ----------
91 | name : str
92 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
93 |
94 | device : Union[str, torch.device]
95 | The device to put the loaded model
96 |
97 | jit : bool
98 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
99 |
100 | Returns
101 | -------
102 | model : torch.nn.Module
103 | The CLIP model
104 |
105 | preprocess : Callable[[PIL.Image], torch.Tensor]
106 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
107 | """
108 | if name in _MODELS:
109 | model_path = _download(_MODELS[name])
110 | elif os.path.isfile(name):
111 | model_path = name
112 | else:
113 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
114 |
115 | try:
116 | # loading JIT archive
117 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
118 | state_dict = None
119 | except RuntimeError:
120 | # loading saved state dict
121 | if jit:
122 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
123 | jit = False
124 | state_dict = torch.load(model_path, map_location="cpu")
125 |
126 | if not jit:
127 | model = build_model(state_dict or model.state_dict()).to(device)
128 | if str(device) == "cpu":
129 | model.float()
130 | return model, _transform(model.visual.input_resolution)
131 |
132 | # patch the device names
133 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
134 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
135 |
136 | def patch_device(module):
137 | try:
138 | graphs = [module.graph] if hasattr(module, "graph") else []
139 | except RuntimeError:
140 | graphs = []
141 |
142 | if hasattr(module, "forward1"):
143 | graphs.append(module.forward1.graph)
144 |
145 | for graph in graphs:
146 | for node in graph.findAllNodes("prim::Constant"):
147 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
148 | node.copyAttributes(device_node)
149 |
150 | model.apply(patch_device)
151 | patch_device(model.encode_image)
152 | patch_device(model.encode_text)
153 |
154 | # patch dtype to float32 on CPU
155 | if str(device) == "cpu":
156 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
157 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
158 | float_node = float_input.node()
159 |
160 | def patch_float(module):
161 | try:
162 | graphs = [module.graph] if hasattr(module, "graph") else []
163 | except RuntimeError:
164 | graphs = []
165 |
166 | if hasattr(module, "forward1"):
167 | graphs.append(module.forward1.graph)
168 |
169 | for graph in graphs:
170 | for node in graph.findAllNodes("aten::to"):
171 | inputs = list(node.inputs())
172 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
173 | if inputs[i].node()["value"] == 5:
174 | inputs[i].node().copyAttributes(float_node)
175 |
176 | model.apply(patch_float)
177 | patch_float(model.encode_image)
178 | patch_float(model.encode_text)
179 |
180 | model.float()
181 |
182 | return model, _transform(model.input_resolution.item())
183 |
184 |
185 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
186 | """
187 | Returns the tokenized representation of given input string(s)
188 |
189 | Parameters
190 | ----------
191 | texts : Union[str, List[str]]
192 | An input string or a list of input strings to tokenize
193 |
194 | context_length : int
195 | The context length to use; all CLIP models use 77 as the context length
196 |
197 | truncate: bool
198 | Whether to truncate the text in case its encoding is longer than the context length
199 |
200 | Returns
201 | -------
202 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
203 | """
204 | if isinstance(texts, str):
205 | texts = [texts]
206 |
207 | sot_token = _tokenizer.encoder["<|startoftext|>"]
208 | eot_token = _tokenizer.encoder["<|endoftext|>"]
209 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
210 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
211 |
212 | for i, tokens in enumerate(all_tokens):
213 | if len(tokens) > context_length:
214 | if truncate:
215 | tokens = tokens[:context_length]
216 | tokens[-1] = eot_token
217 | else:
218 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
219 | result[i, :len(tokens)] = torch.tensor(tokens)
220 |
221 | return result
222 |
--------------------------------------------------------------------------------
/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 |
20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24 |
25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27 |
28 | self.relu = nn.ReLU(inplace=True)
29 | self.downsample = None
30 | self.stride = stride
31 |
32 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34 | self.downsample = nn.Sequential(OrderedDict([
35 | ("-1", nn.AvgPool2d(stride)),
36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
37 | ("1", nn.BatchNorm2d(planes * self.expansion))
38 | ]))
39 |
40 | def forward(self, x: torch.Tensor):
41 | identity = x
42 |
43 | out = self.relu(self.bn1(self.conv1(x)))
44 | out = self.relu(self.bn2(self.conv2(out)))
45 | out = self.avgpool(out)
46 | out = self.bn3(self.conv3(out))
47 |
48 | if self.downsample is not None:
49 | identity = self.downsample(x)
50 |
51 | out += identity
52 | out = self.relu(out)
53 | return out
54 |
55 |
56 | class AttentionPool2d(nn.Module):
57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
58 | super().__init__()
59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
60 | self.k_proj = nn.Linear(embed_dim, embed_dim)
61 | self.q_proj = nn.Linear(embed_dim, embed_dim)
62 | self.v_proj = nn.Linear(embed_dim, embed_dim)
63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
64 | self.num_heads = num_heads
65 |
66 | def forward(self, x):
67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
70 | x, _ = F.multi_head_attention_forward(
71 | query=x, key=x, value=x,
72 | embed_dim_to_check=x.shape[-1],
73 | num_heads=self.num_heads,
74 | q_proj_weight=self.q_proj.weight,
75 | k_proj_weight=self.k_proj.weight,
76 | v_proj_weight=self.v_proj.weight,
77 | in_proj_weight=None,
78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79 | bias_k=None,
80 | bias_v=None,
81 | add_zero_attn=False,
82 | dropout_p=0,
83 | out_proj_weight=self.c_proj.weight,
84 | out_proj_bias=self.c_proj.bias,
85 | use_separate_proj_weight=True,
86 | training=self.training,
87 | need_weights=False
88 | )
89 |
90 | return x[0]
91 |
92 |
93 | class ModifiedResNet(nn.Module):
94 | """
95 | A ResNet class that is similar to torchvision's but contains the following changes:
96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98 | - The final pooling layer is a QKV attention instead of an average pool
99 | """
100 |
101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
102 | super().__init__()
103 | self.output_dim = output_dim
104 | self.input_resolution = input_resolution
105 |
106 | # the 3-layer stem
107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108 | self.bn1 = nn.BatchNorm2d(width // 2)
109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
110 | self.bn2 = nn.BatchNorm2d(width // 2)
111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
112 | self.bn3 = nn.BatchNorm2d(width)
113 | self.avgpool = nn.AvgPool2d(2)
114 | self.relu = nn.ReLU(inplace=True)
115 |
116 | # residual layers
117 | self._inplanes = width # this is a *mutable* variable used during construction
118 | self.layer1 = self._make_layer(width, layers[0])
119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
122 |
123 | embed_dim = width * 32 # the ResNet feature dimension
124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
125 |
126 | def _make_layer(self, planes, blocks, stride=1):
127 | layers = [Bottleneck(self._inplanes, planes, stride)]
128 |
129 | self._inplanes = planes * Bottleneck.expansion
130 | for _ in range(1, blocks):
131 | layers.append(Bottleneck(self._inplanes, planes))
132 |
133 | return nn.Sequential(*layers)
134 |
135 | def forward(self, x):
136 | def stem(x):
137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
138 | x = self.relu(bn(conv(x)))
139 | x = self.avgpool(x)
140 | return x
141 |
142 | x = x.type(self.conv1.weight.dtype)
143 | x = stem(x)
144 | x = self.layer1(x)
145 | x = self.layer2(x)
146 | x = self.layer3(x)
147 | x = self.layer4(x)
148 | x = self.attnpool(x)
149 |
150 | return x
151 |
152 |
153 | class LayerNorm(nn.LayerNorm):
154 | """Subclass torch's LayerNorm to handle fp16."""
155 |
156 | def forward(self, x: torch.Tensor):
157 | orig_type = x.dtype
158 | ret = super().forward(x.type(torch.float32))
159 | return ret.type(orig_type)
160 |
161 |
162 | class QuickGELU(nn.Module):
163 | def forward(self, x: torch.Tensor):
164 | return x * torch.sigmoid(1.702 * x)
165 |
166 |
167 | class ResidualAttentionBlock(nn.Module):
168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
169 | super().__init__()
170 |
171 | self.attn = nn.MultiheadAttention(d_model, n_head)
172 | self.ln_1 = LayerNorm(d_model)
173 | self.mlp = nn.Sequential(OrderedDict([
174 | ("c_fc", nn.Linear(d_model, d_model * 4)),
175 | ("gelu", QuickGELU()),
176 | ("c_proj", nn.Linear(d_model * 4, d_model))
177 | ]))
178 | self.ln_2 = LayerNorm(d_model)
179 | self.attn_mask = attn_mask
180 |
181 | def attention(self, x: torch.Tensor):
182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
184 |
185 | def forward(self, x: torch.Tensor):
186 | x = x + self.attention(self.ln_1(x))
187 | x = x + self.mlp(self.ln_2(x))
188 | return x
189 |
190 |
191 | class Transformer(nn.Module):
192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
193 | super().__init__()
194 | self.width = width
195 | self.layers = layers
196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
197 |
198 | def forward(self, x: torch.Tensor):
199 | return self.resblocks(x)
200 |
201 | def forward_prompt(self, x, prompt):
202 | l_length = len(self.resblocks)
203 | p_length = len(prompt) // l_length
204 |
205 | prompt = prompt.reshape(p_length, l_length, -1)
206 | prompt = prompt.permute(1, 0, 2)
207 | prompt = prompt.unsqueeze(2).repeat(1, 1, x.shape[1], 1)
208 | # prompt = prompt.unsqueeze(1).repeat(1, x.shape[1], 1)
209 |
210 | for ind, block in enumerate(self.resblocks):
211 | cls = x[0:1, :, :]
212 | if ind == 0:
213 | spatial = x[1:, :, :]
214 | else:
215 | spatial = x[1+p_length:, :, :]
216 | x = torch.cat([cls, prompt[ind, :], spatial], 0)
217 | # x = torch.cat([cls, prompt, spatial], 0)
218 | x = block(x)
219 | return x
220 |
221 |
222 | class VisionTransformer(nn.Module):
223 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
224 | super().__init__()
225 | self.input_resolution = input_resolution
226 | self.output_dim = output_dim
227 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
228 |
229 | scale = width ** -0.5
230 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
231 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
232 | self.ln_pre = LayerNorm(width)
233 |
234 | self.transformer = Transformer(width, layers, heads)
235 |
236 | self.ln_post = LayerNorm(width)
237 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
238 |
239 | def forward(self, x: torch.Tensor):
240 | x = self.conv1(x) # shape = [*, width, grid, grid]
241 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
242 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
243 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
244 | x = x + self.positional_embedding.to(x.dtype)
245 | x = self.ln_pre(x)
246 |
247 | x = x.permute(1, 0, 2) # NLD -> LND
248 | x = self.transformer(x)
249 | x = x.permute(1, 0, 2) # LND -> NLD
250 |
251 | x = self.ln_post(x[:, 0, :])
252 |
253 | if self.proj is not None:
254 | x = x @ self.proj
255 |
256 | return x
257 |
258 | def forward_prompt(self, x: torch.Tensor, prompt):
259 | x = self.conv1(x) # shape = [*, width, grid, grid]
260 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
261 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
262 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
263 | x = x + self.positional_embedding.to(x.dtype)
264 |
265 | #
266 | cls = x[:, 0:1, :]
267 | spatial = x[:, 1:, :]
268 | prompt = prompt.unsqueeze(0).repeat(len(x), 1, 1)
269 | x = torch.cat([cls, prompt, spatial], 1)
270 | #
271 |
272 | x = self.ln_pre(x)
273 |
274 | x = x.permute(1, 0, 2) # NLD -> LND
275 | x = self.transformer(x)
276 | x = x.permute(1, 0, 2) # LND -> NLD
277 |
278 | x = self.ln_post(x[:, 0, :])
279 |
280 | if self.proj is not None:
281 | out = x @ self.proj
282 |
283 | return out, x
284 |
285 | def forward_prompt_deep(self, x: torch.Tensor, prompt):
286 | x = self.conv1(x) # shape = [*, width, grid, grid]
287 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
288 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
289 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
290 | x = x + self.positional_embedding.to(x.dtype)
291 |
292 | x = self.ln_pre(x)
293 |
294 | x = x.permute(1, 0, 2) # NLD -> LND
295 | x = self.transformer.forward_prompt(x, prompt)
296 | x = x.permute(1, 0, 2) # LND -> NLD
297 |
298 | x = self.ln_post(x[:, 0, :])
299 |
300 | if self.proj is not None:
301 | out = x @ self.proj
302 |
303 | return out, x
304 |
305 |
306 | class CLIP(nn.Module):
307 | def __init__(self,
308 | embed_dim: int,
309 | # vision
310 | image_resolution: int,
311 | vision_layers: Union[Tuple[int, int, int, int], int],
312 | vision_width: int,
313 | vision_patch_size: int,
314 | # text
315 | context_length: int,
316 | vocab_size: int,
317 | transformer_width: int,
318 | transformer_heads: int,
319 | transformer_layers: int
320 | ):
321 | super().__init__()
322 |
323 | self.context_length = context_length
324 |
325 | if isinstance(vision_layers, (tuple, list)):
326 | vision_heads = vision_width * 32 // 64
327 | self.visual = ModifiedResNet(
328 | layers=vision_layers,
329 | output_dim=embed_dim,
330 | heads=vision_heads,
331 | input_resolution=image_resolution,
332 | width=vision_width
333 | )
334 | else:
335 | vision_heads = vision_width // 64
336 | self.visual = VisionTransformer(
337 | input_resolution=image_resolution,
338 | patch_size=vision_patch_size,
339 | width=vision_width,
340 | layers=vision_layers,
341 | heads=vision_heads,
342 | output_dim=embed_dim
343 | )
344 |
345 | self.transformer = Transformer(
346 | width=transformer_width,
347 | layers=transformer_layers,
348 | heads=transformer_heads,
349 | attn_mask=self.build_attention_mask()
350 | )
351 |
352 | self.vocab_size = vocab_size
353 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
354 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
355 | self.ln_final = LayerNorm(transformer_width)
356 |
357 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
358 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
359 |
360 | self.initialize_parameters()
361 |
362 | def initialize_parameters(self):
363 | nn.init.normal_(self.token_embedding.weight, std=0.02)
364 | nn.init.normal_(self.positional_embedding, std=0.01)
365 |
366 | if isinstance(self.visual, ModifiedResNet):
367 | if self.visual.attnpool is not None:
368 | std = self.visual.attnpool.c_proj.in_features ** -0.5
369 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
370 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
371 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
372 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
373 |
374 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
375 | for name, param in resnet_block.named_parameters():
376 | if name.endswith("bn3.weight"):
377 | nn.init.zeros_(param)
378 |
379 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
380 | attn_std = self.transformer.width ** -0.5
381 | fc_std = (2 * self.transformer.width) ** -0.5
382 | for block in self.transformer.resblocks:
383 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
384 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
385 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
386 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
387 |
388 | if self.text_projection is not None:
389 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
390 |
391 | def build_attention_mask(self):
392 | # lazily create causal attention mask, with full attention between the vision tokens
393 | # pytorch uses additive attention mask; fill with -inf
394 | mask = torch.empty(self.context_length, self.context_length)
395 | mask.fill_(float("-inf"))
396 | mask.triu_(1) # zero out the lower diagonal
397 | return mask
398 |
399 | @property
400 | def dtype(self):
401 | return self.visual.conv1.weight.dtype
402 |
403 | def encode_image(self, image):
404 | return self.visual(image.type(self.dtype))
405 |
406 | def encode_image_prompt(self, image, prompt):
407 | return self.visual.forward_prompt(image.type(self.dtype), prompt)
408 |
409 | def encode_image_prompt_deep(self, image, prompt):
410 | return self.visual.forward_prompt_deep(image.type(self.dtype), prompt)
411 |
412 | def encode_text(self, text):
413 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
414 |
415 | x = x + self.positional_embedding.type(self.dtype)
416 | x = x.permute(1, 0, 2) # NLD -> LND
417 | x = self.transformer(x)
418 | x = x.permute(1, 0, 2) # LND -> NLD
419 | x = self.ln_final(x).type(self.dtype)
420 |
421 | # x.shape = [batch_size, n_ctx, transformer.width]
422 | # take features from the eot embedding (eot_token is the highest number in each sequence)
423 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
424 |
425 | return x
426 |
427 | def forward(self, image, text):
428 | image_features = self.encode_image(image)
429 | text_features = self.encode_text(text)
430 |
431 | # normalized features
432 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
433 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
434 |
435 | # cosine similarity as logits
436 | logit_scale = self.logit_scale.exp()
437 | logits_per_image = logit_scale * image_features @ text_features.t()
438 | logits_per_text = logit_scale * text_features @ image_features.t()
439 |
440 | # shape = [global_batch_size, global_batch_size]
441 | return logits_per_image, logits_per_text
442 |
443 |
444 | def convert_weights(model: nn.Module):
445 | """Convert applicable model parameters to fp16"""
446 |
447 | def _convert_weights_to_fp16(l):
448 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
449 | l.weight.data = l.weight.data.half()
450 | if l.bias is not None:
451 | l.bias.data = l.bias.data.half()
452 |
453 | if isinstance(l, nn.MultiheadAttention):
454 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
455 | tensor = getattr(l, attr)
456 | if tensor is not None:
457 | tensor.data = tensor.data.half()
458 |
459 | for name in ["text_projection", "proj"]:
460 | if hasattr(l, name):
461 | attr = getattr(l, name)
462 | if attr is not None:
463 | attr.data = attr.data.half()
464 |
465 | model.apply(_convert_weights_to_fp16)
466 |
467 |
468 | def build_model(state_dict: dict):
469 | vit = "visual.proj" in state_dict
470 |
471 | if vit:
472 | vision_width = state_dict["visual.conv1.weight"].shape[0]
473 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
474 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
475 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
476 | image_resolution = vision_patch_size * grid_size
477 | else:
478 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
479 | vision_layers = tuple(counts)
480 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
481 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
482 | vision_patch_size = None
483 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
484 | image_resolution = output_width * 32
485 |
486 | embed_dim = state_dict["text_projection"].shape[1]
487 | context_length = state_dict["positional_embedding"].shape[0]
488 | vocab_size = state_dict["token_embedding.weight"].shape[0]
489 | transformer_width = state_dict["ln_final.weight"].shape[0]
490 | transformer_heads = transformer_width // 64
491 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
492 |
493 | model = CLIP(
494 | embed_dim,
495 | image_resolution, vision_layers, vision_width, vision_patch_size,
496 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
497 | )
498 |
499 | for key in ["input_resolution", "context_length", "vocab_size"]:
500 | if key in state_dict:
501 | del state_dict[key]
502 |
503 | convert_weights(model)
504 | model.load_state_dict(state_dict)
505 | return model.eval()
506 |
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/configs/datasets/caltech101.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "Caltech101"
3 |
--------------------------------------------------------------------------------
/configs/datasets/dtd.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "DescribableTextures"
3 |
--------------------------------------------------------------------------------
/configs/datasets/eurosat.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "EuroSAT"
3 |
--------------------------------------------------------------------------------
/configs/datasets/fgvc_aircraft.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "FGVCAircraft"
3 |
--------------------------------------------------------------------------------
/configs/datasets/food101.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "Food101"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenet.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNet"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenet_a.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNetA"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenet_r.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNetR"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenet_sketch.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNetSketch"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenetv2.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNetV2"
3 |
--------------------------------------------------------------------------------
/configs/datasets/oxford_flowers.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "OxfordFlowers"
--------------------------------------------------------------------------------
/configs/datasets/oxford_pets.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "OxfordPets"
--------------------------------------------------------------------------------
/configs/datasets/stanford_cars.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "StanfordCars"
3 |
--------------------------------------------------------------------------------
/configs/datasets/sun397.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SUN397"
3 |
--------------------------------------------------------------------------------
/configs/datasets/ucf101.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "UCF101"
3 |
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/vit_b16_c16_ep10_batch1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 16
34 | CTX_INIT: ""
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 4
34 | CTX_INIT: ""
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 1
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/vit_b16_c8_ep10_batch1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 8
34 | CTX_INIT: ""
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn101.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN101"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn101_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN101"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50_ctxv1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
30 |
31 | TRAINER:
32 | COOP:
33 | CTX_INIT: "a photo of a"
34 |
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50_ep100.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
30 |
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50_ep50_ctxv1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
30 |
31 | TRAINER:
32 | COOP:
33 | CTX_INIT: "a photo of a"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50_val.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 200
4 | TEST:
5 | BATCH_SIZE: 200
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | MODEL:
16 | BACKBONE:
17 | NAME: "RN50"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/vit_b16_ep10.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/vit_b16_ep100.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/vit_b16_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/vit_b32.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/32"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/vit_b32_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/32"
--------------------------------------------------------------------------------
/configs/trainers/CoOp_testtime/vit_b16_ep10.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoOp_testtime/vit_b16_ep200.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoOp_testtime/vit_b16_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/Unified/vit_b16_ep100.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
36 |
--------------------------------------------------------------------------------
/configs/trainers/Unified/vit_b16_ep200.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
36 |
--------------------------------------------------------------------------------
/configs/trainers/Unified/vit_b16_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
36 |
--------------------------------------------------------------------------------
/configs/trainers/VPT/vit_b16_ep10.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/VPT/vit_b16_ep100.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/VPT/vit_b16_ep200.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/VPT/vit_b16_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/VPT_deep/vit_b16_ep10.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 1 # 100
6 | NUM_WORKERS: 0 # 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/VPT_deep/vit_b16_ep200.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/VPT_deep/vit_b16_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/VPT_shallow/vit_b16_ep10.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 1 # 100
6 | NUM_WORKERS: 0 # 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/VPT_shallow/vit_b16_ep200.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/VPT_shallow/vit_b16_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/VPT_testtime/vit_b16_ep10.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/VPT_testtime/vit_b16_ep200.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/VPT_testtime/vit_b16_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
1 | coop_testtime: 包含了 coop 的结果: (82.69, 63.29, 71.70)
2 |
3 | VPT: 包含了 vpt-shallow 的结果 (prompt length = 4): (80.17, 70.06, 74.78)
4 | VPT_testtime: 包含了 vpt-mix 的结果 (prompt length = 4): (82.08, 69.10, 75.03)
5 | VPT_deep: 包含了 vpt-deep 的结果 (prompt length = 4): (83.64, 67.31, 74.59)
6 |
7 | Unified: 包含了 vpt-deep + coop, seperate 的结果
8 | Unified_v2: 包含了 vpt-mix + coop, seperate 的结果
9 | Unified_v3: 包含了 vpt-deep + coop, shared 的结果 (10e: 80.40, 74.24, 77.20; 200e: 84.35, 64.86, 73.33)
10 | Unified_v4: 包含了 vpt-deep + coop, DETR encoder 的结果 (200e: 74.56)
11 | Unified_v5: 包含了 vpt-deep + coop self-attn 的结果 (10e: 76.99; 200e: 74.20)
--------------------------------------------------------------------------------
/scripts/cocoop/base2new_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=CoOp_testtime
8 |
9 | DATASET=stanford_cars
10 | SEED=1
11 |
12 | CFG=vit_b16_ep50
13 | SHOTS=16
14 | LOADEP=50
15 | SUB=new
16 |
17 |
18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR}
21 | if [ -d "$DIR" ]; then
22 | echo "Results are available in ${DIR}. Skip this job"
23 | else
24 | echo "Run this job and save the output to ${DIR}"
25 |
26 | python train.py \
27 | --root ${DATA} \
28 | --seed ${SEED} \
29 | --trainer ${TRAINER} \
30 | --dataset-config-file configs/datasets/${DATASET}.yaml \
31 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
32 | --output-dir ${DIR} \
33 | --model-dir ${MODEL_DIR} \
34 | --load-epoch ${LOADEP} \
35 | --eval-only \
36 | DATASET.NUM_SHOTS ${SHOTS} \
37 | DATASET.SUBSAMPLE_CLASSES ${SUB}
38 | fi
--------------------------------------------------------------------------------
/scripts/cocoop/xd_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=/path/to/datasets
7 | TRAINER=CoCoOp
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c4_ep10_batch1_ctxv1
13 | SHOTS=16
14 |
15 |
16 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}. Skip this job"
19 | else
20 | echo "Run this job and save the output to ${DIR}"
21 |
22 | python train.py \
23 | --root ${DATA} \
24 | --seed ${SEED} \
25 | --trainer ${TRAINER} \
26 | --dataset-config-file configs/datasets/${DATASET}.yaml \
27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
28 | --output-dir ${DIR} \
29 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \
30 | --load-epoch 10 \
31 | --eval-only
32 | fi
--------------------------------------------------------------------------------
/scripts/cocoop/xd_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=CoCoOp
8 |
9 | DATASET=caltech101
10 | SEED=1
11 |
12 | CFG=vit_b16_c4_ep10_batch1_ctxv1
13 | SHOTS=16
14 |
15 |
16 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}. Skip this job"
19 | else
20 | echo "Run this job and save the output to ${DIR}"
21 |
22 | CUDA_VISIBLE_DEVICES=1 python train.py \
23 | --root ${DATA} \
24 | --seed ${SEED} \
25 | --trainer ${TRAINER} \
26 | --dataset-config-file configs/datasets/${DATASET}.yaml \
27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
28 | --output-dir ${DIR} \
29 | DATASET.NUM_SHOTS ${SHOTS}
30 | fi
--------------------------------------------------------------------------------
/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=CoOp
8 |
9 | DATASET=$1
10 | CFG=$2
11 | CTP=$3
12 | NCTX=$4
13 | SHOTS=$5
14 | CSC=$6
15 |
16 | # for SEED in 1 2 3
17 | for SEED in 1
18 | do
19 | CUDA_VISIBLE_DEVICES=1; python train.py \
20 | --root ${DATA} \
21 | --seed ${SEED} \
22 | --trainer ${TRAINER} \
23 | --dataset-config-file configs/datasets/${DATASET}.yaml \
24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
25 | --output-dir output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/${DATASET}/seed${SEED} \
26 | --model-dir output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} \
27 | --eval-only \
28 | TRAINER.COOP.N_CTX ${NCTX} \
29 | TRAINER.COOP.CSC ${CSC} \
30 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP}
31 | done
32 |
--------------------------------------------------------------------------------
/scripts/imagenet_coop.sh:
--------------------------------------------------------------------------------
1 | bash ./main.sh imagenet vit_b16_ep50 end 16 16 False
2 |
--------------------------------------------------------------------------------
/scripts/imagenet_zero.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=5 bash zeroshot.sh imagenet vit_b16
2 |
--------------------------------------------------------------------------------
/scripts/main.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ..
4 |
5 | # custom config
6 | DATA=/path/to/datasets
7 | TRAINER=CoOp
8 |
9 | DATASET=$1
10 | CFG=$2 # config file
11 | CTP=$3 # class token position (end or middle)
12 | NCTX=$4 # number of context tokens
13 | SHOTS=$5 # number of shots (1, 2, 4, 8, 16)
14 | CSC=$6 # class-specific context (False or True)
15 |
16 | for SEED in 1 2 3
17 | do
18 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED}
19 | if [ -d "$DIR" ]; then
20 | echo "Results are available in ${DIR}. Skip this job"
21 | else
22 | echo "Run this job and save the output to ${DIR}"
23 | python train.py \
24 | --root ${DATA} \
25 | --seed ${SEED} \
26 | --trainer ${TRAINER} \
27 | --dataset-config-file configs/datasets/${DATASET}.yaml \
28 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
29 | --output-dir ${DIR} \
30 | TRAINER.COOP.N_CTX ${NCTX} \
31 | TRAINER.COOP.CSC ${CSC} \
32 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
33 | DATASET.NUM_SHOTS ${SHOTS}
34 | fi
35 | done
--------------------------------------------------------------------------------
/scripts/unified/base2new_test_coop.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=CoOp
8 |
9 | DATASET=imagenet
10 |
11 | CFG=vit_b16_ep50
12 | SHOTS=16
13 | LOADEP=50
14 | SUB=base
15 |
16 | for SEED in 1
17 | do
18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR}
21 | CUDA_VISIBLE_DEVICES=0 python train.py \
22 | --root ${DATA} \
23 | --seed ${SEED} \
24 | --trainer ${TRAINER} \
25 | --dataset-config-file configs/datasets/${DATASET}.yaml \
26 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
27 | --output-dir ${DIR} \
28 | --model-dir ${MODEL_DIR} \
29 | --load-epoch ${LOADEP} \
30 | --eval-only \
31 | DATASET.NUM_SHOTS ${SHOTS} \
32 | DATASET.SUBSAMPLE_CLASSES ${SUB}
33 | done
34 |
--------------------------------------------------------------------------------
/scripts/unified/base2new_test_new.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=ucf101
10 |
11 | CFG=vit_b16_ep200
12 | SHOTS=16
13 | LOADEP=200
14 | SUB=new
15 |
16 | for SEED in 1
17 | do
18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR}
21 | CUDA_VISIBLE_DEVICES=5 python train.py \
22 | --root ${DATA} \
23 | --seed ${SEED} \
24 | --trainer ${TRAINER} \
25 | --dataset-config-file configs/datasets/${DATASET}.yaml \
26 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
27 | --output-dir ${DIR} \
28 | --model-dir ${MODEL_DIR} \
29 | --load-epoch ${LOADEP} \
30 | --eval-only \
31 | DATASET.NUM_SHOTS ${SHOTS} \
32 | DATASET.SUBSAMPLE_CLASSES ${SUB}
33 | done
34 |
--------------------------------------------------------------------------------
/scripts/unified/base2new_test_new2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=ucf101
10 |
11 | CFG=vit_b16_ep200
12 | SHOTS=16
13 | LOADEP=200
14 | SUB=new
15 |
16 | for SEED in 2
17 | do
18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR}
21 | CUDA_VISIBLE_DEVICES=6 python train.py \
22 | --root ${DATA} \
23 | --seed ${SEED} \
24 | --trainer ${TRAINER} \
25 | --dataset-config-file configs/datasets/${DATASET}.yaml \
26 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
27 | --output-dir ${DIR} \
28 | --model-dir ${MODEL_DIR} \
29 | --load-epoch ${LOADEP} \
30 | --eval-only \
31 | DATASET.NUM_SHOTS ${SHOTS} \
32 | DATASET.SUBSAMPLE_CLASSES ${SUB}
33 | done
34 |
--------------------------------------------------------------------------------
/scripts/unified/base2new_test_new3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=ucf101
10 |
11 | CFG=vit_b16_ep200
12 | SHOTS=16
13 | LOADEP=200
14 | SUB=new
15 |
16 | for SEED in 3
17 | do
18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR}
21 | CUDA_VISIBLE_DEVICES=7 python train.py \
22 | --root ${DATA} \
23 | --seed ${SEED} \
24 | --trainer ${TRAINER} \
25 | --dataset-config-file configs/datasets/${DATASET}.yaml \
26 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
27 | --output-dir ${DIR} \
28 | --model-dir ${MODEL_DIR} \
29 | --load-epoch ${LOADEP} \
30 | --eval-only \
31 | DATASET.NUM_SHOTS ${SHOTS} \
32 | DATASET.SUBSAMPLE_CLASSES ${SUB}
33 | done
34 |
--------------------------------------------------------------------------------
/scripts/unified/base2new_train_caltech-101.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=caltech101
10 |
11 | CFG=vit_b16_ep10
12 | SHOTS=16
13 |
14 | for SEED in 3
15 | do
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | CUDA_VISIBLE_DEVICES=7 python train.py \
18 | --root ${DATA} \
19 | --seed ${SEED} \
20 | --trainer ${TRAINER} \
21 | --dataset-config-file configs/datasets/${DATASET}.yaml \
22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
23 | --output-dir ${DIR} \
24 | DATASET.NUM_SHOTS ${SHOTS} \
25 | DATASET.SUBSAMPLE_CLASSES base
26 | done
27 |
--------------------------------------------------------------------------------
/scripts/unified/base2new_train_dtd.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=dtd
10 |
11 | CFG=vit_b16_ep200
12 | SHOTS=16
13 |
14 | for SEED in 3
15 | do
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | CUDA_VISIBLE_DEVICES=7 python train.py \
18 | --root ${DATA} \
19 | --seed ${SEED} \
20 | --trainer ${TRAINER} \
21 | --dataset-config-file configs/datasets/${DATASET}.yaml \
22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
23 | --output-dir ${DIR} \
24 | DATASET.NUM_SHOTS ${SHOTS} \
25 | DATASET.SUBSAMPLE_CLASSES base
26 | done
--------------------------------------------------------------------------------
/scripts/unified/base2new_train_eurosat.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=eurosat
10 |
11 | CFG=vit_b16_ep200
12 | SHOTS=16
13 |
14 | for SEED in 3
15 | do
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | CUDA_VISIBLE_DEVICES=7 python train.py \
18 | --root ${DATA} \
19 | --seed ${SEED} \
20 | --trainer ${TRAINER} \
21 | --dataset-config-file configs/datasets/${DATASET}.yaml \
22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
23 | --output-dir ${DIR} \
24 | DATASET.NUM_SHOTS ${SHOTS} \
25 | DATASET.SUBSAMPLE_CLASSES base
26 | done
27 |
--------------------------------------------------------------------------------
/scripts/unified/base2new_train_fgvc_aircraft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=fgvc_aircraft
10 |
11 | CFG=vit_b16_ep200
12 | SHOTS=16
13 |
14 | for SEED in 1
15 | do
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | CUDA_VISIBLE_DEVICES=5 python train.py \
18 | --root ${DATA} \
19 | --seed ${SEED} \
20 | --trainer ${TRAINER} \
21 | --dataset-config-file configs/datasets/${DATASET}.yaml \
22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
23 | --output-dir ${DIR} \
24 | DATASET.NUM_SHOTS ${SHOTS} \
25 | DATASET.SUBSAMPLE_CLASSES base
26 | done
--------------------------------------------------------------------------------
/scripts/unified/base2new_train_food101.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=food101
10 |
11 | CFG=vit_b16_ep200
12 | SHOTS=16
13 |
14 | for SEED in 3
15 | do
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | CUDA_VISIBLE_DEVICES=7 python train.py \
18 | --root ${DATA} \
19 | --seed ${SEED} \
20 | --trainer ${TRAINER} \
21 | --dataset-config-file configs/datasets/${DATASET}.yaml \
22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
23 | --output-dir ${DIR} \
24 | DATASET.NUM_SHOTS ${SHOTS} \
25 | DATASET.SUBSAMPLE_CLASSES base
26 | done
--------------------------------------------------------------------------------
/scripts/unified/base2new_train_imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=imagenet
10 |
11 | CFG=vit_b16_ep50
12 | SHOTS=16
13 |
14 | for SEED in 3
15 | do
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | CUDA_VISIBLE_DEVICES=2 python train.py \
18 | --root ${DATA} \
19 | --seed ${SEED} \
20 | --trainer ${TRAINER} \
21 | --dataset-config-file configs/datasets/${DATASET}.yaml \
22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
23 | --output-dir ${DIR} \
24 | DATASET.NUM_SHOTS ${SHOTS} \
25 | DATASET.SUBSAMPLE_CLASSES base
26 | done
--------------------------------------------------------------------------------
/scripts/unified/base2new_train_oxford_flowers.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=oxford_flowers
10 |
11 | CFG=vit_b16_ep200
12 | SHOTS=16
13 |
14 | for SEED in 1
15 | do
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | CUDA_VISIBLE_DEVICES=5 python train.py \
18 | --root ${DATA} \
19 | --seed ${SEED} \
20 | --trainer ${TRAINER} \
21 | --dataset-config-file configs/datasets/${DATASET}.yaml \
22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
23 | --output-dir ${DIR} \
24 | DATASET.NUM_SHOTS ${SHOTS} \
25 | DATASET.SUBSAMPLE_CLASSES base
26 | done
--------------------------------------------------------------------------------
/scripts/unified/base2new_train_oxford_pets.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=oxford_pets
10 |
11 | CFG=vit_b16_ep200
12 | SHOTS=16
13 |
14 | for SEED in 3
15 | do
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | CUDA_VISIBLE_DEVICES=7 python train.py \
18 | --root ${DATA} \
19 | --seed ${SEED} \
20 | --trainer ${TRAINER} \
21 | --dataset-config-file configs/datasets/${DATASET}.yaml \
22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
23 | --output-dir ${DIR} \
24 | DATASET.NUM_SHOTS ${SHOTS} \
25 | DATASET.SUBSAMPLE_CLASSES base
26 | done
--------------------------------------------------------------------------------
/scripts/unified/base2new_train_stanford_cars.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=stanford_cars
10 |
11 | CFG=vit_b16_ep200
12 | SHOTS=16
13 |
14 | for SEED in 1
15 | do
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | CUDA_VISIBLE_DEVICES=5 python train.py \
18 | --root ${DATA} \
19 | --seed ${SEED} \
20 | --trainer ${TRAINER} \
21 | --dataset-config-file configs/datasets/${DATASET}.yaml \
22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
23 | --output-dir ${DIR} \
24 | DATASET.NUM_SHOTS ${SHOTS} \
25 | DATASET.SUBSAMPLE_CLASSES base
26 | done
--------------------------------------------------------------------------------
/scripts/unified/base2new_train_sun397.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=sun397
10 |
11 | CFG=vit_b16_ep200
12 | SHOTS=16
13 |
14 | for SEED in 3
15 | do
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | CUDA_VISIBLE_DEVICES=7 python train.py \
18 | --root ${DATA} \
19 | --seed ${SEED} \
20 | --trainer ${TRAINER} \
21 | --dataset-config-file configs/datasets/${DATASET}.yaml \
22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
23 | --output-dir ${DIR} \
24 | DATASET.NUM_SHOTS ${SHOTS} \
25 | DATASET.SUBSAMPLE_CLASSES base
26 | done
--------------------------------------------------------------------------------
/scripts/unified/base2new_train_ucf101.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=Unified
8 |
9 | DATASET=ucf101
10 |
11 | CFG=vit_b16_ep200
12 | SHOTS=16
13 |
14 | for SEED in 3
15 | do
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | CUDA_VISIBLE_DEVICES=7 python train.py \
18 | --root ${DATA} \
19 | --seed ${SEED} \
20 | --trainer ${TRAINER} \
21 | --dataset-config-file configs/datasets/${DATASET}.yaml \
22 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
23 | --output-dir ${DIR} \
24 | DATASET.NUM_SHOTS ${SHOTS} \
25 | DATASET.SUBSAMPLE_CLASSES base
26 | done
--------------------------------------------------------------------------------
/scripts/zeroshot.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=ZeroshotCLIP
8 | DATASET=$1
9 | CFG=$2 # rn50, rn101, vit_b32 or vit_b16
10 |
11 | python train.py \
12 | --root ${DATA} \
13 | --trainer ${TRAINER} \
14 | --dataset-config-file configs/datasets/${DATASET}.yaml \
15 | --config-file configs/trainers/CoOp/${CFG}.yaml \
16 | --output-dir output/${TRAINER}/${CFG}/${DATASET} \
17 | --eval-only
18 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from dassl.utils import setup_logger, set_random_seed, collect_env_info
5 | from dassl.config import get_cfg_default
6 | from dassl.engine import build_trainer
7 |
8 | # custom
9 | import datasets.oxford_pets
10 | import datasets.oxford_flowers
11 | import datasets.fgvc_aircraft
12 | import datasets.dtd
13 | import datasets.eurosat
14 | import datasets.stanford_cars
15 | import datasets.food101
16 | import datasets.sun397
17 | import datasets.caltech101
18 | import datasets.ucf101
19 | import datasets.imagenet
20 |
21 | import datasets.imagenet_sketch
22 | import datasets.imagenetv2
23 | import datasets.imagenet_a
24 | import datasets.imagenet_r
25 |
26 | import trainers.coop
27 | import trainers.cocoop
28 | import trainers.zsclip
29 | import trainers.coop_testtime
30 | import trainers.vpt
31 | import trainers.vpt_testtime
32 | import trainers.vpt_deep
33 | import trainers.vpt_shallow
34 | import trainers.unified
35 | import trainers.unified_v2
36 | import trainers.unified_v3
37 | import trainers.unified_v4
38 | import trainers.unified_v5
39 | import trainers.unified_v6
40 |
41 |
42 | def print_args(args, cfg):
43 | print("***************")
44 | print("** Arguments **")
45 | print("***************")
46 | optkeys = list(args.__dict__.keys())
47 | optkeys.sort()
48 | for key in optkeys:
49 | print("{}: {}".format(key, args.__dict__[key]))
50 | print("************")
51 | print("** Config **")
52 | print("************")
53 | print(cfg)
54 |
55 |
56 | def reset_cfg(cfg, args):
57 | if args.root:
58 | cfg.DATASET.ROOT = args.root
59 |
60 | if args.output_dir:
61 | cfg.OUTPUT_DIR = args.output_dir
62 |
63 | if args.resume:
64 | cfg.RESUME = args.resume
65 |
66 | if args.seed:
67 | cfg.SEED = args.seed
68 |
69 | if args.source_domains:
70 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains
71 |
72 | if args.target_domains:
73 | cfg.DATASET.TARGET_DOMAINS = args.target_domains
74 |
75 | if args.transforms:
76 | cfg.INPUT.TRANSFORMS = args.transforms
77 |
78 | if args.trainer:
79 | cfg.TRAINER.NAME = args.trainer
80 |
81 | if args.backbone:
82 | cfg.MODEL.BACKBONE.NAME = args.backbone
83 |
84 | if args.head:
85 | cfg.MODEL.HEAD.NAME = args.head
86 |
87 |
88 | def extend_cfg(cfg):
89 | """
90 | Add new config variables.
91 |
92 | E.g.
93 | from yacs.config import CfgNode as CN
94 | cfg.TRAINER.MY_MODEL = CN()
95 | cfg.TRAINER.MY_MODEL.PARAM_A = 1.
96 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
97 | cfg.TRAINER.MY_MODEL.PARAM_C = False
98 | """
99 | from yacs.config import CfgNode as CN
100 |
101 | cfg.TRAINER.COOP = CN()
102 | cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors
103 | cfg.TRAINER.COOP.CSC = False # class-specific context
104 | cfg.TRAINER.COOP.CTX_INIT = "" # initialization words
105 | cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp
106 | cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front'
107 |
108 | cfg.TRAINER.COCOOP = CN()
109 | cfg.TRAINER.COCOOP.N_CTX = 16 # number of context vectors
110 | cfg.TRAINER.COCOOP.CTX_INIT = "" # initialization words
111 | cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp
112 |
113 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
114 |
115 |
116 | def setup_cfg(args):
117 | cfg = get_cfg_default()
118 | extend_cfg(cfg)
119 |
120 | # 1. From the dataset config file
121 | if args.dataset_config_file:
122 | cfg.merge_from_file(args.dataset_config_file)
123 |
124 | # 2. From the method config file
125 | if args.config_file:
126 | cfg.merge_from_file(args.config_file)
127 |
128 | # 3. From input arguments
129 | reset_cfg(cfg, args)
130 |
131 | # 4. From optional input arguments
132 | cfg.merge_from_list(args.opts)
133 |
134 | cfg.freeze()
135 |
136 | return cfg
137 |
138 |
139 | def main(args):
140 | cfg = setup_cfg(args)
141 | if cfg.SEED >= 0:
142 | print("Setting fixed seed: {}".format(cfg.SEED))
143 | set_random_seed(cfg.SEED)
144 | setup_logger(cfg.OUTPUT_DIR)
145 |
146 | if torch.cuda.is_available() and cfg.USE_CUDA:
147 | torch.backends.cudnn.benchmark = True
148 |
149 | print_args(args, cfg)
150 | print("Collecting env info ...")
151 | print("** System info **\n{}\n".format(collect_env_info()))
152 |
153 | trainer = build_trainer(cfg)
154 |
155 | if args.eval_only:
156 | trainer.load_model(args.model_dir, epoch=args.load_epoch)
157 | trainer.test()
158 | return
159 |
160 | if not args.no_train:
161 | trainer.train()
162 |
163 |
164 | if __name__ == "__main__":
165 | parser = argparse.ArgumentParser()
166 | parser.add_argument("--root", type=str, default="", help="path to dataset")
167 | parser.add_argument("--output-dir", type=str, default="", help="output directory")
168 | parser.add_argument(
169 | "--resume",
170 | type=str,
171 | default="",
172 | help="checkpoint directory (from which the training resumes)",
173 | )
174 | parser.add_argument(
175 | "--seed", type=int, default=-1, help="only positive value enables a fixed seed"
176 | )
177 | parser.add_argument(
178 | "--source-domains", type=str, nargs="+", help="source domains for DA/DG"
179 | )
180 | parser.add_argument(
181 | "--target-domains", type=str, nargs="+", help="target domains for DA/DG"
182 | )
183 | parser.add_argument(
184 | "--transforms", type=str, nargs="+", help="data augmentation methods"
185 | )
186 | parser.add_argument(
187 | "--config-file", type=str, default="", help="path to config file"
188 | )
189 | parser.add_argument(
190 | "--dataset-config-file",
191 | type=str,
192 | default="",
193 | help="path to config file for dataset setup",
194 | )
195 | parser.add_argument("--trainer", type=str, default="", help="name of trainer")
196 | parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone")
197 | parser.add_argument("--head", type=str, default="", help="name of head")
198 | parser.add_argument("--eval-only", action="store_true", help="evaluation only")
199 | parser.add_argument(
200 | "--model-dir",
201 | type=str,
202 | default="",
203 | help="load model from this directory for eval-only mode",
204 | )
205 | parser.add_argument(
206 | "--load-epoch", type=int, help="load model weights at this epoch for evaluation"
207 | )
208 | parser.add_argument(
209 | "--no-train", action="store_true", help="do not call trainer.train()"
210 | )
211 | parser.add_argument(
212 | "opts",
213 | default=None,
214 | nargs=argparse.REMAINDER,
215 | help="modify config options using the command-line",
216 | )
217 | args = parser.parse_args()
218 | main(args)
219 |
--------------------------------------------------------------------------------
/trainers/.flake8:
--------------------------------------------------------------------------------
1 | # This is an example .flake8 config, used when developing *Black* itself.
2 | # Keep in sync with setup.cfg which is used for source packages.
3 |
4 | [flake8]
5 | ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811
6 | max-line-length = 100
7 | max-complexity = 18
8 | select = B,C,E,F,W,T4,B9
9 | exclude = build
10 | per-file-ignores =
11 | **/__init__.py:F401,F403,E402
12 | **/configs/**.py:F401,E402
13 | configs/**.py:F401,E402
14 | **/tests/config/**.py:F401,E402
15 | tests/config/**.py:F401,E402
16 |
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuhangzang/UPT/3d1640fcfd2532fd651041bc955fc5baff51c71f/trainers/__init__.py
--------------------------------------------------------------------------------
/trainers/cocoop.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.cuda.amp import GradScaler, autocast
7 | from torch.nn import functional as F
8 |
9 | from clip import clip
10 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
11 | from dassl.engine import TRAINER_REGISTRY, TrainerX
12 | from dassl.optim import build_lr_scheduler, build_optimizer
13 | from dassl.utils import load_checkpoint, load_pretrained_weights
14 |
15 | _tokenizer = _Tokenizer()
16 |
17 |
18 | def load_clip_to_cpu(cfg):
19 | backbone_name = cfg.MODEL.BACKBONE.NAME
20 | url = clip._MODELS[backbone_name]
21 | model_path = clip._download(url)
22 |
23 | try:
24 | # loading JIT archive
25 | model = torch.jit.load(model_path, map_location="cpu").eval()
26 | state_dict = None
27 |
28 | except RuntimeError:
29 | state_dict = torch.load(model_path, map_location="cpu")
30 |
31 | model = clip.build_model(state_dict or model.state_dict())
32 |
33 | return model
34 |
35 |
36 | class TextEncoder(nn.Module):
37 | def __init__(self, clip_model):
38 | super().__init__()
39 | self.transformer = clip_model.transformer
40 | self.positional_embedding = clip_model.positional_embedding
41 | self.ln_final = clip_model.ln_final
42 | self.text_projection = clip_model.text_projection
43 | self.dtype = clip_model.dtype
44 |
45 | def forward(self, prompts, tokenized_prompts):
46 | x = prompts + self.positional_embedding.type(self.dtype)
47 | x = x.permute(1, 0, 2) # NLD -> LND
48 | x = self.transformer(x)
49 | x = x.permute(1, 0, 2) # LND -> NLD
50 | x = self.ln_final(x).type(self.dtype)
51 |
52 | # x.shape = [batch_size, n_ctx, transformer.width]
53 | # take features from the eot embedding (eot_token is the highest number in each sequence)
54 | x = x[torch.arange(x.shape[0]),
55 | tokenized_prompts.argmax(dim=-1)] @ self.text_projection
56 |
57 | return x
58 |
59 |
60 | class PromptLearner(nn.Module):
61 | def __init__(self, cfg, classnames, clip_model):
62 | super().__init__()
63 | n_cls = len(classnames)
64 | n_ctx = cfg.TRAINER.COCOOP.N_CTX
65 | ctx_init = cfg.TRAINER.COCOOP.CTX_INIT
66 | dtype = clip_model.dtype
67 | ctx_dim = clip_model.ln_final.weight.shape[0]
68 | vis_dim = clip_model.visual.output_dim
69 | clip_imsize = clip_model.visual.input_resolution
70 | cfg_imsize = cfg.INPUT.SIZE[0]
71 | assert cfg_imsize == clip_imsize
72 | # f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
73 |
74 | if ctx_init:
75 | # use given words to initialize context vectors
76 | ctx_init = ctx_init.replace("_", " ")
77 | n_ctx = len(ctx_init.split(" "))
78 | prompt = clip.tokenize(ctx_init)
79 | with torch.no_grad():
80 | embedding = clip_model.token_embedding(prompt).type(dtype)
81 | ctx_vectors = embedding[0, 1:1 + n_ctx, :]
82 | prompt_prefix = ctx_init
83 | else:
84 | # random initialization
85 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
86 | nn.init.normal_(ctx_vectors, std=0.02)
87 | prompt_prefix = " ".join(["X"] * n_ctx)
88 |
89 | print(f'Initial context: "{prompt_prefix}"')
90 | print(f"Number of context words (tokens): {n_ctx}")
91 |
92 | self.ctx = nn.Parameter(ctx_vectors)
93 |
94 | self.meta_net = nn.Sequential(
95 | OrderedDict([("linear1", nn.Linear(vis_dim, vis_dim // 16)),
96 | ("relu", nn.ReLU(inplace=True)),
97 | ("linear2", nn.Linear(vis_dim // 16, ctx_dim))]))
98 |
99 | if cfg.TRAINER.COCOOP.PREC == "fp16":
100 | self.meta_net.half()
101 |
102 | classnames = [name.replace("_", " ") for name in classnames]
103 | name_lens = [len(_tokenizer.encode(name)) for name in classnames]
104 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
105 |
106 | tokenized_prompts = torch.cat([clip.tokenize(p)
107 | for p in prompts]) # (n_cls, n_tkn)
108 | with torch.no_grad():
109 | embedding = clip_model.token_embedding(tokenized_prompts).type(
110 | dtype)
111 |
112 | # These token vectors will be saved when in save_model(),
113 | # but they should be ignored in load_model() as we want to use
114 | # those computed using the current class names
115 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
116 | self.register_buffer("token_suffix",
117 | embedding[:, 1 + n_ctx:, :]) # CLS, EOS
118 |
119 | self.n_cls = n_cls
120 | self.n_ctx = n_ctx
121 | self.tokenized_prompts = tokenized_prompts # torch.Tensor
122 | self.name_lens = name_lens
123 |
124 | def construct_prompts(self, ctx, prefix, suffix, label=None):
125 | # dim0 is either batch_size (during training) or n_cls (during testing)
126 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
127 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
128 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
129 |
130 | if label is not None:
131 | prefix = prefix[label]
132 | suffix = suffix[label]
133 |
134 | prompts = torch.cat(
135 | [
136 | prefix, # (dim0, 1, dim)
137 | ctx, # (dim0, n_ctx, dim)
138 | suffix, # (dim0, *, dim)
139 | ],
140 | dim=1,
141 | )
142 |
143 | return prompts
144 |
145 | def forward(self, im_features):
146 | prefix = self.token_prefix
147 | suffix = self.token_suffix
148 | ctx = self.ctx # (n_ctx, ctx_dim)
149 | bias = self.meta_net(im_features) # (batch, ctx_dim)
150 | bias = bias.unsqueeze(1) # (batch, 1, ctx_dim)
151 | ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim)
152 | ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim)
153 |
154 | # Use instance-conditioned context tokens for all classes
155 | prompts = []
156 | for ctx_shifted_i in ctx_shifted:
157 | ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
158 | pts_i = self.construct_prompts(ctx_i, prefix,
159 | suffix) # (n_cls, n_tkn, ctx_dim)
160 | prompts.append(pts_i)
161 | prompts = torch.stack(prompts)
162 |
163 | return prompts
164 |
165 |
166 | class CustomCLIP(nn.Module):
167 | def __init__(self, cfg, classnames, clip_model):
168 | super().__init__()
169 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
170 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
171 | self.image_encoder = clip_model.visual
172 | self.text_encoder = TextEncoder(clip_model)
173 | self.logit_scale = clip_model.logit_scale
174 | self.dtype = clip_model.dtype
175 |
176 | def forward(self, image, label=None):
177 | tokenized_prompts = self.tokenized_prompts
178 | logit_scale = self.logit_scale.exp()
179 |
180 | image_features = self.image_encoder(image.type(self.dtype))
181 | image_features = image_features / image_features.norm(dim=-1,
182 | keepdim=True)
183 |
184 | prompts = self.prompt_learner(image_features)
185 |
186 | logits = []
187 | for pts_i, imf_i in zip(prompts, image_features):
188 | text_features = self.text_encoder(pts_i, tokenized_prompts)
189 | text_features = text_features / text_features.norm(dim=-1,
190 | keepdim=True)
191 | l_i = logit_scale * imf_i @ text_features.t()
192 | logits.append(l_i)
193 | logits = torch.stack(logits)
194 |
195 | if self.prompt_learner.training:
196 | return F.cross_entropy(logits, label)
197 |
198 | return logits
199 |
200 |
201 | @TRAINER_REGISTRY.register()
202 | class CoCoOp(TrainerX):
203 | def check_cfg(self, cfg):
204 | assert cfg.TRAINER.COCOOP.PREC in ["fp16", "fp32", "amp"]
205 |
206 | def build_model(self):
207 | cfg = self.cfg
208 | classnames = self.dm.dataset.classnames
209 |
210 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
211 | clip_model = load_clip_to_cpu(cfg)
212 |
213 | if cfg.TRAINER.COCOOP.PREC == "fp32" or cfg.TRAINER.COCOOP.PREC == "amp":
214 | # CLIP's default precision is fp16
215 | clip_model.float()
216 |
217 | print("Building custom CLIP")
218 | self.model = CustomCLIP(cfg, classnames, clip_model)
219 |
220 | print("Turning off gradients in both the image and the text encoder")
221 | name_to_update = "prompt_learner"
222 |
223 | for name, param in self.model.named_parameters():
224 | if name_to_update not in name:
225 | param.requires_grad_(False)
226 |
227 | # Double check
228 | enabled = set()
229 | for name, param in self.model.named_parameters():
230 | if param.requires_grad:
231 | enabled.add(name)
232 | print(f"Parameters to be updated: {enabled}")
233 |
234 | if cfg.MODEL.INIT_WEIGHTS:
235 | load_pretrained_weights(self.model.prompt_learner,
236 | cfg.MODEL.INIT_WEIGHTS)
237 |
238 | self.model.to(self.device)
239 | # NOTE: only give prompt_learner to the optimizer
240 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
241 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
242 | self.register_model("prompt_learner", self.model.prompt_learner,
243 | self.optim, self.sched)
244 |
245 | self.scaler = GradScaler(
246 | ) if cfg.TRAINER.COCOOP.PREC == "amp" else None
247 |
248 | # Note that multi-gpu training could be slow because CLIP's size is
249 | # big, which slows down the copy operation in DataParallel
250 | device_count = torch.cuda.device_count()
251 | if device_count > 1:
252 | print(
253 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!"
254 | )
255 | self.model = nn.DataParallel(self.model)
256 |
257 | def forward_backward(self, batch):
258 | image, label = self.parse_batch_train(batch)
259 |
260 | model = self.model
261 | optim = self.optim
262 | scaler = self.scaler
263 |
264 | prec = self.cfg.TRAINER.COCOOP.PREC
265 | if prec == "amp":
266 | with autocast():
267 | loss = model(image, label)
268 | optim.zero_grad()
269 | scaler.scale(loss).backward()
270 | scaler.step(optim)
271 | scaler.update()
272 | else:
273 | loss = model(image, label)
274 | optim.zero_grad()
275 | loss.backward()
276 | optim.step()
277 |
278 | loss_summary = {"loss": loss.item()}
279 |
280 | if (self.batch_idx + 1) == self.num_batches:
281 | self.update_lr()
282 |
283 | return loss_summary
284 |
285 | def parse_batch_train(self, batch):
286 | input = batch["img"]
287 | label = batch["label"]
288 | input = input.to(self.device)
289 | label = label.to(self.device)
290 | return input, label
291 |
292 | def load_model(self, directory, epoch=None):
293 | if not directory:
294 | print(
295 | "Note that load_model() is skipped as no pretrained model is given"
296 | )
297 | return
298 |
299 | names = self.get_model_names()
300 |
301 | # By default, the best model is loaded
302 | model_file = "model-best.pth.tar"
303 |
304 | if epoch is not None:
305 | model_file = "model.pth.tar-" + str(epoch)
306 |
307 | for name in names:
308 | model_path = osp.join(directory, name, model_file)
309 |
310 | if not osp.exists(model_path):
311 | raise FileNotFoundError(
312 | 'Model not found at "{}"'.format(model_path))
313 |
314 | checkpoint = load_checkpoint(model_path)
315 | state_dict = checkpoint["state_dict"]
316 | epoch = checkpoint["epoch"]
317 |
318 | # Ignore fixed token vectors
319 | if "token_prefix" in state_dict:
320 | del state_dict["token_prefix"]
321 |
322 | if "token_suffix" in state_dict:
323 | del state_dict["token_suffix"]
324 |
325 | print("Loading weights to {} "
326 | 'from "{}" (epoch = {})'.format(name, model_path, epoch))
327 | # set strict=False
328 | self._models[name].load_state_dict(state_dict, strict=False)
329 |
--------------------------------------------------------------------------------
/trainers/coop.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.cuda.amp import GradScaler, autocast
6 | from torch.nn import functional as F
7 |
8 | from clip import clip
9 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
10 | from dassl.engine import TRAINER_REGISTRY, TrainerX
11 | from dassl.metrics import compute_accuracy
12 | from dassl.optim import build_lr_scheduler, build_optimizer
13 | from dassl.utils import load_checkpoint, load_pretrained_weights
14 |
15 | _tokenizer = _Tokenizer()
16 |
17 |
18 | def load_clip_to_cpu(cfg):
19 | backbone_name = cfg.MODEL.BACKBONE.NAME
20 | url = clip._MODELS[backbone_name]
21 | model_path = clip._download(url)
22 |
23 | try:
24 | # loading JIT archive
25 | model = torch.jit.load(model_path, map_location="cpu").eval()
26 | state_dict = None
27 |
28 | except RuntimeError:
29 | state_dict = torch.load(model_path, map_location="cpu")
30 |
31 | model = clip.build_model(state_dict or model.state_dict())
32 |
33 | return model
34 |
35 |
36 | class TextEncoder(nn.Module):
37 | def __init__(self, clip_model):
38 | super().__init__()
39 | self.transformer = clip_model.transformer
40 | self.positional_embedding = clip_model.positional_embedding
41 | self.ln_final = clip_model.ln_final
42 | self.text_projection = clip_model.text_projection
43 | self.dtype = clip_model.dtype
44 |
45 | def forward(self, prompts, tokenized_prompts):
46 | x = prompts + self.positional_embedding.type(self.dtype)
47 | x = x.permute(1, 0, 2) # NLD -> LND
48 | x = self.transformer(x)
49 | x = x.permute(1, 0, 2) # LND -> NLD
50 | x = self.ln_final(x).type(self.dtype)
51 |
52 | # x.shape = [batch_size, n_ctx, transformer.width]
53 | # take features from the eot embedding (eot_token is the highest number in each sequence)
54 | x = x[torch.arange(x.shape[0]),
55 | tokenized_prompts.argmax(dim=-1)] @ self.text_projection
56 |
57 | return x
58 |
59 |
60 | class PromptLearner(nn.Module):
61 | def __init__(self, cfg, classnames, clip_model):
62 | super().__init__()
63 | n_cls = len(classnames)
64 | n_ctx = cfg.TRAINER.COOP.N_CTX
65 | ctx_init = cfg.TRAINER.COOP.CTX_INIT
66 | dtype = clip_model.dtype
67 | ctx_dim = clip_model.ln_final.weight.shape[0]
68 | clip_imsize = clip_model.visual.input_resolution
69 | cfg_imsize = cfg.INPUT.SIZE[0]
70 | assert cfg_imsize == clip_imsize
71 | # f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
72 |
73 | if ctx_init:
74 | # use given words to initialize context vectors
75 | ctx_init = ctx_init.replace("_", " ")
76 | n_ctx = len(ctx_init.split(" "))
77 | prompt = clip.tokenize(ctx_init)
78 | with torch.no_grad():
79 | embedding = clip_model.token_embedding(prompt).type(dtype)
80 | ctx_vectors = embedding[0, 1:1 + n_ctx, :]
81 | prompt_prefix = ctx_init
82 |
83 | else:
84 | # random initialization
85 | if cfg.TRAINER.COOP.CSC:
86 | print("Initializing class-specific contexts")
87 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
88 | else:
89 | print("Initializing a generic context")
90 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
91 | nn.init.normal_(ctx_vectors, std=0.02)
92 | prompt_prefix = " ".join(["X"] * n_ctx)
93 |
94 | print(f'Initial context: "{prompt_prefix}"')
95 | print(f"Number of context words (tokens): {n_ctx}")
96 |
97 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized
98 |
99 | classnames = [name.replace("_", " ") for name in classnames]
100 | name_lens = [len(_tokenizer.encode(name)) for name in classnames]
101 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
102 |
103 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
104 | with torch.no_grad():
105 | embedding = clip_model.token_embedding(tokenized_prompts).type(
106 | dtype)
107 |
108 | # These token vectors will be saved when in save_model(),
109 | # but they should be ignored in load_model() as we want to use
110 | # those computed using the current class names
111 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
112 | self.register_buffer("token_suffix",
113 | embedding[:, 1 + n_ctx:, :]) # CLS, EOS
114 |
115 | self.n_cls = n_cls
116 | self.n_ctx = n_ctx
117 | self.tokenized_prompts = tokenized_prompts # torch.Tensor
118 | self.name_lens = name_lens
119 | self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION
120 |
121 | def forward(self):
122 | ctx = self.ctx
123 | if ctx.dim() == 2:
124 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
125 |
126 | prefix = self.token_prefix
127 | suffix = self.token_suffix
128 |
129 | if self.class_token_position == "end":
130 | prompts = torch.cat(
131 | [
132 | prefix, # (n_cls, 1, dim)
133 | ctx, # (n_cls, n_ctx, dim)
134 | suffix, # (n_cls, *, dim)
135 | ],
136 | dim=1,
137 | )
138 |
139 | elif self.class_token_position == "middle":
140 | half_n_ctx = self.n_ctx // 2
141 | prompts = []
142 | for i in range(self.n_cls):
143 | name_len = self.name_lens[i]
144 | prefix_i = prefix[i:i + 1, :, :]
145 | class_i = suffix[i:i + 1, :name_len, :]
146 | suffix_i = suffix[i:i + 1, name_len:, :]
147 | ctx_i_half1 = ctx[i:i + 1, :half_n_ctx, :]
148 | ctx_i_half2 = ctx[i:i + 1, half_n_ctx:, :]
149 | prompt = torch.cat(
150 | [
151 | prefix_i, # (1, 1, dim)
152 | ctx_i_half1, # (1, n_ctx//2, dim)
153 | class_i, # (1, name_len, dim)
154 | ctx_i_half2, # (1, n_ctx//2, dim)
155 | suffix_i, # (1, *, dim)
156 | ],
157 | dim=1,
158 | )
159 | prompts.append(prompt)
160 | prompts = torch.cat(prompts, dim=0)
161 |
162 | elif self.class_token_position == "front":
163 | prompts = []
164 | for i in range(self.n_cls):
165 | name_len = self.name_lens[i]
166 | prefix_i = prefix[i:i + 1, :, :]
167 | class_i = suffix[i:i + 1, :name_len, :]
168 | suffix_i = suffix[i:i + 1, name_len:, :]
169 | ctx_i = ctx[i:i + 1, :, :]
170 | prompt = torch.cat(
171 | [
172 | prefix_i, # (1, 1, dim)
173 | class_i, # (1, name_len, dim)
174 | ctx_i, # (1, n_ctx, dim)
175 | suffix_i, # (1, *, dim)
176 | ],
177 | dim=1,
178 | )
179 | prompts.append(prompt)
180 | prompts = torch.cat(prompts, dim=0)
181 |
182 | else:
183 | raise ValueError
184 |
185 | return prompts
186 |
187 |
188 | class CustomCLIP(nn.Module):
189 | def __init__(self, cfg, classnames, clip_model):
190 | super().__init__()
191 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
192 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
193 | self.image_encoder = clip_model.visual
194 | self.text_encoder = TextEncoder(clip_model)
195 | self.logit_scale = clip_model.logit_scale
196 | self.dtype = clip_model.dtype
197 |
198 | def forward(self, image):
199 | image_features = self.image_encoder(image.type(self.dtype))
200 |
201 | prompts = self.prompt_learner()
202 | tokenized_prompts = self.tokenized_prompts
203 | text_features = self.text_encoder(prompts, tokenized_prompts)
204 |
205 | image_features = image_features / image_features.norm(dim=-1,
206 | keepdim=True)
207 | text_features = text_features / text_features.norm(dim=-1,
208 | keepdim=True)
209 |
210 | logit_scale = self.logit_scale.exp()
211 | logits = logit_scale * image_features @ text_features.t()
212 |
213 | return logits
214 |
215 |
216 | @TRAINER_REGISTRY.register()
217 | class CoOp(TrainerX):
218 | """Context Optimization (CoOp).
219 |
220 | Learning to Prompt for Vision-Language Models
221 | https://arxiv.org/abs/2109.01134
222 | """
223 | def check_cfg(self, cfg):
224 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"]
225 |
226 | def build_model(self):
227 | cfg = self.cfg
228 | classnames = self.dm.dataset.classnames
229 |
230 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
231 | clip_model = load_clip_to_cpu(cfg)
232 |
233 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp":
234 | # CLIP's default precision is fp16
235 | clip_model.float()
236 |
237 | print("Building custom CLIP")
238 | self.model = CustomCLIP(cfg, classnames, clip_model)
239 |
240 | print("Turning off gradients in both the image and the text encoder")
241 | for name, param in self.model.named_parameters():
242 | if "prompt_learner" not in name:
243 | param.requires_grad_(False)
244 |
245 | if cfg.MODEL.INIT_WEIGHTS:
246 | load_pretrained_weights(self.model.prompt_learner,
247 | cfg.MODEL.INIT_WEIGHTS)
248 |
249 | self.model.to(self.device)
250 | # NOTE: only give prompt_learner to the optimizer
251 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
252 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
253 | self.register_model("prompt_learner", self.model.prompt_learner,
254 | self.optim, self.sched)
255 |
256 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None
257 |
258 | # Note that multi-gpu training could be slow because CLIP's size is
259 | # big, which slows down the copy operation in DataParallel
260 | device_count = 1 # torch.cuda.device_count()
261 | if device_count > 1:
262 | print(
263 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!"
264 | )
265 | self.model = nn.DataParallel(self.model)
266 |
267 | def forward_backward(self, batch):
268 | image, label = self.parse_batch_train(batch)
269 |
270 | prec = self.cfg.TRAINER.COOP.PREC
271 | if prec == "amp":
272 | with autocast():
273 | output = self.model(image)
274 | loss = F.cross_entropy(output, label)
275 | self.optim.zero_grad()
276 | self.scaler.scale(loss).backward()
277 | self.scaler.step(self.optim)
278 | self.scaler.update()
279 | else:
280 | output = self.model(image)
281 | loss = F.cross_entropy(output, label)
282 | self.model_backward_and_update(loss)
283 |
284 | loss_summary = {
285 | "loss": loss.item(),
286 | "acc": compute_accuracy(output, label)[0].item(),
287 | }
288 |
289 | if (self.batch_idx + 1) == self.num_batches:
290 | self.update_lr()
291 |
292 | return loss_summary
293 |
294 | def parse_batch_train(self, batch):
295 | input = batch["img"]
296 | label = batch["label"]
297 | input = input.to(self.device)
298 | label = label.to(self.device)
299 | return input, label
300 |
301 | def load_model(self, directory, epoch=None):
302 | if not directory:
303 | print(
304 | "Note that load_model() is skipped as no pretrained model is given"
305 | )
306 | return
307 |
308 | names = self.get_model_names()
309 |
310 | # By default, the best model is loaded
311 | model_file = "model-best.pth.tar"
312 |
313 | if epoch is not None:
314 | model_file = "model.pth.tar-" + str(epoch)
315 |
316 | for name in names:
317 | model_path = osp.join(directory, name, model_file)
318 |
319 | if not osp.exists(model_path):
320 | raise FileNotFoundError(
321 | 'Model not found at "{}"'.format(model_path))
322 |
323 | checkpoint = load_checkpoint(model_path)
324 | state_dict = checkpoint["state_dict"]
325 | epoch = checkpoint["epoch"]
326 |
327 | # Ignore fixed token vectors
328 | if "token_prefix" in state_dict:
329 | del state_dict["token_prefix"]
330 |
331 | if "token_suffix" in state_dict:
332 | del state_dict["token_suffix"]
333 |
334 | print("Loading weights to {} "
335 | 'from "{}" (epoch = {})'.format(name, model_path, epoch))
336 | # set strict=False
337 | self._models[name].load_state_dict(state_dict, strict=False)
338 |
--------------------------------------------------------------------------------
/trainers/coop_testtime.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 | from PIL import ImageFilter
5 | from torchvision import transforms
6 | from tqdm import tqdm
7 |
8 | from dassl.data.data_manager import (DataManager, DatasetWrapper,
9 | build_data_loader)
10 | from dassl.data.datasets import build_dataset
11 | from dassl.data.samplers import build_sampler
12 | from dassl.data.transforms import build_transform
13 | from dassl.engine import TRAINER_REGISTRY
14 | from dassl.utils import read_image
15 |
16 | from .coop import CoOp
17 |
18 | # from copy import deepcopy
19 |
20 |
21 | class GaussianBlur(object):
22 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
23 | def __init__(self, sigma=[.1, 2.]):
24 | self.sigma = sigma
25 |
26 | def __call__(self, x):
27 | sigma = random.uniform(self.sigma[0], self.sigma[1])
28 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
29 | return x
30 |
31 |
32 | class DatasetWrapper_aug(DatasetWrapper):
33 | def __init__(self, cfg, data_source, transform=None, is_train=False):
34 | super().__init__(cfg, data_source, transform, is_train)
35 |
36 | normalize = transforms.Normalize(
37 | mean=[0.48145466, 0.4578275, 0.40821073],
38 | std=[0.26862954, 0.26130258, 0.27577711])
39 | augment = transforms.Compose([
40 | transforms.RandomResizedCrop(224, scale=(0.08, 1.)),
41 | transforms.RandomApply(
42 | [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
43 | transforms.RandomGrayscale(p=0.2),
44 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
45 | transforms.RandomHorizontalFlip(),
46 | transforms.ToTensor(),
47 | normalize,
48 | ])
49 | self.augment = augment
50 |
51 | def __getitem__(self, idx):
52 | item = self.data_source[idx]
53 |
54 | output = {
55 | "label": item.label,
56 | "domain": item.domain,
57 | "impath": item.impath
58 | }
59 |
60 | img0 = read_image(item.impath)
61 |
62 | if self.transform is not None:
63 | if isinstance(self.transform, (list, tuple)):
64 | raise NotImplementedError
65 | for i, tfm in enumerate(self.transform):
66 | img = self._transform_image(tfm, img0)
67 | output["img"] = img
68 | else:
69 | img = self._transform_image(self.transform, img0)
70 | output["img"] = img
71 | for name in ["img_aug", "img_aug2", "img_aug3"]:
72 | output[name] = self._transform_image(self.augment, img0)
73 |
74 | if self.return_img0:
75 | output["img0"] = self.to_tensor(img0)
76 |
77 | return output
78 |
79 |
80 | def build_data_loader(cfg,
81 | sampler_type='SequentialSampler',
82 | data_source=None,
83 | batch_size=64,
84 | n_domain=0,
85 | n_ins=2,
86 | tfm=None,
87 | is_train=True,
88 | dataset_wrapper=None):
89 | # Build sampler
90 | if not is_train:
91 | random.shuffle(data_source)
92 | sampler = build_sampler(sampler_type,
93 | cfg=cfg,
94 | data_source=data_source,
95 | batch_size=batch_size,
96 | n_domain=n_domain,
97 | n_ins=n_ins)
98 |
99 | if dataset_wrapper is None:
100 | dataset_wrapper = DatasetWrapper
101 |
102 | # Build data loader
103 | data_loader = torch.utils.data.DataLoader(
104 | dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train),
105 | batch_size=batch_size,
106 | sampler=sampler,
107 | num_workers=cfg.DATALOADER.NUM_WORKERS,
108 | drop_last=is_train and len(data_source) >= batch_size,
109 | pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),
110 | )
111 | assert len(data_loader) > 0
112 |
113 | return data_loader
114 |
115 |
116 | class DataManager_aug(DataManager):
117 | def __init__(self,
118 | cfg,
119 | custom_tfm_train=None,
120 | custom_tfm_test=None,
121 | dataset_wrapper=None):
122 | super().__init__(cfg, custom_tfm_train, custom_tfm_test,
123 | dataset_wrapper)
124 | # Build test_loader
125 | dataset = build_dataset(cfg)
126 |
127 | if custom_tfm_test is None:
128 | tfm_test = build_transform(cfg, is_train=False)
129 | else:
130 | print("* Using custom transform for testing")
131 | tfm_test = custom_tfm_test
132 |
133 | test_loader = build_data_loader(
134 | cfg,
135 | sampler_type=cfg.DATALOADER.TEST.SAMPLER,
136 | data_source=dataset.test,
137 | batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
138 | tfm=tfm_test,
139 | is_train=False,
140 | dataset_wrapper=DatasetWrapper_aug,
141 | )
142 | self.test_loader = test_loader
143 |
144 |
145 | def softmax_entropy(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
146 | return -(y.softmax(1) * x.log_softmax(1)).sum(1).mean(0)
147 |
148 |
149 | @TRAINER_REGISTRY.register()
150 | class CoOp_testtime(CoOp):
151 | def build_data_loader(self):
152 | dm = DataManager_aug(self.cfg)
153 |
154 | self.train_loader_x = dm.train_loader_x
155 | self.train_loader_u = dm.train_loader_u # optional, can be None
156 | self.val_loader = dm.val_loader # optional, can be None
157 | self.test_loader = dm.test_loader
158 | self.num_classes = dm.num_classes
159 | self.num_source_domains = dm.num_source_domains
160 | self.lab2cname = dm.lab2cname # dict {label: classname}
161 |
162 | self.dm = dm
163 |
164 | def test(self, split=None):
165 | """A generic testing pipeline."""
166 | # self.set_model_mode("eval")
167 | self.model.to(self.device)
168 |
169 | self._optims['prompt_learner'].param_groups[0]['lr'] = 2e-6
170 | self.evaluator.reset()
171 |
172 | if split is None:
173 | split = self.cfg.TEST.SPLIT
174 |
175 | if split == "val" and self.val_loader is not None:
176 | data_loader = self.val_loader
177 | print("Do evaluation on {} set".format(split))
178 | else:
179 | data_loader = self.test_loader
180 | print("Do evaluation on test set")
181 |
182 | # model_state = deepcopy(self.model.state_dict())
183 | for batch_idx, batch in enumerate(tqdm(data_loader)):
184 | # self.model.load_state_dict(model_state, strict=True)
185 | input, input_aug, input_aug2, label = self.parse_batch_test(batch)
186 |
187 | for _ in range(0):
188 | output = self.model_inference(input)
189 | output_aug = self.model_inference(input_aug)
190 | loss = softmax_entropy(output_aug, output)
191 | self.model_backward_and_update(loss)
192 |
193 | for _ in range(0):
194 | output = self.model_inference(input)
195 | output_aug = self.model_inference(input_aug2)
196 | loss = softmax_entropy(output_aug, output)
197 | self.model_backward_and_update(loss)
198 |
199 | with torch.no_grad():
200 | output = self.model_inference(input)
201 | self.evaluator.process(output, label)
202 |
203 | results = self.evaluator.evaluate()
204 |
205 | for k, v in results.items():
206 | tag = "{}/{}".format(split, k)
207 | self.write_scalar(tag, v, self.epoch)
208 |
209 | return list(results.values())[0]
210 |
211 | def parse_batch_test(self, batch):
212 | input = batch["img"]
213 | input_aug = batch["img_aug"]
214 | input_aug2 = batch["img_aug2"]
215 | label = batch["label"]
216 |
217 | input = input.to(self.device)
218 | input_aug = input_aug.to(self.device)
219 | input_aug2 = input_aug2.to(self.device)
220 | label = label.to(self.device)
221 |
222 | return input, input_aug, input_aug2, label
223 |
--------------------------------------------------------------------------------
/trainers/imagenet_templates.py:
--------------------------------------------------------------------------------
1 | # source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb
2 |
3 | IMAGENET_TEMPLATES = [
4 | "a bad photo of a {}.",
5 | "a photo of many {}.",
6 | "a sculpture of a {}.",
7 | "a photo of the hard to see {}.",
8 | "a low resolution photo of the {}.",
9 | "a rendering of a {}.",
10 | "graffiti of a {}.",
11 | "a bad photo of the {}.",
12 | "a cropped photo of the {}.",
13 | "a tattoo of a {}.",
14 | "the embroidered {}.",
15 | "a photo of a hard to see {}.",
16 | "a bright photo of a {}.",
17 | "a photo of a clean {}.",
18 | "a photo of a dirty {}.",
19 | "a dark photo of the {}.",
20 | "a drawing of a {}.",
21 | "a photo of my {}.",
22 | "the plastic {}.",
23 | "a photo of the cool {}.",
24 | "a close-up photo of a {}.",
25 | "a black and white photo of the {}.",
26 | "a painting of the {}.",
27 | "a painting of a {}.",
28 | "a pixelated photo of the {}.",
29 | "a sculpture of the {}.",
30 | "a bright photo of the {}.",
31 | "a cropped photo of a {}.",
32 | "a plastic {}.",
33 | "a photo of the dirty {}.",
34 | "a jpeg corrupted photo of a {}.",
35 | "a blurry photo of the {}.",
36 | "a photo of the {}.",
37 | "a good photo of the {}.",
38 | "a rendering of the {}.",
39 | "a {} in a video game.",
40 | "a photo of one {}.",
41 | "a doodle of a {}.",
42 | "a close-up photo of the {}.",
43 | "a photo of a {}.",
44 | "the origami {}.",
45 | "the {} in a video game.",
46 | "a sketch of a {}.",
47 | "a doodle of the {}.",
48 | "a origami {}.",
49 | "a low resolution photo of a {}.",
50 | "the toy {}.",
51 | "a rendition of the {}.",
52 | "a photo of the clean {}.",
53 | "a photo of a large {}.",
54 | "a rendition of a {}.",
55 | "a photo of a nice {}.",
56 | "a photo of a weird {}.",
57 | "a blurry photo of a {}.",
58 | "a cartoon {}.",
59 | "art of a {}.",
60 | "a sketch of the {}.",
61 | "a embroidered {}.",
62 | "a pixelated photo of a {}.",
63 | "itap of the {}.",
64 | "a jpeg corrupted photo of the {}.",
65 | "a good photo of a {}.",
66 | "a plushie {}.",
67 | "a photo of the nice {}.",
68 | "a photo of the small {}.",
69 | "a photo of the weird {}.",
70 | "the cartoon {}.",
71 | "art of the {}.",
72 | "a drawing of the {}.",
73 | "a photo of the large {}.",
74 | "a black and white photo of a {}.",
75 | "the plushie {}.",
76 | "a dark photo of a {}.",
77 | "itap of a {}.",
78 | "graffiti of the {}.",
79 | "a toy {}.",
80 | "itap of my {}.",
81 | "a photo of a cool {}.",
82 | "a photo of a small {}.",
83 | "a tattoo of the {}.",
84 | ]
85 |
86 | IMAGENET_TEMPLATES_SELECT = [
87 | "itap of a {}.",
88 | "a bad photo of the {}.",
89 | "a origami {}.",
90 | "a photo of the large {}.",
91 | "a {} in a video game.",
92 | "art of the {}.",
93 | "a photo of the small {}.",
94 | ]
95 |
--------------------------------------------------------------------------------
/trainers/linter.sh:
--------------------------------------------------------------------------------
1 | rm -rf `find -type d -name .ipynb_checkpoints`
2 | yapf -r -i ./*.py
3 | isort -rc ./*.py
4 | flake8 ./*.py
5 |
--------------------------------------------------------------------------------
/trainers/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .utils import get_rank, get_world_size
6 |
7 | # from .utils import all_gather_batch_with_grad
8 |
9 |
10 | class SIMCLRLoss(nn.Module):
11 | """
12 | This is the SimCLR loss in https://arxiv.org/abs/2002.05709
13 | The embedding vectors are assumed to have size (2 x batch_size, embedding_dim) and
14 | the memory layout that can be reshaped into shape (2, batch_size, embedding_dim).
15 | This memory layout is consistent with the SimCLR collator in
16 | https://github.com/facebookresearch/vissl/blob/master/vissl/data/collators/simclr_collator.py
17 | Config params:
18 | temperature (float): the temperature to be applied on the logits
19 | """
20 | def __init__(self, temperature=0.1):
21 | super().__init__()
22 | self.tau = temperature
23 | self.labels = None
24 | self.masks = None
25 | self.last_local_batch_size = None
26 |
27 | def forward(self, q_a, q_b):
28 | q_a = F.normalize(q_a, dim=-1, p=2)
29 | q_b = F.normalize(q_b, dim=-1, p=2)
30 |
31 | local_batch_size = q_a.size(0)
32 |
33 | # k_a, k_b = all_gather_batch_with_grad([q_a, q_b])
34 | k_a, k_b = q_a, q_b
35 |
36 | if local_batch_size != self.last_local_batch_size:
37 | self.labels = local_batch_size * get_rank() + torch.arange(
38 | local_batch_size, device=q_a.device)
39 | total_batch_size = local_batch_size * get_world_size()
40 | self.masks = F.one_hot(self.labels, total_batch_size) * 1e9
41 | self.last_local_batch_size = local_batch_size
42 |
43 | logits_aa = torch.matmul(q_a, k_a.transpose(0, 1)) / self.tau
44 | logits_aa = logits_aa - self.masks
45 | logits_bb = torch.matmul(q_b, k_b.transpose(0, 1)) / self.tau
46 | logits_bb = logits_bb - self.masks
47 | logits_ab = torch.matmul(q_a, k_b.transpose(0, 1)) / self.tau
48 | logits_ba = torch.matmul(q_b, k_a.transpose(0, 1)) / self.tau
49 |
50 | loss_a = F.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1),
51 | self.labels)
52 | loss_b = F.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1),
53 | self.labels)
54 | loss = (loss_a + loss_b) / 2 # divide by 2 to average over all samples
55 |
56 | # compute accuracy
57 | with torch.no_grad():
58 | pred = torch.argmax(torch.cat([logits_ab, logits_aa], dim=1),
59 | dim=-1)
60 | correct = pred.eq(self.labels).sum()
61 | acc = 100 * correct / local_batch_size
62 |
63 | return loss, acc
64 |
--------------------------------------------------------------------------------
/trainers/unified.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.cuda.amp import GradScaler, autocast
6 | from torch.nn import functional as F
7 |
8 | from clip import clip
9 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
10 | from dassl.engine import TRAINER_REGISTRY, TrainerX
11 | from dassl.metrics import compute_accuracy
12 | from dassl.optim import build_lr_scheduler, build_optimizer
13 | from dassl.utils import load_checkpoint, load_pretrained_weights
14 |
15 | from .coop import TextEncoder, load_clip_to_cpu
16 |
17 | _tokenizer = _Tokenizer()
18 |
19 |
20 | class PromptLearner(nn.Module):
21 | def __init__(self, cfg, classnames, clip_model):
22 | super().__init__()
23 | n_cls = len(classnames)
24 | n_ctx = cfg.TRAINER.COOP.N_CTX
25 | ctx_init = cfg.TRAINER.COOP.CTX_INIT
26 | dtype = clip_model.dtype
27 | ctx_dim = clip_model.ln_final.weight.shape[0]
28 | clip_imsize = clip_model.visual.input_resolution
29 | cfg_imsize = cfg.INPUT.SIZE[0]
30 | assert cfg_imsize == clip_imsize
31 | # f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
32 |
33 | if ctx_init:
34 | # use given words to initialize context vectors
35 | ctx_init = ctx_init.replace("_", " ")
36 | n_ctx = len(ctx_init.split(" "))
37 | prompt = clip.tokenize(ctx_init)
38 | with torch.no_grad():
39 | embedding = clip_model.token_embedding(prompt).type(dtype)
40 | ctx_vectors = embedding[0, 1:1 + n_ctx, :]
41 | prompt_prefix = ctx_init
42 |
43 | else:
44 | # random initialization
45 | if cfg.TRAINER.COOP.CSC:
46 | print("Initializing class-specific contexts")
47 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
48 | else:
49 | print("Initializing a generic context")
50 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
51 | nn.init.normal_(ctx_vectors, std=0.02)
52 | prompt_prefix = " ".join(["X"] * n_ctx)
53 |
54 | print(f'Initial context: "{prompt_prefix}"')
55 | print(f"Number of context words (tokens): {n_ctx}")
56 |
57 | # self.ctx = nn.Parameter(ctx_vectors) # to be optimized
58 | ctx = ctx_vectors.unsqueeze(0)
59 |
60 | classnames = [name.replace("_", " ") for name in classnames]
61 | name_lens = [len(_tokenizer.encode(name)) for name in classnames]
62 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
63 |
64 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
65 | with torch.no_grad():
66 | embedding = clip_model.token_embedding(tokenized_prompts).type(
67 | dtype)
68 |
69 | # These token vectors will be saved when in save_model(),
70 | # but they should be ignored in load_model() as we want to use
71 | # those computed using the current class names
72 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
73 | self.register_buffer("token_suffix",
74 | embedding[:, 1 + n_ctx:, :]) # CLS, EOS
75 |
76 | self.n_cls = n_cls
77 | self.n_ctx = n_ctx
78 | self.tokenized_prompts = tokenized_prompts # torch.Tensor
79 | self.name_lens = name_lens
80 | self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION
81 |
82 | vis_dim = clip_model.visual.positional_embedding.shape[-1]
83 | visual_vectors = torch.empty(1, n_ctx * 2, ctx_dim, dtype=dtype)
84 | nn.init.normal_(visual_vectors, std=0.02)
85 | # self.visual_ctx = nn.Parameter(visual_vectors)
86 | self.uni_ctx = nn.Parameter(visual_vectors)
87 | # uni_ctx = torch.cat([ctx, visual_vectors], 1)
88 | # self.uni_ctx = nn.Parameter(uni_ctx)
89 |
90 | num_heads = 1
91 | dropout = 0.1
92 | self.self_attn = nn.MultiheadAttention(ctx_dim,
93 | num_heads,
94 | dropout=dropout)
95 | self.self_attn.type(dtype)
96 | self.dropout = nn.Dropout(dropout)
97 | self.dropout.type(dtype)
98 | self.dropout1 = nn.Dropout(dropout)
99 | self.dropout1.type(dtype)
100 | self.dropout2 = nn.Dropout(dropout)
101 | self.dropout2.type(dtype)
102 |
103 | self.norm1 = nn.LayerNorm(ctx_dim)
104 | self.norm1.type(dtype)
105 | self.norm2 = nn.LayerNorm(ctx_dim)
106 | self.norm2.type(dtype)
107 |
108 | self.linear1 = nn.Linear(ctx_dim, ctx_dim)
109 | self.linear1.type(dtype)
110 | self.linear2 = nn.Linear(ctx_dim, ctx_dim)
111 | self.linear2.type(dtype)
112 | self.activation = F.relu
113 |
114 | # self.mlp_text = nn.Linear(ctx_dim, ctx_dim)
115 | # self.mlp_text.type(dtype)
116 | self.mlp = nn.Linear(ctx_dim, vis_dim)
117 | self.mlp.type(dtype)
118 |
119 | def get_text_prompt(self):
120 | src = self.uni_ctx
121 | src2 = self.self_attn(src, src, src)[0]
122 | src = src + self.dropout1(src2)
123 | src = self.norm1(src)
124 | src2 = self.norm2(src)
125 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
126 | src = src + self.dropout2(src2)
127 | return src[0, :4, :]
128 |
129 | def get_visual_prompt(self):
130 | src = self.uni_ctx
131 | src2 = self.self_attn(src, src, src)[0]
132 | src = src + self.dropout1(src2)
133 | src = self.norm1(src)
134 | src2 = self.norm2(src)
135 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
136 | src = src + self.dropout2(src2)
137 | return self.mlp(src[0, 4:, :])
138 |
139 | def forward(self):
140 | ctx = self.get_text_prompt()
141 | if ctx.dim() == 2:
142 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
143 |
144 | prefix = self.token_prefix
145 | suffix = self.token_suffix
146 |
147 | if self.class_token_position == "end":
148 | prompts = torch.cat(
149 | [
150 | prefix, # (n_cls, 1, dim)
151 | ctx, # (n_cls, n_ctx, dim)
152 | suffix, # (n_cls, *, dim)
153 | ],
154 | dim=1,
155 | )
156 |
157 | elif self.class_token_position == "middle":
158 | half_n_ctx = self.n_ctx // 2
159 | prompts = []
160 | for i in range(self.n_cls):
161 | name_len = self.name_lens[i]
162 | prefix_i = prefix[i:i + 1, :, :]
163 | class_i = suffix[i:i + 1, :name_len, :]
164 | suffix_i = suffix[i:i + 1, name_len:, :]
165 | ctx_i_half1 = ctx[i:i + 1, :half_n_ctx, :]
166 | ctx_i_half2 = ctx[i:i + 1, half_n_ctx:, :]
167 | prompt = torch.cat(
168 | [
169 | prefix_i, # (1, 1, dim)
170 | ctx_i_half1, # (1, n_ctx//2, dim)
171 | class_i, # (1, name_len, dim)
172 | ctx_i_half2, # (1, n_ctx//2, dim)
173 | suffix_i, # (1, *, dim)
174 | ],
175 | dim=1,
176 | )
177 | prompts.append(prompt)
178 | prompts = torch.cat(prompts, dim=0)
179 |
180 | elif self.class_token_position == "front":
181 | prompts = []
182 | for i in range(self.n_cls):
183 | name_len = self.name_lens[i]
184 | prefix_i = prefix[i:i + 1, :, :]
185 | class_i = suffix[i:i + 1, :name_len, :]
186 | suffix_i = suffix[i:i + 1, name_len:, :]
187 | ctx_i = ctx[i:i + 1, :, :]
188 | prompt = torch.cat(
189 | [
190 | prefix_i, # (1, 1, dim)
191 | class_i, # (1, name_len, dim)
192 | ctx_i, # (1, n_ctx, dim)
193 | suffix_i, # (1, *, dim)
194 | ],
195 | dim=1,
196 | )
197 | prompts.append(prompt)
198 | prompts = torch.cat(prompts, dim=0)
199 |
200 | else:
201 | raise ValueError
202 |
203 | return prompts
204 |
205 |
206 | class CustomCLIP(nn.Module):
207 | def __init__(self, cfg, classnames, clip_model):
208 | super().__init__()
209 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
210 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
211 | self.image_encoder = clip_model.visual
212 | self.text_encoder = TextEncoder(clip_model)
213 | self.logit_scale = clip_model.logit_scale
214 | self.dtype = clip_model.dtype
215 |
216 | def forward(self, image):
217 | # import time
218 | # torch.save(image.data.cpu(), f'img_{int(time.time())}.pt')
219 |
220 | # image_features = self.image_encoder(image.type(self.dtype))
221 | visual_ctx = self.prompt_learner.get_visual_prompt()
222 | image_features, _ = self.image_encoder.forward_prompt_shallow(
223 | image.type(self.dtype), visual_ctx)
224 |
225 | prompts = self.prompt_learner()
226 | tokenized_prompts = self.tokenized_prompts
227 | text_features = self.text_encoder(prompts, tokenized_prompts)
228 |
229 | image_features = image_features / image_features.norm(dim=-1,
230 | keepdim=True)
231 | text_features = text_features / text_features.norm(dim=-1,
232 | keepdim=True)
233 |
234 | logit_scale = self.logit_scale.exp()
235 | logits = logit_scale * image_features @ text_features.t()
236 |
237 | return logits
238 |
239 |
240 | @TRAINER_REGISTRY.register()
241 | class Unified_v6(TrainerX):
242 | def check_cfg(self, cfg):
243 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"]
244 |
245 | def build_model(self):
246 | cfg = self.cfg
247 | classnames = self.dm.dataset.classnames
248 |
249 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
250 | clip_model = load_clip_to_cpu(cfg)
251 |
252 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp":
253 | # CLIP's default precision is fp16
254 | clip_model.float()
255 |
256 | print("Building custom CLIP")
257 | self.model = CustomCLIP(cfg, classnames, clip_model)
258 |
259 | print("Turning off gradients in both the image and the text encoder")
260 | for name, param in self.model.named_parameters():
261 | if "prompt_learner" not in name:
262 | param.requires_grad_(False)
263 |
264 | if cfg.MODEL.INIT_WEIGHTS:
265 | load_pretrained_weights(self.model.prompt_learner,
266 | cfg.MODEL.INIT_WEIGHTS)
267 |
268 | self.model.to(self.device)
269 | # NOTE: only give prompt_learner to the optimizer
270 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
271 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
272 | self.register_model("prompt_learner", self.model.prompt_learner,
273 | self.optim, self.sched)
274 |
275 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None
276 |
277 | # Note that multi-gpu training could be slow because CLIP's size is
278 | # big, which slows down the copy operation in DataParallel
279 | device_count = 1 # torch.cuda.device_count()
280 | if device_count > 1:
281 | print(
282 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!"
283 | )
284 | self.model = nn.DataParallel(self.model)
285 |
286 | def forward_backward(self, batch):
287 | image, label = self.parse_batch_train(batch)
288 |
289 | prec = self.cfg.TRAINER.COOP.PREC
290 | if prec == "amp":
291 | with autocast():
292 | output = self.model(image)
293 | loss = F.cross_entropy(output, label)
294 | self.optim.zero_grad()
295 | self.scaler.scale(loss).backward()
296 | self.scaler.step(self.optim)
297 | self.scaler.update()
298 | else:
299 | output = self.model(image)
300 | loss = F.cross_entropy(output, label)
301 | self.model_backward_and_update(loss)
302 |
303 | loss_summary = {
304 | "loss": loss.item(),
305 | "acc": compute_accuracy(output, label)[0].item(),
306 | }
307 |
308 | if (self.batch_idx + 1) == self.num_batches:
309 | self.update_lr()
310 |
311 | return loss_summary
312 |
313 | def parse_batch_train(self, batch):
314 | input = batch["img"]
315 | label = batch["label"]
316 | input = input.to(self.device)
317 | label = label.to(self.device)
318 | return input, label
319 |
320 | def load_model(self, directory, epoch=None):
321 | if not directory:
322 | print(
323 | "Note that load_model() is skipped as no pretrained model is given"
324 | )
325 | return
326 |
327 | names = self.get_model_names()
328 |
329 | # By default, the best model is loaded
330 | model_file = "model-best.pth.tar"
331 |
332 | if epoch is not None:
333 | model_file = "model.pth.tar-" + str(epoch)
334 |
335 | for name in names:
336 | model_path = osp.join(directory, name, model_file)
337 |
338 | if not osp.exists(model_path):
339 | raise FileNotFoundError(
340 | 'Model not found at "{}"'.format(model_path))
341 |
342 | checkpoint = load_checkpoint(model_path)
343 | state_dict = checkpoint["state_dict"]
344 | epoch = checkpoint["epoch"]
345 |
346 | # Ignore fixed token vectors
347 | if "token_prefix" in state_dict:
348 | del state_dict["token_prefix"]
349 |
350 | if "token_suffix" in state_dict:
351 | del state_dict["token_suffix"]
352 |
353 | print("Loading weights to {} "
354 | 'from "{}" (epoch = {})'.format(name, model_path, epoch))
355 | # set strict=False
356 | self._models[name].load_state_dict(state_dict, strict=False)
357 |
--------------------------------------------------------------------------------
/trainers/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | import os
5 | import random
6 | import shutil
7 |
8 | # This source code is licensed under the license found in the
9 | # LICENSE file in the root directory of this source tree.
10 | import numpy as np
11 | import torch
12 | import torch.autograd as autograd
13 | import torch.distributed as dist
14 | from PIL import ImageFilter
15 | from torch.nn import DataParallel
16 | from torch.nn.parallel import DistributedDataParallel
17 |
18 |
19 | def get_model(model):
20 | if isinstance(model, DataParallel):
21 | return model.module
22 | elif isinstance(model, DistributedDataParallel):
23 | return model.module
24 | else:
25 | return model
26 |
27 |
28 | def setup_for_distributed(is_master):
29 | """
30 | This function disables printing when not in master process
31 | """
32 | import builtins as __builtin__
33 | builtin_print = __builtin__.print
34 |
35 | def print(*args, **kwargs):
36 | force = kwargs.pop('force', False)
37 | if is_master or force:
38 | builtin_print(*args, **kwargs)
39 |
40 | __builtin__.print = print
41 |
42 |
43 | def is_dist_avail_and_initialized():
44 | if not dist.is_available():
45 | return False
46 | if not dist.is_initialized():
47 | return False
48 | return True
49 |
50 |
51 | def get_world_size():
52 | if not is_dist_avail_and_initialized():
53 | return 1
54 | return dist.get_world_size()
55 |
56 |
57 | def get_rank():
58 | if not is_dist_avail_and_initialized():
59 | return 0
60 | return dist.get_rank()
61 |
62 |
63 | def is_main_process():
64 | return get_rank() == 0
65 |
66 |
67 | def save_on_master(state, is_best, output_dir):
68 | if is_main_process():
69 | ckpt_path = f'{output_dir}/checkpoint.pt'
70 | best_path = f'{output_dir}/checkpoint_best.pt'
71 | torch.save(state, ckpt_path)
72 | if is_best:
73 | shutil.copyfile(ckpt_path, best_path)
74 |
75 |
76 | def init_distributed_mode(args):
77 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
78 | args.rank = int(os.environ["RANK"])
79 | args.world_size = int(os.environ['WORLD_SIZE'])
80 | args.gpu = int(os.environ['LOCAL_RANK'])
81 | elif 'SLURM_PROCID' in os.environ:
82 | args.rank = int(os.environ['SLURM_PROCID'])
83 | args.gpu = args.rank % torch.cuda.device_count()
84 | else:
85 | print('Not using distributed mode')
86 | args.distributed = False
87 | return
88 |
89 | args.distributed = True
90 |
91 | torch.cuda.set_device(args.gpu)
92 | args.dist_backend = 'nccl'
93 | print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url),
94 | flush=True)
95 | torch.distributed.init_process_group(backend=args.dist_backend,
96 | init_method=args.dist_url,
97 | world_size=args.world_size,
98 | rank=args.rank)
99 | torch.distributed.barrier()
100 | setup_for_distributed(args.rank == 0)
101 |
102 |
103 | def scaled_all_reduce(tensors, is_scale=True):
104 | """Performs the scaled all_reduce operation on the provided tensors.
105 | The input tensors are modified in-place. Currently supports only the sum
106 | reduction operator. The reduced values are scaled by the inverse size of the
107 | world size.
108 | """
109 | world_size = get_world_size()
110 | # There is no need for reduction in the single-proc case
111 | if world_size == 1:
112 | return tensors
113 | # Queue the reductions
114 | reductions = []
115 | for tensor in tensors:
116 | reduction = dist.all_reduce(tensor, async_op=True)
117 | reductions.append(reduction)
118 | # Wait for reductions to finish
119 | for reduction in reductions:
120 | reduction.wait()
121 | # Scale the results
122 | if is_scale:
123 | for tensor in tensors:
124 | tensor.mul_(1.0 / world_size)
125 | return tensors
126 |
127 |
128 | def all_gather_batch(tensors):
129 | """
130 | Performs all_gather operation on the provided tensors.
131 | """
132 | # Queue the gathered tensors
133 | world_size = get_world_size()
134 | # There is no need for reduction in the single-proc case
135 | if world_size == 1:
136 | return tensors
137 | tensor_list = []
138 | output_tensor = []
139 | for tensor in tensors:
140 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
141 | dist.all_gather(
142 | tensor_all,
143 | tensor,
144 | async_op=False # performance opt
145 | )
146 |
147 | tensor_list.append(tensor_all)
148 |
149 | for tensor_all in tensor_list:
150 | output_tensor.append(torch.cat(tensor_all, dim=0))
151 | return output_tensor
152 |
153 |
154 | class GatherLayer(autograd.Function):
155 | """
156 | Gather tensors from all workers with support for backward propagation:
157 | This implementation does not cut the gradients as torch.distributed.all_gather does.
158 | """
159 | @staticmethod
160 | def forward(ctx, x):
161 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
162 | dist.all_gather(output, x)
163 | return tuple(output)
164 |
165 | @staticmethod
166 | def backward(ctx, *grads):
167 | all_gradients = torch.stack(grads)
168 | dist.all_reduce(all_gradients)
169 | return all_gradients[dist.get_rank()]
170 |
171 |
172 | def all_gather_batch_with_grad(tensors):
173 | """
174 | Performs all_gather operation on the provided tensors.
175 | Graph remains connected for backward grad computation.
176 | """
177 | # Queue the gathered tensors
178 | world_size = get_world_size()
179 | # There is no need for reduction in the single-proc case
180 | if world_size == 1:
181 | return tensors
182 | tensor_list = []
183 | output_tensor = []
184 |
185 | for tensor in tensors:
186 | tensor_all = GatherLayer.apply(tensor)
187 | tensor_list.append(tensor_all)
188 |
189 | for tensor_all in tensor_list:
190 | output_tensor.append(torch.cat(tensor_all, dim=0))
191 | return output_tensor
192 |
193 |
194 | def cosine_scheduler(base_value,
195 | final_value,
196 | epochs,
197 | niter_per_ep,
198 | warmup_epochs=0,
199 | start_warmup_value=0):
200 | warmup_schedule = np.array([])
201 | warmup_iters = warmup_epochs * niter_per_ep
202 | if warmup_epochs > 0:
203 | warmup_schedule = np.linspace(start_warmup_value, base_value,
204 | warmup_iters)
205 |
206 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
207 | schedule = final_value + 0.5 * (base_value - final_value) * (
208 | 1 + np.cos(np.pi * iters / len(iters)))
209 |
210 | schedule = np.concatenate((warmup_schedule, schedule))
211 | assert len(schedule) == epochs * niter_per_ep
212 | return schedule
213 |
214 |
215 | class GaussianBlur(object):
216 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
217 | def __init__(self, sigma=[.1, 2.]):
218 | self.sigma = sigma
219 |
220 | def __call__(self, x):
221 | sigma = random.uniform(self.sigma[0], self.sigma[1])
222 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
223 | return x
224 |
--------------------------------------------------------------------------------
/trainers/vpt.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from collections import OrderedDict
3 | from copy import deepcopy
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.cuda.amp import GradScaler, autocast
8 | from torch.nn import functional as F
9 | from tqdm import tqdm
10 |
11 | from clip import clip
12 | from dassl.data.data_manager import DataManager
13 | from dassl.data.datasets import build_dataset
14 | from dassl.data.transforms import build_transform
15 | from dassl.engine import TRAINER_REGISTRY, TrainerX
16 | from dassl.metrics import compute_accuracy
17 | from dassl.optim import build_lr_scheduler, build_optimizer
18 | from dassl.utils import load_checkpoint, load_pretrained_weights
19 |
20 | from .coop import load_clip_to_cpu
21 | from .coop_testtime import DatasetWrapper_aug, build_data_loader
22 | from .losses import SIMCLRLoss
23 | from .zsclip import CUSTOM_TEMPLATES
24 |
25 |
26 | class DataManager_aug(DataManager):
27 | def __init__(self,
28 | cfg,
29 | custom_tfm_train=None,
30 | custom_tfm_test=None,
31 | dataset_wrapper=None):
32 | # Load dataset
33 | dataset = build_dataset(cfg)
34 |
35 | # Build transform
36 | if custom_tfm_train is None:
37 | tfm_train = build_transform(cfg, is_train=True)
38 | else:
39 | print("* Using custom transform for training")
40 | tfm_train = custom_tfm_train
41 |
42 | if custom_tfm_test is None:
43 | tfm_test = build_transform(cfg, is_train=False)
44 | else:
45 | print("* Using custom transform for testing")
46 | tfm_test = custom_tfm_test
47 |
48 | # Build train_loader_x
49 | train_loader_x = build_data_loader(
50 | cfg,
51 | sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
52 | data_source=dataset.train_x,
53 | batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
54 | n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
55 | n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
56 | tfm=tfm_train,
57 | is_train=True,
58 | dataset_wrapper=DatasetWrapper_aug,
59 | )
60 |
61 | # Build train_loader_u
62 | train_loader_u = None
63 | if dataset.train_u:
64 | sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER
65 | batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE
66 | n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN
67 | n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS
68 |
69 | if cfg.DATALOADER.TRAIN_U.SAME_AS_X:
70 | sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER
71 | batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
72 | n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN
73 | n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS
74 |
75 | train_loader_u = build_data_loader(
76 | cfg,
77 | sampler_type=sampler_type_,
78 | data_source=dataset.train_u,
79 | batch_size=batch_size_,
80 | n_domain=n_domain_,
81 | n_ins=n_ins_,
82 | tfm=tfm_train,
83 | is_train=True,
84 | dataset_wrapper=dataset_wrapper,
85 | )
86 |
87 | # Build val_loader
88 | val_loader = None
89 | if dataset.val:
90 | val_loader = build_data_loader(
91 | cfg,
92 | sampler_type=cfg.DATALOADER.TEST.SAMPLER,
93 | data_source=dataset.val,
94 | batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
95 | tfm=tfm_test,
96 | is_train=False,
97 | dataset_wrapper=dataset_wrapper,
98 | )
99 |
100 | # Build test_loader
101 | test_loader = build_data_loader(
102 | cfg,
103 | sampler_type=cfg.DATALOADER.TEST.SAMPLER,
104 | data_source=dataset.test,
105 | batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
106 | tfm=tfm_test,
107 | is_train=False,
108 | dataset_wrapper=DatasetWrapper_aug,
109 | )
110 |
111 | # Attributes
112 | self._num_classes = dataset.num_classes
113 | self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)
114 | self._lab2cname = dataset.lab2cname
115 |
116 | # Dataset and data-loaders
117 | self.dataset = dataset
118 | self.train_loader_x = train_loader_x
119 | self.train_loader_u = train_loader_u
120 | self.val_loader = val_loader
121 | self.test_loader = test_loader
122 |
123 | if cfg.VERBOSE:
124 | self.show_dataset_summary(cfg)
125 |
126 |
127 | class CustomVPT(nn.Module):
128 | def __init__(self, cfg, classnames, clip_model, device):
129 | super().__init__()
130 |
131 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
132 | prompts = [temp.format(c.replace("_", " ")) for c in classnames]
133 | print(f"Prompts: {prompts}")
134 | prompts = torch.cat([clip.tokenize(p) for p in prompts])
135 | prompts = prompts.to(device)
136 | clip_model = clip_model.to(device)
137 |
138 | with torch.no_grad():
139 | text_features = clip_model.encode_text(prompts)
140 | text_features = text_features / text_features.norm(dim=-1,
141 | keepdim=True)
142 |
143 | self.text_features = text_features
144 | self.clip_model = clip_model
145 |
146 | n_ctx = cfg.TRAINER.COOP.N_CTX
147 | dtype = clip_model.dtype
148 | ctx_dim = self.clip_model.visual.positional_embedding.shape[-1]
149 |
150 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype, device=device)
151 | nn.init.normal_(ctx_vectors, std=0.02)
152 | self.visual_ctx = nn.Parameter(ctx_vectors)
153 | # self.mlp = self._build_mlp(768, 4096, 256, dtype=dtype)
154 | self.mlp = self._build_mlp(768, 512, 256, dtype=dtype)
155 | self.mlp.to(device)
156 |
157 | def _build_mlp(self, in_dim, mlp_dim, out_dim, dtype):
158 | return nn.Sequential(
159 | OrderedDict([
160 | ("layer1", nn.Linear(in_dim, mlp_dim).to(dtype)),
161 | # ("bn1", nn.SyncBatchNorm(mlp_dim)),
162 | ("relu1", nn.ReLU(inplace=True)),
163 | # ("layer2", nn.Linear(mlp_dim, mlp_dim).to(dtype)),
164 | # ("bn2", nn.SyncBatchNorm(mlp_dim)),
165 | # ("relu2", nn.ReLU(inplace=True)),
166 | ("layer3", nn.Linear(mlp_dim, out_dim).to(dtype)),
167 | ]))
168 |
169 | def forward(self, image, aug1, aug2, training=True):
170 | if training:
171 | image_features, _ = self.clip_model.encode_image_prompt(
172 | image, self.visual_ctx)
173 | image_features = image_features / image_features.norm(dim=-1,
174 | keepdim=True)
175 | logit_scale = self.clip_model.logit_scale.exp()
176 | logits = logit_scale * image_features @ self.text_features.t()
177 | else:
178 | logits = None
179 |
180 | _, h2 = self.clip_model.encode_image_prompt(aug1, self.visual_ctx)
181 | _, h3 = self.clip_model.encode_image_prompt(aug2, self.visual_ctx)
182 | aug1_embed = self.mlp(h2)
183 | aug2_embed = self.mlp(h3)
184 |
185 | return logits, aug1_embed, aug2_embed
186 |
187 |
188 | @TRAINER_REGISTRY.register()
189 | class VPT(TrainerX):
190 | def check_cfg(self, cfg):
191 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"]
192 |
193 | def build_model(self):
194 | cfg = self.cfg
195 | classnames = self.dm.dataset.classnames
196 |
197 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
198 | clip_model = load_clip_to_cpu(cfg)
199 |
200 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp":
201 | # CLIP's default precision is fp16
202 | clip_model.float()
203 |
204 | print("Building custom VPT")
205 | self.model = CustomVPT(cfg, classnames, clip_model, self.device)
206 |
207 | print("Turning off gradients in both the image and the text encoder")
208 | for name, param in self.model.named_parameters():
209 | # if "ctx" not in name:
210 | if "clip" in name:
211 | param.requires_grad_(False)
212 |
213 | if cfg.MODEL.INIT_WEIGHTS:
214 | # load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)
215 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
216 |
217 | self.model.to(self.device)
218 | # NOTE: only give prompt_learner to the optimizer
219 | # self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
220 | self.optim = build_optimizer(self.model, cfg.OPTIM)
221 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
222 |
223 | # self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
224 | self.register_model("model", self.model, self.optim, self.sched)
225 |
226 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None
227 |
228 | # Note that multi-gpu training could be slow because CLIP's size is
229 | # big, which slows down the copy operation in DataParallel
230 | device_count = 1 # torch.cuda.device_count()
231 | if device_count > 1:
232 | print(
233 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!"
234 | )
235 | self.model = nn.DataParallel(self.model)
236 |
237 | self.ssl_loss = SIMCLRLoss()
238 |
239 | def forward_backward(self, batch):
240 | image, image_aug, image_aug2, label = self.parse_batch_train(batch)
241 |
242 | prec = self.cfg.TRAINER.COOP.PREC
243 | if prec == "amp":
244 | with autocast():
245 | output, aug1_embed, aug2_embed = self.model(
246 | image, image_aug, image_aug2)
247 | loss_ce = F.cross_entropy(output, label)
248 | loss_ssl, acc_ssl = self.ssl_loss(aug1_embed, aug2_embed)
249 | loss = loss_ce + loss_ssl * 0.0
250 | self.optim.zero_grad()
251 | self.scaler.scale(loss).backward()
252 | self.scaler.step(self.optim)
253 | self.scaler.update()
254 | else:
255 | output, aug1_embed, aug2_embed = self.model(
256 | image, image_aug, image_aug2)
257 | loss_ce = F.cross_entropy(output, label)
258 | loss_ssl, acc_ssl = self.ssl_loss(aug1_embed, aug2_embed)
259 | loss = loss_ce + loss_ssl * 0.0
260 | self.model_backward_and_update(loss)
261 |
262 | loss_summary = {
263 | "loss": loss.item(),
264 | "loss_ce": loss_ce.item(),
265 | "loss_ssl": loss_ssl.item(),
266 | "acc": compute_accuracy(output, label)[0].item(),
267 | "acc_ssl": acc_ssl.item()
268 | }
269 |
270 | if (self.batch_idx + 1) == self.num_batches:
271 | self.update_lr()
272 |
273 | return loss_summary
274 |
275 | def load_model(self, directory, epoch=None):
276 | if not directory:
277 | print(
278 | "Note that load_model() is skipped as no pretrained model is given"
279 | )
280 | return
281 |
282 | names = self.get_model_names()
283 |
284 | # By default, the best model is loaded
285 | model_file = "model-best.pth.tar"
286 |
287 | if epoch is not None:
288 | model_file = "model.pth.tar-" + str(epoch)
289 |
290 | for name in names:
291 | model_path = osp.join(directory, name, model_file)
292 |
293 | if not osp.exists(model_path):
294 | raise FileNotFoundError(
295 | 'Model not found at "{}"'.format(model_path))
296 |
297 | checkpoint = load_checkpoint(model_path)
298 | state_dict = checkpoint["state_dict"]
299 | epoch = checkpoint["epoch"]
300 |
301 | # Ignore fixed token vectors
302 | if "token_prefix" in state_dict:
303 | del state_dict["token_prefix"]
304 |
305 | if "token_suffix" in state_dict:
306 | del state_dict["token_suffix"]
307 |
308 | print("Loading weights to {} "
309 | 'from "{}" (epoch = {})'.format(name, model_path, epoch))
310 | # set strict=False
311 | self._models[name].load_state_dict(state_dict, strict=False)
312 |
313 | def parse_batch_train(self, batch):
314 | input = batch["img"]
315 | input_aug = batch["img_aug"]
316 | input_aug2 = batch["img_aug2"]
317 | label = batch["label"]
318 | input = input.to(self.device)
319 | input_aug = input_aug.to(self.device)
320 | input_aug2 = input_aug2.to(self.device)
321 | label = label.to(self.device)
322 | return input, input_aug, input_aug2, label
323 |
324 | # test-time training
325 |
326 | def build_data_loader(self):
327 | dm = DataManager_aug(self.cfg)
328 |
329 | self.train_loader_x = dm.train_loader_x
330 | self.train_loader_u = dm.train_loader_u # optional, can be None
331 | self.val_loader = dm.val_loader # optional, can be None
332 | self.test_loader = dm.test_loader
333 | self.num_classes = dm.num_classes
334 | self.num_source_domains = dm.num_source_domains
335 | self.lab2cname = dm.lab2cname # dict {label: classname}
336 |
337 | self.dm = dm
338 |
339 | def test(self, split=None):
340 | """A generic testing pipeline."""
341 | # self.set_model_mode("eval")
342 | self.model.to(self.device)
343 |
344 | self._optims['model'].param_groups[0]['lr'] = 2e-4
345 | self.evaluator.reset()
346 |
347 | if split is None:
348 | split = self.cfg.TEST.SPLIT
349 |
350 | if split == "val" and self.val_loader is not None:
351 | data_loader = self.val_loader
352 | print("Do evaluation on {} set".format(split))
353 | else:
354 | data_loader = self.test_loader
355 | print("Do evaluation on test set")
356 |
357 | model_state = deepcopy(self.model.state_dict())
358 | for batch_idx, batch in enumerate(tqdm(data_loader)):
359 | self.model.load_state_dict(model_state, strict=True)
360 | input, input_aug, input_aug2, input_aug3, label = self.parse_batch_test(
361 | batch)
362 |
363 | for _ in range(0):
364 | logits, aug1_embed, aug2_embed = self.model_inference(
365 | input, input_aug, input_aug2, training=False)
366 | loss_ssl, _ = self.ssl_loss(aug1_embed, aug2_embed)
367 | self.model_backward_and_update(loss_ssl)
368 |
369 | for _ in range(0):
370 | logits, aug1_embed, aug2_embed = self.model_inference(
371 | input, input_aug2, input_aug3, training=False)
372 | loss_ssl, _ = self.ssl_loss(aug1_embed, aug2_embed)
373 | self.model_backward_and_update(loss_ssl)
374 |
375 | with torch.no_grad():
376 | output, _, _ = self.model_inference(input,
377 | input_aug,
378 | input_aug2,
379 | training=True)
380 | self.evaluator.process(output, label)
381 | results = self.evaluator.evaluate()
382 |
383 | for k, v in results.items():
384 | tag = "{}/{}".format(split, k)
385 | self.write_scalar(tag, v, self.epoch)
386 |
387 | return list(results.values())[0]
388 |
389 | def model_inference(self, input, input_aug, input_aug2, training=True):
390 | return self.model(input, input_aug, input_aug2, training=training)
391 |
392 | def parse_batch_test(self, batch):
393 | input = batch["img"]
394 | input_aug = batch["img_aug"]
395 | input_aug2 = batch["img_aug2"]
396 | input_aug3 = batch["img_aug3"]
397 | label = batch["label"]
398 | input = input.to(self.device)
399 | input_aug = input_aug.to(self.device)
400 | input_aug2 = input_aug2.to(self.device)
401 | input_aug3 = input_aug3.to(self.device)
402 | label = label.to(self.device)
403 | return input, input_aug, input_aug2, input_aug3, label
404 |
--------------------------------------------------------------------------------
/trainers/vpt_deep.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.cuda.amp import GradScaler, autocast
6 | from torch.nn import functional as F
7 |
8 | from clip import clip
9 | from dassl.engine import TRAINER_REGISTRY, TrainerX
10 | from dassl.metrics import compute_accuracy
11 | from dassl.optim import build_lr_scheduler, build_optimizer
12 | from dassl.utils import load_checkpoint, load_pretrained_weights
13 |
14 | from .coop import load_clip_to_cpu
15 | from .zsclip import CUSTOM_TEMPLATES
16 |
17 |
18 | class CustomVPT(nn.Module):
19 | def __init__(self, cfg, classnames, clip_model, device):
20 | super().__init__()
21 |
22 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
23 | prompts = [temp.format(c.replace("_", " ")) for c in classnames]
24 | print(f"Prompts: {prompts}")
25 | prompts = torch.cat([clip.tokenize(p) for p in prompts])
26 | prompts = prompts.to(device)
27 | clip_model = clip_model.to(device)
28 |
29 | with torch.no_grad():
30 | text_features = clip_model.encode_text(prompts)
31 | text_features = text_features / text_features.norm(dim=-1,
32 | keepdim=True)
33 |
34 | self.text_features = text_features
35 | self.clip_model = clip_model
36 |
37 | n_ctx = cfg.TRAINER.COOP.N_CTX
38 | dtype = clip_model.dtype
39 | ctx_dim = self.clip_model.visual.positional_embedding.shape[-1]
40 |
41 | ctx_vectors = torch.empty(n_ctx * 12,
42 | ctx_dim,
43 | dtype=dtype,
44 | device=device)
45 | # ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype, device=device)
46 | nn.init.normal_(ctx_vectors, std=0.02)
47 | self.visual_ctx = nn.Parameter(ctx_vectors)
48 |
49 | def forward(self, image):
50 | image_features, _ = self.clip_model.encode_image_prompt_deep(
51 | image, self.visual_ctx)
52 | image_features = image_features / image_features.norm(dim=-1,
53 | keepdim=True)
54 | logit_scale = self.clip_model.logit_scale.exp()
55 | logits = logit_scale * image_features @ self.text_features.t()
56 |
57 | return logits
58 |
59 |
60 | @TRAINER_REGISTRY.register()
61 | class VPT_deep(TrainerX):
62 | def check_cfg(self, cfg):
63 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"]
64 |
65 | def build_model(self):
66 | cfg = self.cfg
67 | classnames = self.dm.dataset.classnames
68 |
69 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
70 | clip_model = load_clip_to_cpu(cfg)
71 |
72 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp":
73 | # CLIP's default precision is fp16
74 | clip_model.float()
75 |
76 | print("Building custom VPT")
77 | self.model = CustomVPT(cfg, classnames, clip_model, self.device)
78 |
79 | print("Turning off gradients in both the image and the text encoder")
80 | for name, param in self.model.named_parameters():
81 | if "clip" in name:
82 | param.requires_grad_(False)
83 |
84 | if cfg.MODEL.INIT_WEIGHTS:
85 | # load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)
86 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
87 |
88 | self.model.to(self.device)
89 | # NOTE: only give prompt_learner to the optimizer
90 | # self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
91 | self.optim = build_optimizer(self.model, cfg.OPTIM)
92 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
93 |
94 | # self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
95 | self.register_model("model", self.model, self.optim, self.sched)
96 |
97 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None
98 |
99 | # Note that multi-gpu training could be slow because CLIP's size is
100 | # big, which slows down the copy operation in DataParallel
101 | device_count = 1 # torch.cuda.device_count()
102 | if device_count > 1:
103 | print(
104 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!"
105 | )
106 | self.model = nn.DataParallel(self.model)
107 |
108 | def forward_backward(self, batch):
109 | image, label = self.parse_batch_train(batch)
110 |
111 | prec = self.cfg.TRAINER.COOP.PREC
112 | if prec == "amp":
113 | with autocast():
114 | output = self.model(image)
115 | loss = F.cross_entropy(output, label)
116 | self.optim.zero_grad()
117 | self.scaler.scale(loss).backward()
118 | self.scaler.step(self.optim)
119 | self.scaler.update()
120 | else:
121 | output = self.model(image)
122 | loss = F.cross_entropy(output, label)
123 | self.model_backward_and_update(loss)
124 |
125 | loss_summary = {
126 | "loss": loss.item(),
127 | "acc": compute_accuracy(output, label)[0].item(),
128 | }
129 |
130 | if (self.batch_idx + 1) == self.num_batches:
131 | self.update_lr()
132 |
133 | return loss_summary
134 |
135 | def parse_batch_train(self, batch):
136 | input = batch["img"]
137 | label = batch["label"]
138 | input = input.to(self.device)
139 | label = label.to(self.device)
140 | return input, label
141 |
142 | def load_model(self, directory, epoch=None):
143 | if not directory:
144 | print(
145 | "Note that load_model() is skipped as no pretrained model is given"
146 | )
147 | return
148 |
149 | names = self.get_model_names()
150 |
151 | # By default, the best model is loaded
152 | model_file = "model-best.pth.tar"
153 |
154 | if epoch is not None:
155 | model_file = "model.pth.tar-" + str(epoch)
156 |
157 | for name in names:
158 | model_path = osp.join(directory, name, model_file)
159 |
160 | if not osp.exists(model_path):
161 | raise FileNotFoundError(
162 | 'Model not found at "{}"'.format(model_path))
163 |
164 | checkpoint = load_checkpoint(model_path)
165 | state_dict = checkpoint["state_dict"]
166 | epoch = checkpoint["epoch"]
167 |
168 | # Ignore fixed token vectors
169 | if "token_prefix" in state_dict:
170 | del state_dict["token_prefix"]
171 |
172 | if "token_suffix" in state_dict:
173 | del state_dict["token_suffix"]
174 |
175 | print("Loading weights to {} "
176 | 'from "{}" (epoch = {})'.format(name, model_path, epoch))
177 | # set strict=False
178 | self._models[name].load_state_dict(state_dict, strict=False)
179 |
--------------------------------------------------------------------------------
/trainers/vpt_shallow.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.cuda.amp import GradScaler, autocast
6 | from torch.nn import functional as F
7 |
8 | from clip import clip
9 | from dassl.engine import TRAINER_REGISTRY, TrainerX
10 | from dassl.metrics import compute_accuracy
11 | from dassl.optim import build_lr_scheduler, build_optimizer
12 | from dassl.utils import load_checkpoint, load_pretrained_weights
13 |
14 | from .coop import load_clip_to_cpu
15 | from .zsclip import CUSTOM_TEMPLATES
16 |
17 |
18 | class CustomVPT(nn.Module):
19 | def __init__(self, cfg, classnames, clip_model, device):
20 | super().__init__()
21 |
22 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
23 | prompts = [temp.format(c.replace("_", " ")) for c in classnames]
24 | print(f"Prompts: {prompts}")
25 | prompts = torch.cat([clip.tokenize(p) for p in prompts])
26 | prompts = prompts.to(device)
27 | clip_model = clip_model.to(device)
28 |
29 | with torch.no_grad():
30 | text_features = clip_model.encode_text(prompts)
31 | text_features = text_features / text_features.norm(dim=-1,
32 | keepdim=True)
33 |
34 | self.text_features = text_features
35 | self.clip_model = clip_model
36 |
37 | n_ctx = cfg.TRAINER.COOP.N_CTX
38 | dtype = clip_model.dtype
39 | ctx_dim = self.clip_model.visual.positional_embedding.shape[-1]
40 |
41 | ctx_vectors = torch.empty(n_ctx * 1,
42 | ctx_dim,
43 | dtype=dtype,
44 | device=device)
45 | # ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype, device=device)
46 | nn.init.normal_(ctx_vectors, std=0.02)
47 | self.visual_ctx = nn.Parameter(ctx_vectors)
48 |
49 | def forward(self, image):
50 | # import time
51 | # torch.save(image.data.cpu(), f'img_{int(time.time())}.pt')
52 |
53 | image_features, _ = self.clip_model.encode_image_prompt_shallow(
54 | image, self.visual_ctx)
55 | image_features = image_features / image_features.norm(dim=-1,
56 | keepdim=True)
57 | logit_scale = self.clip_model.logit_scale.exp()
58 | logits = logit_scale * image_features @ self.text_features.t()
59 |
60 | return logits
61 |
62 |
63 | @TRAINER_REGISTRY.register()
64 | class VPT_shallow(TrainerX):
65 | def check_cfg(self, cfg):
66 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"]
67 |
68 | def build_model(self):
69 | cfg = self.cfg
70 | classnames = self.dm.dataset.classnames
71 |
72 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
73 | clip_model = load_clip_to_cpu(cfg)
74 |
75 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp":
76 | # CLIP's default precision is fp16
77 | clip_model.float()
78 |
79 | print("Building custom VPT")
80 | self.model = CustomVPT(cfg, classnames, clip_model, self.device)
81 |
82 | print("Turning off gradients in both the image and the text encoder")
83 | for name, param in self.model.named_parameters():
84 | if "clip" in name:
85 | param.requires_grad_(False)
86 |
87 | if cfg.MODEL.INIT_WEIGHTS:
88 | # load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)
89 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
90 |
91 | self.model.to(self.device)
92 | # NOTE: only give prompt_learner to the optimizer
93 | # self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
94 | self.optim = build_optimizer(self.model, cfg.OPTIM)
95 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
96 |
97 | # self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
98 | self.register_model("model", self.model, self.optim, self.sched)
99 |
100 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None
101 |
102 | # Note that multi-gpu training could be slow because CLIP's size is
103 | # big, which slows down the copy operation in DataParallel
104 | device_count = 1 # torch.cuda.device_count()
105 | if device_count > 1:
106 | print(
107 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!"
108 | )
109 | self.model = nn.DataParallel(self.model)
110 |
111 | def forward_backward(self, batch):
112 | image, label = self.parse_batch_train(batch)
113 |
114 | prec = self.cfg.TRAINER.COOP.PREC
115 | if prec == "amp":
116 | with autocast():
117 | output = self.model(image)
118 | loss = F.cross_entropy(output, label)
119 | self.optim.zero_grad()
120 | self.scaler.scale(loss).backward()
121 | self.scaler.step(self.optim)
122 | self.scaler.update()
123 | else:
124 | output = self.model(image)
125 | loss = F.cross_entropy(output, label)
126 | self.model_backward_and_update(loss)
127 |
128 | loss_summary = {
129 | "loss": loss.item(),
130 | "acc": compute_accuracy(output, label)[0].item(),
131 | }
132 |
133 | if (self.batch_idx + 1) == self.num_batches:
134 | self.update_lr()
135 |
136 | return loss_summary
137 |
138 | def parse_batch_train(self, batch):
139 | input = batch["img"]
140 | label = batch["label"]
141 | input = input.to(self.device)
142 | label = label.to(self.device)
143 | return input, label
144 |
145 | def load_model(self, directory, epoch=None):
146 | if not directory:
147 | print(
148 | "Note that load_model() is skipped as no pretrained model is given"
149 | )
150 | return
151 |
152 | names = self.get_model_names()
153 |
154 | # By default, the best model is loaded
155 | model_file = "model-best.pth.tar"
156 |
157 | if epoch is not None:
158 | model_file = "model.pth.tar-" + str(epoch)
159 |
160 | for name in names:
161 | model_path = osp.join(directory, name, model_file)
162 |
163 | if not osp.exists(model_path):
164 | raise FileNotFoundError(
165 | 'Model not found at "{}"'.format(model_path))
166 |
167 | checkpoint = load_checkpoint(model_path)
168 | state_dict = checkpoint["state_dict"]
169 | epoch = checkpoint["epoch"]
170 |
171 | # Ignore fixed token vectors
172 | if "token_prefix" in state_dict:
173 | del state_dict["token_prefix"]
174 |
175 | if "token_suffix" in state_dict:
176 | del state_dict["token_suffix"]
177 |
178 | print("Loading weights to {} "
179 | 'from "{}" (epoch = {})'.format(name, model_path, epoch))
180 | # set strict=False
181 | self._models[name].load_state_dict(state_dict, strict=False)
182 |
--------------------------------------------------------------------------------
/trainers/zsclip.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from clip import clip
4 | from dassl.engine import TRAINER_REGISTRY, TrainerX
5 |
6 | from .coop import load_clip_to_cpu
7 | from .imagenet_templates import IMAGENET_TEMPLATES_SELECT
8 |
9 | CUSTOM_TEMPLATES = {
10 | "OxfordPets": "a photo of a {}, a type of pet.",
11 | "OxfordFlowers": "a photo of a {}, a type of flower.",
12 | "FGVCAircraft": "a photo of a {}, a type of aircraft.",
13 | "DescribableTextures": "{} texture.",
14 | "EuroSAT": "a centered satellite photo of {}.",
15 | "StanfordCars": "a photo of a {}.",
16 | "Food101": "a photo of {}, a type of food.",
17 | "SUN397": "a photo of a {}.",
18 | "Caltech101": "a photo of a {}.",
19 | "UCF101": "a photo of a person doing {}.",
20 | "ImageNet": "a photo of a {}.",
21 | "ImageNetSketch": "a photo of a {}.",
22 | "ImageNetV2": "a photo of a {}.",
23 | "ImageNetA": "a photo of a {}.",
24 | "ImageNetR": "a photo of a {}.",
25 | }
26 |
27 |
28 | @TRAINER_REGISTRY.register()
29 | class ZeroshotCLIP(TrainerX):
30 | def build_model(self):
31 | cfg = self.cfg
32 | classnames = self.dm.dataset.classnames
33 |
34 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
35 | clip_model = load_clip_to_cpu(cfg)
36 | clip_model.to(self.device)
37 |
38 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
39 | prompts = [temp.format(c.replace("_", " ")) for c in classnames]
40 | print(f"Prompts: {prompts}")
41 | prompts = torch.cat([clip.tokenize(p) for p in prompts])
42 | prompts = prompts.to(self.device)
43 |
44 | with torch.no_grad():
45 | text_features = clip_model.encode_text(prompts)
46 | text_features = text_features / text_features.norm(dim=-1,
47 | keepdim=True)
48 |
49 | self.text_features = text_features
50 | self.clip_model = clip_model
51 |
52 | def model_inference(self, image):
53 | image_features = self.clip_model.encode_image(image)
54 | image_features = image_features / image_features.norm(dim=-1,
55 | keepdim=True)
56 | logit_scale = self.clip_model.logit_scale.exp()
57 | logits = logit_scale * image_features @ self.text_features.t()
58 | return logits
59 |
60 |
61 | @TRAINER_REGISTRY.register()
62 | class ZeroshotCLIP2(ZeroshotCLIP):
63 | """Prompt ensembling."""
64 |
65 | # templates = IMAGENET_TEMPLATES
66 | templates = IMAGENET_TEMPLATES_SELECT
67 |
68 | def build_model(self):
69 | cfg = self.cfg
70 | classnames = self.dm.dataset.classnames
71 |
72 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
73 | clip_model = load_clip_to_cpu(cfg)
74 | clip_model.to(self.device)
75 |
76 | for params in clip_model.parameters():
77 | params.requires_grad_(False)
78 |
79 | # add custom-made prompt
80 | if cfg.DATASET.NAME != "ImageNet":
81 | self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]
82 |
83 | num_temp = len(self.templates)
84 | print(f"Prompt ensembling (n={num_temp})")
85 |
86 | mean_text_features = 0
87 | for i, temp in enumerate(self.templates):
88 | prompts = [temp.format(c.replace("_", " ")) for c in classnames]
89 | prompts = torch.cat([clip.tokenize(p)
90 | for p in prompts]).to(self.device)
91 | text_features = clip_model.encode_text(prompts)
92 | text_features = text_features / text_features.norm(dim=-1,
93 | keepdim=True)
94 | mean_text_features = mean_text_features + text_features
95 | mean_text_features = mean_text_features / num_temp
96 | mean_text_features = mean_text_features / mean_text_features.norm(
97 | dim=-1, keepdim=True)
98 |
99 | self.text_features = mean_text_features
100 | self.clip_model = clip_model
101 |
--------------------------------------------------------------------------------