├── models ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── __pycache__ │ │ ├── clip.cpython-37.pyc │ │ ├── model.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── simple_tokenizer.cpython-37.pyc │ ├── simple_tokenizer.py │ ├── clip.py │ └── model.py ├── __pycache__ │ ├── head.cpython-37.pyc │ ├── mae.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── clipmodel.cpython-37.pyc │ ├── swin_transformer.cpython-37.pyc │ └── vision_transformer.cpython-37.pyc ├── __init__.py ├── clipmodel.py ├── head.py └── mae.py ├── figures ├── logo.jpg ├── logo.png ├── data_show.jpg ├── firstIMG.jpg ├── framework.jpg ├── reconst_vis.jpg ├── attentionmaps.jpg └── VehicleMAE_tutorial_screeshot.png ├── VehicleMAE_poster.pdf ├── __pycache__ ├── loader.cpython-37.pyc ├── utils.cpython-37.pyc ├── utils.cpython-38.pyc ├── utils.cpython-39.pyc ├── datatxt.cpython-37.pyc ├── pos_embed.cpython-37.pyc └── datahouneed.cpython-37.pyc ├── OurDataset.py ├── masking_generator.py ├── pos_embed.py ├── requirements.txt ├── misc.py ├── utils.py ├── main.py └── README.md /models/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /figures/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/figures/logo.jpg -------------------------------------------------------------------------------- /figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/figures/logo.png -------------------------------------------------------------------------------- /VehicleMAE_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/VehicleMAE_poster.pdf -------------------------------------------------------------------------------- /figures/data_show.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/figures/data_show.jpg -------------------------------------------------------------------------------- /figures/firstIMG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/figures/firstIMG.jpg -------------------------------------------------------------------------------- /figures/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/figures/framework.jpg -------------------------------------------------------------------------------- /figures/reconst_vis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/figures/reconst_vis.jpg -------------------------------------------------------------------------------- /figures/attentionmaps.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/figures/attentionmaps.jpg -------------------------------------------------------------------------------- /__pycache__/loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/__pycache__/loader.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/datatxt.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/__pycache__/datatxt.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/pos_embed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/__pycache__/pos_embed.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/datahouneed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/__pycache__/datahouneed.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/head.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/models/__pycache__/head.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/mae.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/models/__pycache__/mae.cpython-37.pyc -------------------------------------------------------------------------------- /figures/VehicleMAE_tutorial_screeshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/figures/VehicleMAE_tutorial_screeshot.png -------------------------------------------------------------------------------- /models/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/models/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/clipmodel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/models/__pycache__/clipmodel.cpython-37.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/clip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/models/clip/__pycache__/clip.cpython-37.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/models/clip/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/models/clip/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/swin_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/models/__pycache__/swin_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/vision_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/models/__pycache__/vision_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /models/clip/__pycache__/simple_tokenizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Event-AHU/VehicleMAE/HEAD/models/clip/__pycache__/simple_tokenizer.cpython-37.pyc -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #from .vision_transformer import VisionTransformer, vit_tiny, vit_small, vit_base, vit_large 2 | #from .swin_transformer import SwinTransformer, swin_tiny, swin_small, swin_base, swin_large 3 | from .mae import MaskedAutoencoderViT, mae_vit_base_patch16, mae_vit_large_patch16, mae_vit_huge_patch14 4 | 5 | -------------------------------------------------------------------------------- /OurDataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch.utils.data import Dataset 3 | import pickle 4 | import os 5 | import torch 6 | from typing import Tuple, Optional, Union 7 | 8 | class OurDataset(Dataset): 9 | def __init__(self, pkl_path, transform = None): 10 | dataset_info = pickle.load(open(pkl_path, 'rb+')) 11 | 12 | self.transform=transform 13 | self.rgb_image_name = dataset_info.image_rgb_name 14 | self.lunkuo_image_name = dataset_info.image_lunkuo_name 15 | def __getitem__(self, index)-> Tuple[torch.Tensor, ...]: 16 | 17 | seed = torch.randint(0, 100000, (1,)).item() 18 | rgb_image= self.rgb_image_name[index] 19 | lunkuo_image= self.lunkuo_image_name[index] 20 | 21 | rgb_img_pil = Image.open(rgb_image) 22 | lunkuo_img_pil = Image.open(lunkuo_image).convert('RGB') 23 | 24 | if self.transform is not None: 25 | torch.manual_seed(seed) 26 | rgb_image = self.transform(rgb_img_pil) 27 | torch.manual_seed(seed) 28 | lunkuo_image = self.transform(lunkuo_img_pil) 29 | 30 | return rgb_image,lunkuo_image 31 | def __len__(self): 32 | return len(self.rgb_image_name) 33 | -------------------------------------------------------------------------------- /masking_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | #tube 3 | class TubeMaskingGenerator: 4 | def __init__(self, input_size, mask_ratio): 5 | self.frames, self.height, self.width = input_size 6 | self.num_patches_per_frame = self.height * self.width 7 | self.total_patches = self.frames * self.num_patches_per_frame 8 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) 9 | self.total_masks = self.frames * self.num_masks_per_frame 10 | 11 | def __repr__(self): 12 | repr_str = "Maks: total patches {}, mask patches {}".format( 13 | self.total_patches, self.total_masks 14 | ) 15 | return repr_str 16 | 17 | def __call__(self): 18 | mask_per_frame = np.hstack([ 19 | np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), 20 | np.ones(self.num_masks_per_frame), 21 | ]) 22 | np.random.shuffle(mask_per_frame) 23 | mask = np.tile(mask_per_frame, (self.frames,1)).flatten() 24 | return mask 25 | 26 | #random 27 | class RandomMaskingGenerator: 28 | def __init__(self, input_size, mask_ratio): 29 | self.frames, self.height, self.width = input_size 30 | self.num_patches_per_frame = self.height * self.width 31 | self.total_patches = self.frames * self.num_patches_per_frame 32 | self.total_masks = int(mask_ratio * self.total_patches) 33 | 34 | def __repr__(self): 35 | repr_str = "Maks: total patches {}, mask patches {}".format( 36 | self.total_patches, self.total_masks 37 | ) 38 | return repr_str 39 | 40 | def __call__(self): 41 | mask = np.hstack([ 42 | np.zeros(self.total_patches - self.total_masks), 43 | np.ones(self.total_masks), 44 | ]) 45 | np.random.shuffle(mask) 46 | return mask 47 | 48 | #random 49 | class EmptyMask: 50 | def __init__(self, input_size, mask_ratio): 51 | self.frames, self.height, self.width = input_size 52 | self.num_patches_per_frame = self.height * self.width 53 | self.total_patches = self.frames * self.num_patches_per_frame 54 | self.total_masks = int(mask_ratio * self.total_patches) 55 | 56 | def __repr__(self): 57 | repr_str = "Maks: total patches {}, mask patches {}".format( 58 | self.total_patches, self.total_masks 59 | ) 60 | return repr_str 61 | 62 | def __call__(self): 63 | return [] -------------------------------------------------------------------------------- /pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /models/clipmodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | #from datatxt import HICO_INTERACTIONS 4 | from datahouneed import HICO_INTERACTIONS 5 | from models.clip import clip 6 | import torch.nn.functional as F 7 | 8 | 9 | 10 | def build_clip(args): 11 | 12 | device = torch.device(args.device) 13 | 14 | # build_clip 15 | model_path = args.clip_pre_model 16 | 17 | clip_model, preprocess = clip.load(model_path, device=device) #加载预训练好的模型 18 | 19 | print("Turning off gradients in both the image and the text encoder") 20 | for name, param in clip_model.named_parameters(): 21 | # if "prompt_learner" not in name: 22 | param.requires_grad_(False) 23 | #提取特征 24 | ao_pair = [(d['brand'], d['color'], d['energy'], d['level'], d['long'], d['width'], d['high'], d['doors'], d['seats'], d['wheelbase'], d['years']) for d in HICO_INTERACTIONS] 25 | #生成clip文本 26 | text_inputs = torch.cat( #将多个tensor拼接 27 | [clip.tokenize("a picture of a {} {} {} car ,it is a {} {} ,its length is {}, its width is {}, its height is {}, its wheelbase is {}, it has {} doors and {} seats".format(c,y,b,e,le,lo,w,h,w,d,s,w)) for b, c,e,le,lo,w,h,d,s,w,y in ao_pair]).to(device) #类别构建 28 | #生成每个文本的特征 29 | text_features = clip_model.encode_text(text_inputs) 30 | #归一化特征 31 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 32 | #将text_features加载到GPU中 33 | text_features = text_features.to(device) 34 | 35 | 36 | return clip_model,preprocess,text_features 37 | 38 | 39 | 40 | class ClipBaseModel(nn.Module): 41 | """ 42 | Perform forward pass separately on each resolution input. 43 | The inputs corresponding to a single resolution are clubbed and single 44 | forward is run on the same resolution inputs. Hence we do several 45 | forward passes = number of different resolutions used. We then 46 | concatenate all the output features and run the head forward on these 47 | concatenated features. 48 | """ 49 | def __init__(self, clip_model ,text_features): 50 | super(ClipBaseModel, self).__init__() 51 | # disable layers dedicated to ImageNet labels classification 52 | self.clip = clip_model 53 | #self.preprocess = preprocess 54 | self.text_features = text_features 55 | #self.batch_size = batch_size 56 | #self.register_buffer("temperature", torch.tensor(temperature)) 57 | 58 | 59 | def Clip_Loss(self,image_features,student_image_features): 60 | #对学生网络输出特征进行处理,变为[B,P] 61 | #student_image_features = student_image_features[:,:,0:1] 62 | student_image_features = student_image_features.reshape(len(image_features),-1) 63 | 64 | #将特征l2归一化 65 | image_features_l2 = F.normalize(image_features, p=2, dim=1, eps=1e-12, out=None) 66 | student_image_features_l2 = F.normalize(student_image_features, p=2, dim=1, eps=1e-12, out=None) 67 | #((inputs - targets) ** 2).sum() / inputs.size(0) 68 | #计算相似性loss 69 | similarity_loss = torch.sum((image_features_l2-student_image_features_l2)** 2)/ image_features_l2.size(0) 70 | #F.kl_div(inputs, targets, reduction='batchmean') 71 | #计算图像及文本的交叉模态loss 72 | #余弦相似度作为 logits 73 | logit_scale = self.clip.logit_scale.exp() 74 | #计算每张图片和文本的余弦相似度 75 | logits_clip = logit_scale * image_features_l2 @ self.text_features.t() 76 | logits_mae = logit_scale * student_image_features_l2 @ self.text_features.t() 77 | 78 | #probs_clip = logits_clip.softmax(dim=-1).cpu().numpy() 79 | #probs_mae = logits_mae.softmax(dim=-1).cpu().detach().numpy() 80 | #计算教师网络和学生网络的KLloss 81 | kl_loss = nn.KLDivLoss(reduction='batchmean') 82 | #归一化 83 | temp = 1 84 | input = F.log_softmax(logits_mae/temp, dim=-1) 85 | target = F.softmax(logits_clip/temp, dim=-1) 86 | input_2 = F.softmax(logits_mae/temp, dim=0) 87 | #target = F.log_softmax(logits_clip, dim=-1) 88 | kl_distance_loss = kl_loss(input, target)-torch.mean(input_2 * torch.log2(input_2)) 89 | 90 | return similarity_loss,kl_distance_loss 91 | 92 | 93 | 94 | 95 | 96 | def forward(self,image,student_image_features): 97 | #图像处理 98 | #image = self.preprocess(image).unsqueeze(0) 99 | #生成图像特征 100 | image_features = self.clip.encode_image(image) 101 | 102 | #对图像特征进行归一化 103 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 104 | #计算每张图片和文本的余弦相似度 105 | similarity_loss,kl_distance_loss = self.Clip_Loss(image_features,student_image_features) 106 | 107 | return similarity_loss,kl_distance_loss 108 | -------------------------------------------------------------------------------- /models/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2 ** 8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2 ** 8 + n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152 - 256 - 2 + 1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v + '' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile( 79 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 80 | re.IGNORECASE) 81 | 82 | def bpe(self, token): 83 | if token in self.cache: 84 | return self.cache[token] 85 | word = tuple(token[:-1]) + (token[-1] + '',) 86 | pairs = get_pairs(word) 87 | 88 | if not pairs: 89 | return token + '' 90 | 91 | while True: 92 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 93 | if bigram not in self.bpe_ranks: 94 | break 95 | first, second = bigram 96 | new_word = [] 97 | i = 0 98 | while i < len(word): 99 | try: 100 | j = word.index(first, i) 101 | new_word.extend(word[i:j]) 102 | i = j 103 | except: 104 | new_word.extend(word[i:]) 105 | break 106 | 107 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 108 | new_word.append(first + second) 109 | i += 2 110 | else: 111 | new_word.append(word[i]) 112 | i += 1 113 | new_word = tuple(new_word) 114 | word = new_word 115 | if len(word) == 1: 116 | break 117 | else: 118 | pairs = get_pairs(word) 119 | word = ' '.join(word) 120 | self.cache[token] = word 121 | return word 122 | 123 | def encode(self, text): 124 | bpe_tokens = [] 125 | text = whitespace_clean(basic_clean(text)).lower() 126 | for token in re.findall(self.pat, text): 127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 129 | return bpe_tokens 130 | 131 | def decode(self, tokens): 132 | text = ''.join([self.decoder[token] for token in tokens]) 133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 134 | return text 135 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=4.5=1_gnu 6 | absl-py=1.2.0=pypi_0 7 | addict=2.4.0=pypi_0 8 | aiohttp=3.8.4=pypi_0 9 | aiosignal=1.3.1=pypi_0 10 | altgraph=0.17.3=pypi_0 11 | anyconfig=0.13.0=pypi_0 12 | anyio=3.6.1=pypi_0 13 | apex=0.1=pypi_0 14 | astunparse=1.6.3=pypi_0 15 | async-timeout=4.0.2=pypi_0 16 | asynctest=0.13.0=pypi_0 17 | attrs=22.2.0=pypi_0 18 | backcall=0.2.0=pyh9f0ad1d_0 19 | backports=1.0=pyhd8ed1ab_3 20 | backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 21 | blessings=1.7=pypi_0 22 | ca-certificates=2023.5.7=hbcca054_0 23 | cachetools=5.2.0=pypi_0 24 | certifi=2023.5.7=pyhd8ed1ab_0 25 | charset-normalizer=2.0.12=pypi_0 26 | click=8.1.3=pypi_0 27 | colorama=0.4.6=pypi_0 28 | commonmark=0.9.1=pypi_0 29 | conda-pack=0.7.1=pypi_0 30 | contextlib2=21.6.0=pypi_0 31 | cxxfilt=0.3.0=pypi_0 32 | cycler=0.11.0=pypi_0 33 | datasets=2.10.1=pypi_0 34 | decorator=5.1.1=pyhd8ed1ab_0 35 | dill=0.3.6=pypi_0 36 | easydict=1.10=pypi_0 37 | einops=0.6.1=pypi_0 38 | entrypoints=0.4=pyhd8ed1ab_0 39 | et-xmlfile=1.1.0=pypi_0 40 | exceptiongroup=1.1.2=pypi_0 41 | filelock=3.8.0=pypi_0 42 | flask=2.2.5=pypi_0 43 | flatbuffers=23.5.26=pypi_0 44 | fonttools=4.33.3=pypi_0 45 | frozenlist=1.3.3=pypi_0 46 | fsspec=2023.1.0=pypi_0 47 | ftfy=6.1.1=pypi_0 48 | gast=0.4.0=pypi_0 49 | google-auth=2.11.0=pypi_0 50 | google-auth-oauthlib=0.4.6=pypi_0 51 | google-pasta=0.2.0=pypi_0 52 | gpustat=0.6.0=pypi_0 53 | grad-cam=1.4.8=pypi_0 54 | grpcio=1.48.0=pypi_0 55 | h11=0.12.0=pypi_0 56 | h5py=3.8.0=pypi_0 57 | helper=2.5.0=pypi_0 58 | httpcore=0.15.0=pypi_0 59 | httpx=0.23.0=pypi_0 60 | huggingface-hub=0.10.1=pypi_0 61 | idna=3.3=pypi_0 62 | imageio=2.22.4=pypi_0 63 | imgviz=1.7.2=pypi_0 64 | importlib-metadata=4.12.0=pypi_0 65 | imutils=0.5.4=pypi_0 66 | iniconfig=2.0.0=pypi_0 67 | ipykernel=5.5.5=py37h085eea5_0 68 | ipython=7.33.0=py37h89c1867_0 69 | ipython_genutils=0.2.0=py_1 70 | itsdangerous=2.1.2=pypi_0 71 | jedi=0.18.2=pyhd8ed1ab_0 72 | jinja2=3.1.2=pypi_0 73 | joblib=1.1.0=pypi_0 74 | jsonpatch=1.32=pypi_0 75 | jsonpointer=2.3=pypi_0 76 | jupyter-core=4.12.0=pypi_0 77 | jupyter_client=5.3.4=py37_0 78 | keras=2.11.0=pypi_0 79 | kiwisolver=1.4.2=pypi_0 80 | ld_impl_linux-64=2.35.1=h7274673_9 81 | libclang=14.0.6=pypi_0 82 | libffi=3.3=he6710b0_2 83 | libgcc-ng=9.3.0=h5101ec6_17 84 | libgomp=9.3.0=h5101ec6_17 85 | libsodium=1.0.18=h36c2ea0_1 86 | libstdcxx-ng=9.3.0=hd4cf53a_17 87 | lmdb=1.3.0=pypi_0 88 | lxml=4.9.3=pypi_0 89 | markdown=3.4.1=pypi_0 90 | markupsafe=2.1.1=pypi_0 91 | mat4py=0.5.0=pypi_0 92 | matplotlib=3.5.1=pypi_0 93 | matplotlib-inline=0.1.6=pyhd8ed1ab_0 94 | ml-collections=0.1.1=pypi_0 95 | mmcv=1.3.1=pypi_0 96 | mmcv-full=1.3.17=pypi_0 97 | mmdet=2.25.2=pypi_0 98 | mmengine=0.8.1=pypi_0 99 | mmpretrain=1.0.0=pypi_0 100 | model-index=0.1.11=pypi_0 101 | modelindex=0.0.2=pypi_0 102 | multidict=6.0.4=pypi_0 103 | multiprocess=0.70.14=pypi_0 104 | munkres=1.1.4=pypi_0 105 | natsort=8.1.0=pypi_0 106 | ncurses=6.3=h7f8727e_2 107 | networkx=2.6.3=pypi_0 108 | nltk=3.7=pypi_0 109 | numpy=1.21.6=pypi_0 110 | nvidia-ml-py3=7.352.0=pypi_0 111 | oauthlib=3.2.0=pypi_0 112 | opencv-python=4.6.0.66=pypi_0 113 | openmim=0.3.4=pypi_0 114 | openpyxl=3.1.2=pypi_0 115 | openssl=1.1.1u=h7f8727e_0 116 | opt-einsum=3.3.0=pypi_0 117 | ordered-set=4.1.0=pypi_0 118 | packaging=23.1=pypi_0 119 | pandas=1.3.5=pypi_0 120 | parso=0.8.3=pyhd8ed1ab_0 121 | pexpect=4.8.0=pyh1a96a4e_2 122 | pickleshare=0.7.5=py_1003 123 | pillow=9.1.0=pypi_0 124 | pip=21.2.2=py37h06a4308_0 125 | pluggy=1.2.0=pypi_0 126 | polygon=1.0.9=pypi_0 127 | polygon3=3.0.9.1=pypi_0 128 | prettytable=3.6.0=pypi_0 129 | prompt-toolkit=3.0.38=pyha770c72_0 130 | protobuf=3.19.4=pypi_0 131 | psutil=5.9.0=pypi_0 132 | ptflops=0.7=pypi_0 133 | ptyprocess=0.7.0=pyhd3deb0d_0 134 | pyarrow=11.0.0=pypi_0 135 | pyasn1=0.4.8=pypi_0 136 | pyasn1-modules=0.2.8=pypi_0 137 | pycocotools=2.0.4=pypi_0 138 | pygments=2.14.0=pypi_0 139 | pyinstaller=5.8.0=pypi_0 140 | pyinstaller-hooks-contrib=2023.0=pypi_0 141 | pyparsing=3.0.8=pypi_0 142 | pytest=7.4.0=pypi_0 143 | python=3.7.13=h12debd9_0 144 | python-dateutil=2.8.2=pyhd8ed1ab_0 145 | python_abi=3.7=2_cp37m 146 | pytz=2022.7.1=pypi_0 147 | pywavelets=1.3.0=pypi_0 148 | pyyaml=6.0=pypi_0 149 | pyzmq=22.3.0=pypi_0 150 | readline=8.1.2=h7f8727e_1 151 | regex=2022.8.17=pypi_0 152 | requests=2.27.1=pypi_0 153 | requests-oauthlib=1.3.1=pypi_0 154 | responses=0.18.0=pypi_0 155 | rfc3986=1.5.0=pypi_0 156 | rich=13.1.0=pypi_0 157 | rsa=4.9=pypi_0 158 | safetensors=0.3.1=pypi_0 159 | scikit-image=0.16.2=pypi_0 160 | scikit-learn=1.0.2=pypi_0 161 | scipy=1.7.3=pypi_0 162 | seaborn=0.12.2=pypi_0 163 | sentence-transformers=2.2.2=pypi_0 164 | sentencepiece=0.1.97=pypi_0 165 | setuptools=61.2.0=py37h06a4308_0 166 | shapely=1.8.4=pypi_0 167 | six=1.16.0=pyh6c4a22f_0 168 | sniffio=1.3.0=pypi_0 169 | sqlite=3.38.2=hc218d9a_0 170 | tabulate=0.9.0=pypi_0 171 | tensorboard=2.11.2=pypi_0 172 | tensorboard-data-server=0.6.1=pypi_0 173 | tensorboard-plugin-wit=1.8.1=pypi_0 174 | tensorboardx=2.5.1=pypi_0 175 | tensorflow=2.11.0=pypi_0 176 | tensorflow-estimator=2.11.0=pypi_0 177 | tensorflow-io-gcs-filesystem=0.34.0=pypi_0 178 | termcolor=2.3.0=pypi_0 179 | terminaltables=3.1.10=pypi_0 180 | thop=0.1.1-2209072238=pypi_0 181 | threadpoolctl=3.1.0=pypi_0 182 | timm=0.3.2=pypi_0 183 | tk=8.6.11=h1ccaba5_0 184 | tokenizers=0.13.2=pypi_0 185 | tomli=2.0.1=pypi_0 186 | torch=1.8.1+cu111=pypi_0 187 | torchaudio=0.8.1=pypi_0 188 | torchfile=0.1.0=pypi_0 189 | torchnet=0.0.4=pypi_0 190 | torchvision=0.9.1+cu111=pypi_0 191 | tornado=6.1=py37h5e8e339_1 192 | tqdm=4.64.0=pypi_0 193 | traitlets=5.9.0=pyhd8ed1ab_0 194 | transformers=4.24.0=pypi_0 195 | ttach=0.0.3=pypi_0 196 | typing-extensions=4.2.0=pypi_0 197 | urllib3=1.26.9=pypi_0 198 | utils=1.0.1=pypi_0 199 | visdom=0.1.8.9=pypi_0 200 | wcwidth=0.2.5=pypi_0 201 | websocket-client=1.3.2=pypi_0 202 | websockets=10.3=pypi_0 203 | werkzeug=2.2.2=pypi_0 204 | wheel=0.37.1=pyhd3eb1b0_0 205 | wrapt=1.15.0=pypi_0 206 | xlwt=1.3.0=pypi_0 207 | xxhash=3.2.0=pypi_0 208 | xz=5.2.5=h7b6447c_0 209 | yace=0.5.6=pypi_0 210 | yacs=0.1.8=pypi_0 211 | yapf=0.32.0=pypi_0 212 | yarl=1.8.2=pypi_0 213 | zeromq=4.3.4=h9c3ff4c_0 214 | zipp=3.8.1=pypi_0 215 | zlib=1.2.12=h7f8727e_2 216 | -------------------------------------------------------------------------------- /models/head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import utils 10 | 11 | from utils import trunc_normal_ 12 | 13 | class CSyncBatchNorm(nn.SyncBatchNorm): 14 | def __init__(self, 15 | *args, 16 | with_var=False, 17 | **kwargs): 18 | super(CSyncBatchNorm, self).__init__(*args, **kwargs) 19 | self.with_var = with_var 20 | 21 | def forward(self, x): 22 | # center norm 23 | self.training = False 24 | if not self.with_var: 25 | self.running_var = torch.ones_like(self.running_var) 26 | normed_x = super(CSyncBatchNorm, self).forward(x) 27 | # udpate center 28 | self.training = True 29 | _ = super(CSyncBatchNorm, self).forward(x) 30 | return normed_x 31 | 32 | class PSyncBatchNorm(nn.SyncBatchNorm): 33 | def __init__(self, 34 | *args, 35 | bunch_size, 36 | **kwargs): 37 | procs_per_bunch = min(bunch_size, utils.get_world_size()) 38 | assert utils.get_world_size() % procs_per_bunch == 0 39 | n_bunch = utils.get_world_size() // procs_per_bunch 40 | # 41 | ranks = list(range(utils.get_world_size())) 42 | print('---ALL RANKS----\n{}'.format(ranks)) 43 | rank_groups = [ranks[i*procs_per_bunch: (i+1)*procs_per_bunch] for i in range(n_bunch)] 44 | print('---RANK GROUPS----\n{}'.format(rank_groups)) 45 | process_groups = [torch.distributed.new_group(pids) for pids in rank_groups] 46 | bunch_id = utils.get_rank() // procs_per_bunch 47 | process_group = process_groups[bunch_id] 48 | print('---CURRENT GROUP----\n{}'.format(process_group)) 49 | super(PSyncBatchNorm, self).__init__(*args, process_group=process_group, **kwargs) 50 | 51 | class CustomSequential(nn.Sequential): 52 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 53 | 54 | def forward(self, input): 55 | #print(input) 56 | for module in self: 57 | dim = len(input.shape) 58 | if isinstance(module, self.bn_types) and dim > 2: 59 | perm = list(range(dim - 1)); perm.insert(1, dim - 1) 60 | inv_perm = list(range(dim)) + [1]; inv_perm.pop(1) 61 | input = module(input.permute(*perm)).permute(*inv_perm) 62 | else: 63 | input = module(input) 64 | return input 65 | 66 | class DINOHead(nn.Module): 67 | def __init__(self, in_dim, out_dim, norm=None, act='gelu', last_norm=None, 68 | nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, **kwargs): 69 | super().__init__() 70 | norm = self._build_norm(norm, hidden_dim) 71 | last_norm = self._build_norm(last_norm, out_dim, affine=False, **kwargs) 72 | act = self._build_act(act) 73 | 74 | nlayers = max(nlayers, 1) 75 | if nlayers == 1: 76 | if bottleneck_dim > 0: 77 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 78 | else: 79 | self.mlp = nn.Linear(in_dim, out_dim) 80 | else: 81 | layers = [nn.Linear(in_dim, hidden_dim)] 82 | if norm is not None: 83 | layers.append(norm) 84 | layers.append(act) 85 | for _ in range(nlayers - 2): 86 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 87 | if norm is not None: 88 | layers.append(norm) 89 | layers.append(act) 90 | if bottleneck_dim > 0: 91 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 92 | else: 93 | layers.append(nn.Linear(hidden_dim, out_dim)) 94 | self.mlp = CustomSequential(*layers) 95 | self.apply(self._init_weights) 96 | 97 | if bottleneck_dim > 0: 98 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 99 | self.last_layer.weight_g.data.fill_(1) 100 | if norm_last_layer: 101 | self.last_layer.weight_g.requires_grad = False 102 | else: 103 | self.last_layer = None 104 | 105 | self.last_norm = last_norm 106 | 107 | def _init_weights(self, m): 108 | if isinstance(m, nn.Linear): 109 | trunc_normal_(m.weight, std=.02) 110 | if isinstance(m, nn.Linear) and m.bias is not None: 111 | nn.init.constant_(m.bias, 0) 112 | 113 | def forward(self, x): 114 | x = self.mlp(x) 115 | if self.last_layer is not None: 116 | x = nn.functional.normalize(x, dim=-1, p=2) 117 | x = self.last_layer(x) 118 | if self.last_norm is not None: 119 | x = self.last_norm(x) 120 | return x 121 | 122 | def _build_norm(self, norm, hidden_dim, **kwargs): 123 | if norm == 'bn': 124 | norm = nn.BatchNorm1d(hidden_dim, **kwargs) 125 | elif norm == 'syncbn': 126 | norm = nn.SyncBatchNorm(hidden_dim, **kwargs) 127 | elif norm == 'csyncbn': 128 | norm = CSyncBatchNorm(hidden_dim, **kwargs) 129 | elif norm == 'psyncbn': 130 | norm = PSyncBatchNorm(hidden_dim, **kwargs) 131 | elif norm == 'ln': 132 | norm = nn.LayerNorm(hidden_dim, **kwargs) 133 | else: 134 | assert norm is None, "unknown norm type {}".format(norm) 135 | return norm 136 | 137 | def _build_act(self, act): 138 | if act == 'relu': 139 | act = nn.ReLU() 140 | elif act == 'gelu': 141 | act = nn.GELU() 142 | else: 143 | assert False, "unknown act type {}".format(act) 144 | return act 145 | 146 | class Head(DINOHead): 147 | 148 | def __init__(self, *args, patch_out_dim=8192, norm=None, act='gelu', last_norm=None, 149 | nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, 150 | shared_head=False, **kwargs): 151 | 152 | super(Head, self).__init__(*args, 153 | norm=norm, 154 | act=act, 155 | last_norm=last_norm, 156 | nlayers=nlayers, 157 | hidden_dim=hidden_dim, 158 | bottleneck_dim=bottleneck_dim, 159 | norm_last_layer=norm_last_layer, 160 | **kwargs) 161 | 162 | if not shared_head: 163 | if bottleneck_dim > 0: 164 | self.last_layer2 = nn.utils.weight_norm(nn.Linear(bottleneck_dim, patch_out_dim, bias=False)) 165 | self.last_layer2.weight_g.data.fill_(1) 166 | if norm_last_layer: 167 | self.last_layer2.weight_g.requires_grad = False 168 | else: 169 | self.mlp2 = nn.Linear(hidden_dim, patch_out_dim) 170 | self.last_layer2 = None 171 | 172 | self.last_norm2 = self._build_norm(last_norm, patch_out_dim, affine=False, **kwargs) 173 | else: 174 | if bottleneck_dim > 0: 175 | self.last_layer2 = self.last_layer 176 | else: 177 | self.mlp2 = self.mlp[-1] 178 | self.last_layer2 = None 179 | 180 | self.last_norm2 = self.last_norm 181 | 182 | def forward(self, x): 183 | if len(x.shape) == 2: 184 | return super(Head, self).forward(x) 185 | 186 | if self.last_layer is not None: 187 | x = self.mlp(x) 188 | x = nn.functional.normalize(x, dim=-1, p=2) 189 | x1 = self.last_layer(x[:, 0]) 190 | x2 = self.last_layer2(x[:, 1:]) 191 | else: 192 | x = self.mlp[:-1](x) 193 | x1 = self.mlp[-1](x[:, 0]) 194 | x2 = self.mlp2(x[:, 1:]) 195 | 196 | if self.last_norm is not None: 197 | x1 = self.last_norm(x1) 198 | x2 = self.last_norm2(x2) 199 | 200 | return x1, x2 201 | -------------------------------------------------------------------------------- /models/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | __all__ = ["available_models", "load", "tokenize"] 26 | _tokenizer = _Tokenizer() 27 | 28 | _MODELS = { 29 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 30 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 31 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 32 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 33 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 34 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 35 | } 36 | 37 | 38 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 39 | os.makedirs(root, exist_ok=True) 40 | filename = os.path.basename(url) 41 | 42 | expected_sha256 = url.split("/")[-2] 43 | download_target = os.path.join(root, filename) 44 | 45 | if os.path.exists(download_target) and not os.path.isfile(download_target): 46 | raise RuntimeError(f"{download_target} exists and is not a regular file") 47 | 48 | if os.path.isfile(download_target): 49 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 50 | return download_target 51 | else: 52 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 53 | 54 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 55 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 56 | while True: 57 | buffer = source.read(8192) 58 | if not buffer: 59 | break 60 | 61 | output.write(buffer) 62 | loop.update(len(buffer)) 63 | 64 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 65 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 66 | 67 | return download_target 68 | 69 | 70 | def _transform(n_px): 71 | return Compose([ 72 | Resize(n_px, interpolation=BICUBIC), 73 | CenterCrop(n_px), 74 | lambda image: image.convert("RGB"), 75 | ToTensor(), 76 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 77 | ]) 78 | 79 | 80 | def available_models() -> List[str]: 81 | """Returns the names of available CLIP models""" 82 | return list(_MODELS.keys()) 83 | 84 | 85 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 86 | """Load a CLIP model 87 | 88 | Parameters 89 | ---------- 90 | name : str 91 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 92 | 字符串,用于指定CLIP使用的图像编码器模型 93 | 94 | device : Union[str, torch.device] 95 | The device to put the loaded model 96 | 字符串或者torch.device的输出结果。用于指定加载模型的设备,gpu或者cpu 97 | 98 | jit : bool 99 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 100 | 布尔值,是否加载优化的JIT模型 101 | 102 | Returns:字符串。用于指定下载的模型的保存地址,默认值如下代码所示 103 | 104 | ------- 105 | model : torch.nn.Module 106 | The CLIP model, 107 | 108 | preprocess : Callable[[PIL.Image], torch.Tensor] 109 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 110 | """ 111 | if name in _MODELS: 112 | model_path = _download(_MODELS[name]) 113 | elif os.path.isfile(name): 114 | model_path = name 115 | else: 116 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 117 | 118 | try: 119 | # loading JIT archive 120 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 121 | state_dict = None 122 | except RuntimeError: 123 | # loading saved state dict 124 | if jit: 125 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 126 | jit = False 127 | state_dict = torch.load(model_path, map_location="cpu") 128 | 129 | if not jit: 130 | model = build_model(state_dict or model.state_dict()).to(device) 131 | if str(device) == "cpu": 132 | model.float() 133 | return model, _transform(model.visual.input_resolution) 134 | 135 | 136 | # patch the device names 137 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 138 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 139 | 140 | def patch_device(module): 141 | try: 142 | graphs = [module.graph] if hasattr(module, "graph") else [] 143 | except RuntimeError: 144 | graphs = [] 145 | 146 | if hasattr(module, "forward1"): 147 | graphs.append(module.forward1.graph) 148 | 149 | for graph in graphs: 150 | for node in graph.findAllNodes("prim::Constant"): 151 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 152 | node.copyAttributes(device_node) 153 | 154 | model.apply(patch_device) 155 | patch_device(model.encode_image) 156 | patch_device(model.encode_text) 157 | 158 | # patch dtype to float32 on CPU 159 | if str(device) == "cpu": 160 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 161 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 162 | float_node = float_input.node() 163 | 164 | def patch_float(module): 165 | try: 166 | graphs = [module.graph] if hasattr(module, "graph") else [] 167 | except RuntimeError: 168 | graphs = [] 169 | 170 | if hasattr(module, "forward1"): 171 | graphs.append(module.forward1.graph) 172 | 173 | for graph in graphs: 174 | for node in graph.findAllNodes("aten::to"): 175 | inputs = list(node.inputs()) 176 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 177 | if inputs[i].node()["value"] == 5: 178 | inputs[i].node().copyAttributes(float_node) 179 | 180 | model.apply(patch_float) 181 | patch_float(model.encode_image) 182 | patch_float(model.encode_text) 183 | 184 | model.float() 185 | 186 | return model, _transform(model.input_resolution.item()) 187 | 188 | 189 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 190 | """ 191 | Returns the tokenized representation of given input string(s) 192 | 193 | Parameters 194 | ---------- 195 | texts : Union[str, List[str]] 196 | An input string or a list of input strings to tokenize 197 | 198 | context_length : int 199 | The context length to use; all CLIP models use 77 as the context length 200 | 201 | truncate: bool 202 | Whether to truncate the text in case its encoding is longer than the context length 203 | 204 | Returns 205 | ------- 206 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 207 | """ 208 | if isinstance(texts, str): 209 | texts = [texts] 210 | 211 | sot_token = _tokenizer.encoder["<|startoftext|>"] 212 | eot_token = _tokenizer.encoder["<|endoftext|>"] 213 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 214 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 215 | 216 | for i, tokens in enumerate(all_tokens): 217 | if len(tokens) > context_length: 218 | if truncate: 219 | tokens = tokens[:context_length] 220 | tokens[-1] = eot_token 221 | else: 222 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 223 | result[i, :len(tokens)] = torch.tensor(tokens) 224 | 225 | return result 226 | -------------------------------------------------------------------------------- /models/mae.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) ByteDance, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # -------------------------------------------------------- 8 | # References: 9 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 10 | # MAE: https://github.com/facebookresearch/mae 11 | # -------------------------------------------------------- 12 | 13 | from functools import partial 14 | import torch 15 | import numpy as np 16 | import math 17 | import random 18 | import torch.nn as nn 19 | 20 | # Copyright (c) Meta Platforms, Inc. and affiliates. 21 | # All rights reserved. 22 | 23 | # This source code is licensed under the license found in the 24 | # LICENSE file in the root directory of this source tree. 25 | # -------------------------------------------------------- 26 | # References: 27 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 28 | # DeiT: https://github.com/facebookresearch/deit 29 | # -------------------------------------------------------- 30 | 31 | from functools import partial 32 | 33 | import torch 34 | import torch.nn as nn 35 | 36 | from timm.models.vision_transformer import PatchEmbed, Block 37 | 38 | from pos_embed import get_2d_sincos_pos_embed 39 | 40 | class MaskedAutoencoderViT(nn.Module):#基于vit实现的mae 41 | """ Masked Autoencoder with VisionTransformer backbone 42 | img_size:输入图像宽和高。patch_size:每一个patch的宽和高。in_chans:输入通道数。 43 | embed_dim:mse编码器的Hidden size。depth:mae中transform的块数Layers的层数。num_heads编码器的头数 44 | 解码器的三个参数 45 | 编码器的输出需要降维 46 | """ 47 | 48 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 49 | embed_dim=768, depth=12, num_heads=16, 50 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 51 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,mask_ratio=0.75,use_learnable_pos_emb=True): 52 | super().__init__() 53 | 54 | # -------------------------------------------------------------------------- 55 | # MAE encoder specifics 56 | #图片喂入就可以得到patch的序列 57 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)#实例化需要传入照片的大小,patch的大小,照片输入的通道数,embed_dim 58 | #num_patches:得到块的数量 59 | num_patches = self.patch_embed.num_patches 60 | #实例化两个参数cls_token和pos_embed 61 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))#可训练的 62 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding不可训练的,位置编码requires_grad不可训练 63 | #定义编码器的transform blocks,使用Module的列表,不能使用普通列表,不然无法被Module识别 64 | self.blocks = nn.ModuleList([ 65 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)#多头注意力头数num_heads,, qk_scale=None 66 | for i in range(depth)])#最新版timm要注释掉qk_scale 67 | self.norm = norm_layer(embed_dim)#对encode output做归一化 68 | # -------------------------------------------------------------------------- 69 | 70 | # -------------------------------------------------------------------------- 71 | # MAE decoder specifics 72 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)#构建一个线性映射层,将编码器的输出映射到解码器的特征维度上 73 | ''' 74 | # Probability prediction network 概率预测网络 75 | self.pos_embed_probs = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 76 | self.get_token_probs = nn.Sequential( 77 | Block(dim=embed_dim, num_heads=8, mlp_ratio=4., qkv_bias=False, 78 | norm_layer=nn.LayerNorm), 79 | nn.Linear(embed_dim, 1), 80 | torch.nn.Flatten(start_dim=1), 81 | ) #轻量级的多头注意网络(MHA) 82 | self.softmax = nn.Softmax(dim=-1) #mask中的softmax激活 83 | ''' 84 | #self.visible_patches = int(num_patches*(1-mask_ratio))#计算未被掩码的pach数 85 | 86 | self.apply(self._init_weights) 87 | 88 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))#可训练的token,用于替换掉那些被mask的块 89 | 90 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 91 | 92 | #定义解码器的transform blocks 93 | self.decoder_blocks = nn.ModuleList([ 94 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)#qk_scale=None, 95 | for i in range(decoder_depth)]) 96 | 97 | self.decoder_norm = norm_layer(decoder_embed_dim)#对decoder output做归一化 98 | #patch_size**2 * in_chans,patch的面积乘上通道数,映射层 99 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch 100 | self.decoder_image = nn.Linear(196, 1, bias=True) 101 | #self.decode_clip = nn.Linear(1024, 768, bias=True) 102 | # -------------------------------------------------------------------------- 103 | 104 | self.norm_pix_loss = norm_pix_loss#是否要对像素做归一化再去算loss 105 | self.use_learnable_pos_emb = use_learnable_pos_emb 106 | self.initialize_weights() 107 | #初始化权重 108 | def initialize_weights(self): 109 | # initialization 110 | # initialize (and freeze) pos_embed by sin-cos embedding对pos_embed进行初始化 111 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 112 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 113 | 114 | #对decode的pos_embed进行初始化 115 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 116 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 117 | 118 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d)均匀分布的初始化 119 | w = self.patch_embed.proj.weight.data 120 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 121 | 122 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)高斯分布的初始化 123 | torch.nn.init.normal_(self.cls_token, std=.02) 124 | torch.nn.init.normal_(self.mask_token, std=.02) 125 | 126 | # initialize nn.Linear and nn.LayerNorm 127 | self.apply(self._init_weights)#接受的参数是一个函数,函数会作用在当前这个Module和当前这个Module的子Module 128 | 129 | def _init_weights(self, m): 130 | if isinstance(m, nn.Linear):#如果进入这个模块的是nn.Linear的实例,则会对权重做均匀分布的初始化 131 | # we use xavier_uniform following official JAX ViT: 132 | torch.nn.init.xavier_uniform_(m.weight) 133 | if isinstance(m, nn.Linear) and m.bias is not None:#如果有bias则做一个bias为0的初始化 134 | nn.init.constant_(m.bias, 0) 135 | elif isinstance(m, nn.LayerNorm):#对层归一化逻辑进行判断,如果是则将权重和偏置做常数初始化 136 | nn.init.constant_(m.bias, 0) 137 | nn.init.constant_(m.weight, 1.0) 138 | 139 | def patchify(self, imgs):#把图片划分成块 140 | """ 141 | imgs: (N, 3, H, W) 142 | x: (N, L, patch_size**2 *3) patch_size**2 *3:单张图像的像素点个数,L:为图像尺寸,N为batch的大小 143 | """ 144 | p = self.patch_embed.patch_size[0] #p=patch_size的大小 145 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 #若img图像尺寸相对于patch_size不整除则跳出 146 | 147 | h = w = imgs.shape[2] // p 148 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 149 | x = torch.einsum('nchpwq->nhwpqc', x) 150 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))#batch_size*patch的数目*patch的大小 151 | return x 152 | 153 | def unpatchify(self, x):#把块的图片还原成图片 154 | """ 155 | x: (N, L, patch_size**2 *3) # 156 | imgs: (N, 3, H, W) 157 | """ 158 | p = self.patch_embed.patch_size[0] 159 | h = w = int(x.shape[1]**.5) 160 | assert h * w == x.shape[1] 161 | 162 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 163 | x = torch.einsum('nhwpqc->nchpwq', x) 164 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 165 | return imgs 166 | 167 | def random_masking(self, x, mask_ratio):#随机掩码 168 | """ 169 | Perform per-sample random masking by per-sample shuffling. 170 | Per-sample shuffling is done by argsort random noise. 171 | x: [N, L, D], sequence 172 | """ 173 | N, L, D = x.shape # batch, length, dim 174 | len_keep = int(L * (1 - mask_ratio))#算出保留块的数目 175 | 176 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]生成随机矩阵,均匀分布 batch_size*length 177 | 178 | # sort noise for each sample 对每个样本的噪声进行排序 179 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove排序从小到大取掩码和被掩码时用的 180 | ids_restore = torch.argsort(ids_shuffle, dim=1)#对索引进行排序,还原序列时要用到的 181 | 182 | # keep the first subset 保留第一个子集 183 | ids_keep = ids_shuffle[:, :len_keep]# 184 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))#dim:维度。未被掩码的序列 185 | 186 | # generate the binary mask: 0 is keep, 1 is remove 187 | mask = torch.ones([N, L], device=x.device)#batch_size*lenth大小的提供给decode使用的全一矩阵 188 | mask[:, :len_keep] = 0#对未掩码设置为0 189 | # unshuffle to get the binary mask 190 | mask = torch.gather(mask, dim=1, index=ids_restore)#在原图中被掩码的位置 191 | 192 | return x_masked, mask, ids_restore 193 | 194 | def forward_encoder(self, x, mask_ratio):#x:输入图像。mask_ratio:掩码比例 195 | # embed patches 196 | x = self.patch_embed(x) 197 | 198 | # add pos embed w/o cls token 199 | x = x + self.pos_embed[:, 1:, :] 200 | 201 | x, mask, ids_restore = self.random_masking(x, mask_ratio)#随机掩码ids_restore:恢复的索引 202 | 203 | # append cls token 204 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 205 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 206 | x = torch.cat((cls_tokens, x), dim=1)#encode的输入 207 | 208 | # apply Transformer blocks 209 | for blk in self.blocks: 210 | x = blk(x) 211 | x = self.norm(x)#encode的输出 212 | 213 | return x, mask, ids_restore 214 | 215 | def forward_decoder(self, x, ids_restore): 216 | # embed tokens 217 | x = self.decoder_embed(x) 218 | # append mask tokens to sequence 219 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)#扩维 220 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token,拼接 221 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle,得到被还原后的顺序 222 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 223 | 224 | # add pos embed加上位置编码 225 | x = x + self.decoder_pos_embed 226 | 227 | 228 | # apply Transformer blocks对decode的所有block进行递归的调用 229 | for blk in self.decoder_blocks: 230 | x = blk(x) 231 | x = self.decoder_norm(x)#得到decode输出 232 | 233 | tezheng = x 234 | 235 | clip_tezheng = tezheng 236 | clip_tezheng = clip_tezheng[:, 1:, :] 237 | clip_tezheng = clip_tezheng.transpose(1,2) 238 | clip_tezheng = self.decoder_image(clip_tezheng) 239 | 240 | # predictor projection映射到像素的特征上 241 | x = self.decoder_pred(x) 242 | 243 | # remove cls token移除cls token 244 | x = x[:, 1:, :] 245 | 246 | return x ,tezheng,clip_tezheng 247 | 248 | def forward_loss(self, imgs, pred, mask):#平方差loss 249 | """ 250 | imgs: [N, 3, H, W] 251 | pred: [N, L, p*p*3] 252 | mask: [N, L], 0 is keep, 1 is remove, 253 | """ 254 | target = self.patchify(imgs) #原始图像patch后 255 | if self.norm_pix_loss: 256 | mean = target.mean(dim=-1, keepdim=True)#计算均值 257 | var = target.var(dim=-1, keepdim=True)#计算方差 258 | target = (target - mean) / (var + 1.e-6)**.5#对target做均值方差的归一化,1.e-6:防止方差为0 259 | 260 | loss = (pred - target) ** 2 #计算均方损失函数 261 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch,均值 262 | #loss = (loss * mask).sum() / mask.sum() # 将每个mask后的patch累加 263 | 264 | return loss 265 | 266 | 267 | def del_tensor_ele_n(self,arr, index, n): 268 | """ 269 | arr: 输入tensor 270 | index: 需要删除位置的索引 271 | n: 从index开始,需要删除的行数 272 | """ 273 | arr1 = arr[0:index] 274 | arr2 = arr[index+n:] 275 | return torch.cat((arr1,arr2),dim=0) 276 | 277 | @torch.jit.ignore 278 | def no_weight_decay(self): 279 | return {'pos_embed', 'cls_token', 'mask_token'} 280 | 281 | def forward(self, imgs, mask_ratio):#结合返回 282 | if mask_ratio != 0: 283 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 284 | pred,tezheng,clip_tezheng = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 285 | 286 | loss= self.forward_loss(imgs, pred, mask) 287 | return tezheng,loss,mask,ids_restore,clip_tezheng 288 | else: 289 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 290 | latent = self.decoder_embed(latent) 291 | 292 | return latent 293 | 294 | #三种mae,mae_vit_base,mae_vit_large,mae_vit_huge 295 | def mae_vit_base_patch16_dec512d8b(**kwargs): 296 | model = MaskedAutoencoderViT( 297 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 298 | decoder_embed_dim=512, decoder_num_heads=16, 299 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 300 | return model 301 | 302 | #decoder_depth=8, 303 | def mae_vit_large_patch16_dec512d8b(**kwargs): 304 | model = MaskedAutoencoderViT( 305 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 306 | decoder_embed_dim=512, decoder_num_heads=16, 307 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 308 | return model 309 | 310 | 311 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 312 | model = MaskedAutoencoderViT( 313 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 314 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 315 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 316 | return model 317 | 318 | 319 | # set recommended archs 320 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 321 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 322 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 323 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # MAE: https://github.com/facebookresearch/mae 9 | # -------------------------------------------------------- 10 | 11 | import builtins 12 | import datetime 13 | from email.policy import strict 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t", fn=None): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | self.fn = fn 91 | 92 | def update(self, **kwargs): 93 | for k, v in kwargs.items(): 94 | if v is None: 95 | continue 96 | if isinstance(v, torch.Tensor): 97 | v = v.item() 98 | assert isinstance(v, (float, int)) 99 | self.meters[k].update(v) 100 | 101 | def __getattr__(self, attr): 102 | if attr in self.meters: 103 | return self.meters[attr] 104 | if attr in self.__dict__: 105 | return self.__dict__[attr] 106 | raise AttributeError("'{}' object has no attribute '{}'".format( 107 | type(self).__name__, attr)) 108 | 109 | def __str__(self): 110 | loss_str = [] 111 | for name, meter in self.meters.items(): 112 | loss_str.append( 113 | "{}: {}".format(name, str(meter)) 114 | ) 115 | return self.delimiter.join(loss_str) 116 | 117 | def synchronize_between_processes(self): 118 | for meter in self.meters.values(): 119 | meter.synchronize_between_processes() 120 | 121 | def add_meter(self, name, meter): 122 | self.meters[name] = meter 123 | 124 | def log_every(self, iterable, print_freq, header=None): 125 | i = 0 126 | if not header: 127 | header = '' 128 | start_time = time.time() 129 | end = time.time() 130 | iter_time = SmoothedValue(fmt='{avg:.4f}') 131 | data_time = SmoothedValue(fmt='{avg:.4f}') 132 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 133 | log_msg = [ 134 | header, 135 | '[{0' + space_fmt + '}/{1}]', 136 | 'eta: {eta}', 137 | '{meters}', 138 | 'time: {time}', 139 | 'data: {data}' 140 | ] 141 | if torch.cuda.is_available(): 142 | log_msg.append('max mem: {memory:.0f}') 143 | log_msg = self.delimiter.join(log_msg) 144 | MB = 1024.0 * 1024.0 145 | for obj in iterable: 146 | data_time.update(time.time() - end) 147 | yield obj 148 | iter_time.update(time.time() - end) 149 | if i % print_freq == 0 or i == len(iterable) - 1: 150 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 151 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 152 | if torch.cuda.is_available(): 153 | if is_main_process(): 154 | msg = log_msg.format( 155 | i, len(iterable), eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), data=str(data_time), 158 | memory=torch.cuda.max_memory_allocated() / MB) 159 | print(msg) 160 | if self.fn: 161 | with open(self.fn, mode="a", encoding="utf-8") as f: 162 | f.write(msg + '\n') 163 | 164 | else: 165 | if is_main_process(): 166 | msg = log_msg.format( 167 | i, len(iterable), eta=eta_string, 168 | meters=str(self), 169 | time=str(iter_time), data=str(data_time)) 170 | print(msg) 171 | if self.fn: 172 | with open(self.fn, mode='a', encoding='utf-8') as f: 173 | f.write(msg + '\n') 174 | i += 1 175 | end = time.time() 176 | total_time = time.time() - start_time 177 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 178 | print('{} Total time: {} ({:.4f} s / it)'.format( 179 | header, total_time_str, total_time / len(iterable))) 180 | 181 | 182 | def setup_for_distributed(is_master): 183 | """ 184 | This function disables printing when not in master process 185 | """ 186 | builtin_print = builtins.print 187 | 188 | def print(*args, **kwargs): 189 | force = kwargs.pop('force', False) 190 | force = force or (get_world_size() > 8) 191 | if is_master or force: 192 | now = datetime.datetime.now().time() 193 | builtin_print('[{}] '.format(now), end='') # print with time stamp 194 | builtin_print(*args, **kwargs) 195 | 196 | builtins.print = print 197 | 198 | 199 | def is_dist_avail_and_initialized(): 200 | if not dist.is_available(): 201 | return False 202 | if not dist.is_initialized(): 203 | return False 204 | return True 205 | 206 | 207 | def get_world_size(): 208 | if not is_dist_avail_and_initialized(): 209 | return 1 210 | return dist.get_world_size() 211 | 212 | 213 | def get_rank(): 214 | if not is_dist_avail_and_initialized(): 215 | return 0 216 | return dist.get_rank() 217 | 218 | 219 | def is_main_process(): 220 | return get_rank() == 0 221 | 222 | 223 | def save_on_master(*args, **kwargs): 224 | if is_main_process(): 225 | torch.save(*args, **kwargs) 226 | 227 | 228 | def init_distributed_mode(args): 229 | if args.dist_on_itp: 230 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 231 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 232 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 233 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 234 | os.environ['LOCAL_RANK'] = str(args.gpu) 235 | os.environ['RANK'] = str(args.rank) 236 | os.environ['WORLD_SIZE'] = str(args.world_size) 237 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 238 | # get the distributed information 239 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 240 | args.rank = int(os.environ["RANK"]) 241 | args.world_size = int(os.environ['WORLD_SIZE']) 242 | args.gpu = int(os.environ['LOCAL_RANK']) # ngpus_per_node----localrank 243 | elif 'SLURM_PROCID' in os.environ: 244 | args.rank = int(os.environ['SLURM_PROCID']) 245 | args.gpu = args.rank % torch.cuda.device_count() 246 | else: 247 | print('Not using distributed mode') 248 | setup_for_distributed(is_master=True) # hack 249 | args.distributed = False 250 | return 251 | 252 | args.distributed = True 253 | 254 | # 表明当前进程使用的GPU的号,如果不指定,就是在device0上进行程序的执行操作 255 | torch.cuda.set_device(args.gpu) 256 | args.dist_backend = 'nccl' 257 | print('| distributed init (rank {}): {}, gpu {}'.format( 258 | args.rank, args.dist_url, args.gpu), flush=True) 259 | # world_size就是机器的个数,nnodes 260 | # init_method:交换数据的主节点 261 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 262 | world_size=args.world_size, rank=args.rank) 263 | torch.distributed.barrier() 264 | setup_for_distributed(args.rank == 0) 265 | 266 | 267 | class NativeScalerWithGradNormCount: 268 | state_dict_key = "amp_scaler" 269 | 270 | def __init__(self): 271 | self._scaler = torch.cuda.amp.GradScaler() 272 | 273 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 274 | self._scaler.scale(loss).backward(create_graph=create_graph) 275 | if update_grad: 276 | if clip_grad is not None: 277 | assert parameters is not None 278 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 279 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 280 | else: 281 | self._scaler.unscale_(optimizer) 282 | norm = get_grad_norm_(parameters) 283 | self._scaler.step(optimizer) 284 | self._scaler.update() 285 | else: 286 | norm = None 287 | return norm 288 | 289 | def state_dict(self): 290 | return self._scaler.state_dict() 291 | 292 | def load_state_dict(self, state_dict): 293 | self._scaler.load_state_dict(state_dict) 294 | 295 | 296 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 297 | if isinstance(parameters, torch.Tensor): 298 | parameters = [parameters] 299 | parameters = [p for p in parameters if p.grad is not None] 300 | norm_type = float(norm_type) 301 | if len(parameters) == 0: 302 | return torch.tensor(0.) 303 | device = parameters[0].grad.device 304 | if norm_type == inf: 305 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 306 | else: 307 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 308 | return total_norm 309 | 310 | 311 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, last=False): 312 | output_dir = Path(args.output_dir) 313 | 314 | if last: 315 | epoch_name = "last" 316 | else: 317 | epoch_name = str(epoch + 1) 318 | if loss_scaler is not None: 319 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 320 | for checkpoint_path in checkpoint_paths: 321 | to_save = { 322 | 'model': model_without_ddp.state_dict(), 323 | 'optimizer': optimizer.state_dict(), 324 | 'epoch': epoch, 325 | 'scaler': loss_scaler.state_dict(), 326 | 'args': args, 327 | } 328 | 329 | save_on_master(to_save, checkpoint_path) 330 | else: 331 | client_state = {'epoch': epoch} 332 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 333 | 334 | 335 | def save_teacher_model(args, model, model_without_ddp, epoch=None, last=None): 336 | output_dir = Path(args.output_dir) 337 | if last is None: 338 | epoch_name = str(epoch + 1) 339 | checkpoint_path = output_dir / ('checkpoint-teacher-%s.pth' % epoch_name) 340 | elif epoch is None: 341 | checkpoint_path = output_dir / ('checkpoint-teacher-last.pth') 342 | to_save = {'model':model_without_ddp.state_dict()} 343 | save_on_master(to_save, checkpoint_path) 344 | 345 | 346 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 347 | if args.resume: 348 | if args.resume.startswith('https'): 349 | checkpoint = torch.hub.load_state_dict_from_url( 350 | args.resume, map_location='cpu', check_hash=True) 351 | else: 352 | checkpoint = torch.load(args.resume, map_location='cpu') 353 | msg = model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 354 | print("Resume checkpoint %s" % args.resume) 355 | print("="*20 + ">") 356 | print("load model from resume") 357 | print(msg) 358 | if 'checkpoint-last.pth' in args.resume and 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 359 | optimizer.load_state_dict(checkpoint['optimizer']) 360 | args.start_epoch = checkpoint['epoch'] + 1 361 | if 'scaler' in checkpoint: 362 | loss_scaler.load_state_dict(checkpoint['scaler']) 363 | 364 | print("With optim & sched!") 365 | 366 | 367 | def load_start_epoch(args): 368 | if args.resume: 369 | if args.resume.startswith('https'): 370 | checkpoint = torch.hub.load_state_dict_from_url( 371 | args.resume, map_location='cpu', check_hash=True) 372 | else: 373 | checkpoint = torch.load(args.resume, map_location='cpu') 374 | 375 | if 'checkpoint-last.pth' in args.resume and 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 376 | args.start_epoch = checkpoint['epoch'] + 1 377 | print("find init epoch") 378 | 379 | def find_stage_index(epoch, args): 380 | for index, num in enumerate(args.stage_epochs): 381 | if epoch < num: 382 | return index-1 383 | 384 | 385 | def load_teacher_model(args, teacher_without_ddp): 386 | if args.teacher_resume: 387 | if args.teacher_resume.startswith('https'): 388 | checkpoint = torch.hub.load_state_dict_from_url( 389 | args.teacher_resume, map_location='cpu', check_hash=True) 390 | else: 391 | checkpoint = torch.load(args.teacher_resume, map_location='cpu') 392 | 393 | if "model" in checkpoint: 394 | checkpoint = checkpoint['model'] 395 | 396 | 397 | from util.pos_embed import interpolate_pos_embed 398 | if 'pos_embed' in checkpoint and checkpoint['pos_embed'] is not None and teacher_without_ddp.pos_embed is not None: 399 | interpolate_pos_embed(teacher_without_ddp, checkpoint) 400 | 401 | msg = teacher_without_ddp.load_state_dict(checkpoint, strict=False) 402 | print("="*20 + ">") 403 | print("load teacher from teacher resume") 404 | print(msg) 405 | print("Resume teacher checkpoint %s" % args.teacher_resume) 406 | 407 | 408 | 409 | def all_reduce_mean(x): 410 | world_size = get_world_size() 411 | if world_size > 1: 412 | x_reduce = torch.tensor(x).cuda() 413 | dist.all_reduce(x_reduce) 414 | x_reduce /= world_size 415 | return x_reduce.item() 416 | else: 417 | return x -------------------------------------------------------------------------------- /models/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 204 | super().__init__() 205 | self.input_resolution = input_resolution 206 | self.output_dim = output_dim 207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 208 | 209 | scale = width ** -0.5 210 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 212 | self.ln_pre = LayerNorm(width) 213 | 214 | self.transformer = Transformer(width, layers, heads) 215 | 216 | self.ln_post = LayerNorm(width) 217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 218 | 219 | def forward(self, x: torch.Tensor): 220 | x = self.conv1(x) # shape = [*, width, grid, grid] 221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 223 | x = torch.cat( 224 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 225 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 226 | x = x + self.positional_embedding.to(x.dtype) 227 | x = self.ln_pre(x) 228 | 229 | x = x.permute(1, 0, 2) # NLD -> LND 230 | x = self.transformer(x) 231 | x = x.permute(1, 0, 2) # LND -> NLD 232 | 233 | x = self.ln_post(x[:, 0, :]) 234 | 235 | if self.proj is not None: 236 | x = x @ self.proj 237 | 238 | return x 239 | 240 | 241 | class CLIP(nn.Module): 242 | def __init__(self, 243 | embed_dim: int, 244 | # vision 245 | image_resolution: int, 246 | vision_layers: Union[Tuple[int, int, int, int], int], 247 | vision_width: int, 248 | vision_patch_size: int, 249 | # text 250 | context_length: int, 251 | vocab_size: int, 252 | transformer_width: int, 253 | transformer_heads: int, 254 | transformer_layers: int 255 | ): 256 | super().__init__() 257 | 258 | self.context_length = context_length 259 | 260 | if isinstance(vision_layers, (tuple, list)): 261 | vision_heads = vision_width * 32 // 64 262 | self.visual = ModifiedResNet( 263 | layers=vision_layers, 264 | output_dim=embed_dim, 265 | heads=vision_heads, 266 | input_resolution=image_resolution, 267 | width=vision_width 268 | ) 269 | else: 270 | vision_heads = vision_width // 64 271 | self.visual = VisionTransformer( 272 | input_resolution=image_resolution, 273 | patch_size=vision_patch_size, 274 | width=vision_width, 275 | layers=vision_layers, 276 | heads=vision_heads, 277 | output_dim=embed_dim 278 | ) 279 | 280 | self.transformer = Transformer( 281 | width=transformer_width, 282 | layers=transformer_layers, 283 | heads=transformer_heads, 284 | attn_mask=self.build_attention_mask() 285 | ) 286 | 287 | self.vocab_size = vocab_size 288 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 289 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 290 | self.ln_final = LayerNorm(transformer_width) 291 | 292 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 293 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 294 | 295 | self.initialize_parameters() 296 | 297 | def initialize_parameters(self): 298 | nn.init.normal_(self.token_embedding.weight, std=0.02) 299 | nn.init.normal_(self.positional_embedding, std=0.01) 300 | 301 | if isinstance(self.visual, ModifiedResNet): 302 | if self.visual.attnpool is not None: 303 | std = self.visual.attnpool.c_proj.in_features ** -0.5 304 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 305 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 306 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 307 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 308 | 309 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 310 | for name, param in resnet_block.named_parameters(): 311 | if name.endswith("bn3.weight"): 312 | nn.init.zeros_(param) 313 | 314 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 315 | attn_std = self.transformer.width ** -0.5 316 | fc_std = (2 * self.transformer.width) ** -0.5 317 | for block in self.transformer.resblocks: 318 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 319 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 320 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 321 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 322 | 323 | if self.text_projection is not None: 324 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 325 | 326 | def build_attention_mask(self): 327 | # lazily create causal attention mask, with full attention between the vision tokens 328 | # pytorch uses additive attention mask; fill with -inf 329 | mask = torch.empty(self.context_length, self.context_length) 330 | mask.fill_(float("-inf")) 331 | mask.triu_(1) # zero out the lower diagonal 332 | return mask 333 | 334 | @property 335 | def dtype(self): 336 | return self.visual.conv1.weight.dtype 337 | 338 | def encode_image(self, image): 339 | return self.visual(image.type(self.dtype)) 340 | 341 | def encode_text(self, text): 342 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 343 | 344 | x = x + self.positional_embedding.type(self.dtype) 345 | x = x.permute(1, 0, 2) # NLD -> LND 346 | x = self.transformer(x) 347 | x = x.permute(1, 0, 2) # LND -> NLD 348 | x = self.ln_final(x).type(self.dtype) 349 | 350 | # x.shape = [batch_size, n_ctx, transformer.width] 351 | # take features from the eot embedding (eot_token is the highest number in each sequence) 352 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 353 | 354 | return x 355 | 356 | def forward(self, image, text): 357 | image_features = self.encode_image(image) 358 | print('clip_image_features:') 359 | print(image_features.shape) 360 | text_features = self.encode_text(text) 361 | 362 | # normalized features,归一化特征 363 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 364 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 365 | 366 | # cosine similarity as logits,余弦相似度作为 logits 367 | logit_scale = self.logit_scale.exp() 368 | logits_per_image = logit_scale * image_features @ text_features.t() 369 | logits_per_text = logit_scale * text_features @ image_features.t() 370 | 371 | # shape = [global_batch_size, global_batch_size] 372 | return logits_per_image, logits_per_text 373 | 374 | 375 | def convert_weights(model: nn.Module): 376 | """Convert applicable model parameters to fp16""" 377 | 378 | def _convert_weights_to_fp16(l): 379 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 380 | l.weight.data = l.weight.data.half() 381 | if l.bias is not None: 382 | l.bias.data = l.bias.data.half() 383 | 384 | if isinstance(l, nn.MultiheadAttention): 385 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 386 | tensor = getattr(l, attr) 387 | if tensor is not None: 388 | tensor.data = tensor.data.half() 389 | 390 | for name in ["text_projection", "proj"]: 391 | if hasattr(l, name): 392 | attr = getattr(l, name) 393 | if attr is not None: 394 | attr.data = attr.data.half() 395 | 396 | model.apply(_convert_weights_to_fp16) 397 | 398 | 399 | def build_model(state_dict: dict): 400 | vit = "visual.proj" in state_dict 401 | 402 | if vit: 403 | vision_width = state_dict["visual.conv1.weight"].shape[0] 404 | vision_layers = len( 405 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 406 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 407 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 408 | image_resolution = vision_patch_size * grid_size 409 | else: 410 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in 411 | [1, 2, 3, 4]] 412 | vision_layers = tuple(counts) 413 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 414 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 415 | vision_patch_size = None 416 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 417 | image_resolution = output_width * 32 418 | 419 | embed_dim = state_dict["text_projection"].shape[1] 420 | context_length = state_dict["positional_embedding"].shape[0] 421 | vocab_size = state_dict["token_embedding.weight"].shape[0] 422 | transformer_width = state_dict["ln_final.weight"].shape[0] 423 | transformer_heads = transformer_width // 64 424 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 425 | 426 | model = CLIP( 427 | embed_dim, 428 | image_resolution, vision_layers, vision_width, vision_patch_size, 429 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 430 | ) 431 | 432 | for key in ["input_resolution", "context_length", "vocab_size"]: 433 | if key in state_dict: 434 | del state_dict[key] 435 | 436 | convert_weights(model) 437 | model.load_state_dict(state_dict) 438 | return model.eval() 439 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Mostly copy-paste from torchvision references or other public repos like DETR: 9 | https://github.com/facebookresearch/detr/blob/master/util/misc.py 10 | """ 11 | 12 | import os 13 | import sys 14 | import time 15 | import math 16 | import json 17 | import random 18 | import datetime 19 | import subprocess 20 | import numpy as np 21 | import torch 22 | import torch.distributed as dist 23 | 24 | from collections import defaultdict, deque 25 | from pathlib import Path 26 | from torch import nn 27 | from PIL import ImageFilter, ImageOps, Image, ImageDraw 28 | 29 | 30 | def save_model(args, epoch, model,model_without_ddp, optimizer, loss_scaler): 31 | output_dir = Path(args.output_dir) 32 | epoch_name = str(epoch) 33 | if loss_scaler is not None: 34 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 35 | for checkpoint_path in checkpoint_paths: 36 | to_save = { 37 | 'model': model_without_ddp.state_dict(), 38 | 'optimizer': optimizer.state_dict(), 39 | 'epoch': epoch, 40 | 'scaler': loss_scaler.state_dict(), 41 | 'args': args, 42 | } 43 | 44 | save_on_master(to_save, checkpoint_path) 45 | else: 46 | client_state = {'epoch': epoch} 47 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 48 | 49 | 50 | def clip_gradients(model, clip): 51 | norms = [] 52 | for name, p in model.named_parameters(): 53 | if p.grad is not None: 54 | param_norm = p.grad.data.norm(2) 55 | norms.append(param_norm.item()) 56 | clip_coef = clip / (param_norm + 1e-6) 57 | if clip_coef < 1: 58 | p.grad.data.mul_(clip_coef) 59 | return norms 60 | 61 | 62 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer): 63 | if epoch >= freeze_last_layer: 64 | return 65 | for n, p in model.named_parameters(): 66 | if "last_layer" in n: 67 | p.grad = None 68 | 69 | 70 | def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): 71 | """ 72 | Re-start from checkpoint 73 | """ 74 | if not os.path.isfile(ckp_path): 75 | return 76 | print("Found checkpoint at {}".format(ckp_path)) 77 | 78 | # open checkpoint file 79 | checkpoint = torch.load(ckp_path, map_location="cpu") 80 | 81 | # key is what to look for in the checkpoint file 82 | # value is the object to load 83 | # example: {'state_dict': model} 84 | for key, value in kwargs.items(): 85 | if key in checkpoint and value is not None: 86 | try: 87 | msg = value.load_state_dict(checkpoint[key], strict=False) 88 | print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg)) 89 | except TypeError: 90 | try: 91 | msg = value.load_state_dict(checkpoint[key]) 92 | print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path)) 93 | except ValueError: 94 | print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path)) 95 | else: 96 | print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path)) 97 | 98 | # re load variable important for the run 99 | if run_variables is not None: 100 | for var_name in run_variables: 101 | if var_name in checkpoint: 102 | run_variables[var_name] = checkpoint[var_name] 103 | 104 | 105 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 106 | warmup_schedule = np.array([]) 107 | warmup_iters = warmup_epochs * niter_per_ep 108 | if warmup_epochs > 0: 109 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 110 | 111 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 112 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 113 | 114 | schedule = np.concatenate((warmup_schedule, schedule)) 115 | assert len(schedule) == epochs * niter_per_ep 116 | return schedule 117 | 118 | 119 | def bool_flag(s): 120 | """ 121 | Parse boolean arguments from the command line. 122 | """ 123 | FALSY_STRINGS = {"off", "false", "0"} 124 | TRUTHY_STRINGS = {"on", "true", "1"} 125 | if s.lower() in FALSY_STRINGS: 126 | return False 127 | elif s.lower() in TRUTHY_STRINGS: 128 | return True 129 | else: 130 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 131 | 132 | 133 | def fix_random_seeds(seed=31): 134 | """ 135 | Fix random seeds. 136 | """ 137 | random.seed(seed) 138 | os.environ['PYTHONHASHSEED'] = str(seed) 139 | torch.manual_seed(seed) 140 | torch.cuda.manual_seed_all(seed) 141 | np.random.seed(seed) 142 | return seed 143 | 144 | 145 | class SmoothedValue(object): 146 | """Track a series of values and provide access to smoothed values over a 147 | window or the global series average. 148 | """ 149 | 150 | def __init__(self, window_size=20, fmt=None): 151 | if fmt is None: 152 | fmt = "{median:.6f} ({global_avg:.6f})" 153 | self.deque = deque(maxlen=window_size) 154 | self.total = 0.0 155 | self.count = 0 156 | self.fmt = fmt 157 | 158 | def update(self, value, n=1): 159 | self.deque.append(value) 160 | self.count += n 161 | self.total += value * n 162 | 163 | def synchronize_between_processes(self): 164 | """ 165 | Warning: does not synchronize the deque! 166 | """ 167 | if not is_dist_avail_and_initialized(): 168 | return 169 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 170 | dist.barrier() 171 | dist.all_reduce(t) 172 | t = t.tolist() 173 | self.count = int(t[0]) 174 | self.total = t[1] 175 | 176 | @property 177 | def median(self): 178 | d = torch.tensor(list(self.deque)) 179 | return d.median().item() 180 | 181 | @property 182 | def avg(self): 183 | d = torch.tensor(list(self.deque), dtype=torch.float32) 184 | return d.mean().item() 185 | 186 | @property 187 | def global_avg(self): 188 | return self.total / self.count 189 | 190 | @property 191 | def max(self): 192 | return max(self.deque) 193 | 194 | @property 195 | def value(self): 196 | return self.deque[-1] 197 | 198 | def __str__(self): 199 | return self.fmt.format( 200 | median=self.median, 201 | avg=self.avg, 202 | global_avg=self.global_avg, 203 | max=self.max, 204 | value=self.value) 205 | 206 | 207 | class MetricLogger(object): 208 | def __init__(self, delimiter="\t"): 209 | self.meters = defaultdict(SmoothedValue) 210 | self.delimiter = delimiter 211 | 212 | def update(self, **kwargs): 213 | for k, v in kwargs.items(): 214 | if isinstance(v, torch.Tensor): 215 | v = v.item() 216 | assert isinstance(v, (float, int)) 217 | self.meters[k].update(v) 218 | 219 | def __getattr__(self, attr): 220 | if attr in self.meters: 221 | return self.meters[attr] 222 | if attr in self.__dict__: 223 | return self.__dict__[attr] 224 | raise AttributeError("'{}' object has no attribute '{}'".format( 225 | type(self).__name__, attr)) 226 | 227 | def __str__(self): 228 | loss_str = [] 229 | for name, meter in self.meters.items(): 230 | loss_str.append( 231 | "{}: {}".format(name, str(meter)) 232 | ) 233 | return self.delimiter.join(loss_str) 234 | 235 | def synchronize_between_processes(self): 236 | for meter in self.meters.values(): 237 | meter.synchronize_between_processes() 238 | 239 | def add_meter(self, name, meter): 240 | self.meters[name] = meter 241 | 242 | def log_every(self, iterable, print_freq, header=None): 243 | i = 0 244 | if not header: 245 | header = '' 246 | start_time = time.time() 247 | end = time.time() 248 | iter_time = SmoothedValue(fmt='{avg:.6f}') 249 | data_time = SmoothedValue(fmt='{avg:.6f}') 250 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 251 | if torch.cuda.is_available(): 252 | log_msg = self.delimiter.join([ 253 | header, 254 | '[{0' + space_fmt + '}/{1}]', 255 | 'eta: {eta}', 256 | '{meters}', 257 | 'time: {time}', 258 | 'data: {data}', 259 | 'max mem: {memory:.0f}' 260 | ]) 261 | else: 262 | log_msg = self.delimiter.join([ 263 | header, 264 | '[{0' + space_fmt + '}/{1}]', 265 | 'eta: {eta}', 266 | '{meters}', 267 | 'time: {time}', 268 | 'data: {data}' 269 | ]) 270 | MB = 1024.0 * 1024.0 271 | for obj in iterable: 272 | data_time.update(time.time() - end) 273 | yield obj 274 | iter_time.update(time.time() - end) 275 | if i % print_freq == 0 or i == len(iterable) - 1: 276 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 277 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 278 | if torch.cuda.is_available(): 279 | print(log_msg.format( 280 | i, len(iterable), eta=eta_string, 281 | meters=str(self), 282 | time=str(iter_time), data=str(data_time), 283 | memory=torch.cuda.max_memory_allocated() / MB)) 284 | else: 285 | print(log_msg.format( 286 | i, len(iterable), eta=eta_string, 287 | meters=str(self), 288 | time=str(iter_time), data=str(data_time))) 289 | i += 1 290 | end = time.time() 291 | total_time = time.time() - start_time 292 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 293 | print('{} Total time: {} ({:.6f} s / it)'.format( 294 | header, total_time_str, total_time / len(iterable))) 295 | 296 | 297 | def get_sha(): 298 | cwd = os.path.dirname(os.path.abspath(__file__)) 299 | 300 | def _run(command): 301 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 302 | sha = 'N/A' 303 | diff = "clean" 304 | branch = 'N/A' 305 | try: 306 | sha = _run(['git', 'rev-parse', 'HEAD']) 307 | subprocess.check_output(['git', 'diff'], cwd=cwd) 308 | diff = _run(['git', 'diff-index', 'HEAD']) 309 | diff = "has uncommited changes" if diff else "clean" 310 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 311 | except Exception: 312 | pass 313 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 314 | return message 315 | 316 | 317 | def is_dist_avail_and_initialized(): 318 | if not dist.is_available(): 319 | return False 320 | if not dist.is_initialized(): 321 | return False 322 | return True 323 | 324 | 325 | def get_world_size(): 326 | if not is_dist_avail_and_initialized(): 327 | return 1 328 | return dist.get_world_size() 329 | 330 | 331 | def get_rank(): 332 | if not is_dist_avail_and_initialized(): 333 | return 0 334 | return dist.get_rank() 335 | 336 | 337 | def is_main_process(): 338 | return get_rank() == 0 339 | 340 | 341 | def save_on_master(*args, **kwargs): 342 | if is_main_process(): 343 | torch.save(*args, **kwargs) 344 | 345 | 346 | def setup_for_distributed(is_master): 347 | """ 348 | This function disables printing when not in master process 349 | """ 350 | import builtins as __builtin__ 351 | builtin_print = __builtin__.print 352 | 353 | def print(*args, **kwargs): 354 | force = kwargs.pop('force', False) 355 | if is_master or force: 356 | builtin_print(*args, **kwargs) 357 | 358 | __builtin__.print = print 359 | 360 | 361 | def init_distributed_mode(args): 362 | # launched with torch.distributed.launch 363 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 364 | args.rank = int(os.environ["RANK"]) 365 | args.world_size = int(os.environ['WORLD_SIZE']) 366 | args.gpu = int(os.environ['LOCAL_RANK']) 367 | # launched with submitit on a slurm cluster 368 | elif 'SLURM_PROCID' in os.environ: 369 | args.rank = int(os.environ['SLURM_PROCID']) 370 | args.gpu = args.rank % torch.cuda.device_count() 371 | # launched naively with `python main_dino.py` 372 | # we manually add MASTER_ADDR and MASTER_PORT to env variables 373 | elif torch.cuda.is_available(): 374 | print('Will run the code on one GPU.') 375 | args.rank, args.gpu, args.world_size = 0, 0, 1 376 | os.environ['MASTER_ADDR'] = '127.0.0.1' 377 | os.environ['MASTER_PORT'] = '29506' 378 | else: 379 | print('Does not support training without GPU.') 380 | sys.exit(1) 381 | 382 | dist.init_process_group( 383 | backend="nccl", 384 | init_method=args.dist_url, 385 | world_size=args.world_size, 386 | rank=args.rank, 387 | ) 388 | 389 | torch.cuda.set_device(args.gpu) 390 | print('| distributed init (rank {}): {}'.format( 391 | args.rank, args.dist_url), flush=True) 392 | dist.barrier() 393 | setup_for_distributed(args.rank == 0) 394 | 395 | 396 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 397 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 398 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 399 | def norm_cdf(x): 400 | # Computes standard normal cumulative distribution function 401 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 402 | 403 | if (mean < a - 2 * std) or (mean > b + 2 * std): 404 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 405 | "The distribution of values may be incorrect.", 406 | stacklevel=2) 407 | 408 | with torch.no_grad(): 409 | # Values are generated by using a truncated uniform distribution and 410 | # then using the inverse CDF for the normal distribution. 411 | # Get upper and lower cdf values 412 | l = norm_cdf((a - mean) / std) 413 | u = norm_cdf((b - mean) / std) 414 | 415 | # Uniformly fill tensor with values from [l, u], then translate to 416 | # [2l-1, 2u-1]. 417 | tensor.uniform_(2 * l - 1, 2 * u - 1) 418 | 419 | # Use inverse cdf transform for normal distribution to get truncated 420 | # standard normal 421 | tensor.erfinv_() 422 | 423 | # Transform to proper mean, std 424 | tensor.mul_(std * math.sqrt(2.)) 425 | tensor.add_(mean) 426 | 427 | # Clamp to ensure it's in the proper range 428 | tensor.clamp_(min=a, max=b) 429 | return tensor 430 | 431 | 432 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 433 | # type: (Tensor, float, float, float, float) -> Tensor 434 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 435 | 436 | 437 | class LARS(torch.optim.Optimizer): 438 | """ 439 | Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py 440 | """ 441 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001, 442 | weight_decay_filter=None, lars_adaptation_filter=None): 443 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, 444 | eta=eta, weight_decay_filter=weight_decay_filter, 445 | lars_adaptation_filter=lars_adaptation_filter) 446 | super().__init__(params, defaults) 447 | 448 | @torch.no_grad() 449 | def step(self): 450 | for g in self.param_groups: 451 | for p in g['params']: 452 | dp = p.grad 453 | 454 | if dp is None: 455 | continue 456 | 457 | if p.ndim != 1: 458 | dp = dp.add(p, alpha=g['weight_decay']) 459 | 460 | if p.ndim != 1: 461 | param_norm = torch.norm(p) 462 | update_norm = torch.norm(dp) 463 | one = torch.ones_like(param_norm) 464 | q = torch.where(param_norm > 0., 465 | torch.where(update_norm > 0, 466 | (g['eta'] * param_norm / update_norm), one), one) 467 | dp = dp.mul(q) 468 | 469 | param_state = self.state[p] 470 | if 'mu' not in param_state: 471 | param_state['mu'] = torch.zeros_like(p) 472 | mu = param_state['mu'] 473 | mu.mul_(g['momentum']).add_(dp) 474 | 475 | p.add_(mu, alpha=-g['lr']) 476 | 477 | def create_ds_config(args): 478 | args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json") 479 | with open(args.deepspeed_config, mode="w") as writer: 480 | ds_config = { 481 | "train_batch_size": args.batch_size * get_world_size(), 482 | "train_micro_batch_size_per_gpu": args.batch_size, 483 | "steps_per_print": 1000, 484 | "optimizer": { 485 | "type": "Adam", 486 | "adam_w_mode": True, 487 | "params": { 488 | "lr": args.lr, 489 | "weight_decay": args.weight_decay, 490 | "bias_correction": True, 491 | "betas": [ 492 | 0.9, 493 | 0.999 494 | ], 495 | "eps": 1e-8 496 | } 497 | }, 498 | "fp16": { 499 | "enabled": True, 500 | "loss_scale": 0, 501 | "initial_scale_power": 7, 502 | "loss_scale_window": 128 503 | } 504 | } 505 | 506 | writer.write(json.dumps(ds_config, indent=2)) 507 | 508 | class MultiCropWrapper(nn.Module): 509 | """ 510 | Perform forward pass separately on each resolution input. 511 | The inputs corresponding to a single resolution are clubbed and single 512 | forward is run on the same resolution inputs. Hence we do several 513 | forward passes = number of different resolutions used. We then 514 | concatenate all the output features and run the head forward on these 515 | concatenated features. 516 | """ 517 | def __init__(self, backbone, head=None): 518 | super(MultiCropWrapper, self).__init__() 519 | # disable layers dedicated to ImageNet labels classification 520 | backbone.fc, backbone.head = nn.Identity(), nn.Identity() 521 | self.backbone = backbone 522 | if head is None: 523 | self.head = nn.Identity() 524 | else: 525 | self.head = head 526 | 527 | def forward(self, x, mask_ratio=None, return_backbone_feat=False): 528 | # convert to list 529 | if mask_ratio != 0: 530 | t = 0 531 | _out,loss,out_mask,ids_restore,clip_tezheng = self.backbone(x, mask_ratio) 532 | t = _out.size(0) 533 | 534 | output_new = self.head(_out) 535 | 536 | return output_new,loss,out_mask,ids_restore,t,clip_tezheng 537 | else: 538 | _out = self.backbone(x, mask_ratio) 539 | 540 | output_new = self.head(_out) 541 | 542 | return output_new 543 | 544 | def get_params_groups(model): 545 | regularized = [] 546 | not_regularized = [] 547 | for name, param in model.named_parameters(): 548 | if not param.requires_grad: 549 | continue 550 | # we do not regularize biases nor Norm parameters 551 | if name.endswith(".bias") or len(param.shape) == 1: 552 | not_regularized.append(param) 553 | else: 554 | regularized.append(param) 555 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] 556 | 557 | 558 | def has_batchnorms(model): 559 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 560 | for name, module in model.named_modules(): 561 | if isinstance(module, bn_types): 562 | return True 563 | return False 564 | 565 | 566 | def concat_all_gather(tensor): 567 | """ 568 | Performs all_gather operation on the provided tensors. 569 | *** Warning ***: torch.distributed.all_gather has no gradient. 570 | """ 571 | tensors_gather = [torch.ones_like(tensor) 572 | for _ in range(torch.distributed.get_world_size())] 573 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 574 | 575 | output = torch.cat(tensors_gather, dim=0) 576 | return output 577 | 578 | 579 | class PCA(): 580 | """ 581 | Class to compute and apply PCA. 582 | """ 583 | def __init__(self, dim=256, whit=0.5): 584 | self.dim = dim 585 | self.whit = whit 586 | self.mean = None 587 | 588 | def train_pca(self, cov): 589 | """ 590 | Takes a covariance matrix (np.ndarray) as input. 591 | """ 592 | d, v = np.linalg.eigh(cov) 593 | eps = d.max() * 1e-5 594 | n_0 = (d < eps).sum() 595 | if n_0 > 0: 596 | d[d < eps] = eps 597 | 598 | # total energy 599 | totenergy = d.sum() 600 | 601 | # sort eigenvectors with eigenvalues order 602 | idx = np.argsort(d)[::-1][:self.dim] 603 | d = d[idx] 604 | v = v[:, idx] 605 | 606 | print("keeping %.2f %% of the energy" % (d.sum() / totenergy * 100.0)) 607 | 608 | # for the whitening 609 | d = np.diag(1. / d**self.whit) 610 | 611 | # principal components 612 | self.dvt = np.dot(d, v.T) 613 | 614 | def apply(self, x): 615 | # input is from numpy 616 | if isinstance(x, np.ndarray): 617 | if self.mean is not None: 618 | x -= self.mean 619 | return np.dot(self.dvt, x.T).T 620 | 621 | # input is from torch and is on GPU 622 | if x.is_cuda: 623 | if self.mean is not None: 624 | x -= torch.cuda.FloatTensor(self.mean) 625 | return torch.mm(torch.cuda.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1) 626 | 627 | # input if from torch, on CPU 628 | if self.mean is not None: 629 | x -= torch.FloatTensor(self.mean) 630 | return torch.mm(torch.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1) 631 | 632 | 633 | def compute_ap(ranks, nres): 634 | """ 635 | Computes average precision for given ranked indexes. 636 | Arguments 637 | --------- 638 | ranks : zerro-based ranks of positive images 639 | nres : number of positive images 640 | Returns 641 | ------- 642 | ap : average precision 643 | """ 644 | 645 | # number of images ranked by the system 646 | nimgranks = len(ranks) 647 | 648 | # accumulate trapezoids in PR-plot 649 | ap = 0 650 | 651 | recall_step = 1. / nres 652 | 653 | for j in np.arange(nimgranks): 654 | rank = ranks[j] 655 | 656 | if rank == 0: 657 | precision_0 = 1. 658 | else: 659 | precision_0 = float(j) / rank 660 | 661 | precision_1 = float(j + 1) / (rank + 1) 662 | 663 | ap += (precision_0 + precision_1) * recall_step / 2. 664 | 665 | return ap 666 | 667 | 668 | def compute_map(ranks, gnd, kappas=[]): 669 | """ 670 | Computes the mAP for a given set of returned results. 671 | Usage: 672 | map = compute_map (ranks, gnd) 673 | computes mean average precsion (map) only 674 | map, aps, pr, prs = compute_map (ranks, gnd, kappas) 675 | computes mean average precision (map), average precision (aps) for each query 676 | computes mean precision at kappas (pr), precision at kappas (prs) for each query 677 | Notes: 678 | 1) ranks starts from 0, ranks.shape = db_size X #queries 679 | 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array 680 | 3) If there are no positive images for some query, that query is excluded from the evaluation 681 | """ 682 | 683 | map = 0. 684 | nq = len(gnd) # number of queries 685 | aps = np.zeros(nq) 686 | pr = np.zeros(len(kappas)) 687 | prs = np.zeros((nq, len(kappas))) 688 | nempty = 0 689 | 690 | for i in np.arange(nq): 691 | qgnd = np.array(gnd[i]['ok']) 692 | 693 | # no positive images, skip from the average 694 | if qgnd.shape[0] == 0: 695 | aps[i] = float('nan') 696 | prs[i, :] = float('nan') 697 | nempty += 1 698 | continue 699 | 700 | try: 701 | qgndj = np.array(gnd[i]['junk']) 702 | except: 703 | qgndj = np.empty(0) 704 | 705 | # sorted positions of positive and junk images (0 based) 706 | pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)] 707 | junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)] 708 | 709 | k = 0; 710 | ij = 0; 711 | if len(junk): 712 | # decrease positions of positives based on the number of 713 | # junk images appearing before them 714 | ip = 0 715 | while (ip < len(pos)): 716 | while (ij < len(junk) and pos[ip] > junk[ij]): 717 | k += 1 718 | ij += 1 719 | pos[ip] = pos[ip] - k 720 | ip += 1 721 | 722 | # compute ap 723 | ap = compute_ap(pos, len(qgnd)) 724 | map = map + ap 725 | aps[i] = ap 726 | 727 | # compute precision @ k 728 | pos += 1 # get it to 1-based 729 | for j in np.arange(len(kappas)): 730 | kq = min(max(pos), kappas[j]); 731 | prs[i, j] = (pos <= kq).sum() / kq 732 | pr = pr + prs[i, :] 733 | 734 | map = map / (nq - nempty) 735 | pr = pr / (nq - nempty) 736 | 737 | return map, aps, pr, prs 738 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import sys 10 | import datetime 11 | import time 12 | import math 13 | import json 14 | import numpy as np 15 | import utils 16 | import models 17 | import torch 18 | import torch.nn as nn 19 | import torch.distributed as dist 20 | import torch.backends.cudnn as cudnn 21 | import torch.nn.functional as F 22 | import OurDataset 23 | from pathlib import Path 24 | from PIL import Image 25 | from torchvision import datasets, transforms 26 | from torchvision.transforms import InterpolationMode 27 | from tensorboardX import SummaryWriter 28 | from models.head import Head 29 | from models.clipmodel import ClipBaseModel,build_clip 30 | 31 | 32 | def get_args_parser(): 33 | parser = argparse.ArgumentParser('VehicleMAE', add_help=False) 34 | 35 | # Model parameters 36 | parser.add_argument('--arch', default='mae_vit_base_patch16', type=str, 37 | choices=['mae_vit_base_patch16', 'mae_vit_large_patch16', 'mae_vit_huge_patch14'], 38 | help="""Name of architecture to train. For quick experiments with ViTs, 39 | we recommend using vit_tiny or vit_small.""") 40 | parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels 41 | of input square patches - default 16 (for 16x16 patches). Using smaller 42 | values leads to better performance but requires more memory. Applies only 43 | for ViTs (vit_tiny, vit_small and vit_base). If <16, we recommend disabling 44 | mixed precision training (--use_fp16 false) to avoid unstabilities.""") 45 | parser.add_argument('--out_dim', default=8192, type=int, help="""Dimensionality of 46 | output for [CLS] token.""") 47 | parser.add_argument('--patch_out_dim', default=8192, type=int, help="""Dimensionality of 48 | output for patch tokens.""") 49 | parser.add_argument('--shared_head', default=False, type=utils.bool_flag, help="""Wether to share 50 | the same head for [CLS] token output and patch tokens output. When set to false, patch_out_dim 51 | is ignored and enforced to be same with out_dim. (Default: False)""") 52 | parser.add_argument('--shared_head_teacher', default=True, type=utils.bool_flag, help="""See above. 53 | Only works for teacher model. (Defeault: True)""") 54 | parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag, 55 | help="""Whether or not to weight normalize the last layer of the head. 56 | Not normalizing leads to better performance but can make the training unstable. 57 | In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""") 58 | parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA 59 | parameter for teacher update. The value is increased to 1 during training with cosine schedule. 60 | We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""") 61 | parser.add_argument('--norm_in_head', default=None, 62 | help="Whether to use batch normalizations in projection head (Default: None)") 63 | parser.add_argument('--act_in_head', default='gelu', 64 | help="Whether to use batch normalizations in projection head (Default: gelu)") 65 | parser.add_argument('--use_masked_im_modeling', default=True, type=utils.bool_flag, 66 | help="Whether to use masked image modeling (mim) in backbone (Default: True)") 67 | parser.add_argument('--pred_ratio', default=0.3, type=float, nargs='+', help="""Ratio of partial prediction. 68 | If a list of ratio is specified, one of them will be randomly choosed for each patch.""") 69 | parser.add_argument('--pred_ratio_var', default=0, type=float, nargs='+', help="""Variance of partial prediction 70 | ratio. Length should be indentical to the length of pred_ratio. 0 for disabling. """) 71 | parser.add_argument('--pred_shape', default='block', type=str, help="""Shape of partial prediction.""") 72 | parser.add_argument('--pred_start_epoch', default=0, type=int, help="""Start epoch to perform masked 73 | image prediction. We typically set this to 50 for swin transformer. (Default: 0)""") 74 | parser.add_argument('--lambda1', default=1.0, type=float, help="""loss weight for dino 75 | loss over [CLS] tokens (Default: 1.0)""") 76 | parser.add_argument('--lambda2', default=1.0, type=float, help="""loss weight for beit 77 | loss over masked patch tokens (Default: 1.0)""") 78 | parser.add_argument('--norm_pix_loss', action='store_true', 79 | help='Use (per-patch) normalized pixels as targets for computing loss') 80 | parser.set_defaults(norm_pix_loss=False) 81 | parser.add_argument('--input_size', default=224, type=int, 82 | help='images input size') 83 | parser.add_argument('--pin_mem', action='store_true', 84 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 85 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 86 | parser.set_defaults(pin_mem=True) 87 | parser.add_argument('--mask_ratio', default=0.75, type=float, 88 | help='Masking ratio (percentage of removed patches).') 89 | parser.add_argument('--use_learnable_pos_emb', default=False, 90 | type=str, help='masked strategy of video tokens/patches False') 91 | 92 | 93 | # Temperature teacher parameters 94 | parser.add_argument('--warmup_teacher_temp', default=0.04, type=float, 95 | help="""Initial value for the teacher temperature: 0.04 works well in most cases. 96 | Try decreasing it if the training loss does not decrease.""") 97 | parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup) 98 | of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend 99 | starting with the default value of 0.04 and increase this slightly if needed.""") 100 | parser.add_argument('--warmup_teacher_patch_temp', default=0.04, type=float, help="""See 101 | `--warmup_teacher_temp`""") 102 | parser.add_argument('--teacher_patch_temp', default=0.07, type=float, help=""""See 103 | `--teacher_temp`""") 104 | parser.add_argument('--warmup_teacher_temp_epochs', default=30, type=int, 105 | help='Number of warmup epochs for the teacher temperature (Default: 30).') 106 | parser.add_argument('--tlayernorm', type=int, default=0, choices=[0, 1], 107 | help="0: without teache rlayernorm \ 108 | 1:with vit original self.norm") 109 | 110 | # Training/Optimization parameters 111 | parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not 112 | to use half precision for training. Improves training time and memory requirements, 113 | but can provoke instability and slight decay of performance. We recommend disabling 114 | mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""") #True 115 | parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the 116 | weight decay. With ViT, a smaller value at the beginning of training works well.""") 117 | parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the 118 | weight decay. We use a cosine schedule for WD and using a larger decay by 119 | the end of training improves performance for ViTs.""") 120 | parser.add_argument('--clip_grad', type=float, default=3.0, help="""Maximal parameter 121 | gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can 122 | help optimization for larger ViT architectures. 0 for disabling.""") 123 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, 124 | help='Per-GPU batch-size : number of distinct images loaded on one GPU.') 125 | parser.add_argument('--epochs', default=300, type=int, help='Number of epochs of training.') 126 | parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs 127 | during which we keep the output layer fixed. Typically doing so during 128 | the first epoch helps training. Try increasing this value if the loss does not decrease.""") 129 | parser.add_argument("--lr", default=0.00025, type=float, help="""Learning rate at the end of #0.0005 130 | linear warmup (highest LR used during training). The learning rate is linearly scaled 131 | with the batch size, and specified here for a reference batch size of 256.""") 132 | parser.add_argument("--warmup_epochs", default=10, type=int, 133 | help="Number of epochs for the linear learning-rate warm up.") 134 | parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the #1e-6 135 | end of optimization. We use a cosine LR schedule with linear warmup.""") 136 | parser.add_argument('--optimizer', default='adamw', type=str, 137 | choices=['adamw', 'sgd', 'lars'], help="""Type of optimizer. We recommend using adamw with ViTs.""") 138 | parser.add_argument('--load_from', default=None, help="""Path to load checkpoints to resume training.""") #训练中断后加载之前训好的模型 139 | parser.add_argument('--drop_path', type=float, default=0.1, help="""Drop path rate for student network.""") 140 | 141 | #CLIP 142 | parser.add_argument('--clip_backbone', default='ViT-B-16', choices=['RN50', 'RN50x16', 'RN101', 'ViT-B-32', 'ViT-B-16']) 143 | 144 | 145 | # Multi-crop parameters 146 | parser.add_argument('--global_crops_number', type=int, default=2, help="""Number of global 147 | views to generate. Default is to use two global crops. """) 148 | parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.14, 1.), 149 | help="""Scale range of the cropped image before resizing, relatively to the origin image. 150 | Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we 151 | recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""") 152 | parser.add_argument('--local_crops_number', type=int, default=0, help="""Number of small 153 | local views to generate. Set this parameter to 0 to disable multi-crop training. 154 | When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """) 155 | parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4), 156 | help="""Scale range of the cropped image before resizing, relatively to the origin image. 157 | Used for small local view cropping of multi-crop.""") 158 | parser.add_argument('--device', default='cuda', 159 | help='device to use for training / testing') 160 | 161 | # Learnable masking parameters 162 | parser.add_argument('--softmax_temp', type=float, default=1e-2, metavar='Learnable_Mask', 163 | help='Softmax temp used to compute probability values for each patch') 164 | 165 | # Misc 166 | parser.add_argument('--pkl_path', default='/home/lcl_d/wuwentao/data/ourdata.pkl', type=str,help='Please specify path to the ImageNet training data.') 167 | parser.add_argument('--output_dir', default="/home/lcl_d/wuwentao/VehicleMAE/output", type=str, help='Path to save logs and checkpoints.') 168 | parser.add_argument('--clip_pre_model', default="/home/lcl_d/wuwentao/maeclip/clip_pre_model/ViT-B-16.pt", type=str, help='clip pretrain model.') 169 | parser.add_argument('--saveckp_freq', default=50, type=int, help='Save checkpoint every x epochs.') 170 | parser.add_argument('--seed', default=0, type=int, help='Random seed.') 171 | parser.add_argument('--num_workers', default=8, type=int, help='Number of data loading workers per GPU.') 172 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training; see https://pytorch.org/docs/stable/distributed.html""") 173 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 174 | return parser 175 | 176 | def train_vehicle(args): 177 | utils.init_distributed_mode(args) 178 | seed = utils.fix_random_seeds(args.seed) 179 | print("git:\n {}\n".format(utils.get_sha())) 180 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 181 | cudnn.benchmark = True 182 | 183 | device = torch.device(args.device) 184 | 185 | # ============ preparing data ... ============ 186 | transform_train = transforms.Compose([ 187 | #随机截取一部分,然后Resize成224*224 188 | transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC), # 3 is bicubic 189 | #transforms.Resize([224,224]), 190 | #随机翻转 191 | transforms.RandomHorizontalFlip(), 192 | #将图像变为0~1的浮点数 193 | transforms.ToTensor(), 194 | #进行特定的均值,方差归一化(in) 195 | transforms.Normalize(mean=[0.446, 0.452, 0.466], std=[0.277, 0.278, 0.276])]) #[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 196 | 197 | dataset_train = OurDataset.OurDataset(args.pkl_path, transform=transform_train) 198 | 199 | sampler_train = torch.utils.data.DistributedSampler(dataset_train, shuffle=True) 200 | 201 | data_loader_train = torch.utils.data.DataLoader( 202 | dataset_train, sampler=sampler_train, 203 | batch_size=args.batch_size_per_gpu, 204 | num_workers=args.num_workers, 205 | pin_memory=args.pin_mem, 206 | drop_last=True, 207 | )#从这拿到的都是minibatch 208 | 209 | print(f"Data loaded: there are {len(dataset_train)} images.") 210 | 211 | # ============ building student,teacher_clip and teacher networks ... ============ 212 | # we changed the name DeiT-S for ViT-S to avoid confusions 213 | 214 | clip_model, preprocess, text_features = build_clip(args) 215 | 216 | student = models.__dict__[args.arch](decoder_depth = 8,norm_pix_loss=args.norm_pix_loss) 217 | 218 | teacher = models.__dict__[args.arch](decoder_depth = 0,mask_ratio = 0) 219 | 220 | embed_dim = 512 221 | 222 | # multi-crop wrapper handles forward with inputs of different resolutions 223 | student = utils.MultiCropWrapper(student, Head( 224 | embed_dim, 225 | args.out_dim, 226 | patch_out_dim=args.patch_out_dim, 227 | norm=args.norm_in_head, 228 | act=args.act_in_head, 229 | norm_last_layer=args.norm_last_layer, 230 | shared_head=args.shared_head, 231 | use_learnable_pos_emb = args.use_learnable_pos_emb 232 | )) 233 | teacher = utils.MultiCropWrapper(teacher,Head( 234 | embed_dim, 235 | args.out_dim, 236 | patch_out_dim=args.patch_out_dim, 237 | norm=args.norm_in_head, 238 | act=args.act_in_head, 239 | shared_head=args.shared_head_teacher, 240 | ),) 241 | teacher_clip = ClipBaseModel( 242 | clip_model, 243 | text_features, 244 | ) 245 | 246 | student.to(device) 247 | teacher.to(device) 248 | teacher_clip.to(device) 249 | 250 | student_without_ddp = student 251 | 252 | if utils.has_batchnorms(student): 253 | student = nn.SyncBatchNorm.convert_sync_batchnorm(student) 254 | teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher) 255 | teacher_clip = nn.SyncBatchNorm.convert_sync_batchnorm(teacher_clip) 256 | 257 | # we need DDP wrapper to have synchro batch norms working...我们需要 DDP 包装器才能使同步批处理规范正常工作... 258 | teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu], broadcast_buffers=False,find_unused_parameters=False) 259 | teacher_without_ddp = teacher.module 260 | student_without_ddp = student.module 261 | teacher_clip = nn.parallel.DistributedDataParallel(teacher_clip, device_ids=[args.gpu], broadcast_buffers=False,find_unused_parameters=False) 262 | teacher_clip_without_ddp = teacher_clip.module 263 | else: 264 | # teacher_without_ddp and teacher are the same thing 265 | teacher_without_ddp = teacher 266 | teacher_clip_without_ddp = teacher_clip 267 | student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu], broadcast_buffers=False,find_unused_parameters=False) 268 | #教师和学生从相同权重开始 269 | teacher_without_ddp.load_state_dict(student.module.state_dict(), strict=False) 270 | #没有通过老师的反向传播,所以不需要梯度 271 | for p in teacher.parameters(): 272 | p.requires_grad = False 273 | for p in teacher_clip.parameters(): 274 | p.requires_grad = False 275 | print(f"Student and Teacher are built: they are both {args.arch} network.") 276 | 277 | # ============ preparing loss ... ============ 278 | same_dim = args.shared_head or args.shared_head_teacher 279 | VehicleMAE_loss = VehicleMAELoss( 280 | args.out_dim, 281 | args.out_dim if same_dim else args.patch_out_dim, 282 | args.global_crops_number, 283 | args.local_crops_number, 284 | args.warmup_teacher_temp, 285 | args.teacher_temp, 286 | args.warmup_teacher_patch_temp, 287 | args.teacher_patch_temp, 288 | args.warmup_teacher_temp_epochs, 289 | args.epochs, 290 | lambda1=args.lambda1, 291 | lambda2=args.lambda2, 292 | mim_start_epoch=args.pred_start_epoch, 293 | ).cuda() 294 | 295 | if utils.is_main_process(): # Tensorboard configuration 296 | local_runs = os.path.join(args.output_dir, 'tf_logs') 297 | writer = SummaryWriter(logdir=local_runs) 298 | 299 | # ============ preparing optimizer ... ============ 300 | params_groups = utils.get_params_groups(student) 301 | if args.optimizer == "adamw": 302 | optimizer = torch.optim.AdamW(params_groups) # to use with ViTs 303 | elif args.optimizer == "sgd": 304 | optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9) # lr is set by scheduler 305 | elif args.optimizer == "lars": 306 | optimizer = utils.LARS(params_groups) # to use with convnet and large batches 307 | # for mixed precision training 308 | fp16_scaler = None 309 | if args.use_fp16: 310 | fp16_scaler = torch.cuda.amp.GradScaler() 311 | 312 | # ============ init schedulers ... ============ 313 | lr_schedule = utils.cosine_scheduler( 314 | args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256, # linear scaling rule 315 | args.min_lr, 316 | args.epochs, len(data_loader_train), 317 | warmup_epochs=args.warmup_epochs, 318 | ) 319 | wd_schedule = utils.cosine_scheduler( 320 | args.weight_decay, 321 | args.weight_decay_end, 322 | args.epochs, len(data_loader_train), 323 | ) 324 | # momentum parameter is increased to 1. during training with a cosine schedule 325 | momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1, 326 | args.epochs, len(data_loader_train)) 327 | 328 | print(f"Loss, optimizer and schedulers ready.") 329 | 330 | # ============ optionally resume training ... ============ 331 | to_restore = {"epoch": 0} 332 | if args.load_from: 333 | utils.restart_from_checkpoint( 334 | os.path.join(args.output_dir, args.load_from), 335 | run_variables=to_restore, 336 | student=student, 337 | teacher=teacher, 338 | optimizer=optimizer, 339 | fp16_scaler=fp16_scaler, 340 | vehiclemae_loss=VehicleMAE_loss, 341 | ) 342 | start_epoch = to_restore["epoch"] 343 | 344 | start_time = time.time() 345 | print("Starting our training!") 346 | for epoch in range(start_epoch, args.epochs): 347 | data_loader_train.sampler.set_epoch(epoch) 348 | 349 | # ============ training one epoch of iBOT ... ============ 350 | 351 | train_stats = train_one_epoch(student, teacher,teacher_clip, teacher_without_ddp,teacher_clip_without_ddp, VehicleMAE_loss, 352 | data_loader_train, optimizer, device,lr_schedule, wd_schedule, momentum_schedule, 353 | epoch, fp16_scaler, args) 354 | 355 | 356 | # ============ writing logs ... ============ 357 | save_dict = { 358 | 'student': student.state_dict(), 359 | 'teacher': teacher.state_dict(), 360 | 'teacher_clip': teacher_clip.state_dict(), 361 | 'optimizer': optimizer.state_dict(), 362 | 'epoch': epoch + 1, 363 | 'args': args, 364 | 'VehicleMAE_loss': VehicleMAE_loss.state_dict(), 365 | } 366 | if fp16_scaler is not None: 367 | save_dict['fp16_scaler'] = fp16_scaler.state_dict() 368 | utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth')) 369 | if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs -1): 370 | utils.save_model( 371 | args=args, model=student, model_without_ddp=student_without_ddp, optimizer=optimizer, 372 | #args=args, model=student, optimizer=optimizer, 373 | loss_scaler=VehicleMAE_loss, epoch=epoch) 374 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 375 | 'epoch': epoch} 376 | 377 | 378 | if utils.is_main_process(): 379 | with (Path(args.output_dir) / "log.txt").open("a") as f: 380 | f.write(json.dumps(log_stats) + "\n") 381 | for k, v in train_stats.items(): 382 | writer.add_scalar(k, v, epoch) 383 | 384 | total_time = time.time() - start_time 385 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 386 | print('Training time {}'.format(total_time_str)) 387 | 388 | 389 | def train_one_epoch(student, teacher,teacher_clip, teacher_without_ddp,teacher_clip_without_ddp, VehicleMAE_loss, data_loader, 390 | optimizer, device: torch.device,lr_schedule, wd_schedule, momentum_schedule,epoch, 391 | fp16_scaler, args): 392 | 393 | metric_logger = utils.MetricLogger(delimiter=" ") 394 | header = 'Epoch: [{}]'.format(epoch) 395 | 396 | # common params 397 | names_q, params_q, names_k, params_k = [], [], [], [] 398 | for name_q, param_q in student.module.named_parameters(): 399 | names_q.append(name_q) 400 | params_q.append(param_q) 401 | for name_k, param_k in teacher_without_ddp.named_parameters(): 402 | names_k.append(name_k) 403 | params_k.append(param_k) 404 | names_common = list(set(names_q) & set(names_k)) 405 | params_q = [param_q for name_q, param_q in zip(names_q, params_q) if name_q in names_common] 406 | params_k = [param_k for name_k, param_k in zip(names_k, params_k) if name_k in names_common] 407 | 408 | for it, (rgb_samples, tir_samples) in enumerate(metric_logger.log_every(data_loader, 20, header)): 409 | # update weight decay and learning rate according to their schedule 410 | it = len(data_loader) * epoch + it # global training iteration 411 | for i, param_group in enumerate(optimizer.param_groups): 412 | param_group["lr"] = lr_schedule[it] 413 | if i == 0: # only the first group is regularized 414 | param_group["weight_decay"] = wd_schedule[it] 415 | 416 | # move images to gpu 417 | images = rgb_samples.to(device, non_blocking=True) 418 | images_lunkuo =tir_samples.to(device, non_blocking=True) 419 | 420 | with torch.cuda.amp.autocast(fp16_scaler is not None): 421 | student_output,student_loss,masks,_,_,tezheng = student(images, mask_ratio=args.mask_ratio) 422 | teacher_output = teacher(images_lunkuo,mask_ratio = 0) 423 | similarity_loss,kl_distance_loss = teacher_clip(images,tezheng) 424 | 425 | student_loss = ((student_loss * masks).sum() / masks.sum())*4 426 | 427 | masks = masks.type(torch.bool) 428 | all_loss = VehicleMAE_loss(student_output, teacher_output, masks, epoch,student_loss,similarity_loss,kl_distance_loss) 429 | loss = all_loss.pop('loss') 430 | 431 | #loss_value = loss.item() #取具体的数值 432 | 433 | 434 | if not math.isfinite(loss.item()): 435 | print("Loss is {}, stopping training".format(loss.item()), force=True) 436 | sys.exit(1) 437 | 438 | # log statistics 439 | probs1 = teacher_output[0].chunk(args.global_crops_number) 440 | probs2 = student_output[0].chunk(args.global_crops_number) 441 | pred1 = utils.concat_all_gather(probs1[0].max(dim=1)[1]) 442 | pred2 = utils.concat_all_gather(probs2[1].max(dim=1)[1]) 443 | acc = (pred1 == pred2).sum() / pred1.size(0) 444 | 445 | optimizer.zero_grad() 446 | 447 | param_norms = None 448 | if fp16_scaler is None: 449 | loss.backward() 450 | if args.clip_grad: 451 | param_norms = utils.clip_gradients(student, args.clip_grad) 452 | utils.cancel_gradients_last_layer(epoch, student, 453 | args.freeze_last_layer) 454 | optimizer.step() 455 | else: 456 | fp16_scaler.scale(loss).backward() 457 | if args.clip_grad: 458 | 459 | fp16_scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 460 | param_norms = utils.clip_gradients(student, args.clip_grad) 461 | utils.cancel_gradients_last_layer(epoch, student, 462 | args.freeze_last_layer) 463 | fp16_scaler.step(optimizer) 464 | fp16_scaler.update() 465 | 466 | # EMA update for the teacher动态更新教师网络 467 | with torch.no_grad(): 468 | m = momentum_schedule[it] # momentum parameter 469 | for param_q, param_k in zip(params_q, params_k): 470 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) 471 | 472 | # logging 473 | torch.cuda.synchronize() 474 | metric_logger.update(loss=loss.item()) 475 | for key, value in all_loss.items(): 476 | metric_logger.update(**{key: value.item()}) 477 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 478 | metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"]) 479 | metric_logger.update(acc=acc) 480 | 481 | metric_logger.synchronize_between_processes() 482 | print("Averaged stats:", metric_logger) 483 | return_dict = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 484 | return return_dict 485 | 486 | 487 | class VehicleMAELoss(nn.Module): 488 | def __init__(self, out_dim, patch_out_dim, ngcrops, nlcrops, warmup_teacher_temp, 489 | teacher_temp, warmup_teacher_temp2, teacher_temp2, 490 | warmup_teacher_temp_epochs, nepochs, student_temp=0.1, 491 | center_momentum=0.9, center_momentum2=0.9, 492 | lambda1=1.0, lambda2=1.0, mim_start_epoch=0): 493 | super().__init__() 494 | self.student_temp = student_temp 495 | self.center_momentum = center_momentum 496 | self.center_momentum2 = center_momentum2 497 | self.ngcrops = ngcrops 498 | self.nlcrops = nlcrops 499 | self.ncrops = ngcrops + nlcrops 500 | self.register_buffer("center", torch.zeros(1, out_dim)) 501 | self.register_buffer("center2", torch.zeros(1, 1, patch_out_dim)) 502 | self.lambda1 = lambda1 503 | self.lambda2 = lambda2 504 | 505 | # we apply a warm up for the teacher temperature because 506 | # a too high temperature makes the training instable at the beginning 507 | self.teacher_temp_schedule = np.concatenate(( 508 | np.linspace(warmup_teacher_temp, 509 | teacher_temp, warmup_teacher_temp_epochs), 510 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp 511 | )) 512 | self.teacher_temp2_schedule = np.concatenate(( 513 | np.linspace(warmup_teacher_temp2, 514 | teacher_temp2, warmup_teacher_temp_epochs), 515 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp2 516 | )) if mim_start_epoch == 0 else np.concatenate(( 517 | np.ones(mim_start_epoch) * warmup_teacher_temp2, 518 | np.linspace(warmup_teacher_temp2, 519 | teacher_temp2, warmup_teacher_temp_epochs), 520 | np.ones(nepochs - warmup_teacher_temp_epochs - mim_start_epoch) * teacher_temp2 521 | )) 522 | 523 | def forward(self, student_output, teacher_output, student_mask, epoch,mae_loss,similarity_loss,kl_distance_loss): 524 | """ 525 | Cross-entropy between softmax outputs of the teacher and student networks. 526 | """ 527 | 528 | student_cls, student_patch = student_output 529 | teacher_cls, teacher_patch = teacher_output 530 | 531 | # [CLS] and patch for global patches 532 | student_cls = student_cls / self.student_temp 533 | student_cls_c = student_cls.chunk(self.ncrops) 534 | student_patch = student_patch / self.student_temp 535 | student_patch_c = student_patch.chunk(self.ngcrops) 536 | 537 | # teacher centering and sharpening 538 | temp = self.teacher_temp_schedule[epoch] 539 | temp2 = self.teacher_temp2_schedule[epoch] 540 | teacher_cls_c = F.softmax((teacher_cls - self.center) / temp, dim=-1) 541 | teacher_cls_c = teacher_cls_c.detach().chunk(self.ngcrops) 542 | teacher_patch_c = F.softmax((teacher_patch - self.center2) / temp2, dim=-1) 543 | teacher_patch_c = teacher_patch_c.detach().chunk(self.ngcrops) 544 | 545 | 546 | total_loss1, n_loss_terms1 = 0, 0 547 | total_loss2, n_loss_terms2 = 0, 0 548 | for q in range(len(teacher_patch_c)): 549 | for v in range(len(student_patch_c)): 550 | if v == q: 551 | loss2 = torch.sum(-teacher_patch_c[q] * F.log_softmax(student_patch_c[v], dim=-1), dim=-1) 552 | mask = student_mask[v] 553 | mask = ~mask 554 | loss2 = torch.sum(loss2 * mask.float(), dim=-1) / mask.sum(dim=-1).clamp(min=1.0) 555 | total_loss2 += loss2.mean() #mean均值函数 556 | n_loss_terms2 += 1 557 | else: 558 | loss1 = torch.sum(-teacher_cls_c[q] * F.log_softmax(student_cls_c[v], dim=-1), dim=-1) 559 | total_loss1 += loss1.mean() 560 | n_loss_terms1 += 1 561 | 562 | total_loss1 = total_loss1 / n_loss_terms1 * self.lambda1*0.02 563 | total_loss2 = total_loss2 / n_loss_terms2 * self.lambda2*0.02 564 | 565 | kl_distance_loss = kl_distance_loss*0.1 566 | similarity_loss = similarity_loss*2 567 | 568 | total_loss = dict( cls =total_loss1, patch=total_loss2,mae =mae_loss,similarity = similarity_loss,kl = kl_distance_loss, loss=total_loss1+total_loss2+ mae_loss+similarity_loss+kl_distance_loss) 569 | 570 | self.update_center(teacher_cls, teacher_patch) 571 | return total_loss 572 | 573 | @torch.no_grad() 574 | 575 | def update_center(self, teacher_cls, teacher_patch): 576 | """ 577 | Update center used for teacher output. 578 | """ 579 | cls_center = torch.sum(teacher_cls, dim=0, keepdim=True) 580 | dist.all_reduce(cls_center) 581 | cls_center = cls_center / (len(teacher_cls) * dist.get_world_size()) 582 | self.center = self.center * self.center_momentum + cls_center * (1 - self.center_momentum) 583 | 584 | patch_center = torch.sum(teacher_patch.mean(1), dim=0, keepdim=True) 585 | dist.all_reduce(patch_center) 586 | patch_center = patch_center / (len(teacher_patch) * dist.get_world_size()) 587 | self.center2 = self.center2 * self.center_momentum2 + patch_center * (1 - self.center_momentum2) 588 | 589 | 590 | 591 | if __name__ == '__main__': 592 | parser = argparse.ArgumentParser('VehicleMAE', parents=[get_args_parser()]) 593 | args = parser.parse_args() 594 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 595 | train_vehicle(args) 596 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | firstIMG 4 |

5 | 6 | 7 | 8 | Official PyTorch implementation of **Structural Information Guided Multimodal Pre-training for Vehicle-centric Perception**, Xiao Wang, Wentao Wu, Chenglong Li, Zhicheng Zhao, Zhe Chen, Yukai Shi, Jin Tang, AAAI-2024 9 | [[arXiv](https://arxiv.org/pdf/2312.09812.pdf)] 10 | [[Poster](https://github.com/Event-AHU/VehicleMAE/blob/main/VehicleMAE_poster.pdf)] 11 | 12 | 13 | 14 | ## Abstract 15 | Understanding vehicles in images is important for various applications such as intelligent transportation and self-driving systems. Existing vehicle-centric works typically pre-train models on large-scale classification datasets and then fine-tune them for specific downstream tasks. However, they neglect the specific characteristics of vehicle perception in different tasks and might thus lead to sub-optimal performance. To address this issue, we propose a novel vehicle-centric pre-training framework called VehicleMAE, which incorporates the structural information including the spatial structure from vehicle profile information and the semantic structure from informative high-level natural language descriptions for effective masked vehicle appearance reconstruction. To be specific, we explicitly ex-tract the sketch lines of vehicles as a form of the spatial structure to guide vehicle reconstruction. The more comprehensive knowledge distilled from the CLIP big model 16 | based on the similarity between the paired/unpaired vehicle image-text sample is further taken into consideration to help achieve a better understanding of vehicles. A large-scale dataset is built to pre-train our model, termed Autobot1M, which contains about 1M vehicle images and 12693 text information. Extensive experiments on four vehicle-based downstream tasks fully validated the effectiveness of our VehicleMAE. 17 | 18 | 19 |

20 | firstIMG 21 |

22 | 23 | 24 | 25 | 26 | ## Video Tutorial 27 | 28 | * **Video Tutorial for this work can be found by clicking the image below:** 29 |

30 | 31 | Tutorials 32 | 33 |

34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | ## Our Proposed Framework VehicleMAE 42 | 43 |

44 | framework 45 |

46 | 47 | ## Environment Setting 48 | 49 | Configure the environment according to the content of the requirements.txt file. 50 | 51 | ## Dataset Download 52 | 53 | 54 | 55 |

56 | data_show 57 |

58 | 59 | Baidu Netdisk Link :[download](https://pan.baidu.com/s/1VE0V6cuimfZ3qZCILwwORA?pwd=tpds) 60 | 61 | Extracted code :tpds 62 | 63 | ## Pre-trained Model Download 64 | Pre-trained Model | Vit-base 65 | ---- | ----- 66 | Pre-trained checkpoint | [download](https://pan.baidu.com/s/1wB2QdzItdVVYQQ491ZOOrA?pwd=6zkx) 67 | Extracted code |6zkx 68 | 69 | 70 | ## Training 71 | 72 | ```bibtex 73 | #If you pre-training VehicleMAE using a single GPU, please run. 74 | CUDA_VISIBLE_DEVICES=0 python main.py 75 | #If you pre-training VehicleMAE using multiple GPUs, please run. 76 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py 77 | ``` 78 | 79 | ## Experimental Results 80 | 81 | We used full fine-tuning to test the pre-trained model on four downstream tasks. The results are shown in the table below. 82 | 83 | 84 | 85 | 89 | 93 | 97 | 101 | 105 | 109 | 110 | 111 | 116 | 121 | 126 | 131 | 136 | 141 | 146 | 151 | 152 | 153 | 157 | 162 | 167 | 172 | 177 | 182 | 187 | 192 | 197 | 202 | 203 | 204 | 208 | 213 | 218 | 223 | 228 | 233 | 238 | 243 | 248 | 253 | 254 | 255 | 259 | 264 | 269 | 274 | 279 | 284 | 289 | 294 | 299 | 304 | 305 | 306 | 310 | 315 | 320 | 325 | 330 | 335 | 340 | 345 | 350 | 355 | 356 | 357 | 361 | 366 | 371 | 376 | 381 | 386 | 391 | 396 | 401 | 406 | 407 | 408 | 412 | 417 | 422 | 427 | 432 | 437 | 442 | 447 | 452 | 457 | 458 | 459 | 463 | 468 | 473 | 478 | 483 | 488 | 493 | 498 | 503 | 508 | 509 |
87 |

Method

88 |
91 |

Dataset

92 |
95 |

VAR

96 |
99 |

V-Reid

100 |
103 |

VFR

104 |
107 |

VPS

108 |
114 |

mA

115 |
119 |

Acc

120 |
124 |

F1

125 |
129 |

mAP

130 |
134 |

R1

135 |
139 |

Acc

140 |
144 |

mIou

145 |
149 |

mAcc

150 |
155 |

Scratch

156 |
160 |

-

161 |
165 |

84.67

166 |
170 |

80.86

171 |
175 |

84.90

176 |
180 |

35.3

181 |
185 |

57.3

186 |
190 |

24.8

191 |
195 |

49.36

196 |
200 |

59.22

201 |
206 |

MoCov3

207 |
211 |

Imagenet1K

212 |
216 |

90.38

217 |
221 |

93.88

222 |
226 |

95.33

227 |
231 |

75.5

232 |
236 |

94.4

237 |
241 |

91.3

242 |
246 |

73.17

247 |
251 |

78.60

252 |
257 |

DINO

258 |
262 |

Imagenet1K

263 |
267 |

89.92

268 |
272 |

91.09

273 |
277 |

93.11

278 |
282 |

64.3

283 |
287 |

91.5

288 |
292 |

-

293 |
297 |

68.43

298 |
302 |

73.37

303 |
308 |

IBOT

309 |
313 |

Imagenet1K

314 |
318 |

89.51

319 |
323 |

90.17

324 |
328 |

92.37

329 |
333 |

68.9

334 |
338 |

92.6

339 |
343 |

81.1

344 |
348 |

66.03

349 |
353 |

71.06

354 |
359 |

MAE

360 |
364 |

Imagenet1K

365 |
369 |

89.69

370 |
374 |

93.60

375 |
379 |

95.08

380 |
384 |

76.7

385 |
389 |

95.8

390 |
394 |

91.2

395 |
399 |

69.54

400 |
404 |

75.36

405 |
410 |

MAE

411 |
415 |

Autobot1M

416 |
420 |

90.19

421 |
425 |

94.06

426 |
430 |

95.43

431 |
435 |

75.5

436 |
440 |

95.4

441 |
445 |

91.3

446 |
450 |

69.00

451 |
455 |

75.36

456 |
461 |

VehicleMAE

462 |
466 |

Autobot1M

467 |
471 |

92.21

472 |
476 |

94.91

477 |
481 |

96.17

482 |
486 |

85.6

487 |
491 |

97.9

492 |
496 |

94.5

497 |
501 |

73.29

502 |
506 |

80.22

507 |
510 | 511 | The four downstream tasks are vehicle attribute recognition (VAR), vehicle re-identification (V-Reid), vehicle fine-grained recognition (VFR), and vehicle partial segmentation (VPS). 512 | 513 | 514 | ## Visual Results 515 | 516 |

517 | reconst_vis 518 |

519 | 520 |

521 | attentionmaps 522 |

523 | 524 | 525 | 526 | ## Acknowledgement 527 | [[MAE](https://github.com/facebookresearch/mae)] 528 | [[BDCN](https://github.com/pkuCactus/BDCN)] 529 | [[CLIP](https://github.com/openai/CLIP)] 530 | 531 | 532 | ## Citation 533 | 534 | If you find this work helps your research, please cite the following paper and give us a star. 535 | ```bibtex 536 | @misc{wang2023structural, 537 | title={Structural Information Guided Multimodal Pre-training for Vehicle-centric Perception}, 538 | author={Xiao Wang and Wentao Wu and Chenglong Li and Zhicheng Zhao and Zhe Chen and Yukai Shi and Jin Tang}, 539 | year={2023}, 540 | eprint={2312.09812}, 541 | archivePrefix={arXiv}, 542 | primaryClass={cs.CV} 543 | } 544 | ``` 545 | 546 | 547 | if you have any problems with this work, please leave an issue. 548 | --------------------------------------------------------------------------------