├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── simple_tokenizer.py ├── clip.py └── model.py ├── requirements.txt ├── adem ├── __init__.py ├── tokenizer.py ├── adapter.py ├── build.py └── model.py ├── demo.py ├── engine.py ├── eval_caption.py ├── eval_mme.py ├── util ├── datasets.py ├── coco_karpathy_dataset.py ├── base_prompt.py ├── randaugment.py └── misc.py ├── README.md ├── eval_sqa.py └── train.py /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | sentencepiece==0.1.99 3 | timm==0.6.5 4 | ftfy 5 | regex 6 | pandas -------------------------------------------------------------------------------- /adem/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import ModelArgs, Transformer 2 | from .tokenizer import Tokenizer 3 | 4 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hao840/ADEM-VL/HEAD/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /adem/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | from sentencepiece import SentencePieceProcessor 5 | 6 | 7 | class Tokenizer: 8 | def __init__(self, model_path: str): 9 | # reload tokenizer 10 | assert os.path.isfile(model_path), model_path 11 | self.sp_model = SentencePieceProcessor(model_file=model_path) 12 | print(f"Reloaded SentencePiece model from {model_path}") 13 | 14 | # BOS / EOS token IDs 15 | self.n_words: int = self.sp_model.vocab_size() 16 | self.bos_id: int = self.sp_model.bos_id() 17 | self.eos_id: int = self.sp_model.eos_id() 18 | self.pad_id: int = self.sp_model.pad_id() 19 | print(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") 20 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 21 | 22 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 23 | assert type(s) is str 24 | t = self.sp_model.encode(s) 25 | if bos: 26 | t = [self.bos_id] + t 27 | if eos: 28 | t = t + [self.eos_id] 29 | return t 30 | 31 | def decode(self, t: List[int]) -> str: 32 | return self.sp_model.decode(t) 33 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataclasses import dataclass 3 | import os 4 | 5 | from PIL import Image 6 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 7 | import torch 8 | from torchvision.transforms import transforms 9 | 10 | from adem.build import create_model 11 | from adem.tokenizer import Tokenizer 12 | 13 | 14 | @dataclass 15 | class ModelArgs: 16 | llama_model_path = './data/weights/' 17 | llm_model = '7B' 18 | max_seq_len = 512 19 | hidden_proj = 128 20 | cpu_load = False 21 | alpha = 0.1 22 | adapter_dim = 12 23 | gradient_checkpointing = False 24 | is_train = False 25 | data_root = './data/' 26 | clip = 'ViT-L/14' 27 | clip_root = './clip' 28 | down_sample_num = [256, 64] 29 | no_cls = False 30 | drop_ratio = 0.1 31 | 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--data_root', type=str, default='./data') 35 | parser.add_argument('--clip', type=str, default='ViT-L/14') 36 | parser.add_argument('--clip_root', type=str, default='./clip') 37 | parser.add_argument('--llm_model', type=str, default='7B') 38 | parser.add_argument('--adapter_path', type=str, default='./output_dir') 39 | 40 | parser.add_argument('--batch_size', type=int, default=4) 41 | parser.add_argument('--down_sample_num', type=int, nargs='+', default=[256, 64]) 42 | parser.add_argument('--alpha', type=float, default=0.1) 43 | parser.add_argument('--beta', type=float, default=0.01) 44 | parser.add_argument('--drop_ratio', type=float, default=0.1) 45 | parser.add_argument('--no_cls', action='store_true') 46 | 47 | args = parser.parse_args() 48 | 49 | model_args = ModelArgs() 50 | model_args.llama_model_path = os.path.join(args.data_root, "weights/") 51 | model_args.llm_model = args.llm_model 52 | model_args.alpha = args.alpha 53 | model_args.beta = args.beta 54 | model_args.data_root = args.data_root 55 | model_args.clip = args.clip 56 | model_args.clip_root = args.clip_root 57 | model_args.down_sample_num = args.down_sample_num 58 | model_args.no_cls = args.no_cls 59 | model_args.drop_ratio = args.drop_ratio 60 | 61 | llama = create_model(model_args) 62 | adapter = torch.load(os.path.join(args.adapter_path, 'checkpoint-14.pth'))['model'] 63 | sd = {} 64 | for k in adapter: 65 | sd[k.replace('module.', '')] = adapter[k] 66 | _IncompatibleKeys = llama.load_state_dict(sd, False) 67 | print(_IncompatibleKeys) 68 | 69 | tokenizer = Tokenizer(model_path=os.path.join(args.llama_model_path, 'weights/tokenizer.model')) 70 | vis_processor = transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC), transforms.ToTensor(), 71 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 72 | 73 | image = '/cache/data/data/cocoimages/test/COCO_test2014_000000000069.jpg' 74 | prompt = 'Describe this image.' 75 | prompt = f'Instruction: {prompt}\nResponse:' 76 | 77 | if image is not None: 78 | raw_image = Image.open(image).convert('RGB') 79 | image = vis_processor(raw_image).unsqueeze(0).cuda() 80 | indicator = 1 81 | else: 82 | image = torch.Tensor(torch.zeros(3, 224, 224)).cuda() 83 | indicator = 0 84 | 85 | outputs = llama.generate( 86 | prompts=[prompt], 87 | images=[image], 88 | indicators=[indicator], 89 | max_gen_len=384, 90 | tokenizer=tokenizer, 91 | temperature=0.1, 92 | top_p=0.75, 93 | ) 94 | -------------------------------------------------------------------------------- /adem/adapter.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | from torch.cuda.amp import autocast 6 | from torch.nn import functional as F 7 | 8 | from clip.model import ResidualAttentionBlock 9 | import adem 10 | 11 | 12 | class Adapter(nn.Module): 13 | def __init__( 14 | self, 15 | in_features=768, 16 | hidden_dim=8, 17 | drop_ratio=0 18 | ): 19 | super().__init__() 20 | if hidden_dim > 0: 21 | self.fc1 = nn.Linear(in_features, hidden_dim, bias=False) 22 | self.fc2 = nn.Linear(hidden_dim, in_features, bias=False) 23 | self.hidden_dim = hidden_dim 24 | nn.init.zeros_(self.fc2.weight) 25 | self.dropout = nn.Dropout(0.1) 26 | self.drop_ratio = drop_ratio 27 | 28 | def forward(self, x, vis_weight): 29 | with autocast(): 30 | if vis_weight is not None: 31 | image_embeds, adapter_emb1, adapter_emb2 = vis_weight 32 | x = (F.silu(x)) @ (F.silu(image_embeds + adapter_emb1).permute(0, 2, 1)) 33 | if self.drop_ratio > 0: 34 | score, _ = x.sort(dim=2) 35 | threshold = score[:, :, int(self.drop_ratio * score.size(2))].unsqueeze(2) 36 | mask = torch.ones_like(x) 37 | mask[torch.where(x < threshold.expand_as(x))] = 0 38 | x *= mask 39 | x = x @ (image_embeds + adapter_emb2) 40 | else: 41 | x = self.fc1(x) 42 | x = self.dropout(F.gelu(x)) 43 | x = self.fc2(x) 44 | return x 45 | 46 | 47 | def checkpoint(func, enable, training, *args, **kwargs): 48 | if enable and training: 49 | return torch.utils.checkpoint.checkpoint(func, *args, **kwargs) 50 | else: 51 | return func(*args, **kwargs) 52 | 53 | 54 | def forward_llama(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], 55 | vis_weight): 56 | x_norm = self.attention_norm(x) 57 | h = x + checkpoint(self.attention, self.gradient_checkpointing, self.training, x_norm, start_pos, freqs_cis, mask) 58 | h_norm = self.ffn_norm(h) 59 | out = h + checkpoint(self.feed_forward, self.gradient_checkpointing, self.training, h_norm) + self.adapter_mlp( 60 | h_norm, vis_weight) * self.s 61 | return out 62 | 63 | 64 | def forward_clip(self, x: torch.Tensor): 65 | x = x + self.attention(self.ln_1(x)) 66 | x = x + self.mlp(self.ln_2(x)) + self.adapter_mlp(self.ln_2(x), None) * self.s 67 | return x 68 | 69 | 70 | def set_Llama_Adapter(model, s=1, gradient_checkpointing=False, drop_ratio=0): 71 | for _ in model.children(): 72 | if type(_) == adem.model.TransformerBlock: 73 | _.adapter_mlp = Adapter(_.dim, hidden_dim=0, drop_ratio=drop_ratio) 74 | _.s = s 75 | _.gradient_checkpointing = gradient_checkpointing 76 | bound_method = forward_llama.__get__(_, _.__class__) 77 | setattr(_, 'forward', bound_method) 78 | elif len(list(_.children())) != 0: 79 | set_Llama_Adapter(_, s, gradient_checkpointing=gradient_checkpointing, drop_ratio=drop_ratio) 80 | 81 | 82 | def set_Clip_Adapter(model, dim=8, s=0.1): 83 | for _ in model.children(): 84 | if type(_) == ResidualAttentionBlock: 85 | _.adapter_mlp = Adapter(1024, hidden_dim=dim) 86 | _.s = s 87 | bound_method = forward_clip.__get__(_, _.__class__) 88 | setattr(_, 'forward', bound_method) 89 | elif len(list(_.children())) != 0: 90 | set_Clip_Adapter(_, dim, s) 91 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | import math 4 | import torch 5 | 6 | import util.misc as misc 7 | 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | 23 | 24 | def train_one_epoch(model: torch.nn.Module, 25 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 26 | device: torch.device, epoch: int, loss_scaler, 27 | log_writer=None, 28 | args=None): 29 | model.train(True) 30 | metric_logger = misc.MetricLogger(delimiter=" ") 31 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 32 | header = 'Epoch: [{}]'.format(epoch) 33 | print_freq = 100 34 | 35 | accum_iter = args.accum_iter 36 | 37 | optimizer.zero_grad() 38 | 39 | if log_writer is not None: 40 | print('log_dir: {}'.format(log_writer.log_dir)) 41 | 42 | prefix_img = torch.tensor(data_loader.dataset.tokenizer.encode("Image: ", bos=False, eos=False), dtype=torch.int64) 43 | prefix_nonimg = torch.tensor(data_loader.dataset.tokenizer.encode("Image: N/A", bos=False, eos=False), 44 | dtype=torch.int64) 45 | 46 | for data_iter_step, (examples, labels, example_mask, images, indicators) in enumerate( 47 | metric_logger.log_every(data_loader, print_freq, header)): 48 | # we use a per iteration (instead of per epoch) lr scheduler 49 | if data_iter_step % accum_iter == 0: 50 | adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 51 | 52 | prefix_img = prefix_img.to(examples.device) 53 | prefix_nonimg = prefix_nonimg.to(examples.device) 54 | c_loss = model(examples, labels, images=images, example_mask=example_mask, prefix_img=prefix_img, 55 | prefix_nonimg=prefix_nonimg, 56 | indicators=indicators) 57 | 58 | if torch.isnan(c_loss): 59 | print('nan') 60 | c_loss = torch.nan_to_num(c_loss) * 0 61 | loss = c_loss 62 | loss_value = loss.item() 63 | c_loss_value = c_loss.item() 64 | loss = loss / accum_iter 65 | loss_scaler(loss, optimizer, parameters=model.parameters(), 66 | update_grad=(data_iter_step + 1) % accum_iter == 0, clip_grad=args.clip_grad) 67 | 68 | if (data_iter_step + 1) % accum_iter == 0: 69 | optimizer.zero_grad() 70 | 71 | torch.cuda.synchronize() 72 | 73 | metric_logger.update(closs=c_loss_value) 74 | 75 | lr = optimizer.param_groups[0]["lr"] 76 | metric_logger.update(lr=lr) 77 | 78 | c_loss_value_reduce = misc.all_reduce_mean(c_loss_value) 79 | 80 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 81 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 82 | log_writer.add_scalar('c_train_loss', c_loss_value_reduce, epoch_1000x) 83 | log_writer.add_scalar('lr', lr, epoch_1000x) 84 | 85 | metric_logger.synchronize_between_processes() 86 | print("Averaged stats:", metric_logger) 87 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 88 | -------------------------------------------------------------------------------- /eval_caption.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataclasses import dataclass 3 | import json 4 | import os 5 | import re 6 | 7 | import torch.utils.data 8 | 9 | from adem.build import create_model 10 | from adem.tokenizer import Tokenizer 11 | from util.coco_karpathy_dataset import coco_caption_eval, coco_karpathy_caption_eval 12 | from util.misc import MetricLogger 13 | 14 | 15 | @dataclass 16 | class ModelArgs: 17 | llama_model_path = './data/weights/' 18 | llm_model = '7B' 19 | max_seq_len = 512 20 | hidden_proj = 128 21 | cpu_load = False 22 | alpha = 0.1 23 | adapter_dim = 12 24 | gradient_checkpointing = False 25 | is_train = False 26 | data_root = './data/' 27 | clip = 'ViT-L/14' 28 | clip_root = './clip' 29 | down_sample_num = [256, 64] 30 | no_cls = False 31 | drop_ratio = 0.1 32 | 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--data_root', type=str, default='./data') 36 | parser.add_argument('--clip', type=str, default='ViT-L/14') 37 | parser.add_argument('--clip_root', type=str, default='./clip') 38 | parser.add_argument('--llm_model', type=str, default='7B') 39 | parser.add_argument('--adapter_path', type=str, default='./output_dir') 40 | parser.add_argument('--log_dir', type=str, default='./output_dir') 41 | 42 | parser.add_argument('--batch_size', type=int, default=4) 43 | parser.add_argument('--down_sample_num', type=int, nargs='+', default=[256, 64]) 44 | parser.add_argument('--alpha', type=float, default=0.1) 45 | parser.add_argument('--beta', type=float, default=0.01) 46 | parser.add_argument('--drop_ratio', type=float, default=0.1) 47 | parser.add_argument('--no_cls', action='store_true') 48 | 49 | args = parser.parse_args() 50 | log_dir = args.log_dir if args.log_dir is not None else './logs' 51 | os.makedirs(log_dir, exist_ok=True) 52 | llama_model_path = os.path.join(args.data_root, "weights/") 53 | 54 | model_args = ModelArgs() 55 | model_args.llama_model_path = llama_model_path 56 | model_args.llm_model = args.llm_model 57 | model_args.alpha = args.alpha 58 | model_args.beta = args.beta 59 | model_args.data_root = args.data_root 60 | model_args.clip = args.clip 61 | model_args.clip_root = args.clip_root 62 | model_args.down_sample_num = args.down_sample_num 63 | model_args.no_cls = args.no_cls 64 | model_args.drop_ratio = args.drop_ratio 65 | 66 | llama = create_model(model_args) 67 | adapter = torch.load(os.path.join(args.adapter_path, 'checkpoint-4.pth'))['model'] 68 | sd = {} 69 | for k in adapter: 70 | sd[k.replace('module.', '')] = adapter[k] 71 | _IncompatibleKeys = llama.load_state_dict(sd, False) 72 | print(_IncompatibleKeys) 73 | 74 | tokenizer = Tokenizer(model_path=os.path.join(args.llama_model_path, 'tokenizer.model')) 75 | 76 | dataset_test = coco_karpathy_caption_eval(image_root=os.path.join(args.data_root, 'images'), 77 | ann_root=os.path.join(args.data_root, 'coco_caption')) 78 | 79 | data_loader_test = torch.utils.data.DataLoader( 80 | dataset_test, 81 | batch_size=args.batch_size, 82 | shuffle=False, 83 | drop_last=False, 84 | ) 85 | 86 | llama.eval() 87 | 88 | pattern = re.compile(r'picture of (.+)') 89 | 90 | metric_logger = MetricLogger(delimiter=" ") 91 | header = 'Caption generation:' 92 | print_freq = 100 93 | 94 | result = [] 95 | prompt = 'a picture of' 96 | for image, image_id in metric_logger.log_every(data_loader_test, print_freq, header): 97 | 98 | captions = llama.generate( 99 | [prompt] * image.size(0), images=image, indicators=[1] * image.size(0), max_gen_len=20, tokenizer=tokenizer, 100 | temperature=0.0 101 | ) 102 | 103 | matched_caption = [] 104 | for c in captions: 105 | pred = pattern.findall(c) 106 | if len(pred) >= 1: 107 | pred = pred[0] 108 | else: 109 | print(c) 110 | pred = c 111 | matched_caption.append(pred) 112 | 113 | for caption, img_id in zip(matched_caption, image_id): 114 | result.append({"image_id": img_id.item(), "caption": caption}) 115 | 116 | result_file = os.path.join(log_dir, 'test_result.json') 117 | json.dump(result, open(result_file, 'w')) 118 | 119 | coco_test = coco_caption_eval(os.path.join(args.data_root, 'coco_caption/'), result_file, split='val') 120 | 121 | log_stats = {**{f'test_{k}': v for k, v in coco_test.eval.items()}} 122 | 123 | with open(os.path.join(log_dir, "evaluate.txt"), "a") as f: 124 | f.write(json.dumps(log_stats) + "\n") 125 | -------------------------------------------------------------------------------- /adem/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import json 4 | from adem import ModelArgs, Tokenizer, Transformer 5 | from adem.adapter import set_Llama_Adapter, set_Clip_Adapter 6 | 7 | from pathlib import Path 8 | 9 | 10 | def _load_and_redistribute_checkpoint(llama_model_path, model_name): 11 | with open(Path(llama_model_path) / model_name / 'params.json') as f: 12 | params = json.load(f) 13 | tokenizer = Tokenizer(model_path=str(Path(llama_model_path) / 'tokenizer.model')) 14 | print('Using model path: %s, model_name: %s' % (llama_model_path, model_name)) 15 | if model_name == '7B': 16 | checkpoint = torch.load(llama_model_path + model_name + '/consolidated.00.pth', map_location="cpu") 17 | return checkpoint, tokenizer, params 18 | 19 | checkpoints = (Path(llama_model_path) / model_name).glob('*.pth') 20 | checkpoints = sorted(checkpoints) 21 | 22 | loaded = [] 23 | for x in checkpoints: 24 | print('loading from', x) 25 | loaded.append(torch.load(x, map_location='cpu')) 26 | 27 | full_state_dict = {} 28 | split_dims = {} 29 | 30 | def add_weight_with_split_dim(name, dim): 31 | if dim < 0: # bcast without split 32 | full_state_dict[name] = loaded[0][name].clone() 33 | else: 34 | full_state_dict[name] = torch.cat([x[name] for x in loaded], dim=dim) 35 | for x in loaded: 36 | del x[name] 37 | split_dims[name] = dim 38 | 39 | add_weight_with_split_dim('tok_embeddings.weight', 1) 40 | add_weight_with_split_dim('norm.weight', -1) 41 | add_weight_with_split_dim('output.weight', 0) 42 | for i in range(params['n_layers']): 43 | print('gathering layer %d of %d' % (i, params['n_layers'])) 44 | layer_prefix = f'layers.{i}.' 45 | bcast_names = [ 46 | 'attention_norm.weight', 47 | 'ffn_norm.weight', 48 | ] 49 | column_parallel_names = [ 50 | 'attention.wq.weight', 51 | 'attention.wk.weight', 52 | 'attention.wv.weight', 53 | 'feed_forward.w1.weight', 54 | 'feed_forward.w3.weight', 55 | ] 56 | row_parallel_names = [ 57 | 'attention.wo.weight', 58 | 'feed_forward.w2.weight', 59 | ] 60 | for key in bcast_names: 61 | add_weight_with_split_dim(layer_prefix + key, -1) 62 | for key in column_parallel_names: 63 | add_weight_with_split_dim(layer_prefix + key, 0) 64 | for key in row_parallel_names: 65 | add_weight_with_split_dim(layer_prefix + key, 1) 66 | 67 | checkpoint = full_state_dict 68 | 69 | return checkpoint, tokenizer, params 70 | 71 | 72 | def create_model(args): 73 | llama_model_path = args.llama_model_path 74 | model_name = args.llm_model 75 | 76 | checkpoint, tokenizer, params = _load_and_redistribute_checkpoint(llama_model_path, model_name) 77 | 78 | model_args: ModelArgs = ModelArgs( 79 | max_seq_len=args.max_seq_len, max_batch_size=64, hidden_proj=args.hidden_proj, 80 | is_train=args.is_train, **params 81 | ) 82 | 83 | model_args.vocab_size = tokenizer.n_words 84 | model_args.clip = args.clip 85 | model_args.clip_root = args.clip_root 86 | model_args.beta = args.beta 87 | model_args.down_sample_num = args.down_sample_num 88 | if len(model_args.down_sample_num) == 1: 89 | model_args.down_sample_num = model_args.down_sample_num[0] 90 | model_args.with_cls = not args.no_cls 91 | 92 | if args.cpu_load: 93 | # cpu load is slow, but is freindly for GPU with limited memory. 94 | torch.set_default_tensor_type(torch.HalfTensor) 95 | else: 96 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 97 | 98 | llama = Transformer(model_args) 99 | 100 | # delete language encoder 101 | del llama.backbone.transformer 102 | 103 | torch.set_default_tensor_type(torch.FloatTensor) 104 | 105 | llama.load_state_dict(checkpoint, strict=False) 106 | 107 | 108 | set_Llama_Adapter(llama, s=args.alpha, gradient_checkpointing=args.gradient_checkpointing, drop_ratio=args.drop_ratio) 109 | set_Clip_Adapter(llama.backbone.visual, dim=args.adapter_dim, s=0.1) 110 | 111 | learnable_keys = ['adapter'] 112 | total = 0. 113 | trainable_names = [] 114 | for name, param in llama.named_parameters(): 115 | for key in learnable_keys: 116 | if key in name: 117 | param.requires_grad = True 118 | param.data = param.data.float() 119 | total += param.nelement() 120 | trainable_names.append(name) 121 | else: 122 | param.requires_grad = False 123 | print(trainable_names) 124 | print(' + Number of trainable params: %.2fM' % (total / 1e6)) 125 | return llama 126 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /eval_mme.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataclasses import dataclass 3 | import os 4 | 5 | from PIL import Image 6 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 7 | import torch 8 | from torchvision.transforms import transforms 9 | 10 | from adem.build import create_model 11 | from adem.tokenizer import Tokenizer 12 | 13 | 14 | @dataclass 15 | class ModelArgs: 16 | llama_model_path = './data/weights/' 17 | llm_model = '7B' 18 | max_seq_len = 512 19 | hidden_proj = 128 20 | cpu_load = False 21 | alpha = 0.1 22 | adapter_dim = 12 23 | gradient_checkpointing = False 24 | is_train = False 25 | data_root = './data/' 26 | clip = 'ViT-L/14' 27 | clip_root = './clip' 28 | down_sample_num = [256, 64] 29 | no_cls = False 30 | drop_ratio = 0.1 31 | 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--data_root', type=str, default='./data') 35 | parser.add_argument('--clip', type=str, default='ViT-L/14') 36 | parser.add_argument('--clip_root', type=str, default='./clip') 37 | parser.add_argument('--llm_model', type=str, default='7B') 38 | parser.add_argument('--adapter_path', type=str, default='./output_dir') 39 | parser.add_argument('--log_dir', type=str, default='./output_dir') 40 | 41 | parser.add_argument('--batch_size', type=int, default=4) 42 | parser.add_argument('--down_sample_num', type=int, nargs='+', default=[256, 64]) 43 | parser.add_argument('--alpha', type=float, default=0.1) 44 | parser.add_argument('--beta', type=float, default=0.01) 45 | parser.add_argument('--drop_ratio', type=float, default=0.1) 46 | parser.add_argument('--no_cls', action='store_true') 47 | 48 | args = parser.parse_args() 49 | log_dir = args.log_dir if args.log_dir is not None else './logs' 50 | os.makedirs(log_dir, exist_ok=True) 51 | llama_model_path = os.path.join(args.data_root, "weights/") 52 | 53 | model_args = ModelArgs() 54 | model_args.llama_model_path = llama_model_path 55 | model_args.llm_model = args.llm_model 56 | model_args.alpha = args.alpha 57 | model_args.beta = args.beta 58 | model_args.data_root = args.data_root 59 | model_args.clip = args.clip 60 | model_args.clip_root = args.clip_root 61 | model_args.down_sample_num = args.down_sample_num 62 | model_args.no_cls = args.no_cls 63 | model_args.drop_ratio = args.drop_ratio 64 | 65 | llama = create_model(model_args) 66 | adapter = torch.load(os.path.join(args.adapter_path, 'checkpoint-14.pth'))['model'] 67 | sd = {} 68 | for k in adapter: 69 | sd[k.replace('module.', '')] = adapter[k] 70 | _IncompatibleKeys = llama.load_state_dict(sd, False) 71 | print(_IncompatibleKeys) 72 | 73 | tokenizer = Tokenizer(model_path=os.path.join(args.llama_model_path, 'tokenizer.model')) 74 | vis_processor = transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC), transforms.ToTensor(), 75 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 76 | 77 | task_meta = { 78 | 'artwork': dict(qa='questions_answers_YN', vis='images', suffix='jpg'), 79 | 'celebrity': dict(qa='questions_answers_YN', vis='images', suffix='jpg'), 80 | 'code_reasoning': dict(qa='', vis='', suffix='png'), 81 | 'color': dict(qa='', vis='', suffix='jpg'), 82 | 'commonsense_reasoning': dict(qa='', vis='', suffix='png'), 83 | 'count': dict(qa='', vis='', suffix='jpg'), 84 | 'existence': dict(qa='', vis='', suffix='jpg'), 85 | 'landmark': dict(qa='questions_answers_YN', vis='images', suffix='jpg'), 86 | 'numerical_calculation': dict(qa='', vis='', suffix='png'), 87 | 'OCR': dict(qa='', vis='', suffix='jpg'), 88 | 'position': dict(qa='', vis='', suffix='jpg'), 89 | 'posters': dict(qa='questions_answers_YN', vis='images', suffix='jpg'), 90 | 'scene': dict(qa='questions_answers_YN', vis='images', suffix='jpg'), 91 | 'text_translation': dict(qa='', vis='', suffix='png'), 92 | } 93 | 94 | os.makedirs(os.path.join(log_dir, 'output')) 95 | for task_idx, task in enumerate(task_meta): 96 | qa_path = os.path.join(args.data_root, 'MME_Benchmark_release_version', task, task_meta[task]['qa']) 97 | vis_path = os.path.join(args.data_root, 'MME_Benchmark_release_version', task, task_meta[task]['vis']) 98 | suffix = task_meta[task]['suffix'] 99 | 100 | results = [] 101 | for qa_name in os.listdir(qa_path): 102 | if not qa_name.split('.')[-1] == 'txt': 103 | continue 104 | vis_name = qa_name.split('.')[0] + f'.{suffix}' 105 | image = Image.open(os.path.join(vis_path, vis_name)).convert('RGB') 106 | image = vis_processor(image).unsqueeze(0) 107 | 108 | with open(os.path.join(qa_path, qa_name)) as f: 109 | items = [l[:-1] if '\n' in l else l for l in f.readlines()] 110 | for item in items: 111 | q, gt = item.split('\t') 112 | 113 | prompt = f'Instruction: {q}\nResponse:' 114 | 115 | output = llama.generate([prompt], images=image, indicators=[1], max_gen_len=10, tokenizer=tokenizer, 116 | temperature=0.0) 117 | 118 | output = output[0].replace('\n', '') 119 | 120 | results.append('\t'.join([vis_name, q, gt, output]) + '\n') 121 | 122 | with open(os.path.join(log_dir, f'output/{task}.txt'), mode='w') as f: 123 | f.writelines(results) 124 | -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | import random 5 | 6 | from PIL import Image 7 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | import torch 9 | import torch.utils.data as Data 10 | from torchvision.transforms import transforms 11 | 12 | from adem import Tokenizer 13 | from util.base_prompt import * 14 | 15 | 16 | class ScienceQADataSet(Data.Dataset): 17 | def __init__(self, args, split, model_path, max_words=512, max_image_feats=1): 18 | super(ScienceQADataSet, self).__init__() 19 | self.args = args 20 | self.problems = json.load(open(os.path.join(args.data_root, 'problems.json'))) 21 | pid_splits = json.load(open(os.path.join(args.data_root, 'pid_splits.json'))) 22 | captions = json.load(open(args.caption_file))["captions"] 23 | self.image_path = os.path.join(args.data_root, 'images', split) 24 | self.tokenizer = Tokenizer(model_path=model_path + '/tokenizer.model') 25 | self.max_words = max_words 26 | self.max_image_feats = max_image_feats 27 | self.split = split 28 | for qid in self.problems: 29 | self.problems[qid]['caption'] = captions[qid] if qid in captions else "" 30 | 31 | self.qids = pid_splits['%s' % (split)] 32 | 33 | print(f"number of problems in split {split}: {len(self.qids)}\n") 34 | 35 | self.transforms = transforms.Compose( 36 | [transforms.Resize((224, 224), interpolation=Image.BICUBIC), transforms.ToTensor(), 37 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 38 | 39 | def tokenize(self, prompt, answer): 40 | example = prompt + answer 41 | prompt = torch.tensor(self.tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.int64) 42 | example = torch.tensor(self.tokenizer.encode(example, bos=True, eos=True), dtype=torch.int64) 43 | padding = self.max_words - example.shape[0] 44 | if padding > 0: 45 | example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) 46 | elif padding < 0: 47 | example = example[:self.max_words] 48 | labels = copy.deepcopy(example) 49 | labels[:len(prompt)] = -1 50 | example_mask = example.ge(0) 51 | label_mask = labels.ge(0) 52 | example[~example_mask] = 0 53 | labels[~label_mask] = 0 54 | example_mask = example_mask.float() 55 | label_mask = label_mask.float() 56 | return example, labels, example_mask, label_mask 57 | 58 | def __getitem__(self, idx): 59 | 60 | prompt_question, prompt_answer = build_prompt(self.problems, self.qids[idx], self.args) 61 | answer, choices, qid = self.problems[self.qids[idx]]["answer"], self.problems[self.qids[idx]]["choices"], \ 62 | self.qids[idx] 63 | 64 | if self.problems[self.qids[idx]]['image'] is not None: 65 | image = Image.open(os.path.join(self.image_path, self.qids[idx], 'image.png')).convert('RGB') 66 | image = self.transforms(image) 67 | indicator = 1 68 | else: 69 | image = torch.Tensor(torch.zeros(3, 224, 224).float()) 70 | indicator = 0 71 | 72 | example, labels, example_mask, label_mask = self.tokenize(prompt_question, prompt_answer) 73 | 74 | if isinstance(image, list): 75 | return example, labels, example_mask, *image, indicator 76 | return example, labels, example_mask, image, indicator 77 | 78 | def __len__(self): 79 | return len(self.qids) 80 | 81 | def shuffle_list(self, list): 82 | random.shuffle(list) 83 | 84 | 85 | class InstrcutDataSet(Data.Dataset): 86 | def __init__(self, args, split, model_path, max_words=512, max_image_feats=1): 87 | super(InstrcutDataSet, self).__init__() 88 | self.args = args 89 | self.data = json.load(open(os.path.join(args.data_root, 'all_data.json')))[split] 90 | self.sqa_train_path = os.path.join(args.data_root, 'images', 'train') 91 | self.coco_train_path = os.path.join(args.data_root, 'cocoimages', 'train') 92 | 93 | self.tokenizer = Tokenizer(model_path=model_path + '/tokenizer.model') 94 | self.max_words = max_words 95 | self.max_image_feats = max_image_feats 96 | self.split = split 97 | self.qids = [item['qid'] for item in self.data] 98 | 99 | print(f"number of problems in split {split}: {len(self.qids)}\n") 100 | 101 | self.transforms = transforms.Compose( 102 | [transforms.Resize((224, 224), interpolation=Image.BICUBIC), transforms.ToTensor(), 103 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 104 | 105 | def tokenize(self, prompt, answer, max_words=512): 106 | example = prompt + answer 107 | # print(prompt) 108 | prompt = torch.tensor(self.tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.int64) 109 | example = torch.tensor(self.tokenizer.encode(example, bos=True, eos=True), dtype=torch.int64) 110 | padding = max_words - example.shape[0] 111 | if padding > 0: 112 | example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) 113 | elif padding < 0: 114 | example = example[:self.max_words] 115 | labels = copy.deepcopy(example) 116 | labels[:len(prompt)] = -1 117 | example_mask = example.ge(0) 118 | label_mask = labels.ge(0) 119 | example[~example_mask] = 0 120 | labels[~label_mask] = 0 121 | example_mask = example_mask.float() 122 | label_mask = label_mask.float() 123 | return example, labels, example_mask, label_mask 124 | 125 | def __getitem__(self, idx): 126 | 127 | prompt_question = self.data[idx]['instruction'] 128 | prompt_answer = self.data[idx]['answer'] 129 | 130 | if self.data[idx]['image'] is not None: 131 | # image_path='../data/images/train' if self.data[idx]['image_source']=='sqa' else '../data/images/train2014' 132 | if self.data[idx]['image_source'] == 'sqa': 133 | image = Image.open(os.path.join(self.sqa_train_path, self.qids[idx], 'image.png')).convert('RGB') 134 | else: 135 | image = Image.open( 136 | os.path.join(self.coco_train_path, 'COCO_train2014_' + self.data[idx]['image'])).convert('RGB') 137 | image = self.transforms(image) 138 | indicator = 1 139 | else: 140 | image = torch.Tensor(torch.zeros(3, 224, 224).float()) 141 | indicator = 0 142 | 143 | # print(prompt_question,prompt_answer) 144 | example, labels, example_mask, label_mask = self.tokenize(prompt_question, prompt_answer) 145 | 146 | return example, labels, example_mask, image, indicator 147 | 148 | def __len__(self): 149 | return len(self.qids) 150 | 151 | def shuffle_list(self, list): 152 | random.shuffle(list) 153 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ADEM-VL 2 | 3 | Official code of paper: [ADEM-VL: Adaptive and Embedded Fusion for Efficient Vision-Language Tuning](https://www.arxiv.org/pdf/2410.17779). 4 | 5 | Zhiwei Hao, Jianyuan Guo, Li Shen, Yong Luo, Han Hu*, Yonggang Wen 6 | 7 | ## Preparation 8 | ```bash 9 | conda create -n adem python=3.8 -y 10 | conda activate adem 11 | 12 | # install pytorch 13 | conda install pytorch==1.13.1 torchvision==0.14.1 -c pytorch 14 | 15 | # install dependencies 16 | pip install -r requirements.txt 17 | ``` 18 | **Data Preparation** 19 | 20 | *the data preparation instruction is borrowed from [LaVIN](https://github.com/luogen1996/LaVIN/tree/main)*. 21 | 22 | - For ScienceQA, please prepare the dataset from the [official repo](https://github.com/lupantech/ScienceQA). 23 | - For Multimodal Chatbot, download the images in _train2014_ split from [MSCOCO](http://images.cocodataset.org/zips/train2014.zip), and obtain the prepared 52k text-only and 158k text-image instruction-following data from [here](https://drive.google.com/file/d/1gORDPruqwXbgy6NYmhpDXO7t089yzsg3/view?usp=share_link). 24 | - Obtain the weights of LLaMA from [this form](https://forms.gle/jk851eBVbX1m5TAv5) (official) or Download [LLaMA-7B](https://huggingface.co/nyanko7/LLaMA-7B/tree/main) and [LLaMA-13B](https://huggingface.co/TheBloke/llama-13b) from HuggingFace (unofficial). 25 | 26 | After that, the file structure should look like: 27 | 28 | ```bash 29 | ADEM-VL/ 30 | |-- adem 31 | |-- train.py 32 | ...... 33 | |-- data/ 34 | |-- problem.json 35 | |-- pid_splits.json 36 | |-- captions.json 37 | |-- all_data.json 38 | |-- images 39 | |-- train2014 # MSCOCO 2014 40 | |-- val2014 # MSCOCO 2014 41 | |-- train # ScienceQA train image 42 | |-- val # ScienceQA val image 43 | |-- test # ScienceQA test image 44 | |-- weights 45 | |-- tokenizer.model 46 | |--7B 47 | |-- params.json 48 | |-- consolidated.00.pth 49 | |--13B 50 | |-- params.json 51 | |-- consolidated.00.pth 52 | |-- consolidated.01.pth 53 | ``` 54 | ## Fine-tuning 55 | Reproduce the performance of LaVIN-7B. 56 | 57 | **ScienceQA** 58 | 59 | ```shell 60 | torchrun --nproc_per_node 8 train.py --data_root /path/to/data/ --clip_root /path/to/data/weights/clip/ --caption_file /path/to/data/captions.json --llama_model_path /path/to/data/weights/ --llm_model 7B --max_seq_len 512 --batch_size 2 --accum_iter 2 --epochs 20 --warmup_epochs 2 --blr 9e-3 --weight_decay 0.02 --adapter_dim 12 --alpha 0.1 --beta 0.01 --drop_ratio 0.1 --down_sample_num 256 64 --dataset sqa 61 | ``` 62 | 63 | **COCO caption** 64 | 65 | ```shell 66 | torchrun --nproc_per_node 8 train.py --data_root /path/to/data/ --clip_root /path/to/data/weights/clip/ --caption_file /path/to/data/captions.json --llama_model_path /path/to/data/weights/ --llm_model 7B --max_seq_len 512 --batch_size 2 --accum_iter 2 --epochs 5 --warmup_epochs 0.1 --blr 9e-3 --weight_decay 0.02 --adapter_dim 12 --alpha 0.1 --beta 0.01 --drop_ratio 0.1 --down_sample_num 256 64 --dataset coco_caption 67 | ``` 68 | 69 | **Instruction following** 70 | 71 | ```shell 72 | torchrun --nproc_per_node 8 train.py --data_root /path/to/data/ --clip_root /path/to/data/weights/clip/ --caption_file /path/to/data/captions.json --llama_model_path /path/to/data/weights/ --llm_model 7B --max_seq_len 512 --batch_size 2 --accum_iter 2 --epochs 15 --warmup_epochs 0.2 --blr 9e-3 --weight_decay 0.02 --adapter_dim 12 --alpha 0.1 --beta 0.01 --drop_ratio 0.1 --down_sample_num 256 64 --dataset instruction 73 | ``` 74 | 75 | To train on fewer GPUs, you can reduce the number of gpus in the scripts and increase gradient accumulation via ```--accum_iter``` to guarantee the total batch size of 32. 76 | 77 | ## Evaluation 78 | 79 | Evaluate fine-tuned model on each tasks. 80 | 81 | **ScienceQA** 82 | 83 | ```shell 84 | python eval_sqa.py --data_root /path/to/data/ --clip_root /path/to/data/weights/clip/ --model 7B --adapter_path ./output_dir --alpha 0.1 --beta 0.01 --drop_ratio 0.1 --down_sample_num 256 64 85 | ``` 86 | 87 | **COCO caption** 88 | 89 | ```shell 90 | # prepare required packages 91 | pip install pycocoevalcap pycocotools 92 | 93 | python eval_caption.py --data_root /path/to/data/ --clip_root /path/to/data/weights/clip/ --model 7B --adapter_path ./output_dir --alpha 0.1 --beta 0.01 --drop_ratio 0.1 --down_sample_num 256 64 94 | ``` 95 | 96 | **Instruction following** 97 | 98 | - **MME** 99 | 100 | 1. Download MME images and eval_tool from the [MME repo](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/blob/Evaluation/README.md). 101 | 2. Run the following command to obtain model predictions: 102 | 103 | ```shell 104 | python eval_instruction.py --data_root /path/to/data/ --clip_root /path/to/data/weights/clip/ --model 7B --adapter_path ./output_dir --alpha 0.1 --beta 0.01 --drop_ratio 0.1 --down_sample_num 256 64 105 | ``` 106 | 107 | 3. Calculate MME results by executing the calculation script comes from the MME eval_tool. 108 | 109 | - **More tasks** 110 | 111 | Evaluation on more tasks can be achieved in a similar way as MME based on tookits like [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) and [vlm-evaluation](https://github.com/TRI-ML/vlm-evaluation). 112 | 113 | ## Model Zoo 114 | | Model | Task | Results | Weights | Training log | 115 | | -------- | --------------------- | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | 116 | | LLaMA-7B | ScienceQA | Averaged accuracy=94.01 | [[Link]](https://github.com/Hao840/ADEM-VL/releases/download/checkpoint/checkpoint_7B_sqa.pth) | [[Link]](https://github.com/Hao840/ADEM-VL/releases/download/checkpoint/train_log_7B_sqa.txt) | 117 | | LLaMA-7B | COCO caption | BLEU-4=38.5, CIDEr=130.1 | [[Link]](https://github.com/Hao840/ADEM-VL/releases/download/checkpoint/checkpoint_7B_caption.pth) | [[Link]](https://github.com/Hao840/ADEM-VL/releases/download/checkpoint/train_log_7B_caption.txt) | 118 | | LLaMA-7B | Instruction following | MME-P=969.7, MME-C=258.9 | [[Link]](https://github.com/Hao840/ADEM-VL/releases/download/checkpoint/checkpoint_7B_instruction.pth) | [[Link]](https://github.com/Hao840/ADEM-VL/releases/download/checkpoint/train_log_7B_instruction.txt) | 119 | 120 | ## Citation 121 | If you find this work helpful, please cite our paper: 122 | ```BibTeX 123 | @misc{hao2024ademvladaptiveembeddedfusion, 124 | title={ADEM-VL: Adaptive and Embedded Fusion for Efficient Vision-Language Tuning}, 125 | author={Zhiwei Hao and Jianyuan Guo and Li Shen and Yong Luo and Han Hu and Yonggang Wen}, 126 | year={2024}, 127 | eprint={2410.17779}, 128 | archivePrefix={arXiv}, 129 | primaryClass={cs.CV}, 130 | url={https://arxiv.org/abs/2410.17779}, 131 | } 132 | ``` 133 | 134 | ## Acknowledgement 135 | This repo borrows some data and codes from [LaVIN](https://github.com/luogen1996/LaVIN/tree/main), [MemVP](https://github.com/JieShibo/MemVP), and [BLIP](https://github.com/salesforce/BLIP). Thanks for their great works. 136 | -------------------------------------------------------------------------------- /util/coco_karpathy_dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | import re 5 | 6 | from PIL import Image 7 | from pycocoevalcap.eval import COCOEvalCap 8 | from pycocotools.coco import COCO 9 | import torch 10 | from torch.utils.data import Dataset 11 | from torchvision.datasets.utils import download_url 12 | from torchvision.transforms import transforms 13 | from torchvision.transforms.functional import InterpolationMode 14 | 15 | from adem import Tokenizer 16 | from util.randaugment import RandomAugment 17 | 18 | 19 | def pre_caption(caption, max_words=50): 20 | caption = re.sub( 21 | r"([.!\"()*#:;~])", 22 | ' ', 23 | caption.lower(), 24 | ) 25 | caption = re.sub( 26 | r"\s{2,}", 27 | ' ', 28 | caption, 29 | ) 30 | caption = caption.rstrip('\n') 31 | caption = caption.strip(' ') 32 | 33 | # truncate caption 34 | caption_words = caption.split(' ') 35 | if len(caption_words) > max_words: 36 | caption = ' '.join(caption_words[:max_words]) 37 | 38 | return caption 39 | 40 | 41 | class coco_karpathy_train(Dataset): 42 | def __init__(self, image_root, ann_root, model_root, img_size=224, max_words=30, prompt=''): 43 | ''' 44 | image_root (string): Root directory of images (e.g. coco/images/) 45 | ann_root (string): directory to store the annotation file 46 | ''' 47 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json' 48 | filename = 'coco_karpathy_train.json' 49 | 50 | download_url(url, ann_root) 51 | 52 | transform_train = transforms.Compose([ 53 | transforms.RandomResizedCrop(img_size, scale=(0.5, 1.0), 54 | interpolation=InterpolationMode.BICUBIC), 55 | transforms.RandomHorizontalFlip(), 56 | RandomAugment(2, 5, isPIL=True, augs=['Identity', 'AutoContrast', 'Brightness', 'Sharpness', 'Equalize', 57 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 58 | transforms.ToTensor(), 59 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 60 | ]) 61 | 62 | self.annotation = json.load(open(os.path.join(ann_root, filename), 'r')) 63 | self.transform = transform_train 64 | self.image_root = image_root 65 | self.max_words = max_words 66 | self.prompt = prompt 67 | self.tokenizer = Tokenizer(model_path=model_root + '/tokenizer.model') 68 | 69 | self.img_ids = {} 70 | n = 0 71 | for ann in self.annotation: 72 | img_id = ann['image_id'] 73 | if img_id not in self.img_ids.keys(): 74 | self.img_ids[img_id] = n 75 | n += 1 76 | 77 | def tokenize(self, prompt, answer): 78 | example = prompt + answer 79 | prompt = torch.tensor(self.tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.int64) 80 | example = torch.tensor(self.tokenizer.encode(example, bos=True, eos=True), dtype=torch.int64) 81 | padding = self.max_words - example.shape[0] 82 | if padding > 0: 83 | example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) 84 | elif padding < 0: 85 | example = example[:self.max_words] 86 | labels = copy.deepcopy(example) 87 | labels[:len(prompt)] = -1 88 | example_mask = example.ge(0) 89 | label_mask = labels.ge(0) 90 | example[~example_mask] = 0 91 | labels[~label_mask] = 0 92 | example_mask = example_mask.float() 93 | label_mask = label_mask.float() 94 | return example, labels, example_mask, label_mask 95 | 96 | def __len__(self): 97 | return len(self.annotation) 98 | 99 | def __getitem__(self, index): 100 | 101 | ann = self.annotation[index] 102 | 103 | image_path = os.path.join(self.image_root, ann['image']) 104 | image = Image.open(image_path).convert('RGB') 105 | image = self.transform(image) 106 | 107 | caption = self.prompt + pre_caption(ann['caption'], self.max_words) 108 | 109 | example, labels, example_mask, label_mask = self.tokenize(self.prompt, caption) 110 | 111 | return example, labels, example_mask, image, 1 112 | 113 | 114 | class coco_karpathy_caption_eval(Dataset): 115 | def __init__(self, image_root, ann_root, img_size=224, split='val'): 116 | ''' 117 | image_root (string): Root directory of images (e.g. coco/images/) 118 | ann_root (string): directory to store the annotation file 119 | split (string): val or test 120 | ''' 121 | urls = {'val': 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 122 | 'test': 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 123 | filenames = {'val': 'coco_karpathy_val.json', 'test': 'coco_karpathy_test.json'} 124 | 125 | download_url(urls[split], ann_root) 126 | 127 | transform_test = transforms.Compose([ 128 | transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC), 129 | transforms.ToTensor(), 130 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 131 | ]) 132 | 133 | self.annotation = json.load(open(os.path.join(ann_root, filenames[split]), 'r')) 134 | self.transform = transform_test 135 | self.image_root = image_root 136 | 137 | def __len__(self): 138 | return len(self.annotation) 139 | 140 | def __getitem__(self, index): 141 | ann = self.annotation[index] 142 | 143 | image_path = os.path.join(self.image_root, ann['image']) 144 | image = Image.open(image_path).convert('RGB') 145 | image = self.transform(image) 146 | 147 | img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1] 148 | 149 | return image, int(img_id) 150 | 151 | 152 | def coco_caption_eval(coco_gt_root, results_file, split): 153 | urls = {'val': 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', 154 | 'test': 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'} 155 | filenames = {'val': 'coco_karpathy_val_gt.json', 'test': 'coco_karpathy_test_gt.json'} 156 | 157 | download_url(urls[split], coco_gt_root) 158 | annotation_file = os.path.join(coco_gt_root, filenames[split]) 159 | 160 | # create coco object and coco_result object 161 | coco = COCO(annotation_file) 162 | coco_result = coco.loadRes(results_file) 163 | 164 | # create coco_eval object by taking coco and coco_result 165 | coco_eval = COCOEvalCap(coco, coco_result) 166 | 167 | # evaluate on a subset of images by setting 168 | # coco_eval.params['image_id'] = coco_result.getImgIds() 169 | # please remove this line when evaluating the full validation set 170 | # coco_eval.params['image_id'] = coco_result.getImgIds() 171 | 172 | # evaluate results 173 | # SPICE will take a few minutes the first time, but speeds up due to caching 174 | coco_eval.evaluate() 175 | 176 | # print output evaluation scores 177 | for metric, score in coco_eval.eval.items(): 178 | print(f'{metric}: {score:.3f}') 179 | 180 | return coco_eval 181 | -------------------------------------------------------------------------------- /util/base_prompt.py: -------------------------------------------------------------------------------- 1 | def get_question_text(problem): 2 | question = problem['question'] 3 | return question 4 | 5 | 6 | def get_context_text(problem, use_caption): 7 | txt_context = problem['hint'] 8 | img_context = problem['caption'] if use_caption else "" 9 | context = " ".join([txt_context, img_context]).strip() 10 | if context == "": 11 | context = "N/A" 12 | return context 13 | 14 | 15 | def get_choice_text(probelm, options): 16 | choices = probelm['choices'] 17 | choice_list = [] 18 | for i, c in enumerate(choices): 19 | choice_list.append("({}) {}".format(options[i], c)) 20 | choice_txt = " ".join(choice_list) 21 | return choice_txt 22 | 23 | 24 | def get_answer(problem, options): 25 | return options[problem['answer']] 26 | 27 | 28 | def get_lecture_text(problem): 29 | # \\n: GPT-3 can generate the lecture with more tokens. 30 | lecture = problem['lecture'].replace("\n", "\\n") 31 | return lecture 32 | 33 | 34 | def get_solution_text(problem): 35 | # \\n: GPT-3 can generate the solution with more tokens 36 | solution = problem['solution'].replace("\n", "\\n") 37 | return solution 38 | 39 | 40 | def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True): 41 | input_format, output_format = format.split("-") 42 | 43 | ## Inputs 44 | if input_format == "CQM": 45 | input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n" 46 | elif input_format == "QCM": 47 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n" 48 | # upper bound experiment 49 | elif input_format == "QCML": 50 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n" 51 | elif input_format == "QCME": 52 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n" 53 | elif input_format == "QCMLE": 54 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n" 55 | 56 | elif input_format == "QCLM": 57 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n" 58 | elif input_format == "QCEM": 59 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n" 60 | elif input_format == "QCLEM": 61 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n" 62 | 63 | # Outputs 64 | if test_example: 65 | output = "Answer:" 66 | elif output_format == 'A': 67 | output = f"Answer: The answer is {answer}." 68 | 69 | elif output_format == 'AL': 70 | output = f"Answer: The answer is {answer}. BECAUSE: {solution}" 71 | elif output_format == 'AE': 72 | output = f"Answer: The answer is {answer}. BECAUSE: {lecture}" 73 | elif output_format == 'ALE': 74 | output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}" 75 | elif output_format == 'AEL': 76 | output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}" 77 | 78 | elif output_format == 'LA': 79 | output = f"Answer: {lecture} The answer is {answer}." 80 | elif output_format == 'EA': 81 | output = f"Answer: {solution} The answer is {answer}." 82 | elif output_format == 'LEA': 83 | output = f"Answer: {lecture} {solution} The answer is {answer}." 84 | elif output_format == 'ELA': 85 | output = f"Answer: {solution} {lecture} The answer is {answer}." 86 | 87 | text = input + output 88 | text = text.replace(" ", " ").strip() 89 | if text.endswith("BECAUSE:"): 90 | text = text.replace("BECAUSE:", "").strip() 91 | return text 92 | 93 | 94 | def create_training_example(format, question, context, choice, answer, lecture, solution): 95 | input_format, output_format = format.split("-") 96 | 97 | ## Inputs 98 | if input_format == "Q": 99 | input = f"Question: {question}\n" 100 | elif input_format == "QM": 101 | input = f"Question: {question}\nOptions: {choice}\n" 102 | elif input_format == "CQM": 103 | input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n" 104 | elif input_format == "QCM": 105 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n" 106 | # upper bound experiment 107 | elif input_format == "QCML": 108 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n" 109 | elif input_format == "QCME": 110 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n" 111 | elif input_format == "QCMLE": 112 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n" 113 | 114 | elif input_format == "QCLM": 115 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n" 116 | elif input_format == "QCEM": 117 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n" 118 | elif input_format == "QCLEM": 119 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n" 120 | 121 | input += "Response:" 122 | input = '\n' + input 123 | 124 | # Outputs 125 | if output_format == 'A': 126 | output = f"The answer is {answer}." 127 | 128 | elif output_format == 'AL': 129 | output = f"The answer is {answer}. BECAUSE: {solution}" 130 | elif output_format == 'AE': 131 | output = f"The answer is {answer}. BECAUSE: {lecture}" 132 | elif output_format == 'ALE': 133 | output = f"The answer is {answer}. BECAUSE: {lecture} {solution}" 134 | elif output_format == 'AEL': 135 | output = f"The answer is {answer}. BECAUSE: {solution} {lecture}" 136 | 137 | elif output_format == 'LA': 138 | output = f"{lecture} The answer is {answer}." 139 | elif output_format == 'EA': 140 | output = f"{solution} The answer is {answer}." 141 | elif output_format == 'LEA': 142 | output = f"{lecture} {solution} The answer is {answer}." 143 | elif output_format == 'ELA': 144 | output = f"{solution} {lecture} The answer is {answer}." 145 | 146 | input = input.replace(" ", " ").strip() 147 | output = output.replace(" ", " ").strip() 148 | if output.endswith("BECAUSE:"): 149 | text = output.replace("BECAUSE:", "").strip() 150 | 151 | # print(input) 152 | return input, output 153 | 154 | 155 | def build_prompt(problems, test_qid, args): 156 | # test example 157 | question = get_question_text(problems[test_qid]) 158 | context = get_context_text(problems[test_qid], args.use_caption) 159 | choice = get_choice_text(problems[test_qid], args.options) 160 | answer = get_answer(problems[test_qid], args.options) 161 | lecture = get_lecture_text(problems[test_qid]) 162 | solution = get_solution_text(problems[test_qid]) 163 | 164 | test_example = create_training_example(args.prompt_format, 165 | question, 166 | context, 167 | choice, 168 | answer, 169 | lecture, 170 | solution) 171 | return test_example 172 | -------------------------------------------------------------------------------- /eval_sqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataclasses import dataclass 3 | import json 4 | import os 5 | import re 6 | 7 | import pandas as pd 8 | from PIL import Image 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | import torch 11 | from torchvision.transforms import transforms 12 | from tqdm import tqdm 13 | 14 | from adem.build import create_model 15 | from adem.tokenizer import Tokenizer 16 | from util.base_prompt import build_prompt 17 | 18 | 19 | @dataclass 20 | class PromptArgs: 21 | prompt_format = 'QCM-A' 22 | use_caption = True 23 | options = ["A", "B", "C", "D", "E"] 24 | 25 | 26 | def get_acc_with_contion(res_pd, key, values): 27 | if isinstance(values, list): 28 | total_pd = res_pd[res_pd[key].isin(values)] 29 | else: 30 | total_pd = res_pd[res_pd[key] == values] 31 | correct_pd = total_pd[total_pd['true_false'] == True] 32 | acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100) 33 | return acc 34 | 35 | 36 | def get_scores(result_file, data_file): 37 | # read result file 38 | results = json.load(open(result_file)) 39 | num = len(results) 40 | assert num == 4241 41 | 42 | sqa_data = json.load(open(data_file)) 43 | 44 | # construct pandas data 45 | sqa_pd = pd.DataFrame(sqa_data).T 46 | res_pd = sqa_pd[sqa_pd['split'] == 'test'] # test set 47 | 48 | # update data 49 | for index, row in res_pd.iterrows(): 50 | res_pd.loc[index, 'no_context'] = True if (not row['hint'] and not row['image']) else False 51 | res_pd.loc[index, 'has_text'] = True if row['hint'] else False 52 | res_pd.loc[index, 'has_image'] = True if row['image'] else False 53 | res_pd.loc[index, 'has_text_image'] = True if (row['hint'] and row['image']) else False 54 | 55 | label = row['answer'] 56 | pred = int(results[index]) 57 | res_pd.loc[index, 'pred'] = pred 58 | res_pd.loc[index, 'true_false'] = (label == pred) 59 | 60 | # accuracy scores 61 | acc_average = len(res_pd[res_pd['true_false'] == True]) / num * 100 62 | 63 | scores = { 64 | 'acc_natural': 65 | get_acc_with_contion(res_pd, 'subject', 'natural science'), 66 | 'acc_social': 67 | get_acc_with_contion(res_pd, 'subject', 'social science'), 68 | 'acc_language': 69 | get_acc_with_contion(res_pd, 'subject', 'language science'), 70 | 'acc_has_text': 71 | get_acc_with_contion(res_pd, 'has_text', True), 72 | 'acc_has_image': 73 | get_acc_with_contion(res_pd, 'has_image', True), 74 | 'acc_no_context': 75 | get_acc_with_contion(res_pd, 'no_context', True), 76 | 'acc_grade_1_6': 77 | get_acc_with_contion(res_pd, 'grade', ['grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6']), 78 | 'acc_grade_7_12': 79 | get_acc_with_contion(res_pd, 'grade', ['grade7', 'grade8', 'grade9', 'grade10', 'grade11', 'grade12']), 80 | 'acc_average': 81 | "{:.2f}".format(acc_average), 82 | } 83 | 84 | return scores 85 | 86 | 87 | def print_scores(scores): 88 | latex_output = "" 89 | for key, score in scores.items(): 90 | print(f"{key[4:]}: \t{score}") 91 | latex_output += f"& {score} " 92 | latex_output += "\\\\" 93 | print(latex_output) 94 | 95 | 96 | def get_pred_idx(prediction, choices, options): 97 | """ 98 | Get the index (e.g. 2) from the prediction (e.g. 'C') 99 | """ 100 | if prediction in options[:len(choices)]: 101 | return options.index(prediction) 102 | else: 103 | return -1 # return random.choice(range(len(choices))) 104 | 105 | 106 | @dataclass 107 | class ModelArgs: 108 | llama_model_path = './data/weights/' 109 | llm_model = '7B' 110 | max_seq_len = 512 111 | hidden_proj = 128 112 | cpu_load = False 113 | alpha = 0.1 114 | adapter_dim = 12 115 | gradient_checkpointing = False 116 | is_train = False 117 | data_root = './data/' 118 | clip = 'ViT-L/14' 119 | clip_root = './clip' 120 | down_sample_num = [256, 64] 121 | no_cls = False 122 | drop_ratio = 0.1 123 | 124 | 125 | def main(): 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument('--data_root', type=str, default='./data') 128 | parser.add_argument('--clip', type=str, default='ViT-L/14') 129 | parser.add_argument('--clip_root', type=str, default='./clip') 130 | parser.add_argument('--llm_model', type=str, default='7B') 131 | parser.add_argument('--adapter_path', type=str, default='./output_dir') 132 | parser.add_argument('--log_dir', type=str, default='./output_dir') 133 | 134 | parser.add_argument('--batch_size', type=int, default=4) 135 | parser.add_argument('--down_sample_num', type=int, nargs='+', default=[256, 64]) 136 | parser.add_argument('--alpha', type=float, default=0.1) 137 | parser.add_argument('--beta', type=float, default=0.01) 138 | parser.add_argument('--drop_ratio', type=float, default=0.1) 139 | parser.add_argument('--no_cls', action='store_true') 140 | 141 | args = parser.parse_args() 142 | log_dir = args.log_dir if args.log_dir is not None else './logs' 143 | os.makedirs(log_dir, exist_ok=True) 144 | llama_model_path = os.path.join(args.data_root, "weights/") 145 | 146 | model_args = ModelArgs() 147 | model_args.llama_model_path = llama_model_path 148 | model_args.llm_model = args.llm_model 149 | model_args.alpha = args.alpha 150 | model_args.beta = args.beta 151 | model_args.data_root = args.data_root 152 | model_args.clip = args.clip 153 | model_args.clip_root = args.clip_root 154 | model_args.down_sample_num = args.down_sample_num 155 | model_args.no_cls = args.no_cls 156 | model_args.drop_ratio = args.drop_ratio 157 | 158 | llama = create_model(model_args) 159 | adapter = torch.load(os.path.join(args.adapter_path, 'checkpoint-19.pth'))['model'] 160 | sd = {} 161 | for k in adapter: 162 | sd[k.replace('module.', '')] = adapter[k] 163 | _IncompatibleKeys = llama.load_state_dict(sd, False) 164 | print(_IncompatibleKeys) 165 | 166 | tokenizer = Tokenizer(model_path=os.path.join(llama_model_path, 'tokenizer.model')) 167 | 168 | split = 'test' 169 | print('split: ', split) 170 | problems = json.load(open(os.path.join(args.data_root, 'problems.json'))) 171 | pid_splits = json.load(open(os.path.join(args.data_root, 'pid_splits.json'))) 172 | captions = json.load(open(os.path.join(args.data_root, 'captions.json')))["captions"] 173 | image_path = os.path.join(args.data_root, 'images', split) 174 | qids = pid_splits['%s' % (split)] 175 | total_items = len(qids) 176 | for qid in problems: 177 | problems[qid]['caption'] = captions[qid] if qid in captions else "" 178 | 179 | print('total_items: ', total_items) 180 | 181 | image_transforms = transforms.Compose( 182 | [transforms.Resize((224, 224), interpolation=Image.BICUBIC), transforms.ToTensor(), 183 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 184 | 185 | prompt_args = PromptArgs() 186 | 187 | pattern = re.compile(r'([A-Z])') 188 | 189 | answers = [] 190 | preds = [] 191 | 192 | print_freq = 100 193 | with tqdm(total=total_items // args.batch_size + 1, ncols=0) as pbar: 194 | for i in range(total_items // args.batch_size + 1): 195 | if i % print_freq == 0: 196 | pbar.update(print_freq) 197 | 198 | batch_qids = qids[i * args.batch_size:(i + 1) * args.batch_size] 199 | if len(batch_qids) == 0: 200 | break 201 | indicators = [] 202 | prompts = [] 203 | images = [] 204 | for qid in batch_qids: 205 | prompt, _ = build_prompt(problems, qid, prompt_args) 206 | prompt += 'The answer is' 207 | answer = problems[qid]["answer"] 208 | if problems[qid]['image'] is not None: 209 | image = Image.open(os.path.join(image_path, qid, 'image.png')).convert('RGB') 210 | image = image_transforms(image) 211 | indicator = 1 212 | else: 213 | image = torch.Tensor(torch.zeros(3, 224, 224).float()) 214 | indicator = 0 215 | prompts.append(prompt) 216 | answers.append(answer) 217 | images.append(image) 218 | indicators.append(indicator) 219 | 220 | images = torch.stack(images) 221 | results = llama.generate( 222 | prompts, images=images, indicators=indicators, max_gen_len=1, tokenizer=tokenizer, temperature=0.0 223 | ) 224 | 225 | for result in results: 226 | pred = pattern.findall(result) 227 | 228 | if len(pred) >= 1: 229 | pred = pred[0] # 'A', 'B', ... 230 | else: 231 | # print(result) 232 | pred = "FAILED" 233 | preds.append(pred) 234 | 235 | # evaluations 236 | results = {} 237 | correct = 0 238 | for i, prediction in enumerate(preds): 239 | pred_idx = get_pred_idx(prediction, problems[qids[i]]["choices"], 240 | prompt_args.options) # 0, 1, ..., 4 241 | if pred_idx == answers[i]: 242 | correct += 1 243 | results[qids[i]] = pred_idx 244 | acc = correct / len(results) * 100 245 | print('overall accuracy: ', acc) 246 | 247 | with open(os.path.join(log_dir, 'preds.json'), 'w') as f: 248 | json.dump(results, f) 249 | 250 | scores = get_scores(os.path.join(log_dir, 'preds.json'), os.path.join(args.data_root, 'problems.json')) 251 | print(scores) 252 | with open(os.path.join(log_dir, 'eval_log.txt'), 'w') as f: 253 | f.write(str(scores)) 254 | 255 | 256 | if __name__ == '__main__': 257 | main() 258 | -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("Pytorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | with open(model_path, 'rb') as opened_file: 127 | try: 128 | # loading JIT archive 129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 130 | state_dict = None 131 | except RuntimeError: 132 | # loading saved state dict 133 | if jit: 134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 135 | jit = False 136 | state_dict = torch.load(opened_file, map_location="cpu") 137 | 138 | if not jit: 139 | model = build_model(state_dict or model.state_dict()).to(device) 140 | if str(device) == "cpu": 141 | model.float() 142 | return model, _transform(model.visual.input_resolution) 143 | 144 | # patch the device names 145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 147 | 148 | def patch_device(module): 149 | try: 150 | graphs = [module.graph] if hasattr(module, "graph") else [] 151 | except RuntimeError: 152 | graphs = [] 153 | 154 | if hasattr(module, "forward1"): 155 | graphs.append(module.forward1.graph) 156 | 157 | for graph in graphs: 158 | for node in graph.findAllNodes("prim::Constant"): 159 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 160 | node.copyAttributes(device_node) 161 | 162 | model.apply(patch_device) 163 | patch_device(model.encode_image) 164 | patch_device(model.encode_text) 165 | 166 | # patch dtype to float32 on CPU 167 | if str(device) == "cpu": 168 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 169 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 170 | float_node = float_input.node() 171 | 172 | def patch_float(module): 173 | try: 174 | graphs = [module.graph] if hasattr(module, "graph") else [] 175 | except RuntimeError: 176 | graphs = [] 177 | 178 | if hasattr(module, "forward1"): 179 | graphs.append(module.forward1.graph) 180 | 181 | for graph in graphs: 182 | for node in graph.findAllNodes("aten::to"): 183 | inputs = list(node.inputs()) 184 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 185 | if inputs[i].node()["value"] == 5: 186 | inputs[i].node().copyAttributes(float_node) 187 | 188 | model.apply(patch_float) 189 | patch_float(model.encode_image) 190 | patch_float(model.encode_text) 191 | 192 | model.float() 193 | 194 | return model, _transform(model.input_resolution.item()) 195 | 196 | 197 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 198 | """ 199 | Returns the tokenized representation of given input string(s) 200 | 201 | Parameters 202 | ---------- 203 | texts : Union[str, List[str]] 204 | An input string or a list of input strings to tokenize 205 | 206 | context_length : int 207 | The context length to use; all CLIP models use 77 as the context length 208 | 209 | truncate: bool 210 | Whether to truncate the text in case its encoding is longer than the context length 211 | 212 | Returns 213 | ------- 214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 215 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 216 | """ 217 | if isinstance(texts, str): 218 | texts = [texts] 219 | 220 | sot_token = _tokenizer.encoder["<|startoftext|>"] 221 | eot_token = _tokenizer.encoder["<|endoftext|>"] 222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 223 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 225 | else: 226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 227 | 228 | for i, tokens in enumerate(all_tokens): 229 | if len(tokens) > context_length: 230 | if truncate: 231 | tokens = tokens[:context_length] 232 | tokens[-1] = eot_token 233 | else: 234 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 235 | result[i, :len(tokens)] = torch.tensor(tokens) 236 | 237 | return result 238 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import os 5 | from pathlib import Path 6 | import random 7 | import time 8 | 9 | import numpy as np 10 | import timm.optim.optim_factory as optim_factory 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | from engine import train_one_epoch 16 | from adem.build import create_model 17 | from util.coco_karpathy_dataset import coco_karpathy_train 18 | from util.datasets import InstrcutDataSet, ScienceQADataSet 19 | import util.misc as misc 20 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 21 | 22 | 23 | def get_args_parser(): 24 | parser = argparse.ArgumentParser() 25 | 26 | # data 27 | parser.add_argument('--dataset', type=str, default='sqa') 28 | parser.add_argument('--data_root', type=str, default='./data') 29 | parser.add_argument('--clip', type=str, default='ViT-L/14') 30 | parser.add_argument('--clip_root', type=str, default='./clip') 31 | parser.add_argument('--llm_model', type=str, default='7B') 32 | parser.add_argument('--output_dir', type=str, default='./output_dir', 33 | help='path where to save, empty for no saving') 34 | parser.add_argument('--log_dir', type=str, default='./output_dir', help='path where to tensorboard log') 35 | parser.add_argument('--prompt_format', 36 | type=str, 37 | default='CQM-A', 38 | choices=[ 39 | 'CQM-A', 'CQM-LA', 'CQM-EA', 'CQM-LEA', 'CQM-ELA', 'CQM-AL', 'CQM-AE', 'CQM-ALE', 'QCM-A', 40 | 'QCM-LA', 'QCM-EA', 'QCM-LEA', 'QCM-ELA', 'QCM-AL', 'QCM-AE', 'QCM-ALE', 'QCML-A', 'QCME-A', 41 | 'QCMLE-A', 'QCLM-A', 'QCEM-A', 'QCLEM-A', 'QCML-AE', 'Q-A', 'QM-A', 'Q-AL', 'QM-EA' 42 | ], 43 | help='prompt format template') 44 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 45 | parser.add_argument('--caption_file', type=str, default='./data/captions.json') 46 | parser.add_argument('--use_caption', action='store_true', help='use image captions or not') 47 | parser.add_argument('--num_workers', default=10, type=int) 48 | parser.add_argument('--pin_mem', action='store_true', 49 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 50 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 51 | parser.set_defaults(pin_mem=True) 52 | 53 | # model 54 | parser.add_argument('--adapter_dim', type=int, default=8, help='the dims of adapter layer') 55 | parser.add_argument('--hidden_proj', type=int, default=128, 56 | help='the visual adapter dim') 57 | parser.add_argument('--max_seq_len', type=int, default=512, help='the maximum sequence length') 58 | parser.add_argument('--seed', default=42, type=int) 59 | parser.add_argument('--resume', default='', help='resume from checkpoint') 60 | 61 | parser.add_argument('--down_sample_num', type=int, nargs='+', default=[256, 64]) 62 | parser.add_argument('--alpha', type=float, default=0.1) 63 | parser.add_argument('--beta', type=float, default=0.01) 64 | parser.add_argument('--drop_ratio', type=float, default=0.1) 65 | parser.add_argument('--no_cls', action='store_true') 66 | 67 | # optim 68 | parser.add_argument('--epochs', default=20, type=int) 69 | parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') 70 | parser.add_argument('--batch_size', default=2, type=int, 71 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 72 | parser.add_argument('--accum_iter', default=2, type=int, 73 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 74 | parser.add_argument('--lr', type=float, default=None, 75 | help='learning rate (absolute lr)') 76 | parser.add_argument('--blr', type=float, default=1e-3, 77 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 78 | parser.add_argument('--min_lr', type=float, default=0., 79 | help='lower lr bound for cyclic schedulers that hit 0') 80 | parser.add_argument('--warmup_epochs', type=float, default=2, 81 | help='epochs to warmup LR') 82 | parser.add_argument('--weight_decay', type=float, default=0.05, 83 | help='weight decay (default: 0.05)') 84 | parser.add_argument('--clip_grad', type=float, default=None, 85 | help='clips gradient norm of an iterable of parameters') 86 | parser.add_argument('--device', default='cuda', help='device to use for training / testing') 87 | parser.add_argument('--cpu_load', action='store_true', help='load the model on cpu and avoid OOM on gpu') 88 | parser.add_argument('--gradient_checkpointing', action='store_true', 89 | help='saving memory costs via gradient_checkpointing') 90 | 91 | # distributed training parameters 92 | parser.add_argument('--world_size', default=1, type=int, 93 | help='number of distributed processes') 94 | parser.add_argument('--local_rank', default=-1, type=int) 95 | parser.add_argument('--dist_on_itp', action='store_true') 96 | parser.add_argument('--dist_url', default='env://', 97 | help='url used to set up distributed training') 98 | 99 | return parser 100 | 101 | 102 | def main(args): 103 | misc.init_distributed_mode(args) 104 | 105 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 106 | print("{}".format(args).replace(', ', ',\n')) 107 | 108 | device = torch.device(args.device) 109 | 110 | # fix the seed for reproducibility 111 | seed = args.seed + misc.get_rank() 112 | torch.manual_seed(seed) 113 | np.random.seed(seed) 114 | g = torch.Generator() 115 | g.manual_seed(seed) 116 | random.seed(seed) 117 | 118 | cudnn.benchmark = False 119 | cudnn.deterministic = True 120 | 121 | args.is_train = True 122 | 123 | llama_model_path = os.path.join(args.data_root, "weights/") 124 | if args.dataset == 'sqa': 125 | dataset_train = ScienceQADataSet(args, 'train', llama_model_path, args.max_seq_len) 126 | elif args.dataset == 'coco_caption': 127 | dataset_train = coco_karpathy_train(image_root=os.path.join(args.data_root, 'images'), 128 | ann_root=os.path.join(args.data_root, 'coco_caption'), 129 | model_root=llama_model_path, 130 | prompt='a picture of ') 131 | elif args.dataset == 'instruction': 132 | dataset_train = InstrcutDataSet(args, 'all', llama_model_path, args.max_seq_len) 133 | else: 134 | raise RuntimeError 135 | 136 | num_tasks = misc.get_world_size() 137 | global_rank = misc.get_rank() 138 | sampler_train = torch.utils.data.DistributedSampler( 139 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 140 | ) 141 | 142 | print("Sampler_train = %s" % str(sampler_train)) 143 | 144 | if global_rank == 0 and args.log_dir is not None: 145 | os.makedirs(args.log_dir, exist_ok=True) 146 | log_writer = SummaryWriter(log_dir=args.log_dir) 147 | else: 148 | log_writer = None 149 | 150 | data_loader_train = torch.utils.data.DataLoader( 151 | dataset_train, sampler=sampler_train, 152 | batch_size=args.batch_size, 153 | num_workers=args.num_workers, 154 | pin_memory=args.pin_mem, 155 | drop_last=True, 156 | generator=g, 157 | ) 158 | 159 | # define the model 160 | model = create_model(args) 161 | model.to(device) 162 | 163 | # for debug. print the data type. 164 | # for name, param in model.named_parameters(): 165 | # print(name, param.dtype) 166 | 167 | model_without_ddp = model 168 | 169 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 170 | 171 | if args.lr is None: # only base_lr is specified 172 | args.lr = args.blr * eff_batch_size / 256 173 | 174 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 175 | print("actual lr: %.2e" % args.lr) 176 | 177 | print("accumulate grad iterations: %d" % args.accum_iter) 178 | print("effective batch size: %d" % eff_batch_size) 179 | 180 | if args.distributed: 181 | print(args.gpu) 182 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 183 | model_without_ddp = model.module 184 | 185 | # following timm: set wd as 0 for bias and norm layers 186 | param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay) 187 | 188 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 189 | print(optimizer) 190 | 191 | # mixed precision scaler 192 | loss_scaler = NativeScaler() 193 | 194 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 195 | 196 | print(f"Start training for {args.epochs} epochs") 197 | start_time = time.time() 198 | for epoch in range(args.start_epoch, args.epochs): 199 | 200 | if args.distributed: 201 | data_loader_train.sampler.set_epoch(epoch) 202 | 203 | train_stats = train_one_epoch( 204 | model, data_loader_train, 205 | optimizer, device, epoch, loss_scaler, 206 | log_writer=log_writer, 207 | args=args 208 | ) 209 | 210 | if args.output_dir: 211 | misc.save_model( 212 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 213 | loss_scaler=loss_scaler, epoch=epoch) 214 | 215 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 216 | 'epoch': epoch, } 217 | 218 | if args.output_dir and misc.is_main_process(): 219 | if log_writer is not None: 220 | log_writer.flush() 221 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 222 | f.write(json.dumps(log_stats) + "\n") 223 | 224 | total_time = time.time() - start_time 225 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 226 | print('Training time {}'.format(total_time_str)) 227 | 228 | 229 | if __name__ == '__main__': 230 | 231 | args = get_args_parser() 232 | args = args.parse_args() 233 | if args.output_dir: 234 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 235 | main(args) 236 | -------------------------------------------------------------------------------- /util/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | scale = (n_bins - 1) / (high - low) 31 | offset = -np.float64(low) * scale 32 | table = np.arange(n_bins) * scale + offset 33 | table[table < 0] = 0 34 | table[table > n_bins - 1] = n_bins - 1 35 | table = table.clip(0, 255).astype(np.uint8) 36 | return table[ch] 37 | 38 | channels = [tune_channel(ch) for ch in cv2.split(img)] 39 | out = cv2.merge(channels) 40 | return out 41 | 42 | 43 | def equalize_func(img): 44 | ''' 45 | same output as PIL.ImageOps.equalize 46 | PIL's implementation is different from cv2.equalize 47 | ''' 48 | n_bins = 256 49 | 50 | def tune_channel(ch): 51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 52 | non_zero_hist = hist[hist != 0].reshape(-1) 53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 54 | if step == 0: return ch 55 | n = np.empty_like(hist) 56 | n[0] = step // 2 57 | n[1:] = hist[:-1] 58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 59 | return table[ch] 60 | 61 | channels = [tune_channel(ch) for ch in cv2.split(img)] 62 | out = cv2.merge(channels) 63 | return out 64 | 65 | 66 | def rotate_func(img, degree, fill=(0, 0, 0)): 67 | ''' 68 | like PIL, rotate by degree, not radians 69 | ''' 70 | H, W = img.shape[0], img.shape[1] 71 | center = W / 2, H / 2 72 | M = cv2.getRotationMatrix2D(center, degree, 1) 73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 74 | return out 75 | 76 | 77 | def solarize_func(img, thresh=128): 78 | ''' 79 | same output as PIL.ImageOps.posterize 80 | ''' 81 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 82 | table = table.clip(0, 255).astype(np.uint8) 83 | out = table[img] 84 | return out 85 | 86 | 87 | def color_func(img, factor): 88 | ''' 89 | same output as PIL.ImageEnhance.Color 90 | ''' 91 | ## implementation according to PIL definition, quite slow 92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 93 | # out = blend(degenerate, img, factor) 94 | # M = ( 95 | # np.eye(3) * factor 96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 97 | # )[np.newaxis, np.newaxis, :] 98 | M = ( 99 | np.float32([ 100 | [0.886, -0.114, -0.114], 101 | [-0.587, 0.413, -0.587], 102 | [-0.299, -0.299, 0.701]]) * factor 103 | + np.float32([[0.114], [0.587], [0.299]]) 104 | ) 105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 106 | return out 107 | 108 | 109 | def contrast_func(img, factor): 110 | """ 111 | same output as PIL.ImageEnhance.Contrast 112 | """ 113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 114 | table = np.array([( 115 | el - mean) * factor + mean 116 | for el in range(256) 117 | ]).clip(0, 255).astype(np.uint8) 118 | out = table[img] 119 | return out 120 | 121 | 122 | def brightness_func(img, factor): 123 | ''' 124 | same output as PIL.ImageEnhance.Contrast 125 | ''' 126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 127 | out = table[img] 128 | return out 129 | 130 | 131 | def sharpness_func(img, factor): 132 | ''' 133 | The differences the this result and PIL are all on the 4 boundaries, the center 134 | areas are same 135 | ''' 136 | kernel = np.ones((3, 3), dtype=np.float32) 137 | kernel[1][1] = 5 138 | kernel /= 13 139 | degenerate = cv2.filter2D(img, -1, kernel) 140 | if factor == 0.0: 141 | out = degenerate 142 | elif factor == 1.0: 143 | out = img 144 | else: 145 | out = img.astype(np.float32) 146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 148 | out = out.astype(np.uint8) 149 | return out 150 | 151 | 152 | def shear_x_func(img, factor, fill=(0, 0, 0)): 153 | H, W = img.shape[0], img.shape[1] 154 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 156 | return out 157 | 158 | 159 | def translate_x_func(img, offset, fill=(0, 0, 0)): 160 | ''' 161 | same output as PIL.Image.transform 162 | ''' 163 | H, W = img.shape[0], img.shape[1] 164 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 166 | return out 167 | 168 | 169 | def translate_y_func(img, offset, fill=(0, 0, 0)): 170 | ''' 171 | same output as PIL.Image.transform 172 | ''' 173 | H, W = img.shape[0], img.shape[1] 174 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 176 | return out 177 | 178 | 179 | def posterize_func(img, bits): 180 | ''' 181 | same output as PIL.ImageOps.posterize 182 | ''' 183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 184 | return out 185 | 186 | 187 | def shear_y_func(img, factor, fill=(0, 0, 0)): 188 | H, W = img.shape[0], img.shape[1] 189 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 191 | return out 192 | 193 | 194 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 195 | replace = np.array(replace, dtype=np.uint8) 196 | H, W = img.shape[0], img.shape[1] 197 | rh, rw = np.random.random(2) 198 | pad_size = pad_size // 2 199 | ch, cw = int(rh * H), int(rw * W) 200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 202 | out = img.copy() 203 | out[x1:x2, y1:y2, :] = replace 204 | return out 205 | 206 | 207 | ### level to args 208 | def enhance_level_to_args(MAX_LEVEL): 209 | def level_to_args(level): 210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 211 | 212 | return level_to_args 213 | 214 | 215 | def shear_level_to_args(MAX_LEVEL, replace_value): 216 | def level_to_args(level): 217 | level = (level / MAX_LEVEL) * 0.3 218 | if np.random.random() > 0.5: level = -level 219 | return (level, replace_value) 220 | 221 | return level_to_args 222 | 223 | 224 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 225 | def level_to_args(level): 226 | level = (level / MAX_LEVEL) * float(translate_const) 227 | if np.random.random() > 0.5: level = -level 228 | return (level, replace_value) 229 | 230 | return level_to_args 231 | 232 | 233 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 234 | def level_to_args(level): 235 | level = int((level / MAX_LEVEL) * cutout_const) 236 | return (level, replace_value) 237 | 238 | return level_to_args 239 | 240 | 241 | def solarize_level_to_args(MAX_LEVEL): 242 | def level_to_args(level): 243 | level = int((level / MAX_LEVEL) * 256) 244 | return (level,) 245 | 246 | return level_to_args 247 | 248 | 249 | def none_level_to_args(level): 250 | return () 251 | 252 | 253 | def posterize_level_to_args(MAX_LEVEL): 254 | def level_to_args(level): 255 | level = int((level / MAX_LEVEL) * 4) 256 | return (level,) 257 | 258 | return level_to_args 259 | 260 | 261 | def rotate_level_to_args(MAX_LEVEL, replace_value): 262 | def level_to_args(level): 263 | level = (level / MAX_LEVEL) * 30 264 | if np.random.random() < 0.5: 265 | level = -level 266 | return (level, replace_value) 267 | 268 | return level_to_args 269 | 270 | 271 | func_dict = { 272 | 'Identity': identity_func, 273 | 'AutoContrast': autocontrast_func, 274 | 'Equalize': equalize_func, 275 | 'Rotate': rotate_func, 276 | 'Solarize': solarize_func, 277 | 'Color': color_func, 278 | 'Contrast': contrast_func, 279 | 'Brightness': brightness_func, 280 | 'Sharpness': sharpness_func, 281 | 'ShearX': shear_x_func, 282 | 'TranslateX': translate_x_func, 283 | 'TranslateY': translate_y_func, 284 | 'Posterize': posterize_func, 285 | 'ShearY': shear_y_func, 286 | } 287 | 288 | translate_const = 10 289 | MAX_LEVEL = 10 290 | replace_value = (128, 128, 128) 291 | arg_dict = { 292 | 'Identity': none_level_to_args, 293 | 'AutoContrast': none_level_to_args, 294 | 'Equalize': none_level_to_args, 295 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 296 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 297 | 'Color': enhance_level_to_args(MAX_LEVEL), 298 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 299 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 300 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 301 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 302 | 'TranslateX': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'TranslateY': translate_level_to_args( 306 | translate_const, MAX_LEVEL, replace_value 307 | ), 308 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 309 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 310 | } 311 | 312 | 313 | class RandomAugment(object): 314 | 315 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 316 | self.N = N 317 | self.M = M 318 | self.isPIL = isPIL 319 | if augs: 320 | self.augs = augs 321 | else: 322 | self.augs = list(arg_dict.keys()) 323 | 324 | def get_random_ops(self): 325 | sampled_ops = np.random.choice(self.augs, self.N) 326 | return [(op, 0.5, self.M) for op in sampled_ops] 327 | 328 | def __call__(self, img): 329 | if self.isPIL: 330 | img = np.array(img) 331 | ops = self.get_random_ops() 332 | for name, prob, level in ops: 333 | if np.random.random() > prob: 334 | continue 335 | args = arg_dict[name](level) 336 | img = func_dict[name](img, *args) 337 | return img 338 | 339 | 340 | if __name__ == '__main__': 341 | a = RandomAugment() 342 | img = np.random.randn(32, 32, 3) 343 | a(img) 344 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | from collections import defaultdict, deque 3 | import datetime 4 | import os 5 | from pathlib import Path 6 | import time 7 | 8 | import torch 9 | from torch import inf 10 | import torch.distributed as dist 11 | 12 | 13 | class SmoothedValue(object): 14 | """Track a series of values and provide access to smoothed values over a 15 | window or the global series average. 16 | """ 17 | 18 | def __init__(self, window_size=20, fmt=None): 19 | if fmt is None: 20 | fmt = "{median:.4f} ({global_avg:.4f})" 21 | self.deque = deque(maxlen=window_size) 22 | self.total = 0.0 23 | self.count = 0 24 | self.fmt = fmt 25 | 26 | def update(self, value, n=1): 27 | self.deque.append(value) 28 | self.count += n 29 | self.total += value * n 30 | 31 | def synchronize_between_processes(self): 32 | """ 33 | Warning: does not synchronize the deque! 34 | """ 35 | if not is_dist_avail_and_initialized(): 36 | return 37 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 38 | 39 | dist.barrier() 40 | dist.all_reduce(t) 41 | t = t.tolist() 42 | self.count = int(t[0]) 43 | self.total = t[1] 44 | 45 | @property 46 | def median(self): 47 | d = torch.tensor(list(self.deque)) 48 | return d.median().item() 49 | 50 | @property 51 | def avg(self): 52 | d = torch.tensor(list(self.deque), dtype=torch.float32) 53 | return d.mean().item() 54 | 55 | @property 56 | def global_avg(self): 57 | return self.total / self.count 58 | 59 | @property 60 | def max(self): 61 | return max(self.deque) 62 | 63 | @property 64 | def value(self): 65 | return self.deque[-1] 66 | 67 | def __str__(self): 68 | return self.fmt.format( 69 | median=self.median, 70 | avg=self.avg, 71 | global_avg=self.global_avg, 72 | max=self.max, 73 | value=self.value) 74 | 75 | 76 | class MetricLogger(object): 77 | def __init__(self, delimiter="\t"): 78 | self.meters = defaultdict(SmoothedValue) 79 | self.delimiter = delimiter 80 | 81 | def update(self, **kwargs): 82 | for k, v in kwargs.items(): 83 | if v is None: 84 | continue 85 | if isinstance(v, torch.Tensor): 86 | v = v.item() 87 | assert isinstance(v, (float, int)) 88 | self.meters[k].update(v) 89 | 90 | def __getattr__(self, attr): 91 | if attr in self.meters: 92 | return self.meters[attr] 93 | if attr in self.__dict__: 94 | return self.__dict__[attr] 95 | raise AttributeError("'{}' object has no attribute '{}'".format( 96 | type(self).__name__, attr)) 97 | 98 | def __str__(self): 99 | loss_str = [] 100 | for name, meter in self.meters.items(): 101 | loss_str.append( 102 | "{}: {}".format(name, str(meter)) 103 | ) 104 | return self.delimiter.join(loss_str) 105 | 106 | def synchronize_between_processes(self): 107 | for meter in self.meters.values(): 108 | meter.synchronize_between_processes() 109 | 110 | def add_meter(self, name, meter): 111 | self.meters[name] = meter 112 | 113 | def log_every(self, iterable, print_freq, header=None): 114 | i = 0 115 | if not header: 116 | header = '' 117 | start_time = time.time() 118 | end = time.time() 119 | iter_time = SmoothedValue(fmt='{avg:.4f}') 120 | data_time = SmoothedValue(fmt='{avg:.4f}') 121 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 122 | log_msg = [ 123 | header, 124 | '[{0' + space_fmt + '}/{1}]', 125 | 'eta: {eta}', 126 | '{meters}', 127 | 'time: {time}', 128 | 'data: {data}' 129 | ] 130 | if torch.cuda.is_available(): 131 | log_msg.append('max mem: {memory:.0f}') 132 | log_msg = self.delimiter.join(log_msg) 133 | MB = 1024.0 * 1024.0 134 | for obj in iterable: 135 | data_time.update(time.time() - end) 136 | yield obj 137 | iter_time.update(time.time() - end) 138 | if i % print_freq == 0 or i == len(iterable) - 1: 139 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 140 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 141 | if torch.cuda.is_available(): 142 | print(log_msg.format( 143 | i, len(iterable), eta=eta_string, 144 | meters=str(self), 145 | time=str(iter_time), data=str(data_time), 146 | memory=torch.cuda.max_memory_allocated() / MB)) 147 | else: 148 | print(log_msg.format( 149 | i, len(iterable), eta=eta_string, 150 | meters=str(self), 151 | time=str(iter_time), data=str(data_time))) 152 | i += 1 153 | end = time.time() 154 | total_time = time.time() - start_time 155 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 156 | print('{} Total time: {} ({:.4f} s / it)'.format( 157 | header, total_time_str, total_time / len(iterable))) 158 | 159 | 160 | def setup_for_distributed(is_master): 161 | """ 162 | This function disables printing when not in master process 163 | """ 164 | builtin_print = builtins.print 165 | 166 | def print(*args, **kwargs): 167 | force = kwargs.pop('force', False) 168 | force = force or (get_world_size() > 8) 169 | if is_master or force: 170 | now = datetime.datetime.now().time() 171 | builtin_print('[{}] '.format(now), end='') # print with time stamp 172 | builtin_print(*args, **kwargs) 173 | 174 | builtins.print = print 175 | 176 | 177 | def is_dist_avail_and_initialized(): 178 | if not dist.is_available(): 179 | return False 180 | if not dist.is_initialized(): 181 | return False 182 | return True 183 | 184 | 185 | def get_world_size(): 186 | if not is_dist_avail_and_initialized(): 187 | return 1 188 | return dist.get_world_size() 189 | 190 | 191 | def get_rank(): 192 | if not is_dist_avail_and_initialized(): 193 | return 0 194 | return dist.get_rank() 195 | 196 | 197 | def is_main_process(): 198 | return get_rank() == 0 199 | 200 | 201 | def save_on_master(*args, **kwargs): 202 | if is_main_process(): 203 | torch.save(*args, **kwargs) 204 | 205 | 206 | def init_distributed_mode(args): 207 | if args.dist_on_itp: 208 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 209 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 210 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 211 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 212 | os.environ['LOCAL_RANK'] = str(args.gpu) 213 | os.environ['RANK'] = str(args.rank) 214 | os.environ['WORLD_SIZE'] = str(args.world_size) 215 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 216 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 217 | args.rank = int(os.environ["RANK"]) 218 | args.world_size = int(os.environ['WORLD_SIZE']) 219 | args.gpu = int(os.environ['LOCAL_RANK']) 220 | elif 'SLURM_PROCID' in os.environ: 221 | args.rank = int(os.environ['SLURM_PROCID']) 222 | args.gpu = args.rank % torch.cuda.device_count() 223 | else: 224 | print('Not using distributed mode') 225 | setup_for_distributed(is_master=True) # hack 226 | args.distributed = False 227 | return 228 | 229 | args.distributed = True 230 | 231 | torch.cuda.set_device(args.gpu) 232 | args.dist_backend = 'nccl' 233 | print('| distributed init (rank {}): {}, gpu {}'.format( 234 | args.rank, args.dist_url, args.gpu), flush=True) 235 | 236 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 237 | world_size=args.world_size, rank=args.rank) 238 | torch.distributed.barrier() 239 | setup_for_distributed(args.rank == 0) 240 | 241 | 242 | class NativeScalerWithGradNormCount: 243 | state_dict_key = "amp_scaler" 244 | 245 | def __init__(self): 246 | self._scaler = torch.cuda.amp.GradScaler() 247 | 248 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 249 | if not torch.isnan(loss): 250 | self._scaler.scale(loss).backward(create_graph=create_graph) 251 | if update_grad: 252 | if clip_grad is not None: 253 | assert parameters is not None 254 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 255 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 256 | else: 257 | self._scaler.unscale_(optimizer) 258 | norm = get_grad_norm_(parameters) 259 | self._scaler.step(optimizer) 260 | self._scaler.update() 261 | else: 262 | norm = None 263 | return norm 264 | 265 | def state_dict(self): 266 | return self._scaler.state_dict() 267 | 268 | def load_state_dict(self, state_dict): 269 | self._scaler.load_state_dict(state_dict) 270 | 271 | 272 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 273 | if isinstance(parameters, torch.Tensor): 274 | parameters = [parameters] 275 | parameters = [p for p in parameters if p.grad is not None] 276 | norm_type = float(norm_type) 277 | if len(parameters) == 0: 278 | return torch.tensor(0.) 279 | device = parameters[0].grad.device 280 | if norm_type == inf: 281 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 282 | else: 283 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), 284 | norm_type) 285 | return total_norm 286 | 287 | 288 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 289 | output_dir = Path(args.output_dir) 290 | epoch_name = str(epoch) 291 | model_without_ddp.eval() 292 | trainable = {} 293 | for n, p in model.named_parameters(): 294 | if 'adapter' in n: 295 | trainable[n] = p.data 296 | # if loss_scaler is not None: 297 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 298 | for checkpoint_path in checkpoint_paths: 299 | to_save = { 300 | 'model': trainable, 301 | 'optimizer': optimizer.state_dict(), 302 | 'epoch': epoch, 303 | 'scaler': loss_scaler.state_dict() if loss_scaler is not None else None, 304 | 'args': args, 305 | } 306 | save_on_master(to_save, checkpoint_path) 307 | 308 | 309 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 310 | if args.resume: 311 | if args.resume.startswith('https'): 312 | checkpoint = torch.hub.load_state_dict_from_url( 313 | args.resume, map_location='cpu', check_hash=True) 314 | else: 315 | checkpoint = torch.load(args.resume, map_location='cpu') 316 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 317 | print("Resume checkpoint %s" % args.resume) 318 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 319 | optimizer.load_state_dict(checkpoint['optimizer']) 320 | args.start_epoch = checkpoint['epoch'] + 1 321 | if 'scaler' in checkpoint: 322 | loss_scaler.load_state_dict(checkpoint['scaler']) 323 | print("With optim & sched!") 324 | 325 | 326 | def all_reduce_mean(x): 327 | world_size = get_world_size() 328 | if world_size > 1: 329 | x_reduce = torch.tensor(x).cuda() 330 | dist.all_reduce(x_reduce) 331 | x_reduce /= world_size 332 | return x_reduce.item() 333 | else: 334 | return x 335 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu1(self.bn1(self.conv1(x))) 46 | out = self.relu2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x[:1], key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0, 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | return x.squeeze(0) 92 | 93 | 94 | class ModifiedResNet(nn.Module): 95 | """ 96 | A ResNet class that is similar to torchvision's but contains the following changes: 97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 99 | - The final pooling layer is a QKV attention instead of an average pool 100 | """ 101 | 102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 103 | super().__init__() 104 | self.output_dim = output_dim 105 | self.input_resolution = input_resolution 106 | 107 | # the 3-layer stem 108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 109 | self.bn1 = nn.BatchNorm2d(width // 2) 110 | self.relu1 = nn.ReLU(inplace=True) 111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 112 | self.bn2 = nn.BatchNorm2d(width // 2) 113 | self.relu2 = nn.ReLU(inplace=True) 114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 115 | self.bn3 = nn.BatchNorm2d(width) 116 | self.relu3 = nn.ReLU(inplace=True) 117 | self.avgpool = nn.AvgPool2d(2) 118 | 119 | # residual layers 120 | self._inplanes = width # this is a *mutable* variable used during construction 121 | self.layer1 = self._make_layer(width, layers[0]) 122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 125 | 126 | embed_dim = width * 32 # the ResNet feature dimension 127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 128 | 129 | def _make_layer(self, planes, blocks, stride=1): 130 | layers = [Bottleneck(self._inplanes, planes, stride)] 131 | 132 | self._inplanes = planes * Bottleneck.expansion 133 | for _ in range(1, blocks): 134 | layers.append(Bottleneck(self._inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | def stem(x): 140 | x = self.relu1(self.bn1(self.conv1(x))) 141 | x = self.relu2(self.bn2(self.conv2(x))) 142 | x = self.relu3(self.bn3(self.conv3(x))) 143 | x = self.avgpool(x) 144 | return x 145 | 146 | x = x.type(self.conv1.weight.dtype) 147 | x = stem(x) 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | x = self.attnpool(x) 153 | 154 | return x 155 | 156 | from torch.cuda.amp import autocast 157 | class LayerNorm(nn.LayerNorm): 158 | """Subclass torch's LayerNorm to handle fp16.""" 159 | 160 | def forward(self, x: torch.Tensor): 161 | orig_type = x.dtype 162 | ret = super().forward(x.type(torch.float32)) 163 | return ret.type(orig_type) 164 | 165 | 166 | class QuickGELU(nn.Module): 167 | def forward(self, x: torch.Tensor): 168 | return x * torch.sigmoid(1.702 * x) 169 | 170 | 171 | class ResidualAttentionBlock(nn.Module): 172 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 173 | super().__init__() 174 | 175 | self.attn = nn.MultiheadAttention(d_model, n_head) 176 | self.ln_1 = LayerNorm(d_model) 177 | self.mlp = nn.Sequential(OrderedDict([ 178 | ("c_fc", nn.Linear(d_model, d_model * 4)), 179 | ("gelu", QuickGELU()), 180 | ("c_proj", nn.Linear(d_model * 4, d_model)) 181 | ])) 182 | self.ln_2 = LayerNorm(d_model) 183 | self.attn_mask = attn_mask 184 | 185 | def attention(self, x: torch.Tensor): 186 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 187 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 188 | 189 | def forward(self, x: torch.Tensor): 190 | x = x + self.attention(self.ln_1(x)) 191 | x = x + self.mlp(self.ln_2(x)) 192 | return x 193 | 194 | 195 | class Transformer(nn.Module): 196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 197 | super().__init__() 198 | self.width = width 199 | self.layers = layers 200 | self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 201 | 202 | def forward(self, x: torch.Tensor): 203 | return self.resblocks(x) 204 | 205 | 206 | class VisionTransformer(nn.Module): 207 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 208 | super().__init__() 209 | self.input_resolution = input_resolution 210 | self.output_dim = output_dim 211 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 212 | 213 | scale = width ** -0.5 214 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 215 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 216 | self.ln_pre = LayerNorm(width) 217 | 218 | self.transformer = Transformer(width, layers, heads) 219 | 220 | self.ln_post = LayerNorm(width) 221 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 222 | 223 | 224 | def forward(self, x: torch.Tensor): 225 | x = self.conv1(x) # shape = [*, width, grid, grid] 226 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 227 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 228 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 229 | x = x + self.positional_embedding.to(x.dtype) 230 | 231 | x = self.ln_pre(x) 232 | 233 | x = x.permute(1, 0, 2) # NLD -> LND 234 | for i, layer in enumerate(self.transformer.resblocks[:-1]): 235 | x = layer(x) 236 | 237 | x = x.permute(1, 0, 2) # LND -> NLD 238 | x = self.ln_post(x) 239 | 240 | return x 241 | 242 | 243 | class CLIP(nn.Module): 244 | def __init__(self, 245 | embed_dim: int, 246 | # vision 247 | image_resolution: int, 248 | vision_layers: Union[Tuple[int, int, int, int], int], 249 | vision_width: int, 250 | vision_patch_size: int, 251 | # text 252 | context_length: int, 253 | vocab_size: int, 254 | transformer_width: int, 255 | transformer_heads: int, 256 | transformer_layers: int 257 | ): 258 | super().__init__() 259 | 260 | self.context_length = context_length 261 | 262 | if isinstance(vision_layers, (tuple, list)): 263 | vision_heads = vision_width * 32 // 64 264 | self.visual = ModifiedResNet( 265 | layers=vision_layers, 266 | output_dim=embed_dim, 267 | heads=vision_heads, 268 | input_resolution=image_resolution, 269 | width=vision_width 270 | ) 271 | else: 272 | vision_heads = vision_width // 64 273 | self.visual = VisionTransformer( 274 | input_resolution=image_resolution, 275 | patch_size=vision_patch_size, 276 | width=vision_width, 277 | layers=vision_layers, 278 | heads=vision_heads, 279 | output_dim=embed_dim 280 | ) 281 | 282 | self.transformer = Transformer( 283 | width=transformer_width, 284 | layers=transformer_layers, 285 | heads=transformer_heads, 286 | attn_mask=self.build_attention_mask() 287 | ) 288 | 289 | self.vocab_size = vocab_size 290 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 291 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 292 | self.ln_final = LayerNorm(transformer_width) 293 | 294 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 295 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 296 | 297 | self.initialize_parameters() 298 | 299 | def initialize_parameters(self): 300 | nn.init.normal_(self.token_embedding.weight, std=0.02) 301 | nn.init.normal_(self.positional_embedding, std=0.01) 302 | 303 | if isinstance(self.visual, ModifiedResNet): 304 | if self.visual.attnpool is not None: 305 | std = self.visual.attnpool.c_proj.in_features ** -0.5 306 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 307 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 308 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 309 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 310 | 311 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 312 | for name, param in resnet_block.named_parameters(): 313 | if name.endswith("bn3.weight"): 314 | nn.init.zeros_(param) 315 | 316 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 317 | attn_std = self.transformer.width ** -0.5 318 | fc_std = (2 * self.transformer.width) ** -0.5 319 | for block in self.transformer.resblocks: 320 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 321 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 322 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 323 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 324 | 325 | if self.text_projection is not None: 326 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 327 | 328 | def build_attention_mask(self): 329 | # lazily create causal attention mask, with full attention between the vision tokens 330 | # pytorch uses additive attention mask; fill with -inf 331 | mask = torch.empty(self.context_length, self.context_length) 332 | mask.fill_(float("-inf")) 333 | mask.triu_(1) # zero out the lower diagonal 334 | return mask 335 | 336 | @property 337 | def dtype(self): 338 | return self.visual.conv1.weight.dtype 339 | 340 | def encode_image(self, image): 341 | return self.visual(image.type(self.dtype)) 342 | 343 | def encode_text(self, text): 344 | print('encode text') 345 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 346 | 347 | x = x + self.positional_embedding.type(self.dtype) 348 | x = x.permute(1, 0, 2) # NLD -> LND 349 | x = self.transformer(x) 350 | x = x.permute(1, 0, 2) # LND -> NLD 351 | # x = self.ln_final(x).type(self.dtype) 352 | 353 | # x.shape = [batch_size, n_ctx, transformer.width] 354 | # take features from the eot embedding (eot_token is the highest number in each sequence) 355 | # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 356 | 357 | return x[:,1:] 358 | 359 | def forward(self, image, text): 360 | image_features = self.encode_image(image) 361 | text_features = self.encode_text(text) 362 | 363 | # normalized features 364 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 365 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 366 | 367 | # cosine similarity as logits 368 | logit_scale = self.logit_scale.exp() 369 | logits_per_image = logit_scale * image_features @ text_features.t() 370 | logits_per_text = logits_per_image.t() 371 | 372 | # shape = [global_batch_size, global_batch_size] 373 | return logits_per_image, logits_per_text 374 | 375 | 376 | def convert_weights(model: nn.Module): 377 | """Convert applicable model parameters to fp16""" 378 | 379 | def _convert_weights_to_fp16(l): 380 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 381 | l.weight.data = l.weight.data.half() 382 | if l.bias is not None: 383 | l.bias.data = l.bias.data.half() 384 | 385 | if isinstance(l, nn.MultiheadAttention): 386 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 387 | tensor = getattr(l, attr) 388 | if tensor is not None: 389 | tensor.data = tensor.data.half() 390 | 391 | for name in ["text_projection", "proj"]: 392 | if hasattr(l, name): 393 | attr = getattr(l, name) 394 | if attr is not None: 395 | attr.data = attr.data.half() 396 | 397 | model.apply(_convert_weights_to_fp16) 398 | 399 | 400 | def build_model(state_dict: dict): 401 | vit = "visual.proj" in state_dict 402 | 403 | if vit: 404 | vision_width = state_dict["visual.conv1.weight"].shape[0] 405 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 406 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 407 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 408 | image_resolution = vision_patch_size * grid_size 409 | else: 410 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 411 | vision_layers = tuple(counts) 412 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 413 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 414 | vision_patch_size = None 415 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 416 | image_resolution = output_width * 32 417 | 418 | embed_dim = state_dict["text_projection"].shape[1] 419 | context_length = state_dict["positional_embedding"].shape[0] 420 | vocab_size = state_dict["token_embedding.weight"].shape[0] 421 | transformer_width = state_dict["ln_final.weight"].shape[0] 422 | transformer_heads = transformer_width // 64 423 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 424 | 425 | model = CLIP( 426 | embed_dim, 427 | image_resolution, vision_layers, vision_width, vision_patch_size, 428 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 429 | ) 430 | 431 | for key in ["input_resolution", "context_length", "vocab_size"]: 432 | if key in state_dict: 433 | del state_dict[key] 434 | model.float() 435 | convert_weights(model) 436 | model.load_state_dict(state_dict,strict=False) 437 | return model.eval() 438 | -------------------------------------------------------------------------------- /adem/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple 3 | 4 | import math 5 | import torch 6 | from torch import nn 7 | from torch.cuda.amp import autocast 8 | from torch.nn import Embedding, Linear 9 | import torch.nn.functional as F 10 | 11 | import clip 12 | 13 | 14 | @dataclass 15 | class ModelArgs: 16 | dim: int = 512 17 | n_layers: int = 8 18 | n_heads: int = 8 19 | vocab_size: int = -1 # defined later by tokenizer 20 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 21 | norm_eps: float = 1e-5 22 | hidden_proj: int = 128 23 | max_batch_size: int = 32 24 | max_seq_len: int = 2048 25 | is_train: bool = True 26 | clip: str = 'ViT-L/14' 27 | clip_root: str = './clip' 28 | beta: float = 0.01 29 | down_sample_num = [256, 64] 30 | with_cls = True 31 | 32 | 33 | class Avg2dSampler(nn.Module): 34 | def __init__( 35 | self, 36 | grid_size, 37 | ): 38 | super().__init__() 39 | self.grid_size = grid_size 40 | 41 | def forward(self, x): 42 | B, N, C = x.shape 43 | ori_grid_size = int(N ** 0.5) 44 | assert ori_grid_size ** 2 == N 45 | 46 | x = x.reshape(B, ori_grid_size, ori_grid_size, C).permute(0, 3, 1, 2) 47 | 48 | out = [] 49 | for gs in self.grid_size: 50 | if not (gs == 16): 51 | x_ = F.adaptive_avg_pool2d(x, gs) 52 | else: 53 | x_ = x 54 | x_ = x_.reshape(B, C, -1).permute(0, 2, 1) 55 | out.append(x_) 56 | 57 | x = torch.cat(out, dim=1) 58 | 59 | return x 60 | 61 | 62 | class RMSNorm(torch.nn.Module): 63 | def __init__(self, dim: int, eps: float = 1e-6): 64 | super().__init__() 65 | self.eps = eps 66 | self.weight = nn.Parameter(torch.ones(dim)) 67 | 68 | def _norm(self, x): 69 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 70 | 71 | def forward(self, x): 72 | output = self._norm(x.float()).type_as(x) 73 | return output * self.weight 74 | 75 | 76 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 77 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 78 | t = torch.arange(end, device=freqs.device) # type: ignore 79 | freqs = torch.outer(t, freqs).float() # type: ignore 80 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 81 | return freqs_cis 82 | 83 | 84 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 85 | ndim = x.ndim 86 | assert 0 <= 1 < ndim 87 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 88 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 89 | return freqs_cis.view(*shape) 90 | 91 | 92 | def apply_rotary_emb( 93 | xq: torch.Tensor, 94 | xk: torch.Tensor, 95 | freqs_cis: torch.Tensor, 96 | ) -> Tuple[torch.Tensor, torch.Tensor]: 97 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 98 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 99 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 100 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 101 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 102 | return xq_out.type_as(xq), xk_out.type_as(xk) 103 | 104 | 105 | class Attention(nn.Module): 106 | def __init__(self, args: ModelArgs): 107 | super().__init__() 108 | 109 | self.n_local_heads = args.n_heads 110 | self.head_dim = args.dim // args.n_heads 111 | 112 | # modified bias for reparameterizing 113 | self.wq = Linear( 114 | args.dim, 115 | args.n_heads * self.head_dim, 116 | bias=False 117 | ) 118 | self.wk = Linear( 119 | args.dim, 120 | args.n_heads * self.head_dim, 121 | bias=False 122 | ) 123 | self.wv = Linear( 124 | args.dim, 125 | args.n_heads * self.head_dim, 126 | bias=False 127 | ) 128 | self.wo = Linear( 129 | args.n_heads * self.head_dim, 130 | args.dim, 131 | bias=False 132 | ) 133 | if not args.is_train: 134 | self.cache_k = torch.zeros( 135 | (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) 136 | ).cuda() 137 | self.cache_v = torch.zeros( 138 | (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) 139 | ).cuda() 140 | else: 141 | self.cache_k = None 142 | self.cache_v = None 143 | 144 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): 145 | bsz, seqlen, _ = x.shape 146 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 147 | 148 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 149 | xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) 150 | xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) 151 | 152 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 153 | 154 | if self.cache_k is not None: 155 | self.cache_k = self.cache_k.to(xq) 156 | self.cache_v = self.cache_v.to(xq) 157 | 158 | self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk 159 | self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv 160 | 161 | keys = self.cache_k[:bsz, : start_pos + seqlen] 162 | values = self.cache_v[:bsz, : start_pos + seqlen] 163 | 164 | else: 165 | keys = xk 166 | values = xv 167 | xq = xq.transpose(1, 2) 168 | keys = keys.transpose(1, 2) 169 | values = values.transpose(1, 2) 170 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) 171 | if mask is not None: 172 | scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) 173 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 174 | output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) 175 | output = output.transpose( 176 | 1, 2 177 | ).contiguous().view(bsz, seqlen, -1) 178 | 179 | return self.wo(output) 180 | 181 | 182 | class FeedForward(nn.Module): 183 | def __init__( 184 | self, 185 | dim: int, 186 | hidden_dim: int, 187 | multiple_of: int, 188 | ): 189 | super().__init__() 190 | hidden_dim = int(2 * hidden_dim / 3) 191 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 192 | 193 | self.w1 = Linear( 194 | dim, hidden_dim, bias=False 195 | ) 196 | self.w2 = Linear( 197 | hidden_dim, dim, bias=False 198 | ) 199 | self.w3 = Linear( 200 | dim, hidden_dim, bias=False 201 | ) 202 | 203 | def forward(self, x): 204 | x = self.w2(F.silu(self.w1(x), inplace=False) * self.w3(x)) 205 | return x 206 | 207 | 208 | class TransformerBlock(nn.Module): 209 | def __init__(self, layer_id: int, args: ModelArgs): 210 | super().__init__() 211 | self.n_heads = args.n_heads 212 | self.dim = args.dim 213 | self.head_dim = args.dim // args.n_heads 214 | self.attention = Attention(args) 215 | self.feed_forward = FeedForward( 216 | dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of 217 | ) 218 | self.layer_id = layer_id 219 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 220 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 221 | self.drop_path = nn.Identity() 222 | 223 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): 224 | h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) 225 | out = h + self.feed_forward(self.ffn_norm(h)) 226 | return out 227 | 228 | 229 | class Projector(nn.Module): 230 | """ Pytorch Implemention of RepAdapter for 1d tensor""" 231 | 232 | def __init__( 233 | self, 234 | in_features=768, 235 | hidden_dim=128, 236 | out_features=4096 237 | ): 238 | super().__init__() 239 | self.fc1 = nn.Linear(in_features, hidden_dim) 240 | self.fc2 = nn.Linear(hidden_dim, out_features) 241 | nn.init.xavier_uniform_(self.fc1.weight) 242 | nn.init.zeros_(self.fc1.bias) 243 | nn.init.xavier_uniform_(self.fc2.weight) 244 | nn.init.zeros_(self.fc2.bias) 245 | 246 | def forward(self, x): 247 | with autocast(): 248 | x = self.fc2(F.silu(self.fc1(x))) 249 | return x 250 | 251 | 252 | class Transformer(nn.Module): 253 | def __init__(self, params: ModelArgs): 254 | super().__init__() 255 | self.params = params 256 | self.vocab_size = params.vocab_size 257 | self.n_layers = params.n_layers 258 | self.tok_embeddings = Embedding( 259 | params.vocab_size, params.dim 260 | ) 261 | 262 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0) 263 | 264 | self.layers = torch.nn.ModuleList() 265 | for layer_id in range(params.n_layers): 266 | self.layers.append(TransformerBlock(layer_id, params)) 267 | 268 | self.norm = RMSNorm(params.dim, eps=params.norm_eps) 269 | self.output = Linear( 270 | params.dim, params.vocab_size, bias=False 271 | ) 272 | 273 | self.freqs_cis = precompute_freqs_cis( 274 | self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 275 | ) 276 | 277 | grid_size = [] 278 | for d in params.down_sample_num: 279 | gs = int(d ** 0.5) 280 | assert gs ** 2 == d 281 | grid_size.append(gs) 282 | self.down_sampler = Avg2dSampler(grid_size) 283 | 284 | self.backbone = clip.load(params.clip, download_root=params.clip_root)[0] 285 | 286 | self.adapter_emb1 = nn.Parameter(torch.randn(1, sum(params.down_sample_num), params.dim) * 0.02) 287 | self.adapter_emb2 = nn.Parameter(torch.zeros(1, sum(params.down_sample_num), params.dim)) 288 | self.adapter_proj = Projector(1024, params.hidden_proj, params.dim).float() 289 | if params.with_cls: 290 | self.adapter_proj_cls = Projector(1024, params.hidden_proj, params.dim).float() 291 | 292 | def insert_image_embeds(self, examples, labels, image_embeds, eos_idxes, prefix_img, prefix_nonimg, indicators): 293 | _bsz, seqlen, _ = examples.shape 294 | new_examples = [] 295 | new_labels = [] 296 | new_eos_idxes = [] 297 | for i, (example, label) in enumerate(zip(examples, labels)): 298 | if indicators[i] > 0.: 299 | new_example = torch.cat([example[:1], prefix_img, image_embeds[i], example[1:]], 0) 300 | new_label = torch.cat([label[:1], 301 | torch.zeros(prefix_img.shape[0] + image_embeds.shape[1]).to( 302 | examples.device).type_as(labels), 303 | label[1:]]) 304 | eos_idx = eos_idxes[i] + prefix_img.shape[0] + image_embeds.shape[1] 305 | new_example = new_example[:seqlen] 306 | new_label = new_label[:seqlen] 307 | if eos_idx > seqlen - 1: 308 | eos_idx = -1 309 | else: 310 | new_example = torch.cat([example[:1], prefix_nonimg, example[1:]], 0) 311 | new_label = torch.cat([label[:1], 312 | torch.zeros(prefix_nonimg.shape[0]).to(examples.device).type_as(labels), 313 | label[1:]]) 314 | eos_idx = eos_idxes[i] + prefix_nonimg.shape[0] 315 | new_example = new_example[:seqlen] 316 | new_label = new_label[:seqlen] 317 | if eos_idx > seqlen - 1: 318 | eos_idx = -1 319 | new_examples.append(new_example.unsqueeze(0)) 320 | new_labels.append(new_label.unsqueeze(0)) 321 | new_eos_idxes.append(eos_idx) 322 | new_examples = torch.cat(new_examples, 0) 323 | new_labels = torch.cat(new_labels, 0) 324 | return new_examples, new_labels, new_eos_idxes 325 | 326 | def forward(self, examples, labels, images=None, example_mask=None, prefix_img=None, prefix_nonimg=None, 327 | indicators=None): 328 | 329 | eos_idxes = (example_mask.sum(1).long() - 1).tolist() 330 | 331 | feats = self.backbone.encode_image(images).half() 332 | image_embeds = self.adapter_proj(self.down_sampler(feats[:, 1:, :])) 333 | 334 | if isinstance(indicators, list): 335 | indicators = torch.Tensor(indicators).to(images.device).long() 336 | 337 | image_embeds *= self.params.beta * indicators.half().view(-1, 1, 1) 338 | vis_weight = [image_embeds, self.adapter_emb1, self.adapter_emb2] 339 | 340 | _bsz, seqlen = examples.shape 341 | 342 | examples = self.tok_embeddings(examples) 343 | if self.params.with_cls: 344 | cls_tokes = self.adapter_proj_cls(feats[:, [0], :]) 345 | prefix_img = self.tok_embeddings(prefix_img.unsqueeze(0)).squeeze(0) 346 | prefix_nonimg = self.tok_embeddings(prefix_nonimg.unsqueeze(0)).squeeze(0) 347 | 348 | h, labels, eos_idxes = self.insert_image_embeds(examples, labels, cls_tokes, eos_idxes, prefix_img, 349 | prefix_nonimg, indicators) 350 | else: 351 | h = examples 352 | 353 | seqlen = (labels > 0).float().nonzero(as_tuple=False)[:, 1].max() + 1 354 | h = h[:, :seqlen] 355 | labels = labels[:, :seqlen] 356 | seqlen = h.size(1) 357 | freqs_cis = self.freqs_cis.to(h.device) 358 | freqs_cis = freqs_cis[:seqlen] 359 | mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device) 360 | mask = torch.triu(mask, diagonal=0 + 1).type_as(h) 361 | 362 | start_pos = 0 363 | for i, layer in enumerate(self.layers): 364 | h = layer(h, start_pos, freqs_cis, mask, vis_weight) 365 | h = self.norm(h) 366 | output = self.output(h) 367 | output = output[:, :-1, :].reshape(-1, self.vocab_size) 368 | labels = labels[:, 1:].flatten() 369 | 370 | c_loss = self.criterion(output, labels) 371 | return c_loss 372 | 373 | @torch.inference_mode() 374 | def generate( 375 | self, 376 | prompts, 377 | images, 378 | indicators, 379 | max_gen_len, 380 | tokenizer=None, 381 | temperature: float = 0, 382 | top_p: float = 0.95, 383 | ): 384 | bsz = len(prompts) 385 | params = self.params 386 | self.eval() 387 | 388 | prefix_img_token = tokenizer.encode("Image: ", bos=True, eos=False) 389 | non_prefix_img_token = tokenizer.encode("Image: N/A", bos=True, eos=False) 390 | 391 | images = images.cuda() 392 | self.backbone.cuda() 393 | 394 | feats = self.backbone.encode_image(images).half() 395 | image_embeds = self.adapter_proj(self.down_sampler(feats[:, 1:, :])) 396 | 397 | indicators = torch.Tensor(indicators).cuda().long() 398 | 399 | image_embeds *= self.params.beta * indicators.half().view(-1, 1, 1) 400 | vis_weight = [image_embeds, self.adapter_emb1, self.adapter_emb2] 401 | 402 | prompt_tokens = [] 403 | for i, x in enumerate(prompts): 404 | if self.params.with_cls: 405 | cls_tokes = self.adapter_proj_cls(feats[:, [0], :]) 406 | if indicators[i] == 1: 407 | token_idx = prefix_img_token + [0] * cls_tokes.size(1) + tokenizer.encode(x, bos=False, eos=False) 408 | else: 409 | token_idx = non_prefix_img_token + tokenizer.encode(x, bos=False, eos=False) 410 | else: 411 | token_idx = tokenizer.encode(x, bos=True, eos=False) 412 | prompt_tokens.append(token_idx) 413 | 414 | max_prompt_size = max([len(t) for t in prompt_tokens]) 415 | total_len = min(512, max_gen_len + max_prompt_size) 416 | 417 | tokens = torch.full((bsz, total_len), 0).cuda().long() 418 | mask = torch.full((bsz, 1, total_len, total_len), float("-inf"), device=tokens.device) 419 | mask = torch.triu(mask, diagonal=1) 420 | 421 | for k, t in enumerate(prompt_tokens): 422 | t = t[:total_len - max_gen_len] 423 | tokens[k, -len(t) - max_gen_len:- max_gen_len] = torch.tensor(t).long() 424 | mask[k, :, -len(t) - max_gen_len:, :-len(t) - max_gen_len] = float("-inf") 425 | 426 | token_embeds = self.tok_embeddings(tokens) 427 | if self.params.with_cls: 428 | for i in range(len(token_embeds)): 429 | if indicators[i] == 1: 430 | pos = len(prefix_img_token) # with bos 431 | if pos - len(prompt_tokens[i]) - max_gen_len < -511: 432 | continue 433 | token_embeds[i, pos - len(prompt_tokens[i]) - max_gen_len:pos - len( 434 | prompt_tokens[i]) - max_gen_len + cls_tokes.size(1)] = cls_tokes[i] 435 | 436 | mask = mask.type_as(token_embeds) 437 | 438 | start_pos = min(max_prompt_size, 512 - max_gen_len) 439 | stop_flag = torch.ones([bsz], dtype=torch.long).cuda() 440 | 441 | prev_pos = 0 442 | for cur_pos in range(start_pos, total_len): 443 | h = token_embeds[:, prev_pos:cur_pos] 444 | 445 | mask_input = mask[:, :, prev_pos:cur_pos, :cur_pos] 446 | 447 | with autocast(): 448 | _bsz, seqlen, _ = h.shape 449 | self.freqs_cis = self.freqs_cis.to(h.device) 450 | freqs_cis = self.freqs_cis[prev_pos: prev_pos + seqlen] 451 | 452 | if mask_input is None and seqlen > 1: 453 | mask_input = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device) 454 | mask_input = torch.triu(mask_input, diagonal=prev_pos + 1).type_as(h) 455 | 456 | for i, layer in enumerate(self.layers): 457 | h = layer(h, prev_pos, freqs_cis, mask_input, vis_weight) 458 | 459 | h = self.norm(h) 460 | output = self.output(h[:, -1, :]) # only compute last logits 461 | logits = output.float() 462 | 463 | if temperature > 0: 464 | probs = torch.softmax(logits / temperature, dim=-1) 465 | next_token = sample_top_p(probs, top_p) 466 | else: 467 | next_token = torch.argmax(logits, dim=-1) 468 | next_token = next_token.reshape(-1) 469 | stop_flag *= (next_token != tokenizer.eos_id).long() 470 | if stop_flag.sum() == 0: 471 | tokens[:, cur_pos] = next_token 472 | break 473 | 474 | next_token_embeds = self.tok_embeddings(next_token) 475 | 476 | token_embeds[:, cur_pos] = next_token_embeds 477 | tokens[:, cur_pos] = next_token 478 | 479 | prev_pos = cur_pos 480 | 481 | decoded = [] 482 | for i, t in enumerate(tokens.tolist()): 483 | try: 484 | t = t[- max_gen_len:] 485 | t = t[: t.index(tokenizer.eos_id)] 486 | except ValueError: 487 | pass 488 | decoded.append(tokenizer.decode(t)) 489 | 490 | return decoded 491 | 492 | 493 | def sample_top_p(probs, p): 494 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 495 | probs_sum = torch.cumsum(probs_sort, dim=-1) 496 | mask = probs_sum - probs_sort > p 497 | probs_sort[mask] = 0.0 498 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 499 | next_token = torch.multinomial(probs_sort, num_samples=1) 500 | next_token = torch.gather(probs_idx, -1, next_token) 501 | return next_token 502 | --------------------------------------------------------------------------------