├── README.md ├── losses.py ├── models.py ├── models_rils.py ├── pos_embed.py ├── utils.py └── vision_transformer.py /README.md: -------------------------------------------------------------------------------- 1 | # RILS: Masked Visual Reconstruction in Language Semantic Space 2 | 3 | This repo includes the official implementation of [*RILS: Masked Visual Reconstruction in Language Semantic Space*](https://arxiv.org/abs/2301.06958) 4 | 5 | ## News 🆕 6 | - `2023/01/21`: RILS is accepted by CVPR 2023! Congratulations and thanks to my all co-authors! 7 | 8 | ## Catalog 9 | - [ ] Code & Checkpoints Release 10 | - [x] Initialization 11 | 12 | ## Acknowledgement ♥️ 13 | 14 | Part of this code is borrowed from [```SLIP```](https://github.com/facebookresearch/slip), [```MAE```](https://github.com/facebookresearch/mae) and [```BEiT```](https://github.com/microsoft/unilm/tree/master/beit), thanks for their awesome work! 15 | 16 | ## Citation 📑 17 | 18 | If you find our project helpful to your research, please star and cite our paper: 19 | 20 | ``` 21 | @article{yang2023RILS, 22 | title={Masked Visual Reconstruction in Language Semantic Space}, 23 | author={Yang, Shusheng and Ge, Yixiao and Yi, Kun and Li, Dian and Shan, Ying and Qie, Xiaohu and Wang, Xinggang}, 24 | journal={arXiv preprint arXiv:2301.06958}, 25 | year={2023} 26 | } 27 | ``` 28 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import utils 11 | 12 | 13 | class CLIPLoss(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | self.labels = None 17 | self.last_local_batch_size = None 18 | 19 | def forward(self, outputs): 20 | image_embed = outputs['image_embed'] 21 | text_embed = outputs['text_embed'] 22 | logit_scale = outputs['logit_scale'] 23 | local_batch_size = image_embed.size(0) 24 | 25 | if local_batch_size != self.last_local_batch_size: 26 | self.labels = local_batch_size * utils.get_rank() + torch.arange( 27 | local_batch_size, device=image_embed.device 28 | ) 29 | self.last_local_batch_size = local_batch_size 30 | 31 | # normalized features 32 | image_embed = F.normalize(image_embed, dim=-1, p=2) 33 | text_embed = F.normalize(text_embed, dim=-1, p=2) 34 | 35 | # gather features from all GPUs 36 | image_embed_all, text_embed_all = \ 37 | utils.all_gather_batch([image_embed, text_embed]) 38 | 39 | # cosine similarity as logits 40 | logits_per_image = logit_scale * image_embed @ text_embed_all.t() 41 | logits_per_text = logit_scale * text_embed @ image_embed_all.t() 42 | 43 | clip_loss = (F.cross_entropy(logits_per_image, self.labels) + \ 44 | F.cross_entropy(logits_per_text, self.labels)) / 2 45 | 46 | # compute accuracy 47 | with torch.no_grad(): 48 | pred = torch.argmax(logits_per_image, dim=-1) 49 | correct = pred.eq(self.labels).sum() 50 | acc = 100 * correct / local_batch_size 51 | 52 | return {'loss': clip_loss, 'clip_loss': clip_loss, 'clip_acc': acc} 53 | 54 | 55 | class SIMCLRLoss(nn.Module): 56 | """ 57 | This is the SimCLR loss in https://arxiv.org/abs/2002.05709 58 | The embedding vectors are assumed to have size (2 x batch_size, embedding_dim) and 59 | the memory layout that can be reshaped into shape (2, batch_size, embedding_dim). 60 | This memory layout is consistent with the SimCLR collator in 61 | https://github.com/facebookresearch/vissl/blob/master/vissl/data/collators/simclr_collator.py 62 | Config params: 63 | temperature (float): the temperature to be applied on the logits 64 | """ 65 | 66 | def __init__(self, temperature=0.1): 67 | super().__init__() 68 | self.tau = temperature 69 | self.labels = None 70 | self.masks = None 71 | self.last_local_batch_size = None 72 | 73 | def forward(self, outputs): 74 | q_a = outputs['aug1_embed'] 75 | q_b = outputs['aug2_embed'] 76 | 77 | q_a = F.normalize(q_a, dim=-1, p=2) 78 | q_b = F.normalize(q_b, dim=-1, p=2) 79 | 80 | local_batch_size = q_a.size(0) 81 | 82 | k_a, k_b = utils.all_gather_batch_with_grad([q_a, q_b]) 83 | 84 | if local_batch_size != self.last_local_batch_size: 85 | self.labels = local_batch_size * utils.get_rank() + torch.arange( 86 | local_batch_size, device=q_a.device 87 | ) 88 | total_batch_size = local_batch_size * utils.get_world_size() 89 | self.masks = F.one_hot(self.labels, total_batch_size) * 1e9 90 | self.last_local_batch_size = local_batch_size 91 | 92 | logits_aa = torch.matmul(q_a, k_a.transpose(0, 1)) / self.tau 93 | logits_aa = logits_aa - self.masks 94 | logits_bb = torch.matmul(q_b, k_b.transpose(0, 1)) / self.tau 95 | logits_bb = logits_bb - self.masks 96 | logits_ab = torch.matmul(q_a, k_b.transpose(0, 1)) / self.tau 97 | logits_ba = torch.matmul(q_b, k_a.transpose(0, 1)) / self.tau 98 | 99 | loss_a = F.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), self.labels) 100 | loss_b = F.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), self.labels) 101 | loss = (loss_a + loss_b) / 2 # divide by 2 to average over all samples 102 | 103 | # compute accuracy 104 | with torch.no_grad(): 105 | pred = torch.argmax(torch.cat([logits_ab, logits_aa], dim=1), dim=-1) 106 | correct = pred.eq(self.labels).sum() 107 | acc = 100 * correct / local_batch_size 108 | 109 | return {'loss': loss, 'ssl_loss': loss, 'ssl_acc': acc} 110 | 111 | 112 | class SLIPLoss(nn.Module): 113 | def __init__(self, ssl_loss, ssl_scale): 114 | super().__init__() 115 | self.clip_loss = CLIPLoss() 116 | self.ssl_loss = ssl_loss 117 | self.ssl_scale = ssl_scale 118 | 119 | def forward(self, outputs): 120 | clip_loss_dict = self.clip_loss(outputs) 121 | clip_loss = clip_loss_dict['clip_loss'] 122 | clip_acc = clip_loss_dict['clip_acc'] 123 | 124 | ssl_loss_dict = self.ssl_loss(outputs) 125 | ssl_loss = ssl_loss_dict['ssl_loss'] 126 | ssl_acc = ssl_loss_dict['ssl_acc'] 127 | 128 | return {'loss': clip_loss + self.ssl_scale * ssl_loss, 129 | 'clip_loss': clip_loss, 130 | 'clip_acc': clip_acc, 131 | 'ssl_loss': ssl_loss, 132 | 'ssl_acc': ssl_acc} 133 | 134 | 135 | class RILSLoss(nn.Module): 136 | def __init__( 137 | self, 138 | stu_tau=0.1, 139 | tea_tau=0.04, 140 | loss_weight=0.5, 141 | ): 142 | super().__init__() 143 | self.labels = None 144 | self.last_local_batch_size = None 145 | 146 | self.stu_tau = stu_tau 147 | self.tea_tau = tea_tau 148 | self.loss_weight = loss_weight 149 | 150 | def forward(self, outputs): 151 | image_embed = outputs['image_embed'] 152 | text_embed = outputs['text_embed'] 153 | logit_scale = outputs['logit_scale'] 154 | local_batch_size = image_embed.size(0) 155 | 156 | if local_batch_size != self.last_local_batch_size: 157 | self.labels = local_batch_size * utils.get_rank() + torch.arange( 158 | local_batch_size, device=image_embed.device 159 | ) 160 | self.last_local_batch_size = local_batch_size 161 | 162 | # normalized features 163 | image_embed = F.normalize(image_embed, dim=-1, p=2) 164 | text_embed = F.normalize(text_embed, dim=-1, p=2) 165 | 166 | # gather features from all GPUs 167 | image_embed_all, text_embed_all = \ 168 | utils.all_gather_batch([image_embed, text_embed]) 169 | 170 | # cosine similarity as logits 171 | logits_per_image = logit_scale * image_embed @ text_embed_all.t() 172 | logits_per_text = logit_scale * text_embed @ image_embed_all.t() 173 | 174 | clip_loss = (F.cross_entropy(logits_per_image, self.labels) + \ 175 | F.cross_entropy(logits_per_text, self.labels)) / 2 176 | 177 | # compute accuracy 178 | with torch.no_grad(): 179 | pred = torch.argmax(logits_per_image, dim=-1) 180 | correct = pred.eq(self.labels).sum() 181 | acc = 100 * correct / local_batch_size 182 | 183 | masked_feat, masked_pred, unmasked_feat, mask = outputs["masked_feat"], outputs["masked_pred"], outputs["unmasked_feat"], outputs["mask"] 184 | masked_pred = F.normalize(masked_pred[:, 1:], dim=-1, p=2) # rm CLS token 185 | unmasked_feat = F.normalize(unmasked_feat[:, 1:], dim=-1, p=2) # rm CLS token 186 | masked_pred_logits = 1 / self.stu_tau * masked_pred @ text_embed_all.t() # student logits 187 | unmasked_feat_logits = (1 / self.tea_tau * unmasked_feat @ text_embed_all.t()).detach() # teacher logits 188 | 189 | # kldivergence for masked reconstruction 190 | recon_loss = -unmasked_feat_logits.softmax(-1) * masked_pred_logits.log_softmax(-1) + unmasked_feat_logits.softmax(-1) * unmasked_feat_logits.log_softmax(-1) 191 | 192 | # we only calculate reconstruction loss for correct-retrieved samples 193 | recon_loss = recon_loss.sum(dim=-1) 194 | mask *= pred.eq(self.labels).unsqueeze(-1) 195 | recon_loss = (recon_loss * mask).sum() / (mask.sum() + 1e-6) 196 | 197 | recon_loss *= self.loss_weight # loss weight 198 | 199 | loss = clip_loss + recon_loss 200 | 201 | return {'loss': loss, 'clip_loss': clip_loss, 'clip_acc': acc, 'recon_loss': recon_loss} 202 | 203 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from github.com/openai/CLIP 8 | from collections import OrderedDict 9 | 10 | import numpy as np 11 | import timm 12 | import torch 13 | from torch import nn 14 | 15 | import losses 16 | from models_rils import rils_vit_base_patch16_dec768d1b, rils_vit_large_patch16_dec1024d1b 17 | 18 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 19 | 20 | def trunc_normal_(tensor, mean=0., std=1.): 21 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 22 | 23 | class LayerNorm(nn.LayerNorm): 24 | """Subclass torch's LayerNorm to handle fp16.""" 25 | 26 | def forward(self, x: torch.Tensor): 27 | orig_type = x.dtype 28 | ret = super().forward(x.type(torch.float32)) 29 | return ret.type(orig_type) 30 | 31 | 32 | class QuickGELU(nn.Module): 33 | def forward(self, x: torch.Tensor): 34 | return x * torch.sigmoid(1.702 * x) 35 | 36 | 37 | class ResidualAttentionBlock(nn.Module): 38 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 39 | super().__init__() 40 | 41 | self.attn = nn.MultiheadAttention(d_model, n_head) 42 | self.ln_1 = LayerNorm(d_model) 43 | self.mlp = nn.Sequential(OrderedDict([ 44 | ("c_fc", nn.Linear(d_model, d_model * 4)), 45 | ("gelu", QuickGELU()), 46 | ("c_proj", nn.Linear(d_model * 4, d_model)) 47 | ])) 48 | self.ln_2 = LayerNorm(d_model) 49 | self.attn_mask = attn_mask 50 | 51 | def attention(self, x: torch.Tensor): 52 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 53 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 54 | 55 | def forward(self, x: torch.Tensor): 56 | x = x + self.attention(self.ln_1(x)) 57 | x = x + self.mlp(self.ln_2(x)) 58 | return x 59 | 60 | 61 | class ResidualCrossAttentionBlock(nn.Module): 62 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 63 | super().__init__() 64 | 65 | self.self_attn = nn.MultiheadAttention(d_model, n_head) 66 | self.ln_1 = LayerNorm(d_model) 67 | self.mlp = nn.Sequential(OrderedDict([ 68 | ("c_fc", nn.Linear(d_model, d_model * 4)), 69 | ("gelu", QuickGELU()), 70 | ("c_proj", nn.Linear(d_model * 4, d_model)) 71 | ])) 72 | self.ln_2 = LayerNorm(d_model) 73 | self.cross_attn = nn.MultiheadAttention(d_model, n_head) 74 | self.ln_3 = LayerNorm(d_model) 75 | self.attn_mask = attn_mask 76 | 77 | def attention(self, x: torch.Tensor): 78 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 79 | return self.self_attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 80 | 81 | def cross_attention(self, q, kv): 82 | return self.cross_attn(q, kv, kv, need_weights=False)[0] 83 | 84 | def forward(self, x, kv): 85 | x = x + self.attention(self.ln_1(x)) 86 | x = x + self.cross_attention(self.ln_3(x), self.ln_3(kv)) 87 | x = x + self.mlp(self.ln_2(x)) 88 | return x 89 | 90 | 91 | class Transformer(nn.Module): 92 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 93 | super().__init__() 94 | self.width = width 95 | self.layers = layers 96 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 97 | 98 | def forward(self, x: torch.Tensor): 99 | return self.resblocks(x) 100 | 101 | 102 | class CrossTransformer(nn.Module): 103 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 104 | super().__init__() 105 | self.width = width 106 | self.layers = layers 107 | self.resblocks = nn.ModuleList([ResidualCrossAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 108 | 109 | def forward(self, x, kv): 110 | for blk in self.resblocks: 111 | x = blk(x, kv) 112 | return x 113 | 114 | 115 | class CLIP(nn.Module): 116 | def __init__(self, 117 | embed_dim: int, 118 | # vision 119 | vision_width: int, 120 | vision_model: nn.Module, 121 | # text 122 | context_length: int, 123 | vocab_size: int, 124 | transformer_width: int, 125 | transformer_heads: int, 126 | transformer_layers: int, 127 | clip_gap=True, 128 | **kwargs, 129 | ): 130 | super().__init__() 131 | 132 | self.context_length = context_length 133 | self.vision_width = vision_width 134 | self.clip_gap = clip_gap 135 | 136 | self.visual = vision_model 137 | 138 | self.transformer = Transformer( 139 | width=transformer_width, 140 | layers=transformer_layers, 141 | heads=transformer_heads, 142 | attn_mask=self.build_attention_mask(), 143 | ) 144 | 145 | self.vocab_size = vocab_size 146 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 147 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 148 | self.ln_final = LayerNorm(transformer_width) 149 | 150 | self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) 151 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 152 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 153 | 154 | self.initialize_parameters() 155 | 156 | def initialize_parameters(self): 157 | nn.init.normal_(self.token_embedding.weight, std=0.02) 158 | nn.init.normal_(self.positional_embedding, std=0.01) 159 | 160 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 161 | attn_std = self.transformer.width ** -0.5 162 | fc_std = (2 * self.transformer.width) ** -0.5 163 | for block in self.transformer.resblocks: 164 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 165 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 166 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 167 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 168 | 169 | nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5) 170 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 171 | 172 | def build_attention_mask(self): 173 | # lazily create causal attention mask, with full attention between the vision tokens 174 | # pytorch uses additive attention mask; fill with -inf 175 | mask = torch.empty(self.context_length, self.context_length) 176 | mask.fill_(float("-inf")) 177 | mask.triu_(1) # zero out the lower diagonal 178 | return mask 179 | 180 | def encode_image(self, image): 181 | x = self.visual(image) 182 | if self.clip_gap: 183 | x = x[:, 1:].mean(1) 184 | else: 185 | x = x[:, 0] 186 | 187 | x = x @ self.image_projection 188 | 189 | return x 190 | 191 | def encode_text(self, text): 192 | x = self.token_embedding(text) # [batch_size, n_ctx, d_model] 193 | x = x + self.positional_embedding 194 | x = x.permute(1, 0, 2) # NLD -> LND 195 | x = self.transformer(x) 196 | x = x.permute(1, 0, 2) # LND -> NLD 197 | x = self.ln_final(x) 198 | 199 | # x.shape = [batch_size, n_ctx, transformer.width] 200 | # take features from the eot embedding (eot_token is the highest number in each sequence) 201 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 202 | 203 | return x 204 | 205 | def forward(self, image, text): 206 | image_embed = self.encode_image(image) 207 | text_embed = self.encode_text(text) 208 | 209 | return {'image_embed': image_embed, 210 | 'text_embed': text_embed, 211 | 'logit_scale': self.logit_scale.exp()} 212 | 213 | 214 | class ProjectionHead(nn.Module): 215 | def __init__(self, 216 | num_layers=2, 217 | in_dim=768, 218 | hidden_dim=4096, 219 | out_dim=32768): 220 | super().__init__() 221 | assert num_layers > 1 222 | 223 | layers = [] 224 | for _ in range(num_layers): 225 | if _ == 0: 226 | layers.append(nn.Linear(in_dim, hidden_dim)) 227 | layers.append(nn.BatchNorm1d(hidden_dim)) 228 | layers.append(nn.GELU()) 229 | elif _ == num_layers - 1: 230 | layers.append(nn.Linear(hidden_dim, out_dim)) 231 | layers.append(nn.BatchNorm1d(out_dim, affine=False)) 232 | else: 233 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 234 | layers.append(nn.BatchNorm1d(hidden_dim)) 235 | layers.append(nn.GELU()) 236 | self.layers = nn.Sequential(*layers) 237 | 238 | self.init_std = .02 239 | self.apply(self._init_weights) 240 | 241 | def _init_weights(self, m): 242 | if isinstance(m, nn.Linear): 243 | trunc_normal_(m.weight, std=self.init_std) 244 | if isinstance(m, nn.Linear) and m.bias is not None: 245 | nn.init.constant_(m.bias, 0) 246 | elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 247 | if m.bias is not None: 248 | nn.init.constant_(m.bias, 0) 249 | if m.weight is not None: 250 | nn.init.constant_(m.weight, 1.0) 251 | else: 252 | pass 253 | 254 | def forward(self, x): 255 | return self.layers(x) 256 | 257 | 258 | class SIMCLR(nn.Module): 259 | def __init__(self, 260 | # vision 261 | vision_width: int, 262 | vision_model: nn.Module, 263 | # ssl 264 | ssl_mlp_dim: int, 265 | ssl_emb_dim: int, 266 | **kwargs, 267 | ): 268 | super().__init__() 269 | 270 | self.vision_width = vision_width 271 | self.visual = vision_model 272 | 273 | self.image_mlp = self._build_mlp(in_dim=vision_width, mlp_dim=ssl_mlp_dim, out_dim=ssl_emb_dim) 274 | 275 | def _build_mlp(self, in_dim, mlp_dim, out_dim): 276 | return nn.Sequential(OrderedDict([ 277 | ("layer1", nn.Linear(in_dim, mlp_dim)), 278 | ("bn1", nn.SyncBatchNorm(mlp_dim)), 279 | ("relu1", nn.ReLU(inplace=True)), 280 | ("layer2", nn.Linear(mlp_dim, mlp_dim)), 281 | ("bn2", nn.SyncBatchNorm(mlp_dim)), 282 | ("relu2", nn.ReLU(inplace=True)), 283 | ("layer3", nn.Linear(mlp_dim, out_dim)), 284 | ])) 285 | 286 | def encode_image(self, image): 287 | x = self.visual(image) 288 | 289 | return x 290 | 291 | def forward(self, aug1, aug2): 292 | h1 = self.visual(aug1) 293 | h2 = self.visual(aug2) 294 | 295 | aug1_embed = self.image_mlp(h1) 296 | aug2_embed = self.image_mlp(h2) 297 | 298 | return {'aug1_embed': aug1_embed, 299 | 'aug2_embed': aug2_embed} 300 | 301 | 302 | class SLIP(CLIP): 303 | def __init__(self, 304 | ssl_mlp_dim: int, 305 | ssl_emb_dim: int, 306 | **kwargs, 307 | ): 308 | super().__init__(**kwargs) 309 | 310 | self.image_mlp = self._build_mlp(in_dim=self.vision_width, mlp_dim=ssl_mlp_dim, out_dim=ssl_emb_dim) 311 | 312 | def _build_mlp(self, in_dim, mlp_dim, out_dim): 313 | return nn.Sequential(OrderedDict([ 314 | ("layer1", nn.Linear(in_dim, mlp_dim)), 315 | ("bn1", nn.SyncBatchNorm(mlp_dim)), 316 | ("relu1", nn.ReLU(inplace=True)), 317 | ("layer2", nn.Linear(mlp_dim, mlp_dim)), 318 | ("bn2", nn.SyncBatchNorm(mlp_dim)), 319 | ("relu2", nn.ReLU(inplace=True)), 320 | ("layer3", nn.Linear(mlp_dim, out_dim)), 321 | ])) 322 | 323 | def forward(self, image, text, aug1, aug2): 324 | aug1_embed = self.image_mlp(self.visual(aug1)) 325 | aug2_embed = self.image_mlp(self.visual(aug2)) 326 | 327 | image_embed = self.encode_image(image) 328 | text_embed = self.encode_text(text) 329 | 330 | return {'image_embed': image_embed, 331 | 'text_embed': text_embed, 332 | 'logit_scale': self.logit_scale.exp(), 333 | 'aug1_embed': aug1_embed, 334 | 'aug2_embed': aug2_embed} 335 | 336 | 337 | class RILS(CLIP): 338 | def __init__(self, 339 | mask_ratio=0.75, 340 | **kwargs, 341 | ): 342 | super().__init__(**kwargs) 343 | 344 | self.mask_ratio = mask_ratio 345 | 346 | def encode_image(self, image): 347 | assert not self.training 348 | masked_feat, masked_pred, unmasked_feat, mask = self.visual(image, mask_ratio=0.) 349 | image_embed = unmasked_feat[:, 1:, ...].mean(1) @ self.image_projection 350 | return image_embed 351 | 352 | def forward(self, image, text): 353 | masked_feat, masked_pred, unmasked_feat, mask = self.visual(image, mask_ratio=self.mask_ratio if self.training else 0.) 354 | image_embed = unmasked_feat[:, 1:, ...].mean(1) @ self.image_projection 355 | text_embed = self.encode_text(text) 356 | 357 | return {'image_embed': image_embed, 358 | 'text_embed': text_embed, 359 | 'logit_scale': self.logit_scale.exp(), 360 | 'masked_feat': masked_feat @ self.image_projection, 361 | 'masked_pred': masked_pred @ self.image_projection, 362 | 'unmasked_feat': unmasked_feat @ self.image_projection, 363 | 'mask': mask} 364 | 365 | 366 | def get_loss(model, ssl_temp, ssl_scale): 367 | if model.startswith('SLIP'): 368 | ssl_loss = losses.SIMCLRLoss(temperature=ssl_temp) 369 | return losses.SLIPLoss(ssl_loss, ssl_scale) 370 | if model.startswith('CLIP'): 371 | return losses.CLIPLoss() 372 | if model.startswith('SIMCLR'): 373 | return losses.SIMCLRLoss(temperature=ssl_temp) 374 | if model.startswith('RILS'): 375 | return losses.RILSLoss() 376 | raise NotImplementedError 377 | 378 | 379 | def get_metric_names(model): 380 | if model.startswith('SLIP'): 381 | return ['loss', 'clip_loss', 'ssl_loss', 'clip_acc', 'ssl_acc'] 382 | elif model.startswith('CLIP'): 383 | return ['loss', 'clip_loss', 'clip_acc'] 384 | elif model.startswith('SIMCLR'): 385 | return ['loss', 'ssl_loss', 'ssl_acc'] 386 | elif model.startswith('RILS'): 387 | return ['loss', 'clip_loss', 'clip_acc', 'recon_loss'] 388 | else: 389 | raise NotImplementedError 390 | 391 | def RILS_VITB16(**kwargs): 392 | vision_model = rils_vit_base_patch16_dec768d1b() 393 | model = RILS(embed_dim=512, vision_width=768, vision_model=vision_model, context_length=77, vocab_size=49408, 394 | transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) 395 | 396 | return model 397 | 398 | def RILS_VITL16(**kwargs): 399 | vision_model = rils_vit_large_patch16_dec1024d1b() 400 | model = RILS(embed_dim=768, vision_width=1024, vision_model=vision_model, context_length=77, vocab_size=49408, 401 | transformer_width=768, transformer_heads=12, transformer_layers=12, **kwargs) 402 | 403 | return model 404 | -------------------------------------------------------------------------------- /models_rils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from vision_transformer import PatchEmbed, Block 7 | 8 | from pos_embed import get_2d_sincos_pos_embed 9 | 10 | 11 | class MaskedAutoencoderViT(nn.Module): 12 | """ Masked Autoencoder with VisionTransformer backbone 13 | """ 14 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 15 | embed_dim=1024, depth=24, num_heads=16, 16 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 17 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=True): 18 | super().__init__() 19 | 20 | # -------------------------------------------------------------------------- 21 | # MAE encoder specifics 22 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 23 | num_patches = self.patch_embed.num_patches 24 | 25 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 26 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 27 | 28 | self.blocks = nn.ModuleList([ 29 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=False, norm_layer=norm_layer, beit_qkv_bias=True) 30 | for i in range(depth)]) 31 | self.norm = norm_layer(embed_dim) 32 | # -------------------------------------------------------------------------- 33 | 34 | # -------------------------------------------------------------------------- 35 | # MAE decoder specifics 36 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 37 | 38 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 39 | 40 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 41 | 42 | self.decoder_blocks = nn.ModuleList([ 43 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=False, norm_layer=norm_layer, beit_qkv_bias=True) 44 | for i in range(decoder_depth)]) 45 | 46 | self.decoder_norm = norm_layer(decoder_embed_dim) 47 | self.decoder_pred = nn.Linear(decoder_embed_dim, embed_dim, bias=True) # decoder to patch 48 | # -------------------------------------------------------------------------- 49 | 50 | self.norm_pix_loss = norm_pix_loss 51 | 52 | self.initialize_weights() 53 | 54 | def initialize_weights(self): 55 | # initialization 56 | # initialize (and freeze) pos_embed by sin-cos embedding 57 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 58 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 59 | 60 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 61 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 62 | 63 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 64 | w = self.patch_embed.proj.weight.data 65 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 66 | 67 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 68 | torch.nn.init.normal_(self.cls_token, std=.02) 69 | torch.nn.init.normal_(self.mask_token, std=.02) 70 | 71 | # initialize nn.Linear and nn.LayerNorm 72 | self.apply(self._init_weights) 73 | 74 | def _init_weights(self, m): 75 | if isinstance(m, nn.Linear): 76 | # we use xavier_uniform following official JAX ViT: 77 | torch.nn.init.xavier_uniform_(m.weight) 78 | if isinstance(m, nn.Linear) and m.bias is not None: 79 | nn.init.constant_(m.bias, 0) 80 | elif isinstance(m, nn.LayerNorm): 81 | nn.init.constant_(m.bias, 0) 82 | nn.init.constant_(m.weight, 1.0) 83 | 84 | def patchify(self, imgs): 85 | """ 86 | imgs: (N, 3, H, W) 87 | x: (N, L, patch_size**2 *3) 88 | """ 89 | p = self.patch_embed.patch_size[0] 90 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 91 | 92 | h = w = imgs.shape[2] // p 93 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 94 | x = torch.einsum('nchpwq->nhwpqc', x) 95 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 96 | return x 97 | 98 | def unpatchify(self, x): 99 | """ 100 | x: (N, L, patch_size**2 *3) 101 | imgs: (N, 3, H, W) 102 | """ 103 | p = self.patch_embed.patch_size[0] 104 | h = w = int(x.shape[1]**.5) 105 | assert h * w == x.shape[1] 106 | 107 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 108 | x = torch.einsum('nhwpqc->nchpwq', x) 109 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 110 | return imgs 111 | 112 | def random_masking(self, x, mask_ratio): 113 | """ 114 | Perform per-sample random masking by per-sample shuffling. 115 | Per-sample shuffling is done by argsort random noise. 116 | x: [N, L, D], sequence 117 | """ 118 | N, L, D = x.shape # batch, length, dim 119 | len_keep = int(L * (1 - mask_ratio)) 120 | 121 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 122 | 123 | # sort noise for each sample 124 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 125 | ids_restore = torch.argsort(ids_shuffle, dim=1) 126 | 127 | # keep the first subset 128 | ids_keep = ids_shuffle[:, :len_keep] 129 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 130 | 131 | # generate the binary mask: 0 is keep, 1 is remove 132 | mask = torch.ones([N, L], device=x.device) 133 | mask[:, :len_keep] = 0 134 | # unshuffle to get the binary mask 135 | mask = torch.gather(mask, dim=1, index=ids_restore) 136 | 137 | return x_masked, mask, ids_restore 138 | 139 | def forward_encoder(self, x, mask_ratio): 140 | # embed patches 141 | x = self.patch_embed(x) 142 | 143 | # add pos embed w/o cls token 144 | x = x + self.pos_embed[:, 1:, :] 145 | 146 | # masking: length -> length * mask_ratio 147 | if mask_ratio > 0: 148 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 149 | else: 150 | x, mask, ids_restore = x, None, None 151 | 152 | # append cls token 153 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 154 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 155 | x = torch.cat((cls_tokens, x), dim=1) 156 | 157 | # apply Transformer blocks 158 | for blk in self.blocks: 159 | x = blk(x) 160 | x = self.norm(x) 161 | 162 | return x, mask, ids_restore 163 | 164 | def forward_decoder(self, x, ids_restore): 165 | # embed tokens 166 | x = self.decoder_embed(x) 167 | 168 | # append mask tokens to sequence 169 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 170 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 171 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 172 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 173 | 174 | # add pos embed 175 | x = x + self.decoder_pos_embed 176 | 177 | # apply Transformer blocks 178 | for blk in self.decoder_blocks: 179 | x = blk(x) 180 | x = self.decoder_norm(x) 181 | 182 | # predictor projection 183 | x = self.decoder_pred(x) 184 | 185 | return x 186 | 187 | def forward(self, imgs, mask_ratio=0.75): 188 | if mask_ratio == 0: 189 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 190 | return None, None, latent, None 191 | else: 192 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 193 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 194 | full_latent, _, _ = self.forward_encoder(imgs, 0) 195 | return latent, pred, full_latent, mask 196 | 197 | def rils_vit_base_patch16_dec768d1b(**kwargs): 198 | model = MaskedAutoencoderViT( 199 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 200 | decoder_embed_dim=768, decoder_depth=1, decoder_num_heads=12, 201 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 202 | return model 203 | 204 | def rils_vit_large_patch16_dec1024d1b(**kwargs): 205 | model = MaskedAutoencoderViT( 206 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 207 | decoder_embed_dim=1024, decoder_depth=1, decoder_num_heads=16, 208 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 209 | return model 210 | -------------------------------------------------------------------------------- /pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import numpy as np 7 | import os 8 | import random 9 | import shutil 10 | import torch 11 | import torch.distributed as dist 12 | import torch.autograd as autograd 13 | 14 | from PIL import ImageFilter 15 | 16 | import os 17 | import sys 18 | import time 19 | import math 20 | import random 21 | import datetime 22 | import subprocess 23 | from collections import defaultdict, deque 24 | 25 | import numpy as np 26 | import torch 27 | from torch import nn 28 | import torch.distributed as dist 29 | from PIL import ImageFilter, ImageOps 30 | 31 | def get_model(model): 32 | if isinstance(model, torch.nn.DataParallel) \ 33 | or isinstance(model, torch.nn.parallel.DistributedDataParallel): 34 | return model.module 35 | else: 36 | return model 37 | 38 | 39 | def setup_for_distributed(is_master, args): 40 | """ 41 | This function disables printing when not in master process 42 | """ 43 | import logging 44 | import builtins 45 | import datetime 46 | builtin_print = builtins.print 47 | logging.basicConfig(filename=os.path.join(args.output_dir, 'output.txt'), 48 | filemode='a', 49 | format='%(asctime)s %(name)s %(levelname)s %(message)s', 50 | datefmt='%Y-%m-%d,%H:%M:%S', 51 | level=logging.INFO) 52 | 53 | def print(*args, **kwargs): 54 | force = kwargs.pop('force', False) 55 | force = force or (get_world_size() > 8) 56 | if is_master or force: 57 | now = datetime.datetime.now().time() 58 | builtin_print('[{}] '.format(now), end='') # print with time stamp 59 | builtin_print(*args, **kwargs) 60 | logging.info(*args) 61 | 62 | builtins.print = print 63 | 64 | 65 | def is_dist_avail_and_initialized(): 66 | if not dist.is_available(): 67 | return False 68 | if not dist.is_initialized(): 69 | return False 70 | return True 71 | 72 | 73 | def get_world_size(): 74 | if not is_dist_avail_and_initialized(): 75 | return 1 76 | return dist.get_world_size() 77 | 78 | 79 | def get_rank(): 80 | if not is_dist_avail_and_initialized(): 81 | return 0 82 | return dist.get_rank() 83 | 84 | 85 | def is_main_process(): 86 | return get_rank() == 0 87 | 88 | 89 | def save_on_master(state, is_best, output_dir): 90 | if is_main_process(): 91 | ckpt_path = f'{output_dir}/checkpoint.pt' 92 | best_path = f'{output_dir}/checkpoint_best.pt' 93 | torch.save(state, ckpt_path) 94 | if is_best: 95 | shutil.copyfile(ckpt_path, best_path) 96 | 97 | 98 | def init_distributed_mode(args): 99 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 100 | args.rank = int(os.environ["RANK"]) 101 | args.world_size = int(os.environ['WORLD_SIZE']) 102 | args.gpu = int(os.environ['LOCAL_RANK']) 103 | elif 'SLURM_PROCID' in os.environ: 104 | args.rank = int(os.environ['SLURM_PROCID']) 105 | args.gpu = args.rank % torch.cuda.device_count() 106 | else: 107 | print('Not using distributed mode') 108 | args.distributed = False 109 | return 110 | 111 | args.distributed = True 112 | 113 | torch.cuda.set_device(args.gpu) 114 | args.dist_backend = 'nccl' 115 | print('| distributed init (rank {}): {}'.format( 116 | args.rank, args.dist_url), flush=True) 117 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 118 | world_size=args.world_size, rank=args.rank) 119 | torch.distributed.barrier() 120 | setup_for_distributed(args.rank == 0, args) 121 | 122 | 123 | def scaled_all_reduce(tensors, is_scale=True): 124 | """Performs the scaled all_reduce operation on the provided tensors. 125 | The input tensors are modified in-place. Currently supports only the sum 126 | reduction operator. The reduced values are scaled by the inverse size of the 127 | world size. 128 | """ 129 | world_size = get_world_size() 130 | # There is no need for reduction in the single-proc case 131 | if world_size == 1: 132 | return tensors 133 | # Queue the reductions 134 | reductions = [] 135 | for tensor in tensors: 136 | reduction = dist.all_reduce(tensor, async_op=True) 137 | reductions.append(reduction) 138 | # Wait for reductions to finish 139 | for reduction in reductions: 140 | reduction.wait() 141 | # Scale the results 142 | if is_scale: 143 | for tensor in tensors: 144 | tensor.mul_(1.0 / world_size) 145 | return tensors 146 | 147 | 148 | def all_gather_batch(tensors): 149 | """ 150 | Performs all_gather operation on the provided tensors. 151 | """ 152 | # Queue the gathered tensors 153 | world_size = get_world_size() 154 | # There is no need for reduction in the single-proc case 155 | if world_size == 1: 156 | return tensors 157 | tensor_list = [] 158 | output_tensor = [] 159 | for tensor in tensors: 160 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] 161 | dist.all_gather( 162 | tensor_all, 163 | tensor, 164 | async_op=False # performance opt 165 | ) 166 | 167 | tensor_list.append(tensor_all) 168 | 169 | for tensor_all in tensor_list: 170 | output_tensor.append(torch.cat(tensor_all, dim=0)) 171 | return output_tensor 172 | 173 | 174 | class GatherLayer(autograd.Function): 175 | """ 176 | Gather tensors from all workers with support for backward propagation: 177 | This implementation does not cut the gradients as torch.distributed.all_gather does. 178 | """ 179 | 180 | @staticmethod 181 | def forward(ctx, x): 182 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 183 | dist.all_gather(output, x) 184 | return tuple(output) 185 | 186 | @staticmethod 187 | def backward(ctx, *grads): 188 | all_gradients = torch.stack(grads) 189 | dist.all_reduce(all_gradients) 190 | return all_gradients[dist.get_rank()] 191 | 192 | 193 | def all_gather_batch_with_grad(tensors): 194 | """ 195 | Performs all_gather operation on the provided tensors. 196 | Graph remains connected for backward grad computation. 197 | """ 198 | # Queue the gathered tensors 199 | world_size = get_world_size() 200 | # There is no need for reduction in the single-proc case 201 | if world_size == 1: 202 | return tensors 203 | tensor_list = [] 204 | output_tensor = [] 205 | 206 | for tensor in tensors: 207 | tensor_all = GatherLayer.apply(tensor) 208 | tensor_list.append(tensor_all) 209 | 210 | for tensor_all in tensor_list: 211 | output_tensor.append(torch.cat(tensor_all, dim=0)) 212 | return output_tensor 213 | 214 | 215 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 216 | warmup_schedule = np.array([]) 217 | warmup_iters = warmup_epochs * niter_per_ep 218 | if warmup_epochs > 0: 219 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 220 | 221 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 222 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 223 | 224 | schedule = np.concatenate((warmup_schedule, schedule)) 225 | assert len(schedule) == epochs * niter_per_ep 226 | return schedule 227 | 228 | 229 | class GaussianBlur(object): 230 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 231 | 232 | def __init__(self, sigma=[.1, 2.]): 233 | self.sigma = sigma 234 | 235 | def __call__(self, x): 236 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 237 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 238 | return x 239 | 240 | class SmoothedValue(object): 241 | """Track a series of values and provide access to smoothed values over a 242 | window or the global series average. 243 | """ 244 | 245 | def __init__(self, window_size=20, fmt=None): 246 | if fmt is None: 247 | fmt = "{median:.6f} ({global_avg:.6f})" 248 | self.deque = deque(maxlen=window_size) 249 | self.total = 0.0 250 | self.count = 0 251 | self.fmt = fmt 252 | 253 | def update(self, value, n=1): 254 | self.deque.append(value) 255 | self.count += n 256 | self.total += value * n 257 | 258 | def synchronize_between_processes(self): 259 | """ 260 | Warning: does not synchronize the deque! 261 | """ 262 | if not is_dist_avail_and_initialized(): 263 | return 264 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 265 | dist.barrier() 266 | dist.all_reduce(t) 267 | t = t.tolist() 268 | self.count = int(t[0]) 269 | self.total = t[1] 270 | 271 | @property 272 | def median(self): 273 | d = torch.tensor(list(self.deque)) 274 | return d.median().item() 275 | 276 | @property 277 | def avg(self): 278 | d = torch.tensor(list(self.deque), dtype=torch.float32) 279 | return d.mean().item() 280 | 281 | @property 282 | def global_avg(self): 283 | return self.total / self.count 284 | 285 | @property 286 | def max(self): 287 | return max(self.deque) 288 | 289 | @property 290 | def value(self): 291 | return self.deque[-1] 292 | 293 | def __str__(self): 294 | return self.fmt.format( 295 | median=self.median, 296 | avg=self.avg, 297 | global_avg=self.global_avg, 298 | max=self.max, 299 | value=self.value) 300 | 301 | class MetricLogger(object): 302 | def __init__(self, delimiter="\t"): 303 | self.meters = defaultdict(SmoothedValue) 304 | self.delimiter = delimiter 305 | 306 | def update(self, **kwargs): 307 | for k, v in kwargs.items(): 308 | if isinstance(v, torch.Tensor): 309 | v = v.item() 310 | assert isinstance(v, (float, int)) 311 | self.meters[k].update(v) 312 | 313 | def __getattr__(self, attr): 314 | if attr in self.meters: 315 | return self.meters[attr] 316 | if attr in self.__dict__: 317 | return self.__dict__[attr] 318 | raise AttributeError("'{}' object has no attribute '{}'".format( 319 | type(self).__name__, attr)) 320 | 321 | def __str__(self): 322 | loss_str = [] 323 | for name, meter in self.meters.items(): 324 | loss_str.append( 325 | "{}: {}".format(name, str(meter)) 326 | ) 327 | return self.delimiter.join(loss_str) 328 | 329 | def synchronize_between_processes(self): 330 | for meter in self.meters.values(): 331 | meter.synchronize_between_processes() 332 | 333 | def add_meter(self, name, meter): 334 | self.meters[name] = meter 335 | 336 | def log_every(self, iterable, print_freq, header=None): 337 | i = 0 338 | if not header: 339 | header = '' 340 | start_time = time.time() 341 | end = time.time() 342 | iter_time = SmoothedValue(fmt='{avg:.6f}') 343 | data_time = SmoothedValue(fmt='{avg:.6f}') 344 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 345 | if torch.cuda.is_available(): 346 | log_msg = self.delimiter.join([ 347 | header, 348 | '[{0' + space_fmt + '}/{1}]', 349 | 'eta: {eta}', 350 | '{meters}', 351 | 'time: {time}', 352 | 'data: {data}', 353 | 'max mem: {memory:.0f}' 354 | ]) 355 | else: 356 | log_msg = self.delimiter.join([ 357 | header, 358 | '[{0' + space_fmt + '}/{1}]', 359 | 'eta: {eta}', 360 | '{meters}', 361 | 'time: {time}', 362 | 'data: {data}' 363 | ]) 364 | MB = 1024.0 * 1024.0 365 | for obj in iterable: 366 | data_time.update(time.time() - end) 367 | yield obj 368 | iter_time.update(time.time() - end) 369 | if i % print_freq == 0 or i == len(iterable) - 1: 370 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 371 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 372 | if torch.cuda.is_available(): 373 | print(log_msg.format( 374 | i, len(iterable), eta=eta_string, 375 | meters=str(self), 376 | time=str(iter_time), data=str(data_time), 377 | memory=torch.cuda.max_memory_allocated() / MB)) 378 | else: 379 | print(log_msg.format( 380 | i, len(iterable), eta=eta_string, 381 | meters=str(self), 382 | time=str(iter_time), data=str(data_time))) 383 | i += 1 384 | end = time.time() 385 | total_time = time.time() - start_time 386 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 387 | print('{} Total time: {} ({:.6f} s / it)'.format( 388 | header, total_time_str, total_time / len(iterable))) -------------------------------------------------------------------------------- /vision_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import collections 3 | from itertools import repeat 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 11 | from timm.models.layers import variance_scaling_ 12 | 13 | def trunc_normal_(tensor, mean=0., std=1.): 14 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 15 | 16 | 17 | def lecun_normal_(tensor): 18 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') 19 | 20 | 21 | def _ntuple(n): 22 | def parse(x): 23 | if isinstance(x, collections.abc.Iterable): 24 | return x 25 | return tuple(repeat(x, n)) 26 | 27 | return parse 28 | 29 | 30 | to_1tuple = _ntuple(1) 31 | to_2tuple = _ntuple(2) 32 | to_3tuple = _ntuple(3) 33 | to_4tuple = _ntuple(4) 34 | to_ntuple = _ntuple 35 | 36 | 37 | def drop_path(x, drop_prob: float = 0., training: bool = False): 38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of 39 | residual blocks). 40 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 41 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 42 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 43 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 44 | 'survival rate' as the argument. 45 | """ 46 | if drop_prob == 0. or not training: 47 | return x 48 | keep_prob = 1 - drop_prob 49 | shape = (x.shape[0], ) + (1, ) * ( 50 | x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 51 | random_tensor = keep_prob + torch.rand( 52 | shape, dtype=x.dtype, device=x.device) 53 | random_tensor.floor_() # binarize 54 | output = x.div(keep_prob) * random_tensor 55 | return output 56 | 57 | 58 | class DropPath(nn.Module): 59 | """Drop paths (Stochastic Depth) per sample (when applied in main path of 60 | residual blocks).""" 61 | def __init__(self, drop_prob=None): 62 | super(DropPath, self).__init__() 63 | self.drop_prob = drop_prob 64 | 65 | def forward(self, x): 66 | return drop_path(x, self.drop_prob, self.training) 67 | 68 | def extra_repr(self) -> str: 69 | return 'p={}'.format(self.drop_prob) 70 | 71 | 72 | class PatchEmbed(nn.Module): 73 | """2D Image to Patch Embedding.""" 74 | def __init__(self, 75 | img_size=224, 76 | patch_size=16, 77 | in_chans=3, 78 | embed_dim=768, 79 | norm_layer=None, 80 | flatten=True): 81 | super().__init__() 82 | img_size = to_2tuple(img_size) 83 | patch_size = to_2tuple(patch_size) 84 | self.img_size = img_size 85 | self.patch_size = patch_size 86 | self.grid_size = (img_size[0] // patch_size[0], 87 | img_size[1] // patch_size[1]) 88 | self.num_patches = self.grid_size[0] * self.grid_size[1] 89 | self.flatten = flatten 90 | 91 | self.proj = nn.Conv2d(in_chans, 92 | embed_dim, 93 | kernel_size=patch_size, 94 | stride=patch_size) 95 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 96 | 97 | def forward(self, x): 98 | B, C, H, W = x.shape 99 | assert H == self.img_size[ 100 | 0], f"Input image height ({H}) doesn't match model ({self.img_size[0]})." 101 | assert W == self.img_size[ 102 | 1], f"Input image width ({W}) doesn't match model ({self.img_size[1]})." 103 | x = self.proj(x) 104 | if self.flatten: 105 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 106 | x = self.norm(x) 107 | return x 108 | 109 | 110 | class Attention(nn.Module): 111 | def __init__(self, 112 | dim, 113 | num_heads=8, 114 | qkv_bias=False, 115 | attn_drop=0., 116 | proj_drop=0., 117 | beit_qkv_bias=False): 118 | super().__init__() 119 | self.num_heads = num_heads 120 | head_dim = dim // num_heads 121 | self.scale = head_dim**-0.5 122 | self.beit_qkv_bias = beit_qkv_bias 123 | 124 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias and not beit_qkv_bias) 125 | if beit_qkv_bias: 126 | self.q_bias = nn.Parameter(torch.zeros(dim)) 127 | self.v_bias = nn.Parameter(torch.zeros(dim)) 128 | self.attn_drop = nn.Dropout(attn_drop) 129 | self.proj = nn.Linear(dim, dim) 130 | self.proj_drop = nn.Dropout(proj_drop) 131 | 132 | def forward(self, x, rel_pos_bias=None): 133 | B, N, C = x.shape 134 | if not self.beit_qkv_bias: 135 | qkv = self.qkv(x) 136 | else: 137 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 138 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 139 | 140 | qkv = qkv.reshape(B, N, 3, self.num_heads, 141 | C // self.num_heads).permute(2, 0, 3, 1, 4) 142 | q, k, v = qkv.unbind( 143 | 0) # make torchscript happy (cannot use tensor as tuple) 144 | 145 | attn = (q @ k.transpose(-2, -1)) * self.scale 146 | if rel_pos_bias is not None: 147 | attn = attn + rel_pos_bias 148 | attn = attn.softmax(dim=-1) 149 | attn = self.attn_drop(attn) 150 | 151 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 152 | x = self.proj(x) 153 | x = self.proj_drop(x) 154 | return x 155 | 156 | def extra_repr(self): 157 | if self.beit_qkv_bias: 158 | return f"(q_bias): torch.Tensor({self.q_bias.shape}, requires_grad={self.q_bias.requires_grad})" +\ 159 | f"\n(v_bias): torch.Tensor({self.v_bias.shape}, requires_grad={self.v_bias.requires_grad})" 160 | 161 | 162 | class Mlp(nn.Module): 163 | """MLP as used in Vision Transformer, MLP-Mixer and related networks.""" 164 | def __init__(self, 165 | in_features, 166 | hidden_features=None, 167 | out_features=None, 168 | act_layer=nn.GELU, 169 | drop=0.): 170 | super().__init__() 171 | out_features = out_features or in_features 172 | hidden_features = hidden_features or in_features 173 | drop_probs = to_2tuple(drop) 174 | 175 | self.fc1 = nn.Linear(in_features, hidden_features) 176 | self.act = act_layer() 177 | self.drop1 = nn.Dropout(drop_probs[0]) 178 | self.fc2 = nn.Linear(hidden_features, out_features) 179 | self.drop2 = nn.Dropout(drop_probs[1]) 180 | 181 | def forward(self, x): 182 | x = self.fc1(x) 183 | x = self.act(x) 184 | x = self.drop1(x) 185 | x = self.fc2(x) 186 | x = self.drop2(x) 187 | return x 188 | 189 | 190 | class Block(nn.Module): 191 | def __init__(self, 192 | dim, 193 | num_heads, 194 | mlp_ratio=4., 195 | qkv_bias=False, 196 | drop=0., 197 | attn_drop=0., 198 | drop_path=0., 199 | act_layer=nn.GELU, 200 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 201 | init_values=None, 202 | beit_qkv_bias=False): 203 | super().__init__() 204 | self.norm1 = norm_layer(dim) 205 | self.attn = Attention(dim, 206 | num_heads=num_heads, 207 | qkv_bias=qkv_bias, 208 | attn_drop=attn_drop, 209 | proj_drop=drop, 210 | beit_qkv_bias=beit_qkv_bias) 211 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 212 | self.drop_path = DropPath( 213 | drop_path) if drop_path > 0. else nn.Identity() 214 | self.norm2 = norm_layer(dim) 215 | mlp_hidden_dim = int(dim * mlp_ratio) 216 | self.mlp = Mlp(in_features=dim, 217 | hidden_features=mlp_hidden_dim, 218 | act_layer=act_layer, 219 | drop=drop) 220 | 221 | self.init_values = init_values 222 | if self.init_values is not None: 223 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 224 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 225 | 226 | def forward(self, x, rel_pos_bias=None): 227 | if self.init_values is None: 228 | x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) 229 | x = x + self.drop_path(self.mlp(self.norm2(x))) 230 | else: 231 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) 232 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 233 | return x 234 | 235 | class RelativePositionBias(nn.Module): 236 | 237 | def __init__(self, window_size, num_heads): 238 | super().__init__() 239 | self.window_size = window_size 240 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 241 | self.relative_position_bias_table = nn.Parameter( 242 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH 243 | 244 | # get pair-wise relative position index for each token inside the window 245 | coords_h = torch.arange(window_size[0]) 246 | coords_w = torch.arange(window_size[1]) 247 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 248 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 249 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 250 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 251 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 252 | relative_coords[:, :, 1] += window_size[1] - 1 253 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 254 | 255 | relative_position_index = \ 256 | torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) 257 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 258 | relative_position_index[0, 0:] = self.num_relative_distance - 3 259 | relative_position_index[0:, 0] = self.num_relative_distance - 2 260 | relative_position_index[0, 0] = self.num_relative_distance - 1 261 | 262 | self.register_buffer("relative_position_index", relative_position_index) 263 | 264 | # trunc_normal_(self.relative_position_bias_table, std=.02) 265 | 266 | def forward(self): 267 | relative_position_bias = \ 268 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 269 | self.window_size[0] * self.window_size[1] + 1, 270 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH 271 | return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 272 | 273 | class VisionTransformer(nn.Module): 274 | """Vision Transformer. 275 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 276 | - https://arxiv.org/abs/2010.11929 277 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 278 | - https://arxiv.org/abs/2012.12877 279 | """ 280 | def __init__(self, 281 | img_size=224, 282 | patch_size=16, 283 | in_chans=3, 284 | num_classes=1000, 285 | embed_dim=768, 286 | depth=12, 287 | num_heads=12, 288 | mlp_ratio=4., 289 | qkv_bias=True, 290 | representation_size=None, 291 | drop_rate=0., 292 | attn_drop_rate=0., 293 | drop_path_rate=0., 294 | embed_layer=PatchEmbed, 295 | norm_layer=None, 296 | act_layer=None, 297 | init_std=0.02, 298 | init_values=None, 299 | beit_qkv_bias=False, 300 | abs_pos_bias=True, 301 | rel_pos_bias=False, 302 | shared_rel_pos_bias=False): 303 | """ 304 | Args: 305 | img_size (int, tuple): input image size 306 | patch_size (int, tuple): patch size 307 | in_chans (int): number of input channels 308 | num_classes (int): number of classes for classification head 309 | embed_dim (int): embedding dimension 310 | depth (int): depth of transformer 311 | num_heads (int): number of attention heads 312 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 313 | qkv_bias (bool): enable bias for qkv if True 314 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 315 | drop_rate (float): dropout rate 316 | attn_drop_rate (float): attention dropout rate 317 | drop_path_rate (float): stochastic depth rate 318 | embed_layer (nn.Module): patch embedding layer 319 | norm_layer: (nn.Module): normalization layer 320 | """ 321 | super().__init__() 322 | self.num_classes = num_classes 323 | self.num_heads = num_heads 324 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 325 | self.num_tokens = 1 326 | self.abs_pos_bias = abs_pos_bias 327 | self.rel_pos_bias = rel_pos_bias 328 | self.shared_rel_pos_bias = shared_rel_pos_bias 329 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 330 | act_layer = act_layer or nn.GELU 331 | 332 | self.patch_embed = embed_layer(img_size=img_size, 333 | patch_size=patch_size, 334 | in_chans=in_chans, 335 | embed_dim=embed_dim) 336 | num_patches = self.patch_embed.num_patches 337 | 338 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 339 | if self.abs_pos_bias: 340 | self.pos_embed = nn.Parameter( 341 | torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 342 | elif self.rel_pos_bias: 343 | if self.shared_rel_pos_bias: 344 | self.rel_pos_bias = RelativePositionBias(self.patch_embed.grid_size, num_heads) 345 | else: 346 | raise NotImplementedError 347 | else: 348 | raise NotImplementedError 349 | self.pos_drop = nn.Dropout(p=drop_rate) 350 | 351 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) 352 | ] # stochastic depth decay rule 353 | self.blocks = nn.Sequential(*[ 354 | Block(dim=embed_dim, 355 | num_heads=num_heads, 356 | mlp_ratio=mlp_ratio, 357 | qkv_bias=qkv_bias, 358 | drop=drop_rate, 359 | attn_drop=attn_drop_rate, 360 | drop_path=dpr[i], 361 | norm_layer=norm_layer, 362 | act_layer=act_layer, 363 | init_values=init_values, 364 | beit_qkv_bias=beit_qkv_bias) for i in range(depth) 365 | ]) 366 | self.norm = norm_layer(embed_dim) 367 | 368 | 369 | # Representation layer 370 | if representation_size: 371 | self.num_features = representation_size 372 | self.pre_logits = nn.Sequential( 373 | OrderedDict([('fc', nn.Linear(embed_dim, representation_size)), 374 | ('act', nn.Tanh())])) 375 | else: 376 | self.pre_logits = nn.Identity() 377 | 378 | # Classifier head(s) 379 | self.head = nn.Linear( 380 | self.num_features, 381 | num_classes) if num_classes > 0 else nn.Identity() 382 | 383 | self.init_std = init_std 384 | if self.abs_pos_bias: 385 | trunc_normal_(self.pos_embed, std=self.init_std) 386 | trunc_normal_(self.cls_token, std=self.init_std) 387 | self.apply(self._init_weights) 388 | self.fix_init_weight() 389 | 390 | def fix_init_weight(self): 391 | def rescale(param, layer_id): 392 | param.div_(math.sqrt(2.0 * layer_id)) 393 | 394 | for layer_id, layer in enumerate(self.blocks): 395 | rescale(layer.attn.proj.weight.data, layer_id + 1) 396 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 397 | 398 | def _init_weights(self, m): 399 | if isinstance(m, nn.Linear): 400 | trunc_normal_(m.weight, std=self.init_std) 401 | if isinstance(m, nn.Linear) and m.bias is not None: 402 | nn.init.constant_(m.bias, 0) 403 | elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): 404 | nn.init.constant_(m.bias, 0) 405 | nn.init.constant_(m.weight, 1.0) 406 | elif isinstance(m, nn.Conv2d): 407 | lecun_normal_(m.weight) 408 | if m.bias is not None: 409 | nn.init.constant_(m.bias, 0) 410 | 411 | @torch.jit.ignore 412 | def no_weight_decay(self): 413 | return {'pos_embed', 'cls_token'} 414 | 415 | def forward_features(self, x): 416 | x = self.patch_embed(x) 417 | cls_token = self.cls_token.expand( 418 | x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 419 | x = torch.cat((cls_token, x), dim=1) 420 | 421 | if self.abs_pos_bias: 422 | x = self.pos_drop(x + self.pos_embed) 423 | x = self.blocks(x) 424 | elif self.rel_pos_bias and self.shared_rel_pos_bias: 425 | for blk in self.blocks: 426 | x = blk(x, rel_pos_bias=self.rel_pos_bias()) 427 | else: 428 | raise NotImplementedError 429 | 430 | x = self.norm(x) 431 | return self.pre_logits(x) 432 | 433 | def forward(self, x): 434 | x = self.forward_features(x) 435 | x = self.head(x) 436 | return x 437 | --------------------------------------------------------------------------------