├── .DS_Store ├── .idea ├── .gitignore ├── deployment.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── ofa.iml ├── sshConfigs.xml ├── vcs.xml └── webServers.xml ├── __pycache__ ├── config.cpython-37.pyc └── data_load.cpython-37.pyc ├── config.py ├── data └── .DS_Store ├── data_load.py ├── eval.py ├── evaluation ├── .DS_Store ├── __init__.py ├── __pycache__ │ └── __init__.cpython-37.pyc ├── bleu │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── bleu.cpython-37.pyc │ │ └── bleu_scorer.cpython-37.pyc │ ├── bleu.py │ └── bleu_scorer.py ├── cider │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── cider.cpython-37.pyc │ │ └── cider_scorer.cpython-37.pyc │ ├── cider.py │ └── cider_scorer.py ├── meteor │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── meteor.cpython-37.pyc │ └── meteor.py ├── rouge │ ├── __init__.py │ └── rouge.py ├── stanford-corenlp-3.4.1.jar └── tokenizer.py ├── knowcap.png ├── models ├── .DS_Store ├── BLIP │ ├── __init__.py │ ├── blip.py │ ├── blip_itm.py │ ├── blip_nlvr.py │ ├── blip_pretrain.py │ ├── blip_retrieval.py │ ├── blip_vqa.py │ ├── caption_coco.yaml │ ├── caption_coco_teacher.yaml │ ├── med.py │ ├── med_config.json │ ├── nlvr_encoder.py │ └── vit.py ├── GIT │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── git_model.cpython-37.pyc │ ├── git.py │ └── git_model.py ├── OFA │ ├── .DS_Store │ ├── __init__.py │ ├── __pycache__ │ │ ├── ofa.cpython-37.pyc │ │ └── ofa_model.cpython-37.pyc │ ├── ofa.py │ └── ofa_model.py └── Transformer │ ├── __init__.py │ └── transformer.py ├── readme.md ├── requirements.txt ├── test.py ├── test_knowcap.py ├── train_multitask.py └── utils ├── .DS_Store ├── __pycache__ ├── beamsearch.cpython-37.pyc ├── eval.cpython-37.pyc ├── import_models.cpython-37.pyc └── vocab.cpython-37.pyc ├── beamsearch.py ├── cc12m.py ├── convert_ofa.py ├── eval.py ├── import_models.py ├── knowcap.py ├── log.py ├── loss.py ├── optimizer_tools.py ├── prepro_data.py ├── prepro_ref_pycoco.py ├── prepro_rwcap.py └── vocab.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/.DS_Store -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # 基于编辑器的 HTTP 客户端请求 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 23 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/ofa.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/sshConfigs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 20 | 21 | -------------------------------------------------------------------------------- /__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/data_load.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/__pycache__/data_load.cpython-37.pyc -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | 5 | parser.add_argument('--seed', type=int, default=222) 6 | parser.add_argument('--id', type=str, default='test') 7 | parser.add_argument('--mode', type=str, default='train') 8 | parser.add_argument('--model', type=str, default=None) 9 | parser.add_argument('--test', type=str, default='test') 10 | parser.add_argument('--epochs', type=int, default=100) 11 | parser.add_argument('--ft_epoch', type=int, default=3) 12 | parser.add_argument('--ckpts_id', type=str, default=None) 13 | parser.add_argument('--step', type=int, default=0) 14 | 15 | parser.add_argument('--local_rank', type=int, default=-1) 16 | parser.add_argument('--nproc_per_node', type=int, default=-1) 17 | 18 | parser.add_argument('--trained_ckpts', default='/home/chengkz/checkpoints/ofa/log/ofa_m1.0_t16_k1.0_222/model/model_300.pt') 19 | parser.add_argument('--ofa_ckpts', default='/home/data_ti4_c/chengkz/scripts/OFA-large') 20 | parser.add_argument('--ofa_ckpts_distill', default='/home/chengkz/checkpoints/ofa/OFA-large-caption-XEfinetuned') 21 | parser.add_argument('--git', default="microsoft/git-large") 22 | parser.add_argument('--git_distill', default="microsoft/git-large-coco") 23 | parser.add_argument('--config_blip', default='./models/BLIP/caption_coco.yaml') 24 | parser.add_argument('--config_blip_t', default='./models/BLIP/caption_coco_teacher.yaml') 25 | parser.add_argument('--data_dir', default='./data') 26 | parser.add_argument('--vocab', default='./data/vocab.pkl') 27 | parser.add_argument('--train', default='./data/train.json') 28 | parser.add_argument('--train_mix', default='./data/train_mix_cc12m_keyword_large.json') 29 | parser.add_argument('--knowcap240', default='/home/chengkz/checkpoints/KnowCap_240') 30 | parser.add_argument('--data_mode', default='mix') 31 | parser.add_argument('--samples_dir', default='./examples/example_images') 32 | parser.add_argument('--samples_out', default=None) 33 | parser.add_argument('--pretrain_model', default=None) 34 | 35 | parser.add_argument('--save_loss_freq', type=int, default=20) 36 | parser.add_argument('--save_model_freq', type=int, default=100) 37 | parser.add_argument('--log_dir', default='/home/chengkz/checkpoints/ofa/log/{}') 38 | 39 | parser.add_argument('--batch_size', type=int, default=60) 40 | parser.add_argument('--val_batch_size', type=int, default=25) 41 | parser.add_argument('--num_workers', type=int, default=1) 42 | parser.add_argument('--fixed_len', type=int, default=20) 43 | parser.add_argument('--lr_enc', type=float, default=2e-5) 44 | parser.add_argument('--learning_rate', type=float, default=7e-6) 45 | parser.add_argument('--grad_clip', type=float, default=0.1) 46 | parser.add_argument('--beam_num', type=int, default=5) 47 | parser.add_argument('--gen_num', type=int, default=5) 48 | parser.add_argument('--beam_alpha', type=float, default=1.0) 49 | parser.add_argument('--length_penalty', type=float, default=1.0) 50 | parser.add_argument('--multitask_weight', type=float, default=0.5) 51 | parser.add_argument('--knowdistill_weight', type=float, default=1.0) 52 | parser.add_argument('--data_ratio', type=float, default=1.0) 53 | parser.add_argument('--label_smoothing', type=float, default=0.0) 54 | parser.add_argument('--KD_temperature', type=float, default=8.0) 55 | 56 | parser.add_argument('--image_dim', type=int, default=2048) 57 | parser.add_argument('--embed_dim', type=int, default=512) 58 | parser.add_argument('--hidden_dim', type=int, default=512) 59 | parser.add_argument('--att_dim', type=int, default=1024) 60 | 61 | parser.add_argument('--method', type=str, default=None) 62 | 63 | config = parser.parse_args() 64 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/data/.DS_Store -------------------------------------------------------------------------------- /data_load.py: -------------------------------------------------------------------------------- 1 | # 用于训练的dataloader 2 | # 不同的模型进行不同的预处理 3 | 4 | import torch 5 | import numpy as np 6 | import json 7 | import os 8 | import pickle 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | from PIL import Image 12 | from torchvision.transforms import InterpolationMode 13 | import torchvision.transforms as transforms 14 | from utils.vocab import Vocabulary 15 | 16 | from transformers.models.ofa.tokenization_ofa import OFATokenizer 17 | from transformers import AutoProcessor 18 | from transformers import BertTokenizer 19 | 20 | from models.BLIP.blip import init_tokenizer 21 | 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | 24 | 25 | class IC_data(Dataset): 26 | """作为val和test时的dataset""" 27 | def __init__(self, config, dir, mode): 28 | super(IC_data, self).__init__() 29 | self.config = config 30 | self.data = json.load(open(dir, 'r')) 31 | self.model = config.model 32 | # 根据不同的model选择不同的transforms 33 | self.patch_resize_transform = self.get_transforms(self.model) 34 | if self.model == 'OFA': 35 | self.ofa_ckpt = config.ofa_ckpts 36 | self.tokenizer = OFATokenizer.from_pretrained(self.ofa_ckpt) 37 | elif self.model == 'BLIP': 38 | self.tokenizer = init_tokenizer() 39 | elif self.model == 'GIT': 40 | self.processor = AutoProcessor.from_pretrained(config.git_distill, local_files_only=True) 41 | self.tokenizer = self.processor.tokenizer 42 | 43 | self.mode = mode 44 | 45 | def get_transforms(self, model): 46 | if model == 'OFA': 47 | self.resolution = 480 48 | self.mean, self.std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 49 | patch_resize_transform = transforms.Compose([ 50 | lambda image: image.convert("RGB"), 51 | transforms.Resize((self.resolution, self.resolution), interpolation=Image.BICUBIC), 52 | transforms.ToTensor(), 53 | transforms.Normalize(mean=self.mean, std=self.std)]) 54 | elif model == 'BLIP': 55 | self.resolution = 384 56 | self.mean, self.std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 57 | patch_resize_transform = transforms.Compose([ 58 | lambda image: image.convert("RGB"), 59 | transforms.Resize((self.resolution, self.resolution), interpolation=Image.BICUBIC), 60 | transforms.ToTensor(), 61 | transforms.Normalize(mean=self.mean, std=self.std)]) 62 | elif model == 'GIT': 63 | patch_resize_transform = lambda img: self.processor(images=img, return_tensors='pt').pixel_values[0] 64 | return patch_resize_transform 65 | 66 | def __getitem__(self, item): 67 | if self.mode == 'train': 68 | """""" 69 | else: 70 | image_path = self.data[item]['filename'] 71 | img = Image.open(image_path) 72 | patch_img = self.patch_resize_transform(img) 73 | image_id = self.data[item]['image_id'] 74 | return image_id, patch_img 75 | 76 | def collate_fn_train(self, batch_data): 77 | """""" 78 | 79 | def collate_fn_eval(self, batch_data): 80 | image_id, image = zip(*batch_data) 81 | image = torch.stack(image, dim=0) 82 | image_feature = {'patch_image': image} 83 | return image_id, image_feature 84 | 85 | def __len__(self): 86 | return len(self.data) 87 | 88 | 89 | class RWConcept_data(Dataset): 90 | 91 | def __init__(self, config, dir, mode): 92 | super(RWConcept_data, self).__init__() 93 | self.config = config 94 | self.data = json.load(open(dir, 'r')) 95 | self.model = config.model 96 | # 根据不同的model选择不同的transforms 97 | self.patch_resize_transform = self.get_transforms(config.model) 98 | if self.model == 'OFA': 99 | self.ofa_ckpt = config.ofa_ckpts 100 | self.tokenizer = OFATokenizer.from_pretrained(self.ofa_ckpt) 101 | elif self.model == 'BLIP': 102 | self.tokenizer = init_tokenizer() 103 | elif self.model == 'GIT': 104 | self.processor = AutoProcessor.from_pretrained(config.git_distill, local_files_only=True) 105 | self.tokenizer = self.processor.tokenizer 106 | self.mode = mode 107 | 108 | def get_transforms(self, model): 109 | if model == 'OFA': 110 | self.resolution = 480 111 | self.mean, self.std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 112 | patch_resize_transform = transforms.Compose([ 113 | lambda image: image.convert("RGB"), 114 | transforms.Resize((self.resolution, self.resolution), interpolation=Image.BICUBIC), 115 | transforms.ToTensor(), 116 | transforms.Normalize(mean=self.mean, std=self.std)]) 117 | elif model == 'BLIP': 118 | self.resolution = 384 119 | self.mean, self.std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 120 | patch_resize_transform = transforms.Compose([ 121 | lambda image: image.convert("RGB"), 122 | transforms.Resize((self.resolution, self.resolution), interpolation=Image.BICUBIC), 123 | transforms.ToTensor(), 124 | transforms.Normalize(mean=self.mean, std=self.std)]) 125 | elif model == 'GIT': 126 | patch_resize_transform = lambda img: self.processor(images=img, return_tensors='pt').pixel_values[0] 127 | return patch_resize_transform 128 | 129 | def __getitem__(self, item): 130 | if self.mode == 'train': 131 | caption = self.data[item]['caption'] 132 | # 不同的模型加载不同的前缀 133 | if self.model == 'OFA': 134 | caption = ' '+caption 135 | elif self.model == "BLIP": 136 | caption = ' a picture of ' + caption 137 | elif self.model == 'GIT': 138 | caption = ' '+caption 139 | 140 | # 不同的模型tokenize的方式不同 141 | if self.model == 'OFA': 142 | cap_id = self.tokenizer([caption], return_tensors="pt").input_ids[0] 143 | elif self.model == 'BLIP': 144 | text = self.tokenizer(caption, padding='longest', truncation=True, max_length=20, return_tensors="pt") 145 | cap_id = text.input_ids[0] 146 | cap_id[0] = self.tokenizer.bos_token_id 147 | elif self.model == 'GIT': 148 | cap_id = self.tokenizer([caption], return_tensors="pt").input_ids[0] 149 | 150 | cap_len = cap_id.shape[0] 151 | if cap_len < self.config.fixed_len: 152 | if self.model == 'OFA': 153 | cap_id = torch.cat([cap_id, torch.ones([self.config.fixed_len-cap_len])], dim=0) 154 | elif self.model == 'BLIP': 155 | cap_id = torch.cat([cap_id, torch.zeros([self.config.fixed_len-cap_len])], dim=0) 156 | elif self.model == 'GIT': 157 | cap_id = torch.cat([cap_id, torch.zeros([self.config.fixed_len-cap_len])], dim=0) 158 | att_mask = torch.cat([torch.ones([cap_len]), torch.zeros([self.config.fixed_len-cap_len])], dim=0) 159 | else: 160 | cap_id = cap_id[:self.config.fixed_len] 161 | cap_len = self.config.fixed_len 162 | att_mask = torch.ones(cap_id.shape) 163 | 164 | image_path = self.data[item]['filename'] 165 | img = Image.open(image_path) 166 | patch_img = self.patch_resize_transform(img) 167 | label = 0 if self.data[item]['data'] == 'coco' else 1 168 | return patch_img, cap_id, att_mask, cap_len, label, self.data[item] 169 | else: 170 | image_path = self.data[item]['filename'] 171 | img = Image.open(image_path) 172 | patch_img = self.patch_resize_transform(img) 173 | image_id = self.data[item]['image_id'] 174 | return image_id, patch_img 175 | 176 | def collate_fn_train(self, batch_data): 177 | image, cap_id, att_mask, cap_len, label, data_item = zip(*batch_data) 178 | image = torch.stack(image, dim=0) 179 | image_feature = {'patch_image': image} 180 | cap_id = torch.stack(cap_id, dim=0) 181 | att_mask = torch.stack(att_mask, dim=0) 182 | cap_len = torch.Tensor(cap_len).int() 183 | label = torch.Tensor(label).int() 184 | return image_feature, cap_id.long(), att_mask.long(), cap_len, label, list(data_item) 185 | 186 | def collate_fn_eval(self, batch_data): 187 | image_id, image = zip(*batch_data) 188 | image = torch.stack(image, dim=0) 189 | image_feature = {'patch_image': image} 190 | return image_id, image_feature 191 | 192 | def __len__(self): 193 | return len(self.data) 194 | 195 | 196 | class RWConcept_data_EWC(Dataset): 197 | 198 | def __init__(self, config, dir, mode='train'): 199 | super(RWConcept_data_EWC, self).__init__() 200 | self.config = config 201 | self.data = json.load(open(dir, 'r')) 202 | self.mean, self.std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 203 | self.resolution = 480 204 | self.patch_resize_transform = transforms.Compose([ 205 | lambda image: image.convert("RGB"), 206 | transforms.Resize((self.resolution, self.resolution), interpolation=Image.BICUBIC), 207 | transforms.ToTensor(), 208 | transforms.Normalize(mean=self.mean, std=self.std)]) 209 | self.ofa_ckpt = config.ofa_ckpts 210 | self.tokenizer = OFATokenizer.from_pretrained(self.ofa_ckpt) 211 | self.mode = mode 212 | 213 | def __getitem__(self, item): 214 | if self.mode == 'train': 215 | caption = ' '+self.data[item]['caption'] 216 | cap_id = self.tokenizer([caption], return_tensors="pt").input_ids[0] 217 | keyword = ' '+self.data[item]['keyword'] 218 | keyword_id = self.tokenizer([keyword], return_tensors="pt").input_ids[0] 219 | keyword_id = keyword_id[keyword_id > 2] 220 | cap_len = cap_id.shape[0] 221 | if cap_len < self. config.fixed_len: 222 | cap_id = torch.cat([cap_id, torch.ones([self.config.fixed_len-cap_len])], dim=0) 223 | att_mask = torch.cat([torch.ones([cap_len]), torch.zeros([self.config.fixed_len-cap_len])], dim=0) 224 | else: 225 | cap_id = cap_id[:self.config.fixed_len] 226 | cap_len = self.config.fixed_len 227 | att_mask = torch.ones(cap_id.shape) 228 | if_keyword = torch.Tensor([True if (item in keyword_id) else False for item in cap_id]) 229 | image_path = self.data[item]['filename'] 230 | img = Image.open(image_path) 231 | patch_img = self.patch_resize_transform(img) 232 | label = 0 if self.data[item]['data'] == 'coco' else 1 233 | return patch_img, cap_id, att_mask, cap_len, label, if_keyword 234 | 235 | def collate_fn_train(self, batch_data): 236 | image, cap_id, att_mask, cap_len, label, if_keyword = zip(*batch_data) 237 | image = torch.stack(image, dim=0) 238 | image_feature = {'patch_image': image} 239 | cap_id = torch.stack(cap_id, dim=0) 240 | if_keyword = torch.stack(if_keyword, dim=0) 241 | att_mask = torch.stack(att_mask, dim=0) 242 | cap_len = torch.Tensor(cap_len).int() 243 | label = torch.Tensor(label).int() 244 | return image_feature, cap_id.long(), att_mask.long(), cap_len, label, if_keyword 245 | 246 | def __len__(self): 247 | return len(self.data) 248 | 249 | 250 | def data_load_rwc_EWC(config, dir, mode): 251 | dataset = RWConcept_data_EWC(config, dir, mode) 252 | data_loader = DataLoader(dataset=dataset, 253 | batch_size=config.batch_size, 254 | shuffle=True, 255 | collate_fn=dataset.collate_fn_train, 256 | num_workers=config.num_workers, 257 | pin_memory=True, 258 | ) 259 | return data_loader 260 | 261 | 262 | def data_load(config, dir, mode): 263 | if mode == 'train': 264 | print("warning: the train_loader is not exist") 265 | dataset = IC_data(config, dir, mode) 266 | data_loader = DataLoader(dataset=dataset, 267 | batch_size=config.batch_size if mode == 'train' else config.val_batch_size, 268 | shuffle=True if mode == 'train' else False, 269 | collate_fn=dataset.collate_fn_train if mode == 'train' else dataset.collate_fn_eval, 270 | num_workers=config.num_workers, 271 | pin_memory=True, 272 | ) 273 | return data_loader 274 | 275 | def data_load_rwc(config, dir, mode): 276 | dataset = RWConcept_data(config, dir, mode) 277 | data_loader = DataLoader(dataset=dataset, 278 | batch_size=config.batch_size if mode == 'train' else config.val_batch_size, 279 | shuffle=False, 280 | collate_fn=dataset.collate_fn_train if mode == 'train' else dataset.collate_fn_eval, 281 | num_workers=config.num_workers, 282 | pin_memory=True, 283 | ) 284 | return data_loader 285 | 286 | 287 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | from .tokenizer.ptbtokenizer import PTBTokenizer 3 | from .bleu.bleu import Bleu 4 | from .meteor.meteor import Meteor 5 | from .rouge.rouge import Rouge 6 | from .cider.cider import Cider 7 | from .spice.spice import Spice 8 | 9 | 10 | class COCOEvalCap: 11 | def __init__(self, coco, cocoRes): 12 | self.evalImgs = [] 13 | self.eval = {} 14 | self.imgToEval = {} 15 | self.coco = coco 16 | self.cocoRes = cocoRes 17 | # self.params = {'image_id': coco.getImgIds()} 18 | 19 | def evaluate(self): 20 | imgIds = self.params['image_id'] 21 | # imgIds = self.coco.getImgIds() 22 | gts = {} 23 | res = {} 24 | for imgId in imgIds: 25 | gts[imgId] = self.coco.imgToAnns[imgId] 26 | res[imgId] = self.cocoRes.imgToAnns[imgId] 27 | 28 | # ================================================= 29 | # Set up scorers 30 | # ================================================= 31 | print('tokenization...') 32 | tokenizer = PTBTokenizer() 33 | gts = tokenizer.tokenize(gts) 34 | res = tokenizer.tokenize(res) 35 | 36 | # ================================================= 37 | # Set up scorers 38 | # ================================================= 39 | print('setting up scorers...') 40 | scorers = [ 41 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 42 | (Meteor(),"METEOR"), 43 | (Rouge(), "ROUGE_L"), 44 | (Cider(), "CIDEr"), 45 | (Spice(), "SPICE") 46 | ] 47 | 48 | # ================================================= 49 | # Compute scores 50 | # ================================================= 51 | for scorer, method in scorers: 52 | print('computing %s score...'%(scorer.method())) 53 | score, scores = scorer.compute_score(gts, res) 54 | if type(method) == list: 55 | for sc, scs, m in zip(score, scores, method): 56 | self.setEval(sc, m) 57 | self.setImgToEvalImgs(scs, gts.keys(), m) 58 | print("%s: %0.3f"%(m, sc)) 59 | else: 60 | self.setEval(score, method) 61 | self.setImgToEvalImgs(scores, gts.keys(), method) 62 | print("%s: %0.3f"%(method, score)) 63 | self.setEvalImgs() 64 | 65 | def evaluate_diy(self, gts, res): 66 | """ 67 | imgIds = self.params['image_id'] 68 | # imgIds = self.coco.getImgIds() 69 | gts = {} 70 | res = {} 71 | for imgId in imgIds: 72 | gts[imgId] = self.coco.imgToAnns[imgId] 73 | res[imgId] = self.cocoRes.imgToAnns[imgId] 74 | """ 75 | # ================================================= 76 | # Set up scorers 77 | # ================================================= 78 | print('tokenization...') 79 | tokenizer = PTBTokenizer() 80 | gts = tokenizer.tokenize(gts) 81 | res = tokenizer.tokenize(res) 82 | 83 | # ================================================= 84 | # Set up scorers 85 | # ================================================= 86 | print('setting up scorers...') 87 | my_score = {} 88 | scorers = [ 89 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 90 | (Meteor(),"METEOR"), 91 | (Rouge(), "ROUGE_L"), 92 | (Cider(), "CIDEr"), 93 | # (Spice(), "SPICE") 94 | ] 95 | 96 | # ================================================= 97 | # Compute scores 98 | # ================================================= 99 | for scorer, method in scorers: 100 | print('computing %s score...'%(scorer.method())) 101 | score, scores = scorer.compute_score(gts, res) 102 | if type(method) == list: 103 | for sc, scs, m in zip(score, scores, method): 104 | self.setEval(sc, m) 105 | my_score[m] = sc 106 | self.setImgToEvalImgs(scs, gts.keys(), m) 107 | print("%s: %0.3f"%(m, sc)) 108 | else: 109 | self.setEval(score, method) 110 | my_score[method] = score 111 | self.setImgToEvalImgs(scores, gts.keys(), method) 112 | print("%s: %0.3f"%(method, score)) 113 | self.setEvalImgs() 114 | return my_score 115 | 116 | def evaluate_diy_every(self, gts, res): 117 | """ 118 | imgIds = self.params['image_id'] 119 | # imgIds = self.coco.getImgIds() 120 | gts = {} 121 | res = {} 122 | for imgId in imgIds: 123 | gts[imgId] = self.coco.imgToAnns[imgId] 124 | res[imgId] = self.cocoRes.imgToAnns[imgId] 125 | """ 126 | # ================================================= 127 | # Set up scorers 128 | # ================================================= 129 | print('tokenization...') 130 | tokenizer = PTBTokenizer() 131 | gts = tokenizer.tokenize(gts) 132 | res = tokenizer.tokenize(res) 133 | 134 | # ================================================= 135 | # Set up scorers 136 | # ================================================= 137 | print('setting up scorers...') 138 | my_score = {} 139 | my_score_every = {} 140 | scorers = [ 141 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 142 | (Meteor(),"METEOR"), 143 | (Rouge(), "ROUGE_L"), 144 | (Cider(), "CIDEr"), 145 | # (Spice(), "SPICE") 146 | ] 147 | 148 | # ================================================= 149 | # Compute scores 150 | # ================================================= 151 | for scorer, method in scorers: 152 | print('computing %s score...'%(scorer.method())) 153 | score, scores = scorer.compute_score(gts, res) 154 | if type(method) == list: 155 | for sc, scs, m in zip(score, scores, method): 156 | self.setEval(sc, m) 157 | my_score[m] = sc 158 | my_score_every[m] = scs 159 | self.setImgToEvalImgs(scs, gts.keys(), m) 160 | print("%s: %0.3f"%(m, sc)) 161 | else: 162 | self.setEval(score, method) 163 | my_score[method] = score 164 | my_score_every[method] = scores 165 | self.setImgToEvalImgs(scores, gts.keys(), method) 166 | print("%s: %0.3f"%(method, score)) 167 | self.setEvalImgs() 168 | return my_score, my_score_every 169 | 170 | def evaluate_diy_test(self, gts, res): 171 | """ 172 | imgIds = self.params['image_id'] 173 | # imgIds = self.coco.getImgIds() 174 | gts = {} 175 | res = {} 176 | for imgId in imgIds: 177 | gts[imgId] = self.coco.imgToAnns[imgId] 178 | res[imgId] = self.cocoRes.imgToAnns[imgId] 179 | """ 180 | # ================================================= 181 | # Set up scorers 182 | # ================================================= 183 | print('tokenization...') 184 | tokenizer = PTBTokenizer() 185 | gts = tokenizer.tokenize(gts) 186 | res = tokenizer.tokenize(res) 187 | 188 | # ================================================= 189 | # Set up scorers 190 | # ================================================= 191 | print('setting up scorers...') 192 | my_score = {} 193 | scorers = [ 194 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 195 | (Meteor(),"METEOR"), 196 | (Rouge(), "ROUGE_L"), 197 | (Cider(), "CIDEr"), 198 | # (Spice(), "SPICE") 199 | ] 200 | 201 | # ================================================= 202 | # Compute scores 203 | # ================================================= 204 | for scorer, method in scorers: 205 | if not method == "CIDEr": 206 | continue 207 | print('computing %s score...'%(scorer.method())) 208 | score, scores = scorer.compute_score(gts, res) 209 | if type(method) == list: 210 | for sc, scs, m in zip(score, scores, method): 211 | self.setEval(sc, m) 212 | my_score[m] = sc 213 | self.setImgToEvalImgs(scs, gts.keys(), m) 214 | print("%s: %0.3f"%(m, sc)) 215 | else: 216 | self.setEval(score, method) 217 | my_score[method] = score 218 | self.setImgToEvalImgs(scores, gts.keys(), method) 219 | print("%s: %0.3f"%(method, score)) 220 | self.setEvalImgs() 221 | return scores 222 | 223 | def setEval(self, score, method): 224 | self.eval[method] = score 225 | 226 | def setImgToEvalImgs(self, scores, imgIds, method): 227 | for imgId, score in zip(imgIds, scores): 228 | if not imgId in self.imgToEval: 229 | self.imgToEval[imgId] = {} 230 | self.imgToEval[imgId]["image_id"] = imgId 231 | self.imgToEval[imgId][method] = score 232 | 233 | def setEvalImgs(self): 234 | self.evalImgs = [eval for imgId, eval in self.imgToEval.items()] -------------------------------------------------------------------------------- /evaluation/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/.DS_Store -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | #from .bleu import Bleu 2 | #from .meteor import Meteor 3 | #from .rouge import Rouge 4 | from .cider import Cider 5 | #from .tokenizer import PTBTokenizer 6 | """ 7 | def compute_scores(gts, gen): 8 | metrics = (Bleu(), Meteor(), Rouge(), Cider()) 9 | all_score = {} 10 | all_scores = {} 11 | for metric in metrics: 12 | score, scores = metric.compute_score(gts, gen) 13 | all_score[str(metric)] = score 14 | all_scores[str(metric)] = scores 15 | 16 | return all_score, all_scores 17 | """ -------------------------------------------------------------------------------- /evaluation/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | from .bleu import Bleu -------------------------------------------------------------------------------- /evaluation/bleu/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/bleu/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/bleu/__pycache__/bleu.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/bleu/__pycache__/bleu.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/bleu/__pycache__/bleu_scorer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/bleu/__pycache__/bleu_scorer.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from .bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | # score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 41 | # score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | return score, scores 44 | 45 | def __str__(self): 46 | return 'BLEU' 47 | -------------------------------------------------------------------------------- /evaluation/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | ''' Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | 24 | def precook(s, n=4, out=False): 25 | """Takes a string as input and returns an object that can be given to 26 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 27 | can take string arguments as well.""" 28 | words = s.split() 29 | counts = defaultdict(int) 30 | for k in range(1, n + 1): 31 | for i in range(len(words) - k + 1): 32 | ngram = tuple(words[i:i + k]) 33 | counts[ngram] += 1 34 | return (len(words), counts) 35 | 36 | 37 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 38 | '''Takes a list of reference sentences for a single segment 39 | and returns an object that encapsulates everything that BLEU 40 | needs to know about them.''' 41 | 42 | reflen = [] 43 | maxcounts = {} 44 | for ref in refs: 45 | rl, counts = precook(ref, n) 46 | reflen.append(rl) 47 | for (ngram, count) in counts.items(): 48 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) 49 | 50 | # Calculate effective reference sentence length. 51 | if eff == "shortest": 52 | reflen = min(reflen) 53 | elif eff == "average": 54 | reflen = float(sum(reflen)) / len(reflen) 55 | 56 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 57 | 58 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 59 | 60 | return (reflen, maxcounts) 61 | 62 | 63 | def cook_test(test, ref_tuple, eff=None, n=4): 64 | '''Takes a test sentence and returns an object that 65 | encapsulates everything that BLEU needs to know about it.''' 66 | 67 | testlen, counts = precook(test, n, True) 68 | reflen, refmaxcounts = ref_tuple 69 | 70 | result = {} 71 | 72 | # Calculate effective reference sentence length. 73 | 74 | if eff == "closest": 75 | result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1] 76 | else: ## i.e., "average" or "shortest" or None 77 | result["reflen"] = reflen 78 | 79 | result["testlen"] = testlen 80 | 81 | result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)] 82 | 83 | result['correct'] = [0] * n 84 | for (ngram, count) in counts.items(): 85 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) 86 | 87 | return result 88 | 89 | 90 | class BleuScorer(object): 91 | """Bleu scorer. 92 | """ 93 | 94 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 95 | 96 | # special_reflen is used in oracle (proportional effective ref len for a node). 97 | 98 | def copy(self): 99 | ''' copy the refs.''' 100 | new = BleuScorer(n=self.n) 101 | new.ctest = copy.copy(self.ctest) 102 | new.crefs = copy.copy(self.crefs) 103 | new._score = None 104 | return new 105 | 106 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 107 | ''' singular instance ''' 108 | 109 | self.n = n 110 | self.crefs = [] 111 | self.ctest = [] 112 | self.cook_append(test, refs) 113 | self.special_reflen = special_reflen 114 | 115 | def cook_append(self, test, refs): 116 | '''called by constructor and __iadd__ to avoid creating new instances.''' 117 | 118 | if refs is not None: 119 | self.crefs.append(cook_refs(refs)) 120 | if test is not None: 121 | cooked_test = cook_test(test, self.crefs[-1]) 122 | self.ctest.append(cooked_test) ## N.B.: -1 123 | else: 124 | self.ctest.append(None) # lens of crefs and ctest have to match 125 | 126 | self._score = None ## need to recompute 127 | 128 | def ratio(self, option=None): 129 | self.compute_score(option=option) 130 | return self._ratio 131 | 132 | def score_ratio(self, option=None): 133 | ''' 134 | return (bleu, len_ratio) pair 135 | ''' 136 | 137 | return self.fscore(option=option), self.ratio(option=option) 138 | 139 | def score_ratio_str(self, option=None): 140 | return "%.4f (%.2f)" % self.score_ratio(option) 141 | 142 | def reflen(self, option=None): 143 | self.compute_score(option=option) 144 | return self._reflen 145 | 146 | def testlen(self, option=None): 147 | self.compute_score(option=option) 148 | return self._testlen 149 | 150 | def retest(self, new_test): 151 | if type(new_test) is str: 152 | new_test = [new_test] 153 | assert len(new_test) == len(self.crefs), new_test 154 | self.ctest = [] 155 | for t, rs in zip(new_test, self.crefs): 156 | self.ctest.append(cook_test(t, rs)) 157 | self._score = None 158 | 159 | return self 160 | 161 | def rescore(self, new_test): 162 | ''' replace test(s) with new test(s), and returns the new score.''' 163 | 164 | return self.retest(new_test).compute_score() 165 | 166 | def size(self): 167 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 168 | return len(self.crefs) 169 | 170 | def __iadd__(self, other): 171 | '''add an instance (e.g., from another sentence).''' 172 | 173 | if type(other) is tuple: 174 | ## avoid creating new BleuScorer instances 175 | self.cook_append(other[0], other[1]) 176 | else: 177 | assert self.compatible(other), "incompatible BLEUs." 178 | self.ctest.extend(other.ctest) 179 | self.crefs.extend(other.crefs) 180 | self._score = None ## need to recompute 181 | 182 | return self 183 | 184 | def compatible(self, other): 185 | return isinstance(other, BleuScorer) and self.n == other.n 186 | 187 | def single_reflen(self, option="average"): 188 | return self._single_reflen(self.crefs[0][0], option) 189 | 190 | def _single_reflen(self, reflens, option=None, testlen=None): 191 | 192 | if option == "shortest": 193 | reflen = min(reflens) 194 | elif option == "average": 195 | reflen = float(sum(reflens)) / len(reflens) 196 | elif option == "closest": 197 | reflen = min((abs(l - testlen), l) for l in reflens)[1] 198 | else: 199 | assert False, "unsupported reflen option %s" % option 200 | 201 | return reflen 202 | 203 | def recompute_score(self, option=None, verbose=0): 204 | self._score = None 205 | return self.compute_score(option, verbose) 206 | 207 | def compute_score(self, option=None, verbose=0): 208 | n = self.n 209 | small = 1e-9 210 | tiny = 1e-15 ## so that if guess is 0 still return 0 211 | bleu_list = [[] for _ in range(n)] 212 | 213 | if self._score is not None: 214 | return self._score 215 | 216 | if option is None: 217 | option = "average" if len(self.crefs) == 1 else "closest" 218 | 219 | self._testlen = 0 220 | self._reflen = 0 221 | totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n} 222 | 223 | # for each sentence 224 | for comps in self.ctest: 225 | testlen = comps['testlen'] 226 | self._testlen += testlen 227 | 228 | if self.special_reflen is None: ## need computation 229 | reflen = self._single_reflen(comps['reflen'], option, testlen) 230 | else: 231 | reflen = self.special_reflen 232 | 233 | self._reflen += reflen 234 | 235 | for key in ['guess', 'correct']: 236 | for k in range(n): 237 | totalcomps[key][k] += comps[key][k] 238 | 239 | # append per image bleu score 240 | bleu = 1. 241 | for k in range(n): 242 | bleu *= (float(comps['correct'][k]) + tiny) \ 243 | / (float(comps['guess'][k]) + small) 244 | bleu_list[k].append(bleu ** (1. / (k + 1))) 245 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 246 | if ratio < 1: 247 | for k in range(n): 248 | bleu_list[k][-1] *= math.exp(1 - 1 / ratio) 249 | 250 | if verbose > 1: 251 | print(comps, reflen) 252 | 253 | totalcomps['reflen'] = self._reflen 254 | totalcomps['testlen'] = self._testlen 255 | 256 | bleus = [] 257 | bleu = 1. 258 | for k in range(n): 259 | bleu *= float(totalcomps['correct'][k] + tiny) \ 260 | / (totalcomps['guess'][k] + small) 261 | bleus.append(bleu ** (1. / (k + 1))) 262 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 263 | if ratio < 1: 264 | for k in range(n): 265 | bleus[k] *= math.exp(1 - 1 / ratio) 266 | 267 | if verbose > 0: 268 | print(totalcomps) 269 | print("ratio:", ratio) 270 | 271 | self._score = bleus 272 | return self._score, bleu_list 273 | -------------------------------------------------------------------------------- /evaluation/cider/__init__.py: -------------------------------------------------------------------------------- 1 | from .cider import Cider -------------------------------------------------------------------------------- /evaluation/cider/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/cider/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/cider/__pycache__/cider.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/cider/__pycache__/cider.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/cider/__pycache__/cider_scorer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/cider/__pycache__/cider_scorer.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from .cider_scorer import CiderScorer 11 | 12 | class Cider: 13 | """ 14 | Main Class to compute the CIDEr metric 15 | 16 | """ 17 | def __init__(self, gts=None, n=4, sigma=6.0): 18 | # set cider to sum over 1 to 4-grams 19 | self._n = n 20 | # set the standard deviation parameter for gaussian penalty 21 | self._sigma = sigma 22 | self.doc_frequency = None 23 | self.ref_len = None 24 | if gts is not None: 25 | tmp_cider = CiderScorer(gts, n=self._n, sigma=self._sigma) 26 | self.doc_frequency = tmp_cider.doc_frequency 27 | self.ref_len = tmp_cider.ref_len 28 | 29 | def compute_score(self, gts, res): 30 | """ 31 | Main function to compute CIDEr score 32 | :param gts (dict) : dictionary with key and value 33 | res (dict) : dictionary with key and value 34 | :return: cider (float) : computed CIDEr score for the corpus 35 | """ 36 | assert(gts.keys() == res.keys()) 37 | cider_scorer = CiderScorer(gts, test=res, n=self._n, sigma=self._sigma, doc_frequency=self.doc_frequency, 38 | ref_len=self.ref_len) 39 | return cider_scorer.compute_score() 40 | 41 | def __str__(self): 42 | return 'CIDEr' 43 | -------------------------------------------------------------------------------- /evaluation/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import math 9 | 10 | def precook(s, n=4): 11 | """ 12 | Takes a string as input and returns an object that can be given to 13 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 14 | can take string arguments as well. 15 | :param s: string : sentence to be converted into ngrams 16 | :param n: int : number of ngrams for which representation is calculated 17 | :return: term frequency vector for occuring ngrams 18 | """ 19 | words = s.split() 20 | counts = defaultdict(int) 21 | for k in range(1,n+1): 22 | for i in range(len(words)-k+1): 23 | ngram = tuple(words[i:i+k]) 24 | counts[ngram] += 1 25 | return counts 26 | 27 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 28 | '''Takes a list of reference sentences for a single segment 29 | and returns an object that encapsulates everything that BLEU 30 | needs to know about them. 31 | :param refs: list of string : reference sentences for some image 32 | :param n: int : number of ngrams for which (ngram) representation is calculated 33 | :return: result (list of dict) 34 | ''' 35 | return [precook(ref, n) for ref in refs] 36 | 37 | def cook_test(test, n=4): 38 | '''Takes a test sentence and returns an object that 39 | encapsulates everything that BLEU needs to know about it. 40 | :param test: list of string : hypothesis sentence for some image 41 | :param n: int : number of ngrams for which (ngram) representation is calculated 42 | :return: result (dict) 43 | ''' 44 | return precook(test, n) 45 | 46 | class CiderScorer(object): 47 | """CIDEr scorer. 48 | """ 49 | 50 | def __init__(self, refs, test=None, n=4, sigma=6.0, doc_frequency=None, ref_len=None): 51 | ''' singular instance ''' 52 | self.n = n 53 | self.sigma = sigma 54 | self.crefs = [] 55 | self.ctest = [] 56 | self.doc_frequency = defaultdict(float) 57 | self.ref_len = None 58 | 59 | for k in refs.keys(): 60 | self.crefs.append(cook_refs(refs[k])) 61 | if test is not None: 62 | self.ctest.append(cook_test(test[k][0])) ## N.B.: -1 63 | else: 64 | self.ctest.append(None) # lens of crefs and ctest have to match 65 | 66 | if doc_frequency is None and ref_len is None: 67 | # compute idf 68 | self.compute_doc_freq() 69 | # compute log reference length 70 | self.ref_len = np.log(float(len(self.crefs))) 71 | else: 72 | self.doc_frequency = doc_frequency 73 | self.ref_len = ref_len 74 | 75 | def compute_doc_freq(self): 76 | ''' 77 | Compute term frequency for reference data. 78 | This will be used to compute idf (inverse document frequency later) 79 | The term frequency is stored in the object 80 | :return: None 81 | ''' 82 | for refs in self.crefs: 83 | # refs, k ref captions of one image 84 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 85 | self.doc_frequency[ngram] += 1 86 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 87 | 88 | def compute_cider(self): 89 | def counts2vec(cnts): 90 | """ 91 | Function maps counts of ngram to vector of tfidf weights. 92 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 93 | The n-th entry of array denotes length of n-grams. 94 | :param cnts: 95 | :return: vec (array of dict), norm (array of float), length (int) 96 | """ 97 | vec = [defaultdict(float) for _ in range(self.n)] 98 | length = 0 99 | norm = [0.0 for _ in range(self.n)] 100 | for (ngram,term_freq) in cnts.items(): 101 | # give word count 1 if it doesn't appear in reference corpus 102 | df = np.log(max(1.0, self.doc_frequency[ngram])) 103 | # ngram index 104 | n = len(ngram)-1 105 | # tf (term_freq) * idf (precomputed idf) for n-grams 106 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 107 | # compute norm for the vector. the norm will be used for computing similarity 108 | norm[n] += pow(vec[n][ngram], 2) 109 | 110 | if n == 1: 111 | length += term_freq 112 | norm = [np.sqrt(n) for n in norm] 113 | return vec, norm, length 114 | 115 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 116 | ''' 117 | Compute the cosine similarity of two vectors. 118 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 119 | :param vec_ref: array of dictionary for vector corresponding to reference 120 | :param norm_hyp: array of float for vector corresponding to hypothesis 121 | :param norm_ref: array of float for vector corresponding to reference 122 | :param length_hyp: int containing length of hypothesis 123 | :param length_ref: int containing length of reference 124 | :return: array of score for each n-grams cosine similarity 125 | ''' 126 | delta = float(length_hyp - length_ref) 127 | # measure consine similarity 128 | val = np.array([0.0 for _ in range(self.n)]) 129 | for n in range(self.n): 130 | # ngram 131 | for (ngram,count) in vec_hyp[n].items(): 132 | # vrama91 : added clipping 133 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 134 | 135 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 136 | val[n] /= (norm_hyp[n]*norm_ref[n]) 137 | 138 | assert(not math.isnan(val[n])) 139 | # vrama91: added a length based gaussian penalty 140 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 141 | return val 142 | 143 | scores = [] 144 | for test, refs in zip(self.ctest, self.crefs): 145 | # compute vector for test captions 146 | vec, norm, length = counts2vec(test) 147 | # compute vector for ref captions 148 | score = np.array([0.0 for _ in range(self.n)]) 149 | for ref in refs: 150 | vec_ref, norm_ref, length_ref = counts2vec(ref) 151 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 152 | # change by vrama91 - mean of ngram scores, instead of sum 153 | score_avg = np.mean(score) 154 | # divide by number of references 155 | score_avg /= len(refs) 156 | # multiply score by 10 157 | score_avg *= 10.0 158 | # append score of an image to the score list 159 | scores.append(score_avg) 160 | return scores 161 | 162 | def compute_score(self): 163 | # compute cider score 164 | score = self.compute_cider() 165 | # debug 166 | # print score 167 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /evaluation/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | from .meteor import Meteor -------------------------------------------------------------------------------- /evaluation/meteor/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/meteor/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/meteor/__pycache__/meteor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/meteor/__pycache__/meteor.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | # Python wrapper for METEOR implementation, by Xinlei Chen 2 | # Acknowledge Michael Denkowski for the generous discussion and help 3 | 4 | import os 5 | import subprocess 6 | import threading 7 | import tarfile 8 | from utils import download_from_url 9 | 10 | METEOR_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/meteor.tgz' 11 | METEOR_JAR = 'meteor-1.5.jar' 12 | 13 | class Meteor: 14 | def __init__(self): 15 | base_path = os.path.dirname(os.path.abspath(__file__)) 16 | jar_path = os.path.join(base_path, METEOR_JAR) 17 | gz_path = os.path.join(base_path, os.path.basename(METEOR_GZ_URL)) 18 | if not os.path.isfile(jar_path): 19 | if not os.path.isfile(gz_path): 20 | download_from_url(METEOR_GZ_URL, gz_path) 21 | tar = tarfile.open(gz_path, "r") 22 | tar.extractall(path=os.path.dirname(os.path.abspath(__file__))) 23 | tar.close() 24 | os.remove(gz_path) 25 | 26 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 27 | '-', '-', '-stdio', '-l', 'en', '-norm'] 28 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 29 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 30 | stdin=subprocess.PIPE, \ 31 | stdout=subprocess.PIPE, \ 32 | stderr=subprocess.PIPE) 33 | # Used to guarantee thread safety 34 | self.lock = threading.Lock() 35 | 36 | def compute_score(self, gts, res): 37 | assert(gts.keys() == res.keys()) 38 | imgIds = gts.keys() 39 | scores = [] 40 | 41 | eval_line = 'EVAL' 42 | self.lock.acquire() 43 | for i in imgIds: 44 | assert(len(res[i]) == 1) 45 | stat = self._stat(res[i][0], gts[i]) 46 | eval_line += ' ||| {}'.format(stat) 47 | 48 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode()) 49 | self.meteor_p.stdin.flush() 50 | for i in range(0,len(imgIds)): 51 | scores.append(float(self.meteor_p.stdout.readline().strip())) 52 | score = float(self.meteor_p.stdout.readline().strip()) 53 | self.lock.release() 54 | 55 | return score, scores 56 | 57 | def _stat(self, hypothesis_str, reference_list): 58 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 59 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 60 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 61 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode()) 62 | self.meteor_p.stdin.flush() 63 | raw = self.meteor_p.stdout.readline().decode().strip() 64 | numbers = [str(int(float(n))) for n in raw.split()] 65 | return ' '.join(numbers) 66 | 67 | def __del__(self): 68 | self.lock.acquire() 69 | self.meteor_p.stdin.close() 70 | self.meteor_p.kill() 71 | self.meteor_p.wait() 72 | self.lock.release() 73 | 74 | def __str__(self): 75 | return 'METEOR' 76 | -------------------------------------------------------------------------------- /evaluation/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | from .rouge import Rouge -------------------------------------------------------------------------------- /evaluation/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | 14 | def my_lcs(string, sub): 15 | """ 16 | Calculates longest common subsequence for a pair of tokenized strings 17 | :param string : list of str : tokens from a string split using whitespace 18 | :param sub : list of str : shorter string, also split using whitespace 19 | :returns: length (list of int): length of the longest common subsequence between the two strings 20 | 21 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 22 | """ 23 | if (len(string) < len(sub)): 24 | sub, string = string, sub 25 | 26 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)] 27 | 28 | for j in range(1, len(sub) + 1): 29 | for i in range(1, len(string) + 1): 30 | if (string[i - 1] == sub[j - 1]): 31 | lengths[i][j] = lengths[i - 1][j - 1] + 1 32 | else: 33 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) 34 | 35 | return lengths[len(string)][len(sub)] 36 | 37 | 38 | class Rouge(): 39 | ''' 40 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 41 | 42 | ''' 43 | 44 | def __init__(self): 45 | # vrama91: updated the value below based on discussion with Hovey 46 | self.beta = 1.2 47 | 48 | def calc_score(self, candidate, refs): 49 | """ 50 | Compute ROUGE-L score given one candidate and references for an image 51 | :param candidate: str : candidate sentence to be evaluated 52 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 53 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 54 | """ 55 | assert (len(candidate) == 1) 56 | assert (len(refs) > 0) 57 | prec = [] 58 | rec = [] 59 | 60 | # split into tokens 61 | token_c = candidate[0].split(" ") 62 | 63 | for reference in refs: 64 | # split into tokens 65 | token_r = reference.split(" ") 66 | # compute the longest common subsequence 67 | lcs = my_lcs(token_r, token_c) 68 | prec.append(lcs / float(len(token_c))) 69 | rec.append(lcs / float(len(token_r))) 70 | 71 | prec_max = max(prec) 72 | rec_max = max(rec) 73 | 74 | if (prec_max != 0 and rec_max != 0): 75 | score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max) 76 | else: 77 | score = 0.0 78 | return score 79 | 80 | def compute_score(self, gts, res): 81 | """ 82 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 83 | Invoked by evaluate_captions.py 84 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 85 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 86 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 87 | """ 88 | assert (gts.keys() == res.keys()) 89 | imgIds = gts.keys() 90 | 91 | score = [] 92 | for id in imgIds: 93 | hypo = res[id] 94 | ref = gts[id] 95 | 96 | score.append(self.calc_score(hypo, ref)) 97 | 98 | # Sanity check. 99 | assert (type(hypo) is list) 100 | assert (len(hypo) == 1) 101 | assert (type(ref) is list) 102 | assert (len(ref) > 0) 103 | 104 | average_score = np.mean(np.array(score)) 105 | return average_score, np.array(score) 106 | 107 | def __str__(self): 108 | return 'ROUGE' 109 | -------------------------------------------------------------------------------- /evaluation/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /evaluation/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import subprocess 13 | import tempfile 14 | 15 | class PTBTokenizer(object): 16 | """Python wrapper of Stanford PTBTokenizer""" 17 | 18 | corenlp_jar = 'stanford-corenlp-3.4.1.jar' 19 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 20 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 21 | 22 | @classmethod 23 | def tokenize(cls, corpus): 24 | cmd = ['java', '-cp', cls.corenlp_jar, \ 25 | 'edu.stanford.nlp.process.PTBTokenizer', \ 26 | '-preserveLines', '-lowerCase'] 27 | 28 | if isinstance(corpus, list) or isinstance(corpus, tuple): 29 | if isinstance(corpus[0], list) or isinstance(corpus[0], tuple): 30 | corpus = {i:c for i, c in enumerate(corpus)} 31 | else: 32 | corpus = {i: [c, ] for i, c in enumerate(corpus)} 33 | 34 | # prepare data for PTB Tokenizer 35 | tokenized_corpus = {} 36 | image_id = [k for k, v in list(corpus.items()) for _ in range(len(v))] 37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in corpus.items() for c in v]) 38 | 39 | # save sentences to temporary file 40 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 41 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 42 | tmp_file.write(sentences.encode()) 43 | tmp_file.close() 44 | 45 | # tokenize sentence 46 | cmd.append(os.path.basename(tmp_file.name)) 47 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 48 | stdout=subprocess.PIPE, stderr=open(os.devnull, 'w')) 49 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 50 | token_lines = token_lines.decode() 51 | lines = token_lines.split('\n') 52 | # remove temp file 53 | os.remove(tmp_file.name) 54 | 55 | # create dictionary for tokenized captions 56 | for k, line in zip(image_id, lines): 57 | if not k in tokenized_corpus: 58 | tokenized_corpus[k] = [] 59 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 60 | if w not in cls.punctuations]) 61 | tokenized_corpus[k].append(tokenized_caption) 62 | 63 | return tokenized_corpus -------------------------------------------------------------------------------- /knowcap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/knowcap.png -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/.DS_Store -------------------------------------------------------------------------------- /models/BLIP/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/BLIP/__init__.py -------------------------------------------------------------------------------- /models/BLIP/blip_itm.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig, BertModel 2 | from transformers import BertTokenizer 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from models.blip import create_vit, init_tokenizer, load_checkpoint 9 | 10 | class BLIP_ITM(nn.Module): 11 | def __init__(self, 12 | med_config = 'configs/med_config.json', 13 | image_size = 384, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | embed_dim = 256, 18 | ): 19 | """ 20 | Args: 21 | med_config (str): path for the mixture of encoder-decoder model's configuration file 22 | image_size (int): input image size 23 | vit (str): model size of vision transformer 24 | """ 25 | super().__init__() 26 | 27 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 28 | self.tokenizer = init_tokenizer() 29 | med_config = BertConfig.from_json_file(med_config) 30 | med_config.encoder_width = vision_width 31 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 32 | 33 | text_width = self.text_encoder.config.hidden_size 34 | 35 | self.vision_proj = nn.Linear(vision_width, embed_dim) 36 | self.text_proj = nn.Linear(text_width, embed_dim) 37 | 38 | self.itm_head = nn.Linear(text_width, 2) 39 | 40 | 41 | def forward(self, image, caption, match_head='itm'): 42 | 43 | image_embeds = self.visual_encoder(image) 44 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 45 | 46 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 47 | return_tensors="pt").to(image.device) 48 | 49 | 50 | if match_head=='itm': 51 | output = self.text_encoder(text.input_ids, 52 | attention_mask = text.attention_mask, 53 | encoder_hidden_states = image_embeds, 54 | encoder_attention_mask = image_atts, 55 | return_dict = True, 56 | ) 57 | itm_output = self.itm_head(output.last_hidden_state[:,0,:]) 58 | return itm_output 59 | 60 | elif match_head=='itc': 61 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 62 | return_dict = True, mode = 'text') 63 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 64 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 65 | 66 | sim = image_feat @ text_feat.t() 67 | return sim 68 | 69 | 70 | def blip_itm(pretrained='',**kwargs): 71 | model = BLIP_ITM(**kwargs) 72 | if pretrained: 73 | model,msg = load_checkpoint(model,pretrained) 74 | assert(len(msg.missing_keys)==0) 75 | return model 76 | -------------------------------------------------------------------------------- /models/BLIP/blip_nlvr.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig 2 | from models.nlvr_encoder import BertModel 3 | from models.vit import interpolate_pos_embed 4 | from models.blip import create_vit, init_tokenizer, is_url 5 | 6 | from timm.models.hub import download_cached_file 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from transformers import BertTokenizer 12 | import numpy as np 13 | 14 | class BLIP_NLVR(nn.Module): 15 | def __init__(self, 16 | med_config = 'configs/med_config.json', 17 | image_size = 480, 18 | vit = 'base', 19 | vit_grad_ckpt = False, 20 | vit_ckpt_layer = 0, 21 | ): 22 | """ 23 | Args: 24 | med_config (str): path for the mixture of encoder-decoder model's configuration file 25 | image_size (int): input image size 26 | vit (str): model size of vision transformer 27 | """ 28 | super().__init__() 29 | 30 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) 31 | self.tokenizer = init_tokenizer() 32 | med_config = BertConfig.from_json_file(med_config) 33 | med_config.encoder_width = vision_width 34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 35 | 36 | self.cls_head = nn.Sequential( 37 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), 38 | nn.ReLU(), 39 | nn.Linear(self.text_encoder.config.hidden_size, 2) 40 | ) 41 | 42 | def forward(self, image, text, targets, train=True): 43 | 44 | image_embeds = self.visual_encoder(image) 45 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 46 | image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0)) 47 | 48 | text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device) 49 | text.input_ids[:,0] = self.tokenizer.enc_token_id 50 | 51 | output = self.text_encoder(text.input_ids, 52 | attention_mask = text.attention_mask, 53 | encoder_hidden_states = [image0_embeds,image1_embeds], 54 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)], 55 | image_atts[image0_embeds.size(0):]], 56 | return_dict = True, 57 | ) 58 | hidden_state = output.last_hidden_state[:,0,:] 59 | prediction = self.cls_head(hidden_state) 60 | 61 | if train: 62 | loss = F.cross_entropy(prediction, targets) 63 | return loss 64 | else: 65 | return prediction 66 | 67 | def blip_nlvr(pretrained='',**kwargs): 68 | model = BLIP_NLVR(**kwargs) 69 | if pretrained: 70 | model,msg = load_checkpoint(model,pretrained) 71 | print("missing keys:") 72 | print(msg.missing_keys) 73 | return model 74 | 75 | 76 | def load_checkpoint(model,url_or_filename): 77 | if is_url(url_or_filename): 78 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 79 | checkpoint = torch.load(cached_file, map_location='cpu') 80 | elif os.path.isfile(url_or_filename): 81 | checkpoint = torch.load(url_or_filename, map_location='cpu') 82 | else: 83 | raise RuntimeError('checkpoint url or path is invalid') 84 | state_dict = checkpoint['model'] 85 | 86 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 87 | 88 | for key in list(state_dict.keys()): 89 | if 'crossattention.self.' in key: 90 | new_key0 = key.replace('self','self0') 91 | new_key1 = key.replace('self','self1') 92 | state_dict[new_key0] = state_dict[key] 93 | state_dict[new_key1] = state_dict[key] 94 | elif 'crossattention.output.dense.' in key: 95 | new_key0 = key.replace('dense','dense0') 96 | new_key1 = key.replace('dense','dense1') 97 | state_dict[new_key0] = state_dict[key] 98 | state_dict[new_key1] = state_dict[key] 99 | 100 | msg = model.load_state_dict(state_dict,strict=False) 101 | print('load checkpoint from %s'%url_or_filename) 102 | return model,msg 103 | -------------------------------------------------------------------------------- /models/BLIP/blip_retrieval.py: -------------------------------------------------------------------------------- 1 | from .med import BertConfig, BertModel 2 | from transformers import BertTokenizer 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from .blip import create_vit, init_tokenizer, load_checkpoint 9 | 10 | class BLIP_Retrieval(nn.Module): 11 | def __init__(self, 12 | med_config = 'configs/med_config.json', 13 | image_size = 384, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | embed_dim = 256, 18 | queue_size = 57600, 19 | momentum = 0.995, 20 | negative_all_rank = False, 21 | ): 22 | """ 23 | Args: 24 | med_config (str): path for the mixture of encoder-decoder model's configuration file 25 | image_size (int): input image size 26 | vit (str): model size of vision transformer 27 | """ 28 | super().__init__() 29 | 30 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 31 | self.tokenizer = init_tokenizer() 32 | med_config = BertConfig.from_json_file(med_config) 33 | med_config.encoder_width = vision_width 34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 35 | 36 | text_width = self.text_encoder.config.hidden_size 37 | 38 | self.vision_proj = nn.Linear(vision_width, embed_dim) 39 | self.text_proj = nn.Linear(text_width, embed_dim) 40 | 41 | self.itm_head = nn.Linear(text_width, 2) 42 | 43 | # create momentum encoders 44 | self.visual_encoder_m, vision_width = create_vit(vit,image_size) 45 | self.vision_proj_m = nn.Linear(vision_width, embed_dim) 46 | self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False) 47 | self.text_proj_m = nn.Linear(text_width, embed_dim) 48 | 49 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], 50 | [self.vision_proj,self.vision_proj_m], 51 | [self.text_encoder,self.text_encoder_m], 52 | [self.text_proj,self.text_proj_m], 53 | ] 54 | self.copy_params() 55 | 56 | # create the queue 57 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) 58 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) 59 | self.register_buffer("idx_queue", torch.full((1,queue_size),-100)) 60 | self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long)) 61 | 62 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0) 63 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0) 64 | 65 | self.queue_size = queue_size 66 | self.momentum = momentum 67 | self.temp = nn.Parameter(0.07*torch.ones([])) 68 | 69 | self.negative_all_rank = negative_all_rank 70 | 71 | 72 | def forward(self, image, caption, alpha, idx): 73 | with torch.no_grad(): 74 | self.temp.clamp_(0.001,0.5) 75 | 76 | image_embeds = self.visual_encoder(image) 77 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 78 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 79 | 80 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 81 | return_tensors="pt").to(image.device) 82 | 83 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 84 | return_dict = True, mode = 'text') 85 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 86 | 87 | ###============== Image-text Contrastive Learning ===================### 88 | idx = idx.view(-1,1) 89 | idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1) 90 | pos_idx = torch.eq(idx, idx_all).float() 91 | sim_targets = pos_idx / pos_idx.sum(1,keepdim=True) 92 | 93 | # get momentum features 94 | with torch.no_grad(): 95 | self._momentum_update() 96 | image_embeds_m = self.visual_encoder_m(image) 97 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) 98 | image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) 99 | 100 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, 101 | return_dict = True, mode = 'text') 102 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 103 | text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) 104 | 105 | sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp 106 | sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp 107 | 108 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets 109 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 110 | 111 | sim_i2t = image_feat @ text_feat_m_all / self.temp 112 | sim_t2i = text_feat @ image_feat_m_all / self.temp 113 | 114 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() 115 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 116 | 117 | loss_ita = (loss_i2t+loss_t2i)/2 118 | 119 | idxs = concat_all_gather(idx) 120 | self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs) 121 | 122 | ###============== Image-text Matching ===================### 123 | encoder_input_ids = text.input_ids.clone() 124 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id 125 | 126 | # forward the positve image-text pair 127 | bs = image.size(0) 128 | output_pos = self.text_encoder(encoder_input_ids, 129 | attention_mask = text.attention_mask, 130 | encoder_hidden_states = image_embeds, 131 | encoder_attention_mask = image_atts, 132 | return_dict = True, 133 | ) 134 | 135 | 136 | if self.negative_all_rank: 137 | # compute sample similarity 138 | with torch.no_grad(): 139 | mask = torch.eq(idx, idxs.t()) 140 | 141 | image_feat_world = concat_all_gather(image_feat) 142 | text_feat_world = concat_all_gather(text_feat) 143 | 144 | sim_i2t = image_feat @ text_feat_world.t() / self.temp 145 | sim_t2i = text_feat @ image_feat_world.t() / self.temp 146 | 147 | weights_i2t = F.softmax(sim_i2t,dim=1) 148 | weights_i2t.masked_fill_(mask, 0) 149 | 150 | weights_t2i = F.softmax(sim_t2i,dim=1) 151 | weights_t2i.masked_fill_(mask, 0) 152 | 153 | image_embeds_world = all_gather_with_grad(image_embeds) 154 | 155 | # select a negative image (from all ranks) for each text 156 | image_embeds_neg = [] 157 | for b in range(bs): 158 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 159 | image_embeds_neg.append(image_embeds_world[neg_idx]) 160 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 161 | 162 | # select a negative text (from all ranks) for each image 163 | input_ids_world = concat_all_gather(encoder_input_ids) 164 | att_mask_world = concat_all_gather(text.attention_mask) 165 | 166 | text_ids_neg = [] 167 | text_atts_neg = [] 168 | for b in range(bs): 169 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 170 | text_ids_neg.append(input_ids_world[neg_idx]) 171 | text_atts_neg.append(att_mask_world[neg_idx]) 172 | 173 | else: 174 | with torch.no_grad(): 175 | mask = torch.eq(idx, idx.t()) 176 | 177 | sim_i2t = image_feat @ text_feat.t() / self.temp 178 | sim_t2i = text_feat @ image_feat.t() / self.temp 179 | 180 | weights_i2t = F.softmax(sim_i2t,dim=1) 181 | weights_i2t.masked_fill_(mask, 0) 182 | 183 | weights_t2i = F.softmax(sim_t2i,dim=1) 184 | weights_t2i.masked_fill_(mask, 0) 185 | 186 | # select a negative image (from same rank) for each text 187 | image_embeds_neg = [] 188 | for b in range(bs): 189 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 190 | image_embeds_neg.append(image_embeds[neg_idx]) 191 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 192 | 193 | # select a negative text (from same rank) for each image 194 | text_ids_neg = [] 195 | text_atts_neg = [] 196 | for b in range(bs): 197 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 198 | text_ids_neg.append(encoder_input_ids[neg_idx]) 199 | text_atts_neg.append(text.attention_mask[neg_idx]) 200 | 201 | text_ids_neg = torch.stack(text_ids_neg,dim=0) 202 | text_atts_neg = torch.stack(text_atts_neg,dim=0) 203 | 204 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0) 205 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) 206 | 207 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) 208 | image_atts_all = torch.cat([image_atts,image_atts],dim=0) 209 | 210 | output_neg = self.text_encoder(text_ids_all, 211 | attention_mask = text_atts_all, 212 | encoder_hidden_states = image_embeds_all, 213 | encoder_attention_mask = image_atts_all, 214 | return_dict = True, 215 | ) 216 | 217 | 218 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) 219 | vl_output = self.itm_head(vl_embeddings) 220 | 221 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], 222 | dim=0).to(image.device) 223 | loss_itm = F.cross_entropy(vl_output, itm_labels) 224 | 225 | return loss_ita, loss_itm 226 | 227 | 228 | @torch.no_grad() 229 | def copy_params(self): 230 | for model_pair in self.model_pairs: 231 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 232 | param_m.data.copy_(param.data) # initialize 233 | param_m.requires_grad = False # not update by gradient 234 | 235 | 236 | @torch.no_grad() 237 | def _momentum_update(self): 238 | for model_pair in self.model_pairs: 239 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 240 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) 241 | 242 | 243 | @torch.no_grad() 244 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs): 245 | # gather keys before updating queue 246 | image_feats = concat_all_gather(image_feat) 247 | text_feats = concat_all_gather(text_feat) 248 | 249 | 250 | batch_size = image_feats.shape[0] 251 | 252 | ptr = int(self.ptr_queue) 253 | assert self.queue_size % batch_size == 0 # for simplicity 254 | 255 | # replace the keys at ptr (dequeue and enqueue) 256 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T 257 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T 258 | self.idx_queue[:, ptr:ptr + batch_size] = idxs.T 259 | ptr = (ptr + batch_size) % self.queue_size # move pointer 260 | 261 | self.ptr_queue[0] = ptr 262 | 263 | 264 | def blip_retrieval(pretrained='',**kwargs): 265 | model = BLIP_Retrieval(**kwargs) 266 | if pretrained: 267 | model,msg = load_checkpoint(model,pretrained) 268 | print("missing keys:") 269 | print(msg.missing_keys) 270 | return model 271 | 272 | 273 | @torch.no_grad() 274 | def concat_all_gather(tensor): 275 | """ 276 | Performs all_gather operation on the provided tensors. 277 | *** Warning ***: torch.distributed.all_gather has no gradient. 278 | """ 279 | tensors_gather = [torch.ones_like(tensor) 280 | for _ in range(torch.distributed.get_world_size())] 281 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 282 | 283 | output = torch.cat(tensors_gather, dim=0) 284 | return output 285 | 286 | 287 | class GatherLayer(torch.autograd.Function): 288 | """ 289 | Gather tensors from all workers with support for backward propagation: 290 | This implementation does not cut the gradients as torch.distributed.all_gather does. 291 | """ 292 | 293 | @staticmethod 294 | def forward(ctx, x): 295 | output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())] 296 | torch.distributed.all_gather(output, x) 297 | return tuple(output) 298 | 299 | @staticmethod 300 | def backward(ctx, *grads): 301 | all_gradients = torch.stack(grads) 302 | torch.distributed.all_reduce(all_gradients) 303 | return all_gradients[torch.distributed.get_rank()] 304 | 305 | 306 | def all_gather_with_grad(tensors): 307 | """ 308 | Performs all_gather operation on the provided tensors. 309 | Graph remains connected for backward grad computation. 310 | """ 311 | # Queue the gathered tensors 312 | world_size = torch.distributed.get_world_size() 313 | # There is no need for reduction in the single-proc case 314 | if world_size == 1: 315 | return tensors 316 | 317 | tensor_all = GatherLayer.apply(tensors) 318 | 319 | return torch.cat(tensor_all, dim=0) 320 | -------------------------------------------------------------------------------- /models/BLIP/blip_vqa.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig, BertModel, BertLMHeadModel 2 | from models.blip import create_vit, init_tokenizer, load_checkpoint 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from transformers import BertTokenizer 8 | import numpy as np 9 | 10 | class BLIP_VQA(nn.Module): 11 | def __init__(self, 12 | med_config = 'configs/med_config.json', 13 | image_size = 480, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | ): 18 | """ 19 | Args: 20 | med_config (str): path for the mixture of encoder-decoder model's configuration file 21 | image_size (int): input image size 22 | vit (str): model size of vision transformer 23 | """ 24 | super().__init__() 25 | 26 | self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) 27 | self.tokenizer = init_tokenizer() 28 | 29 | encoder_config = BertConfig.from_json_file(med_config) 30 | encoder_config.encoder_width = vision_width 31 | self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) 32 | 33 | decoder_config = BertConfig.from_json_file(med_config) 34 | self.text_decoder = BertLMHeadModel(config=decoder_config) 35 | 36 | 37 | def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128): 38 | 39 | image_embeds = self.visual_encoder(image) 40 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 41 | 42 | question = self.tokenizer(question, padding='longest', truncation=True, max_length=35, 43 | return_tensors="pt").to(image.device) 44 | question.input_ids[:,0] = self.tokenizer.enc_token_id 45 | 46 | if train: 47 | ''' 48 | n: number of answers for each question 49 | weights: weight for each answer 50 | ''' 51 | answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device) 52 | answer.input_ids[:,0] = self.tokenizer.bos_token_id 53 | answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) 54 | 55 | question_output = self.text_encoder(question.input_ids, 56 | attention_mask = question.attention_mask, 57 | encoder_hidden_states = image_embeds, 58 | encoder_attention_mask = image_atts, 59 | return_dict = True) 60 | 61 | question_states = [] 62 | question_atts = [] 63 | for b, n in enumerate(n): 64 | question_states += [question_output.last_hidden_state[b]]*n 65 | question_atts += [question.attention_mask[b]]*n 66 | question_states = torch.stack(question_states,0) 67 | question_atts = torch.stack(question_atts,0) 68 | 69 | answer_output = self.text_decoder(answer.input_ids, 70 | attention_mask = answer.attention_mask, 71 | encoder_hidden_states = question_states, 72 | encoder_attention_mask = question_atts, 73 | labels = answer_targets, 74 | return_dict = True, 75 | reduction = 'none', 76 | ) 77 | 78 | loss = weights * answer_output.loss 79 | loss = loss.sum()/image.size(0) 80 | 81 | return loss 82 | 83 | 84 | else: 85 | question_output = self.text_encoder(question.input_ids, 86 | attention_mask = question.attention_mask, 87 | encoder_hidden_states = image_embeds, 88 | encoder_attention_mask = image_atts, 89 | return_dict = True) 90 | 91 | if inference=='generate': 92 | num_beams = 3 93 | question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0) 94 | question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device) 95 | model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts} 96 | 97 | bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device) 98 | 99 | outputs = self.text_decoder.generate(input_ids=bos_ids, 100 | max_length=10, 101 | min_length=1, 102 | num_beams=num_beams, 103 | eos_token_id=self.tokenizer.sep_token_id, 104 | pad_token_id=self.tokenizer.pad_token_id, 105 | **model_kwargs) 106 | 107 | answers = [] 108 | for output in outputs: 109 | answer = self.tokenizer.decode(output, skip_special_tokens=True) 110 | answers.append(answer) 111 | return answers 112 | 113 | elif inference=='rank': 114 | max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask, 115 | answer.input_ids, answer.attention_mask, k_test) 116 | return max_ids 117 | 118 | 119 | 120 | def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): 121 | 122 | num_ques = question_states.size(0) 123 | start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token 124 | 125 | start_output = self.text_decoder(start_ids, 126 | encoder_hidden_states = question_states, 127 | encoder_attention_mask = question_atts, 128 | return_dict = True, 129 | reduction = 'none') 130 | logits = start_output.logits[:,0,:] # first token's logit 131 | 132 | # topk_probs: top-k probability 133 | # topk_ids: [num_question, k] 134 | answer_first_token = answer_ids[:,1] 135 | prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) 136 | topk_probs, topk_ids = prob_first_token.topk(k,dim=1) 137 | 138 | # answer input: [num_question*k, answer_len] 139 | input_ids = [] 140 | input_atts = [] 141 | for b, topk_id in enumerate(topk_ids): 142 | input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) 143 | input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) 144 | input_ids = torch.cat(input_ids,dim=0) 145 | input_atts = torch.cat(input_atts,dim=0) 146 | 147 | targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100) 148 | 149 | # repeat encoder's output for top-k answers 150 | question_states = tile(question_states, 0, k) 151 | question_atts = tile(question_atts, 0, k) 152 | 153 | output = self.text_decoder(input_ids, 154 | attention_mask = input_atts, 155 | encoder_hidden_states = question_states, 156 | encoder_attention_mask = question_atts, 157 | labels = targets_ids, 158 | return_dict = True, 159 | reduction = 'none') 160 | 161 | log_probs_sum = -output.loss 162 | log_probs_sum = log_probs_sum.view(num_ques,k) 163 | 164 | max_topk_ids = log_probs_sum.argmax(dim=1) 165 | max_ids = topk_ids[max_topk_ids>=0,max_topk_ids] 166 | 167 | return max_ids 168 | 169 | 170 | def blip_vqa(pretrained='',**kwargs): 171 | model = BLIP_VQA(**kwargs) 172 | if pretrained: 173 | model,msg = load_checkpoint(model,pretrained) 174 | # assert(len(msg.missing_keys)==0) 175 | return model 176 | 177 | 178 | def tile(x, dim, n_tile): 179 | init_dim = x.size(dim) 180 | repeat_idx = [1] * x.dim() 181 | repeat_idx[dim] = n_tile 182 | x = x.repeat(*(repeat_idx)) 183 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) 184 | return torch.index_select(x, dim, order_index.to(x.device)) 185 | 186 | -------------------------------------------------------------------------------- /models/BLIP/caption_coco.yaml: -------------------------------------------------------------------------------- 1 | # image_root: './export/share/datasets/vision/coco/images/' 2 | image_root: './' 3 | ann_root: 'annotation' 4 | coco_gt_root: 'annotation/coco_gt' 5 | 6 | # set pretrained as a file path or an url 7 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth' 8 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 9 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth' 10 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth' 11 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth' 12 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' 13 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 14 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 15 | 16 | # size of vit model; base or large 17 | vit: 'large' 18 | vit_grad_ckpt: False 19 | vit_ckpt_layer: 0 20 | batch_size: 32 21 | init_lr: 1e-5 22 | 23 | # vit: 'large' 24 | # vit_grad_ckpt: True 25 | # vit_ckpt_layer: 5 26 | # batch_size: 16 27 | # init_lr: 2e-6 28 | 29 | image_size: 384 30 | 31 | # generation configs 32 | max_length: 20 33 | min_length: 5 34 | num_beams: 3 35 | prompt: 'a picture of ' 36 | 37 | # optimizer 38 | weight_decay: 0.05 39 | min_lr: 0 40 | max_epoch: 5 41 | 42 | 43 | -------------------------------------------------------------------------------- /models/BLIP/caption_coco_teacher.yaml: -------------------------------------------------------------------------------- 1 | # image_root: './export/share/datasets/vision/coco/images/' 2 | image_root: './' 3 | ann_root: 'annotation' 4 | coco_gt_root: 'annotation/coco_gt' 5 | 6 | # set pretrained as a file path or an url 7 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth' 8 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 9 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth' 10 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth' 11 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth' 12 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' 13 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 14 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 15 | 16 | # size of vit model; base or large 17 | vit: 'large' 18 | vit_grad_ckpt: False 19 | vit_ckpt_layer: 0 20 | batch_size: 32 21 | init_lr: 1e-5 22 | 23 | # vit: 'large' 24 | # vit_grad_ckpt: True 25 | # vit_ckpt_layer: 5 26 | # batch_size: 16 27 | # init_lr: 2e-6 28 | 29 | image_size: 384 30 | 31 | # generation configs 32 | max_length: 20 33 | min_length: 5 34 | num_beams: 3 35 | prompt: 'a picture of ' 36 | 37 | # optimizer 38 | weight_decay: 0.05 39 | min_lr: 0 40 | max_epoch: 5 41 | 42 | 43 | -------------------------------------------------------------------------------- /models/BLIP/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /models/BLIP/vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on timm code base 8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from functools import partial 15 | 16 | from timm.models.vision_transformer import _cfg, PatchEmbed 17 | from timm.models.registry import register_model 18 | from timm.models.layers import trunc_normal_, DropPath 19 | from timm.models.helpers import named_apply, adapt_input_conv 20 | 21 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 22 | 23 | class Mlp(nn.Module): 24 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 25 | """ 26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x): 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 50 | self.scale = qk_scale or head_dim ** -0.5 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | self.attn_gradients = None 56 | self.attention_map = None 57 | 58 | def save_attn_gradients(self, attn_gradients): 59 | self.attn_gradients = attn_gradients 60 | 61 | def get_attn_gradients(self): 62 | return self.attn_gradients 63 | 64 | def save_attention_map(self, attention_map): 65 | self.attention_map = attention_map 66 | 67 | def get_attention_map(self): 68 | return self.attention_map 69 | 70 | def forward(self, x, register_hook=False): 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 73 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 74 | 75 | attn = (q @ k.transpose(-2, -1)) * self.scale 76 | attn = attn.softmax(dim=-1) 77 | attn = self.attn_drop(attn) 78 | 79 | if register_hook: 80 | self.save_attention_map(attn) 81 | attn.register_hook(self.save_attn_gradients) 82 | 83 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 84 | x = self.proj(x) 85 | x = self.proj_drop(x) 86 | return x 87 | 88 | 89 | class Block(nn.Module): 90 | 91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): 93 | super().__init__() 94 | self.norm1 = norm_layer(dim) 95 | self.attn = Attention( 96 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 97 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 99 | self.norm2 = norm_layer(dim) 100 | mlp_hidden_dim = int(dim * mlp_ratio) 101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 102 | 103 | if use_grad_checkpointing: 104 | self.attn = checkpoint_wrapper(self.attn) 105 | self.mlp = checkpoint_wrapper(self.mlp) 106 | 107 | def forward(self, x, register_hook=False): 108 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 109 | x = x + self.drop_path(self.mlp(self.norm2(x))) 110 | return x 111 | 112 | 113 | class VisionTransformer(nn.Module): 114 | """ Vision Transformer 115 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 116 | https://arxiv.org/abs/2010.11929 117 | """ 118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 119 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 120 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 121 | use_grad_checkpointing=False, ckpt_layer=0): 122 | """ 123 | Args: 124 | img_size (int, tuple): input image size 125 | patch_size (int, tuple): patch size 126 | in_chans (int): number of input channels 127 | num_classes (int): number of classes for classification head 128 | embed_dim (int): embedding dimension 129 | depth (int): depth of transformer 130 | num_heads (int): number of attention heads 131 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 132 | qkv_bias (bool): enable bias for qkv if True 133 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 134 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 135 | drop_rate (float): dropout rate 136 | attn_drop_rate (float): attention dropout rate 137 | drop_path_rate (float): stochastic depth rate 138 | norm_layer: (nn.Module): normalization layer 139 | """ 140 | super().__init__() 141 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 142 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 143 | 144 | self.patch_embed = PatchEmbed( 145 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 146 | 147 | num_patches = self.patch_embed.num_patches 148 | 149 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 150 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 151 | self.pos_drop = nn.Dropout(p=drop_rate) 152 | 153 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 154 | self.blocks = nn.ModuleList([ 155 | Block( 156 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 157 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 158 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) 159 | ) 160 | for i in range(depth)]) 161 | self.norm = norm_layer(embed_dim) 162 | 163 | trunc_normal_(self.pos_embed, std=.02) 164 | trunc_normal_(self.cls_token, std=.02) 165 | self.apply(self._init_weights) 166 | 167 | def _init_weights(self, m): 168 | if isinstance(m, nn.Linear): 169 | trunc_normal_(m.weight, std=.02) 170 | if isinstance(m, nn.Linear) and m.bias is not None: 171 | nn.init.constant_(m.bias, 0) 172 | elif isinstance(m, nn.LayerNorm): 173 | nn.init.constant_(m.bias, 0) 174 | nn.init.constant_(m.weight, 1.0) 175 | 176 | @torch.jit.ignore 177 | def no_weight_decay(self): 178 | return {'pos_embed', 'cls_token'} 179 | 180 | def forward(self, x, register_blk=-1): 181 | B = x.shape[0] 182 | x = self.patch_embed(x) 183 | 184 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 185 | x = torch.cat((cls_tokens, x), dim=1) 186 | 187 | x = x + self.pos_embed[:,:x.size(1),:] 188 | x = self.pos_drop(x) 189 | 190 | for i,blk in enumerate(self.blocks): 191 | x = blk(x, register_blk==i) 192 | x = self.norm(x) 193 | 194 | return x 195 | 196 | @torch.jit.ignore() 197 | def load_pretrained(self, checkpoint_path, prefix=''): 198 | _load_weights(self, checkpoint_path, prefix) 199 | 200 | 201 | @torch.no_grad() 202 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 203 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 204 | """ 205 | import numpy as np 206 | 207 | def _n2p(w, t=True): 208 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 209 | w = w.flatten() 210 | if t: 211 | if w.ndim == 4: 212 | w = w.transpose([3, 2, 0, 1]) 213 | elif w.ndim == 3: 214 | w = w.transpose([2, 0, 1]) 215 | elif w.ndim == 2: 216 | w = w.transpose([1, 0]) 217 | return torch.from_numpy(w) 218 | 219 | w = np.load(checkpoint_path) 220 | if not prefix and 'opt/target/embedding/kernel' in w: 221 | prefix = 'opt/target/' 222 | 223 | if hasattr(model.patch_embed, 'backbone'): 224 | # hybrid 225 | backbone = model.patch_embed.backbone 226 | stem_only = not hasattr(backbone, 'stem') 227 | stem = backbone if stem_only else backbone.stem 228 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 229 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 230 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 231 | if not stem_only: 232 | for i, stage in enumerate(backbone.stages): 233 | for j, block in enumerate(stage.blocks): 234 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 235 | for r in range(3): 236 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 237 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 238 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 239 | if block.downsample is not None: 240 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 241 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 242 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 243 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 244 | else: 245 | embed_conv_w = adapt_input_conv( 246 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 247 | model.patch_embed.proj.weight.copy_(embed_conv_w) 248 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 249 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 250 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 251 | if pos_embed_w.shape != model.pos_embed.shape: 252 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 253 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 254 | model.pos_embed.copy_(pos_embed_w) 255 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 256 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 257 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 258 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 259 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 260 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 261 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 262 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 263 | for i, block in enumerate(model.blocks.children()): 264 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 265 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 266 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 267 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 268 | block.attn.qkv.weight.copy_(torch.cat([ 269 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 270 | block.attn.qkv.bias.copy_(torch.cat([ 271 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 272 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 273 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 274 | for r in range(2): 275 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 276 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 277 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 278 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 279 | 280 | 281 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 282 | # interpolate position embedding 283 | embedding_size = pos_embed_checkpoint.shape[-1] 284 | num_patches = visual_encoder.patch_embed.num_patches 285 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 286 | # height (== width) for the checkpoint position embedding 287 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 288 | # height (== width) for the new position embedding 289 | new_size = int(num_patches ** 0.5) 290 | 291 | if orig_size!=new_size: 292 | # class_token and dist_token are kept unchanged 293 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 294 | # only the position tokens are interpolated 295 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 296 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 297 | pos_tokens = torch.nn.functional.interpolate( 298 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 299 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 300 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 301 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) 302 | 303 | return new_pos_embed 304 | else: 305 | return pos_embed_checkpoint -------------------------------------------------------------------------------- /models/GIT/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/GIT/__init__.py -------------------------------------------------------------------------------- /models/GIT/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/GIT/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/GIT/__pycache__/git_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/GIT/__pycache__/git_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/GIT/git.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home/data_ti4_c/chengkz/ofa/models/GIT') 3 | import torch 4 | import torch.nn as nn 5 | from git_model import GitForCausalLM 6 | from transformers import AutoProcessor 7 | import torch.nn.functional as F 8 | from utils.beamsearch import beam_search, beam_search_scst 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | 12 | class GIT(nn.Module): 13 | 14 | def __init__(self, config, distill_model=False): 15 | super(GIT, self).__init__() 16 | self.config = config 17 | self.processor = AutoProcessor.from_pretrained("microsoft/git-large-coco", local_files_only=False) 18 | self.tokenizer = self.processor.tokenizer 19 | if distill_model: 20 | self.git_model = GitForCausalLM.from_pretrained(config.git_distill) 21 | else: 22 | self.git_model = GitForCausalLM.from_pretrained(config.git) 23 | 24 | def get_enc_output(self, patch_img): 25 | # 为了实现decode_step函数,需要将视觉编码单独分割出来,避免decode_step中每一步都forward视觉编码器 26 | # 同时也改写了GitModel的forward接口,允许提供视觉编码结果以避免重复计算 27 | projected_visual_features = None 28 | if patch_img is not None: 29 | if patch_img.ndim == 4: 30 | # here we assume patch_img is of shape (batch_size, num_channels, height, width) 31 | visual_features = self.git_model.git.image_encoder(patch_img).last_hidden_state 32 | elif patch_img.ndim == 5: 33 | # here we assume patch_img is of shape (batch_size, num_frames, num_channels, height, width) 34 | visual_features = [] 35 | for frame_idx in range(patch_img.shape[1]): 36 | visual_features_frame = self.git_model.git.image_encoder(patch_img[:, frame_idx, :, :]).last_hidden_state 37 | visual_features_frame += self.git_model.git.img_temperal_embedding[frame_idx] 38 | visual_features.append(visual_features_frame) 39 | 40 | # finally, concatenate all features along sequence dimension 41 | visual_features = torch.cat(visual_features, dim=1) 42 | else: 43 | raise ValueError("patch_img must be of rank 4 or 5") 44 | projected_visual_features = self.git_model.git.visual_projection(visual_features) 45 | return projected_visual_features 46 | 47 | def forward(self, patch_img, cap, att_mask, cap_len): 48 | batch_size = patch_img.shape[0] 49 | with torch.no_grad(): 50 | visual_features = self.get_enc_output(patch_img) 51 | logits = self.git_model(input_ids=cap, attention_mask=att_mask, visual_features=visual_features, pixel_values=patch_img).logits 52 | logits = logits[:, -20:, :] 53 | return logits 54 | 55 | def decode_step(self, input_ids, context): 56 | visual_features = context[0] 57 | patch_img = context[1] 58 | att_mask = torch.ones(input_ids.shape).long().to(device) 59 | logits = self.git_model(input_ids=input_ids, attention_mask=att_mask, visual_features=visual_features, pixel_values=patch_img).logits 60 | return logits, None 61 | 62 | def greedy_search(self, patch_img, mode='max'): 63 | """ 64 | patch_img: [batch_size, *img_patch_size] 65 | """ 66 | # 贪心搜索,返回的tokens应该是带有开始符和结束符的,以便用作pseudo-caption 67 | fixed_len = self.config.fixed_len 68 | gen_num = self.config.beam_num if mode == 'prob' else 1 69 | batch_size = patch_img.shape[0]*gen_num 70 | # GIT模型的bos符是bert的cls符,101 71 | sentences = torch.full((batch_size, 1), self.tokenizer.cls_token_id).long().to(device) 72 | log_probs_sen = torch.full((batch_size, 0), 0.0).to(device) 73 | cap_len = torch.LongTensor([fixed_len for i in range(batch_size)]).to(device) 74 | 75 | with torch.no_grad(): 76 | visual_features = self.get_enc_output(patch_img) 77 | 78 | for i in range(fixed_len): 79 | attention_mask = torch.ones(sentences.shape).long().to(device) 80 | logits_all = self.git_model(input_ids=sentences, attention_mask=attention_mask, visual_features=visual_features, 81 | pixel_values=patch_img).logits 82 | logits = logits_all[:, -1, :] 83 | probs = F.softmax(logits, dim=-1) 84 | if mode == 'prob': 85 | token_id = torch.multinomial(probs, 1)[:, 0] 86 | else: 87 | score, token_id = torch.max(probs, dim=-1) 88 | for j in range(batch_size): # 生成过程中记录生成句子长度 89 | if token_id[j].item() == self.tokenizer.sep_token_id and cap_len[j].item() == fixed_len: 90 | cap_len[j] = i + 1 91 | sentences = torch.cat([sentences, token_id.unsqueeze(1)], dim=1) 92 | token_id = token_id.unsqueeze(1) 93 | log_probs_sen = torch.cat([log_probs_sen, torch.log(torch.gather(probs, 1, token_id))], dim=-1) 94 | 95 | all_tokens = [sentences[i][:(cap_len[i] + 1)] for i in range(batch_size)] 96 | all_logprob = [log_probs_sen[i][:cap_len[i]] for i in range(batch_size)] 97 | return all_tokens, all_logprob 98 | 99 | def generate_caption_batchbs(self, patch_img): 100 | batch_size = patch_img.shape[0] 101 | with torch.no_grad(): 102 | visual_features = self.get_enc_output(patch_img) 103 | visual_features = visual_features.repeat_interleave(self.config.beam_num, dim=0) 104 | 105 | vocab_size = 30522 106 | captions = beam_search('Transformer', [visual_features, patch_img], self, batch_size, self.config.fixed_len, self.config.beam_num, 107 | vocab_size, self.config.length_penalty, bos_token_id=self.tokenizer.cls_token_id, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.sep_token_id) 108 | return captions 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /models/OFA/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/OFA/.DS_Store -------------------------------------------------------------------------------- /models/OFA/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/OFA/__init__.py -------------------------------------------------------------------------------- /models/OFA/__pycache__/ofa.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/OFA/__pycache__/ofa.cpython-37.pyc -------------------------------------------------------------------------------- /models/OFA/__pycache__/ofa_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/OFA/__pycache__/ofa_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/OFA/ofa.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home/data_ti4_c/chengkz/ofa/models/OFA') 3 | import torch 4 | import torch.nn as nn 5 | from ofa_model import OFAModel 6 | from transformers.models.ofa.tokenization_ofa import OFATokenizer 7 | import torch.nn.functional as F 8 | from utils.beamsearch import beam_search, beam_search_scst 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | 12 | class OFA(nn.Module): 13 | 14 | def __init__(self, config, distill_model=False): 15 | super(OFA, self).__init__() 16 | self.config = config 17 | self.ofa_ckpts = config.ofa_ckpts 18 | self.tokenizer = OFATokenizer.from_pretrained(self.ofa_ckpts) 19 | if distill_model: 20 | self.ofa_model = OFAModel.from_pretrained(self.config.ofa_ckpts_distill, use_cache=False).to(device) 21 | else: 22 | self.ofa_model = OFAModel.from_pretrained(self.config.ofa_ckpts, use_cache=False).to(device) 23 | #self.ofa_encoder = self.ofa_model.encoder 24 | self.prompt = " what does the image describe?" 25 | self.prompt_input = self.tokenizer([self.prompt], return_tensors="pt").input_ids.to(device) 26 | self.frozen() 27 | # self.re_init() 28 | 29 | def frozen(self): 30 | for name, params in self.named_parameters(): 31 | if 'encoder' in name: 32 | params.requires_grad = False 33 | 34 | def un_frozen(self): 35 | for name, params in self.named_parameters(): 36 | if 'encoder' in name: 37 | params.requires_grad = True 38 | 39 | def re_init(self): 40 | print("reinit decoder") 41 | self.ofa_model.decoder.init_weights() 42 | 43 | def gen_enc_output(self, patch_img): 44 | """ 45 | patch_img: [batch_size, *img_patch_size] 46 | return: [batch_size, 908, 1024] 47 | """ 48 | batch_size = patch_img.shape[0] 49 | prompt_input = self.prompt_input.expand([batch_size, self.prompt_input.shape[1]]) 50 | encoder_outputs = self.ofa_model.encoder(input_ids=prompt_input, patch_images=patch_img) 51 | return encoder_outputs 52 | 53 | def forward(self, patch_img, cap, att_mask, cap_len): 54 | batch_size = patch_img.shape[0] 55 | # with torch.no_grad(): 56 | enc_output = self.gen_enc_output(patch_img) 57 | sentences = cap 58 | attention_mask = att_mask 59 | logits = self.ofa_model(decoder_input_ids=sentences, # [batch_size, cap_len, vocab_size] 60 | attention_mask=attention_mask, encoder_outputs=enc_output).logits 61 | return logits 62 | 63 | def decode_step(self, input_ids, context): 64 | enc_output = context[0] 65 | sentences = input_ids 66 | attention_mask = torch.ones(sentences.shape).long().to(device) 67 | logits = self.ofa_model(decoder_input_ids=sentences, # [batch_size, cap_len, vocab_size] 68 | attention_mask=attention_mask, encoder_outputs=enc_output).logits 69 | return logits, None 70 | 71 | 72 | def greedy_search(self, patch_img, mode='max'): 73 | """ 74 | patch_img: [batch_size, *img_patch_size] 75 | """ 76 | # 贪心搜索,返回的tokens应该是带有开始符和结束符的,以便用作pseudo-caption 77 | fixed_len = self.config.fixed_len 78 | gen_num = self.config.beam_num if mode == 'prob' else 1 79 | batch_size = patch_img.shape[0]*gen_num 80 | # OFA模型的bos符是0 81 | sentences = torch.zeros([batch_size, 1]).long().to(device) 82 | log_probs_sen = torch.full((batch_size, 0), 0.0).to(device) 83 | cap_len = torch.LongTensor([fixed_len for i in range(batch_size)]).to(device) 84 | 85 | with torch.no_grad(): 86 | enc_output = self.gen_enc_output(patch_img) # [batch_size, 908, 1024] 87 | if mode == 'prob': 88 | enc_output.last_hidden_state = enc_output.last_hidden_state.repeat(1, gen_num, 1). \ 89 | view(enc_output.last_hidden_state.shape[0] * gen_num, 90 | enc_output.last_hidden_state.shape[1], enc_output.last_hidden_state.shape[2]) 91 | enc_output.position_embedding = enc_output.position_embedding.repeat(1, gen_num, 1). \ 92 | view(enc_output.position_embedding.shape[0] * gen_num, 93 | enc_output.position_embedding.shape[1], enc_output.position_embedding.shape[2]) 94 | enc_output.padding_mask = enc_output.padding_mask.repeat(1, gen_num). \ 95 | view(enc_output.padding_mask.shape[0] * gen_num, enc_output.padding_mask.shape[1]) 96 | 97 | for i in range(fixed_len): 98 | attention_mask = torch.ones(sentences.shape).long().to(device) 99 | logits_all = self.ofa_model(decoder_input_ids=sentences, # [batch_size, 1, vocab_size] 100 | attention_mask=attention_mask, encoder_outputs=enc_output).logits 101 | logits = logits_all[:, -1, :] 102 | probs = F.softmax(logits, dim=-1) 103 | if mode == 'prob': 104 | token_id = torch.multinomial(probs, 1)[:, 0] 105 | else: 106 | score, token_id = torch.max(probs, dim=-1) 107 | for j in range(batch_size): # 生成过程中记录生成句子长度 108 | if token_id[j].item() == 2 and cap_len[j].item() == fixed_len: 109 | cap_len[j] = i + 1 110 | sentences = torch.cat([sentences, token_id.unsqueeze(1)], dim=1) 111 | token_id = token_id.unsqueeze(1) 112 | log_probs_sen = torch.cat([log_probs_sen, torch.log(torch.gather(probs, 1, token_id))], dim=-1) 113 | 114 | all_tokens = [sentences[i][:(cap_len[i] + 1)] for i in range(batch_size)] 115 | all_logprob = [log_probs_sen[i][:cap_len[i]] for i in range(batch_size)] 116 | return all_tokens, all_logprob 117 | 118 | def generate_caption_batchbs(self, patch_img): 119 | batch_size = patch_img.shape[0] 120 | with torch.no_grad(): 121 | enc_output = self.gen_enc_output(patch_img) 122 | enc_output.last_hidden_state = enc_output.last_hidden_state.repeat(1, self.config.beam_num, 1).\ 123 | view(enc_output.last_hidden_state.shape[0]*self.config.beam_num, enc_output.last_hidden_state.shape[1], enc_output.last_hidden_state.shape[2]) 124 | enc_output.position_embedding = enc_output.position_embedding.repeat(1, self.config.beam_num, 1).\ 125 | view(enc_output.position_embedding.shape[0]*self.config.beam_num, enc_output.position_embedding.shape[1], enc_output.position_embedding.shape[2]) 126 | enc_output.padding_mask = enc_output.padding_mask.repeat(1, self.config.beam_num).\ 127 | view(enc_output.padding_mask.shape[0]*self.config.beam_num, enc_output.padding_mask.shape[1]) 128 | vocab_size = 59457 129 | captions = beam_search('Transformer', [enc_output], self, batch_size, self.config.fixed_len, self.config.beam_num, 130 | vocab_size, self.config.length_penalty, bos_token_id=0, pad_token_id=1, eos_token_id=2) 131 | return captions -------------------------------------------------------------------------------- /models/Transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/Transformer/__init__.py -------------------------------------------------------------------------------- /models/Transformer/transformer.py: -------------------------------------------------------------------------------- 1 | # 基于Transformer架构的图像描述模型 2 | # 包含使用faster-rcnn特征作为输入和cnn特征作为输入 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.utils.weight_norm import weight_norm 7 | import torch.nn.functional as F 8 | import pickle 9 | import math 10 | from utils.beamsearch import beam_search, beam_search_scst 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | 13 | 14 | class PositionalEncoding(nn.Module): 15 | 16 | def __init__(self, d_model, dropout=0.1, max_len=30): 17 | super(PositionalEncoding, self).__init__() 18 | self.dropout = nn.Dropout(p=dropout) 19 | 20 | pe = torch.zeros(max_len, d_model) 21 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 22 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 23 | pe[:, 0::2] = torch.sin(position * div_term) 24 | pe[:, 1::2] = torch.cos(position * div_term) 25 | pe = pe.unsqueeze(0).transpose(0, 1) 26 | self.register_buffer('pe', pe) 27 | 28 | def forward(self, x): 29 | x = x + self.pe[:x.size(0), :] 30 | return self.dropout(x) 31 | 32 | 33 | class Transformer_Encoder(nn.Module): 34 | 35 | def __init__(self, config): 36 | super(Transformer_Encoder, self).__init__() 37 | self.config = config 38 | self.image_dim = config.image_dim 39 | self.embed_dim = config.embed_dim 40 | self.fea2embed = nn.Linear(self.image_dim, self.embed_dim) 41 | encoder_layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=8) 42 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) 43 | 44 | def forward(self, fea_maps): 45 | fea_maps = self.fea2embed(fea_maps) 46 | fea_maps_seq = fea_maps.permute(1, 0, 2) 47 | memory = self.transformer_encoder(src=fea_maps_seq) 48 | return memory 49 | 50 | 51 | class Transformer_Decoder(nn.Module): 52 | 53 | def __init__(self, config): 54 | super(Transformer_Decoder, self).__init__() 55 | self.config = config 56 | self.vocab = pickle.load(open(self.config.vocab, 'rb')) 57 | self.vocab_size = self.vocab.get_size() 58 | self.embed_dim = config.embed_dim 59 | 60 | self.embed = nn.Embedding(self.vocab_size, self.embed_dim) 61 | 62 | self.pos_encoder = PositionalEncoding(self.embed_dim) 63 | decoder_layer = nn.TransformerDecoderLayer(d_model=self.embed_dim, nhead=8) 64 | self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) 65 | 66 | self.fc = weight_norm(nn.Linear(self.embed_dim, self.vocab_size)) 67 | self.dropout = nn.Dropout(0.5) 68 | 69 | def gen_tgt_mask(self, length): 70 | mask = torch.triu(torch.ones(length, length)).permute(1, 0).to(device) 71 | mask = mask.float().masked_fill(mask==0, float('-inf')).masked_fill(mask==1, float(0.0)) 72 | return mask 73 | 74 | def forward(self, memory, cap, cap_len): 75 | cap = cap.permute(1, 0) 76 | tgt_pos_embedding = self.pos_encoder(self.embed(cap)*math.sqrt(self.embed_dim)) 77 | tgt_mask = self.gen_tgt_mask(tgt_pos_embedding.shape[0]) 78 | out = self.transformer_decoder(tgt=tgt_pos_embedding, memory=memory, tgt_mask=tgt_mask) 79 | 80 | pred = self.fc(self.dropout(out)) 81 | pred = pred.permute(1, 0, 2) 82 | 83 | return pred 84 | 85 | def decode_step(self, input_ids, context): 86 | memory = context[0] 87 | cap = input_ids.permute(1, 0) 88 | tgt_pos_embedding = self.pos_encoder(self.embed(cap) * math.sqrt(self.embed_dim)) 89 | tgt_mask = self.gen_tgt_mask(tgt_pos_embedding.shape[0]) 90 | out = self.transformer_decoder(tgt=tgt_pos_embedding, memory=memory, tgt_mask=tgt_mask) 91 | 92 | pred = self.fc(self.dropout(out)) 93 | pred = pred.permute(1, 0, 2) 94 | return pred, None 95 | 96 | 97 | class Transformer_Cap(nn.Module): 98 | 99 | def __init__(self, config): 100 | super(Transformer_Cap, self).__init__() 101 | self.config = config 102 | self.transformer_encoder = Transformer_Encoder(self.config) 103 | self.transformer_decoder = Transformer_Decoder(self.config) 104 | 105 | def forward(self, image_feature, cap, cap_len, mode='xe'): 106 | if mode == 'xe': 107 | fea_maps = image_feature['feature_map'] 108 | memory = self.transformer_encoder(fea_maps) 109 | logit = self.transformer_decoder(memory, cap, cap_len) 110 | return logit 111 | elif mode == 'vanilla_scst': 112 | return self.greedy_search(image_feature, 'prob') 113 | 114 | def beam_search(self, image_feature): 115 | fea_maps = image_feature['feature_map'] 116 | batch_size = fea_maps.shape[0] 117 | memory = self.transformer_encoder(fea_maps) 118 | memory = memory.repeat(1, 1, self.config.beam_num).view(memory.shape[0], memory.shape[1]*self.config.beam_num, memory.shape[2]) 119 | captions, all_tokens, all_logprob = beam_search_scst('Transformer', [memory], self.transformer_decoder, batch_size, self.config.fixed_len, self.config.beam_num, 120 | self.transformer_decoder.vocab_size, self.config.length_penalty) 121 | return captions, all_tokens, all_logprob 122 | 123 | def greedy_search(self, image_feature, mode='max'): 124 | # greedy search或多项式采样search 125 | fea_maps = image_feature['feature_map'] 126 | # 对一个样本采样beam_num个结果 127 | gen_num = self.config.beam_num if mode == 'prob' else 1 128 | fea_maps = fea_maps.unsqueeze(dim=1) 129 | fea_maps = fea_maps.expand([fea_maps.shape[0], gen_num, fea_maps.shape[2], fea_maps.shape[3]]) 130 | fea_maps = fea_maps.reshape(fea_maps.shape[0] * fea_maps.shape[1], fea_maps.shape[2], fea_maps.shape[3]) 131 | batch_size = fea_maps.shape[0] 132 | 133 | sentences = torch.ones([batch_size, 1]).to(device).long() 134 | log_probs_sen = torch.full((batch_size, 0), 0.0).to(device) 135 | cap_len = torch.LongTensor([20 for i in range(batch_size)]).to(device) 136 | 137 | memory = self.transformer_encoder(fea_maps) 138 | context = [memory] 139 | for i in range(self.config.fixed_len): 140 | outputs, _ = self.transformer_decoder.decode_step(sentences, context) 141 | logits = outputs[:, -1, :] 142 | probs = F.softmax(logits, dim=-1) 143 | if mode == 'prob': 144 | token_id = torch.multinomial(probs, 1)[:, 0] 145 | else: 146 | score, token_id = torch.max(probs, dim=-1) 147 | for j in range(batch_size): # 生成过程中记录生成句子长度 148 | if token_id[j].item() == 2 and cap_len[j].item() == 20: 149 | cap_len[j] = i + 1 150 | sentences = torch.cat([sentences, token_id.unsqueeze(1)], dim=1) 151 | token_id = token_id.unsqueeze(1) 152 | log_probs_sen = torch.cat([log_probs_sen, torch.log(torch.gather(probs, 1, token_id))], dim=-1) 153 | 154 | # 利用生成句子长度mask 155 | all_tokens = [sentences[i][:(cap_len[i] + 1)] for i in range(batch_size)] 156 | all_logprob = [log_probs_sen[i][:cap_len[i]] for i in range(batch_size)] 157 | 158 | return all_tokens, all_logprob 159 | 160 | def generate_caption_batchbs(self, image_feature): 161 | fea_maps = image_feature['feature_map'] 162 | batch_size = fea_maps.shape[0] 163 | memory = self.transformer_encoder(fea_maps) 164 | memory = memory.repeat(1, 1, self.config.beam_num).view(memory.shape[0], memory.shape[1]*self.config.beam_num, memory.shape[2]) 165 | caption = beam_search('Transformer', [memory], self.transformer_decoder, batch_size, self.config.fixed_len, self.config.beam_num, 166 | self.transformer_decoder.vocab_size, self.config.length_penalty) 167 | return caption 168 | 169 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## Beyond Generic: Enhancing Image Captioning with Real-World Knowledge using Vision-Language Pre-Training Model 2 | This repo provides the source code & data of our paper: [Beyond Generic: Enhancing Image Captioning with Real-World Knowledge using Vision-Language Pre-Training Model. (ACMMM 23)](https://arxiv.org/abs/2308.01126) 3 | ``` 4 | @misc{cheng2023generic, 5 | title={Beyond Generic: Enhancing Image Captioning with Real-World Knowledge using Vision-Language Pre-Training Model}, 6 | author={Kanzhi Cheng and Wenpo Song and Zheng Ma and Wenhao Zhu and Zixuan Zhu and Jianbing Zhang}, 7 | year={2023}, 8 | eprint={2308.01126}, 9 | archivePrefix={arXiv}, 10 | primaryClass={cs.CV} 11 | } 12 | ``` 13 | ### Code Structure 14 | *** 15 | ```` 16 | ├── config.py # config 17 | ├── data # coco data & knowcap data 18 | │   ├── data_cc12m_SelectForRreplay.json 19 | │   ├── dataset_coco.json 20 | │   ├── test.json 21 | │   ├── train.json 22 | │   ├── val.json 23 | │   ├── knowcap_240.json 24 | │   ├── knowcap_240_test.json 25 | │   ├── knowcap_240_test_unseen.json 26 | │   ├── knowcap_240_val.json 27 | │   ├── train_mix_32000.json 28 | │   └── ... 29 | ├── data_load.py # dataloader 30 | ├── test.py # evaluation on coco 31 | ├── test_knowcap.py # evaluation on knowcap 32 | ├── models # models (OFA,BLIP,GIT) 33 | │   ├── OFA 34 | │   ├── BLIP 35 | │   └── GIT 36 | ├── train_multitask.py # K-Replay training 37 | └── utils # support codes & tools 38 | ├── beamsearch.py # beamsearch 39 | ├── cc12m.py # filter relay data from cc12m 40 | ├── convert_ofa.py # ckpts convert 41 | ├── eval.py # generate captions & calculate metrics 42 | ├── import_models.py 43 | ├── log.py 44 | ├── loss.py # loss function of K-Replay 45 | ├── optimizer_tools.py 46 | └── prepro_data.py # construct the data in ./data 47 | ```` 48 | ### KnowCap Dataset 49 | *** 50 | KnowCap is a new dataset for the evaluation of knowledge-enhanced image captioning, containing 1424 images and 4156 reference descriptions 51 | carefully written by human annotators. 52 | 53 | ![](knowcap.png) 54 | 55 | Download the images and annotations of [KnowCap](https://drive.google.com/file/d/1DOk5WZZgHyO6tKT8A135hMgePid-akFq/view?usp=drive_link). 56 | ### Preparing Data&Model 57 | *** 58 | #### Step1: 59 | Download the images of: 60 | * [COCO2014](https://github.com/ruotianluo/ImageCaptioning.pytorch/blob/master/data/README.md) 61 | * [KnowCap](https://drive.google.com/file/d/1DOk5WZZgHyO6tKT8A135hMgePid-akFq/view?usp=drive_link) 62 | * [Replay images selected from cc12m](https://drive.google.com/file/d/1tdVZ1rUpr5va-NwInMwBglRpSGOzUoMu/view?usp=drive_link) 63 | #### Step2: 64 | `prepro_data.py`, Collate and split coco and knowcap datasets in ./data. 65 | 66 | Alternatively, we provide the processed [data](https://drive.google.com/file/d/1DBdnqcH_lOm--t5pZOlac1j1my4kVgrP/view?usp=drive_link) that can be put into . /data directory. Note that the file_path in each dataset needs to be modified according to the path of the downloaded image in step1. Similarly, some of the parameters in config need to be modified depending on your own. 67 | 68 | #### Step3: 69 | Prepare the ckpts of VLP models (take OFA as an example) for training and testing. 70 | 1. Download the transformers version ckpts of [OFA](https://huggingface.co/OFA-Sys/ofa-large) 71 | 2. However, since there are some [problems](https://github.com/OFA-Sys/OFA/issues/296) with the official ckpts in transformers, we manually replaced the original parameters with the official ckpts in fairseq using `convert_ofa.py` 72 | 73 | Alternatively, we provide the converted [ckpts](https://drive.google.com/file/d/1QQZ9eyO63JBBtyK5YIKA4CJ3jjAPuhQM/view?usp=drive_link). 74 | ### Reproduce the main results 75 | *** 76 | The baseline result of *OFA* in knowcap: `CUDA_VISIBLE_DEVICES=0 python test_knowcap.py --model OFA --ofa_ckpts xxx --length_penalty 1.0`, the `ofa_ckpts` is obtained in step3. 77 | 78 | The *OFA+K-Replay* result in knowcap: `CUDA_VISIBLE_DEVICES=0 python test_knowcap.py --model OFA --trained_ckpts xxx --length_penalty 1.0`, the `trained_ckpts` can be downloaded in [here](https://drive.google.com/file/d/1z2InwjGOcmTOFGr25nIFI_tGCPNBnc1H/view?usp=drive_link). 79 | 80 | To evaluate on coco, use `test.py` instead of `test_knowcap.py`. 81 | 82 | > #### Tips: 83 | To eliminate the need for `coco_id` in the evaluation, we customized the COCOEval function in `eval.py`. 84 | Therefore the `xxx/site-packages/pycocoevalcap/eval.py` needs to be replaced or modified with the `eval.py` to use the current evaluation code. 85 | ### Training with K-Replay 86 | *** 87 | #### Step4: 88 | Start Training with K-Replay: 89 | `CUDA_VISIBLE_DEVICES=0 python train_multitask.py --mode train --model OFA --id ofa_kreplay --batch_size 60 --learning_rate 7e-6 --label_smoothing 0.1 --multitask_weight 1.0 --KD_temperature 16.0 --knowdistill_weight 1.0 --save_model_freq 100 --ofa_ckpts /home/chengkz/checkpoints/ofa/OFA-large-caption-trainedenc --ofa_ckpts_distill /home/chengkz/checkpoints/ofa/OFA-large-caption-XEfinetuned --train_mix ./data/train_mix_32000.json --method XEdistill`. 90 | 91 | The `ofa_ckpts` and `ofa_ckpts_distill` are obtained in step3, `train_mix_32000.json` is obtained in step2. 92 | #### Step5: 93 | Evaluation on COCO: 94 | `CUDA_VISIBLE_DEVICES=0 python test.py --model OFA --id ofa_kreplay --step 300 --length_penalty 1.0`. 95 | 96 | Evaluation on KnowCap: `CUDA_VISIBLE_DEVICES=0 python test_knowcap.py --model OFA --id ofa_kreplay --step 300 --length_penalty 1.0`. 97 | > #### Tips: 98 | OFA uses `resnet` as the backbone of its visual encoder. In our experiments, we found that the `batchnorm` layers in the resnet backbone do not give good estimates of the `mean` and `std` due to the small batchsize we used, which leads to a degradation of the model performance. Therefore, we fixed the `mean` and `std` of these layers during training, by setting `momentum=0.0` in `./transformers/models/ofa/resnet.py`. 99 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | aiohttp 3 | astunparse==1.6.3 4 | async-timeout==3.0.1 5 | attrs 6 | blinker==1.4 7 | brotlipy==0.7.0 8 | cached-property==1.5.2 9 | cachetools 10 | certifi==2021.5.30 11 | cffi==1.14.0 12 | chardet 13 | click 14 | clip 15 | coverage 16 | cryptography 17 | cycler==0.10.0 18 | Cython==0.29.24 19 | decorator==4.4.2 20 | dnspython==2.2.1 21 | echo1-coco-split==0.1.5 22 | et-xmlfile==1.1.0 23 | eventlet==0.33.2 24 | fairscale==0.4.6 25 | filelock==3.3.2 26 | flatbuffers==23.5.26 27 | ftfy==6.1.1 28 | funcy==1.17 29 | future==0.18.2 30 | gast==0.4.0 31 | google-auth 32 | google-auth-oauthlib==0.4.1 33 | google-pasta==0.2.0 34 | greenlet==2.0.1 35 | grpcio 36 | h5py==3.3.0 37 | huggingface-hub==0.13.3 38 | idna 39 | imageio==2.9.0 40 | importlib-metadata==5.1.0 41 | jieba==0.42.1 42 | joblib==1.1.0 43 | jsonlines==3.0.0 44 | keras==2.11.0 45 | kiwisolver 46 | libclang==16.0.6 47 | llvmlite==0.39.1 48 | Markdown 49 | matplotlib==3.4.2 50 | mkl-service==2.3.0 51 | multidict 52 | networkx==2.5.1 53 | nltk==3.4.5 54 | numba==0.56.4 55 | numpy==1.21.0 56 | oauthlib==3.1.0 57 | olefile==0.46 58 | opencv-python==4.5.5.64 59 | openpyxl==3.0.10 60 | opt-einsum==3.3.0 61 | packaging==21.2 62 | pandas==1.3.5 63 | Pillow==8.3.0 64 | protobuf==3.14.0 65 | pyasn1==0.4.8 66 | pyasn1-modules==0.2.8 67 | pycocoevalcap==1.2 68 | pycocotools==2.0.2 69 | pycparser 70 | PyJWT==1.7.1 71 | pynndescent==0.5.8 72 | pyOpenSSL 73 | pyparsing 74 | PySocks 75 | python-dateutil 76 | pytz 77 | PyWavelets==1.1.1 78 | PyYAML==6.0 79 | regex==2021.11.1 80 | requests 81 | requests-oauthlib==1.3.0 82 | rouge==1.0.1 83 | rsa 84 | sacremoses==0.0.46 85 | scikit-image==0.18.2 86 | scikit-learn==1.0.2 87 | scipy==1.7.0 88 | seaborn==0.12.1 89 | sentencepiece==0.1.99 90 | six 91 | sklearn==0.0 92 | tensorboard==2.11.2 93 | tensorboard-data-server==0.6.1 94 | tensorboard-plugin-wit==1.6.0 95 | tensorflow==2.11.0 96 | tensorflow-estimator==2.11.0 97 | tensorflow-io-gcs-filesystem==0.33.0 98 | termcolor==2.3.0 99 | threadpoolctl==3.1.0 100 | tifffile==2021.7.2 101 | timeout-decorator==0.5.0 102 | timm==0.6.13 103 | tokenizers==0.12.1 104 | torch==1.8.1+cu101 105 | torchaudio==0.8.1 106 | torchvision==0.9.1+cu101 107 | tornado 108 | tqdm==4.65.0 109 | transformers==4.28.0.dev0 110 | typing-extensions 111 | umap==0.1.1 112 | umap-learn==0.5.3 113 | urllib3==1.25.8 114 | wcwidth==0.2.5 115 | Werkzeug 116 | wget==3.2 117 | wrapt==1.15.0 118 | XlsxWriter==3.0.3 119 | yarl 120 | zipp 121 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # 测试,通过test参数指定在testa还是testb上进行测试 2 | 3 | import torch 4 | import random 5 | import numpy as np 6 | import json 7 | 8 | from config import config 9 | 10 | from utils.import_models import construct_model 11 | from utils.eval import generate_captions, eval_pycoco 12 | from utils.vocab import Vocabulary 13 | 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | # 随机种子 18 | seed = config.seed 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | torch.backends.cudnn.deterministic = True 24 | 25 | # model 26 | model = construct_model(config).to(device) 27 | log_path = config.log_dir.format(config.id) 28 | trained_model_path = log_path + '/model/model_' + str(config.step) + '.pt' 29 | model.load_state_dict(torch.load(trained_model_path)) 30 | model.eval() 31 | with torch.no_grad(): 32 | gen_pycoco_path = generate_captions(config, model, config.step, config.test, final_test=True) 33 | 34 | if False: # 一些官方的ckpts会出现多余字符需要后处理 35 | pycoco = json.load(open(gen_pycoco_path, 'r')) 36 | for k, v in pycoco.items(): 37 | caption_origin = v[0]["caption"] 38 | caption_new = caption_origin.replace(')', '').replace('\\', '').replace('}', '').replace(']', '').strip() 39 | v[0]["caption"] = caption_new 40 | json.dump(pycoco, open(gen_pycoco_path, 'w')) 41 | 42 | pycoco_results = eval_pycoco(config, gen_pycoco_path, config.test) 43 | print(pycoco_results) 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /train_multitask.py: -------------------------------------------------------------------------------- 1 | # training with K-Replay 2 | import yaml 3 | import json 4 | import torch 5 | import random 6 | import numpy as np 7 | 8 | import time 9 | import torch.nn as nn 10 | 11 | from config import config 12 | from data_load import data_load_rwc 13 | 14 | from utils.import_models import construct_model 15 | from utils.loss import Cross_Entropy, Loss_SCST_OFA, Sent_Level_Concept_Coverage, Loss_Params_Regular, Loss_KD 16 | from utils.log import Log_Writer, train_print 17 | from utils.eval import generate_captions, eval_pycoco 18 | from utils.optimizer_tools import adjust_weight, adjust_lr, cal_fisher_coco, cal_fisher_downtask_mask, adjust_mask, model_grad_mask, RecAdam, cal_fisher_downtask, ratio_dataset 19 | from utils.vocab import Vocabulary 20 | from test_knowcap import cal_knowcap 21 | from models.OFA.ofa import OFA 22 | from models.BLIP.blip import blip_decoder 23 | from models.GIT.git import GIT 24 | 25 | 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | 28 | # 随机种子 29 | seed = config.seed 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed_all(seed) 32 | np.random.seed(seed) 33 | random.seed(seed) 34 | torch.backends.cudnn.deterministic = True 35 | 36 | # log 37 | writer = Log_Writer(config) 38 | global_step = 0 39 | loss_avg = 0 40 | loss_ce_avg = 0 41 | loss_rwc_avg = 0 42 | mt_weight = config.multitask_weight 43 | kd_weight = config.knowdistill_weight 44 | 45 | train_mix = config.train_mix # 用于训练的数据集,既可以是混合的,也可以是coco单独的 46 | if config.data_ratio != 1.0: # 可调整coco和其他的比例 47 | train_mix_data_new = ratio_dataset(train_mix, config.data_ratio) 48 | train_mix = './data/train_mix_cc12m_keyword_'+str(config.data_ratio)+'.json' 49 | json.dump(train_mix_data_new, open(train_mix, 'w')) 50 | 51 | data_mode = config.data_mode # 和train_mix配合使用,决定训练数据和模式,mix|single 52 | method = config.method # 比较的各种方法 53 | model_type = config.model 54 | # data_loader 55 | train_loader = data_load_rwc(config, train_mix, 'train') 56 | 57 | # model 58 | model = construct_model(config).to(device) 59 | if method == 'XEdistill': 60 | if model_type == 'OFA': 61 | model_t = OFA(config, distill_model=True) 62 | elif model_type == 'BLIP': 63 | argst = yaml.load(open(config.config_blip_t, 'r'), Loader=yaml.Loader) 64 | model_t = blip_decoder(pretrained=argst['pretrained'], config=config, image_size=argst['image_size'], 65 | vit=argst['vit'], 66 | vit_grad_ckpt=argst['vit_grad_ckpt'], vit_ckpt_layer=argst['vit_ckpt_layer'], 67 | prompt=argst['prompt']) 68 | elif model_type == 'GIT': 69 | model_t = GIT(config, distill_model=True) 70 | model_t = model_t.to(device) 71 | loss_distill = Loss_KD(config.KD_temperature) 72 | if method == 'Adapter': # Adapter使得只有小部分模型参数参与训练 73 | for name, p in model.named_parameters(): 74 | if p.requires_grad == True: 75 | if 'adapter_ln1' in name or 'adapter_ln2' in name: 76 | p.requires_grad = True 77 | print(name) 78 | else: 79 | p.requires_grad = False 80 | 81 | # optimizer 82 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) 83 | if method == 'RecAdam': # Recall and Learn利用优化器做正则化 84 | pretrain_params = [] 85 | for name, p in model.named_parameters(): 86 | pretrain_params.append(p) 87 | optimizer = RecAdam(model.parameters(), lr=config.learning_rate, betas=(0.9, 0.98), eps=1e-9, anneal_k=1.0, anneal_t0=100, pretrain_params=pretrain_params) 88 | 89 | # loss 90 | loss_cov = Sent_Level_Concept_Coverage() 91 | loss_fn = Cross_Entropy(label_smoothing=config.label_smoothing) 92 | 93 | if method == 'child-tuning': # Child-tuning和EWC与参数在对应任务梯度有关,因此先计算相关梯度 94 | grads_mask_coco = cal_fisher_downtask_mask(config, model) 95 | # grads_mask_coco = adjust_mask(grads_mask_coco) 96 | elif method == 'k-tuning': 97 | grads_mask_knowledge = cal_fisher_downtask_mask(config, model) # 找到和知识相关的参数 98 | elif method == 'EWC': 99 | params_fisher = cal_fisher_downtask(config, model) 100 | params_init = dict() 101 | for name, params in model.named_parameters(): 102 | if params.requires_grad == True: 103 | params_init[name] = params 104 | loss_params_regular = Loss_Params_Regular(params_init, params_fisher) 105 | 106 | if config.step != 0: 107 | log_path = config.log_dir.format(config.ckpts_id) 108 | trained_model_path = log_path + '/model/model_' + str(config.step) + '.pt' 109 | model.load_state_dict(torch.load(trained_model_path)) 110 | global_step = config.step 111 | 112 | for epoch in range(config.epochs): 113 | if global_step == 800: 114 | break 115 | model.train() 116 | totel_step = len(train_loader) 117 | epoch_time = time.time() 118 | step_time = time.time() 119 | 120 | optimizer = adjust_lr(optimizer, epoch) 121 | for step, (image_feature, cap, att_mask, cap_len, labels, data_item) in enumerate(train_loader): 122 | 123 | data_mode = config.data_mode 124 | global_step += 1 125 | optimizer.zero_grad() 126 | 127 | patch_image = image_feature['patch_image'] 128 | patch_image = patch_image.to(device) 129 | cap = cap.to(device) 130 | cap_len = cap_len.to(device) 131 | labels = labels.to(device) 132 | att_mask = att_mask.to(device) 133 | 134 | if labels.sum().item() == 0: 135 | data_mode = 'single' 136 | 137 | if data_mode == 'mix': # 找到其中rwconcept的样本,构建伪 pair进行训练 138 | index_rwc = torch.nonzero(labels==1).squeeze().long() 139 | if index_rwc.shape == torch.Size([]): 140 | index_rwc = index_rwc.unsqueeze(0) 141 | index_coco = torch.nonzero(labels==0).squeeze(dim=1).long() 142 | # 保存原caption以作为label 143 | cap_rwc_label = cap[index_rwc] 144 | # 为这些样本用当前模型生成伪caption 145 | if index_rwc.shape != torch.Size([0]): 146 | with torch.no_grad(): 147 | patch_image_rwc = patch_image[index_rwc] 148 | all_tokens, all_logprob = model.greedy_search(patch_image_rwc, 'max') 149 | cap_new = [] 150 | att_mask_new = [] 151 | cap_len_new = [] 152 | for cap_id in all_tokens: 153 | cap_len_g = cap_id.shape[0] 154 | if cap_len_g < config.fixed_len: 155 | if model_type == 'OFA': 156 | cap_id = torch.cat([cap_id, torch.ones([config.fixed_len - cap_len_g]).to(device)], dim=0) 157 | elif model_type == 'BLIP': 158 | cap_id = torch.cat([cap_id, torch.zeros([config.fixed_len - cap_len_g]).to(device)], dim=0) 159 | elif model_type == 'GIT': 160 | cap_id = torch.cat([cap_id, torch.zeros([config.fixed_len - cap_len_g]).to(device)], dim=0) 161 | att_mask_g = torch.cat([torch.ones([cap_len_g]).to(device), torch.zeros([config.fixed_len - cap_len_g]).to(device)], dim=0) 162 | else: 163 | cap_id = cap_id[:config.fixed_len] 164 | cap_len_g = config.fixed_len 165 | att_mask_g = torch.ones(cap_id.shape).to(device) 166 | cap_new.append(cap_id) 167 | att_mask_new.append(att_mask_g) 168 | cap_len_new.append(cap_len_g) 169 | cap_new = torch.stack(cap_new, dim=0).long() 170 | att_mask_new = torch.stack(att_mask_new, dim=0).long() 171 | cap_len_new = torch.Tensor(cap_len_new).int() 172 | # 将伪caption放回原数据中一起进行forward 173 | cap[index_rwc] = cap_new.to(device) 174 | att_mask[index_rwc] = att_mask_new.to(device) 175 | cap_len[index_rwc] = cap_len_new.to(device) 176 | # 知识蒸馏,用teacher进行一次前向传播获得logit 177 | if method == 'XEdistill': 178 | with torch.no_grad(): 179 | logit_t = model_t(patch_image[index_rwc], cap[index_rwc], att_mask[index_rwc], cap_len[index_rwc]) 180 | 181 | logit = model(patch_image, cap, att_mask, cap_len) 182 | if data_mode == 'single': 183 | loss = loss_fn(logit, cap, cap_len) 184 | loss_avg += loss.item() 185 | elif data_mode == 'mix': 186 | loss_ce = loss_fn(logit[index_coco], cap[index_coco], cap_len[index_coco]) 187 | loss_rwc = loss_cov(logit[index_rwc], cap_rwc_label, cap_len[index_rwc], model_type) 188 | loss = loss_ce + mt_weight * loss_rwc 189 | if method == 'XEdistill': 190 | loss_kd = loss_distill(logit[index_rwc], logit_t, cap_len[index_rwc]) 191 | loss += kd_weight*loss_kd 192 | loss_ce_avg += loss_ce.item() 193 | loss_rwc_avg += loss_rwc.item() 194 | loss_avg += loss.item() 195 | 196 | if method == 'EWC': 197 | loss = loss + loss_params_regular(model) 198 | 199 | loss.backward() 200 | nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) 201 | if method == 'child-tuning': 202 | model_grad_mask(model, grads_mask_coco) 203 | optimizer.step() 204 | 205 | if global_step % config.save_loss_freq == 0: 206 | writer.write_tensorboard('loss', loss_avg/config.save_loss_freq, global_step) 207 | loss_avg = 0 208 | if data_mode == 'mix': 209 | writer.write_tensorboard('loss_ce', loss_ce_avg/config.save_loss_freq, global_step) 210 | writer.write_tensorboard('loss_rwc', loss_rwc_avg/config.save_loss_freq, global_step) 211 | loss_ce_avg = 0 212 | loss_rwc_avg = 0 213 | 214 | train_print(loss.item(), step, totel_step, epoch, time.time() - step_time, time.time() - epoch_time) 215 | step_time = time.time() 216 | 217 | if global_step % config.save_model_freq == 0: 218 | print("Evaluating...") 219 | 220 | # 保存模型 221 | if global_step % 100 == 0: 222 | writer.save_model(model, global_step) 223 | 224 | # validation 225 | model.eval() 226 | with torch.no_grad(): 227 | gen_pycoco_path = generate_captions(config, model, global_step, 'val') 228 | pycoco_results = eval_pycoco(config, gen_pycoco_path, 'val') 229 | pycoco_results_knowcap, acc = cal_knowcap(model, global_step) 230 | writer.write_metrics(pycoco_results, global_step) 231 | writer.write_metrics(pycoco_results_knowcap, global_step) 232 | writer.write_metrics(acc, global_step) 233 | 234 | model.train() 235 | 236 | if global_step == 800: 237 | break -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/utils/.DS_Store -------------------------------------------------------------------------------- /utils/__pycache__/beamsearch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/utils/__pycache__/beamsearch.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/utils/__pycache__/eval.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/import_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/utils/__pycache__/import_models.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/vocab.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/utils/__pycache__/vocab.cpython-37.pyc -------------------------------------------------------------------------------- /utils/beamsearch.py: -------------------------------------------------------------------------------- 1 | # batch beamsearch 2 | # 参照huggingface的实现 https://zhuanlan.zhihu.com/p/167072494 http://www.wuyuanhao.com/2020/03/20/解读beam-search-1-2/ 3 | # 除了支持以batch形式一次为多个样本进行beamsearch,与传统beamsearch的最大不同在于: 4 | # 对于beam中的序列,即使生成了end标识符,beam的宽度也不会减小;而是将生成完成的序列存入BeamHypotheses,并向beam中补充一个新的未生成完成序列, 5 | # 并继续宽度为beam的搜索过程,期间不断用新生成完成的序列更新BeamHypotheses 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | 13 | class BeamHypotheses(object): 14 | # 每个样本绑定一个,其中维护num_beams个当前最优的序列;可向其中添加新序列并自动踢掉分数最低的 15 | def __init__(self, num_beams, max_length, length_penalty): 16 | # 初始化 17 | self.max_length = max_length - 1 18 | self.num_beams = num_beams 19 | self.length_penalty = length_penalty 20 | self.beams = [] 21 | self.worst_score = 1e9 22 | 23 | def __len__(self): 24 | return len(self.beams) 25 | 26 | def add(self, hyp, sum_logprobs): 27 | # 长度惩罚,可自定义 28 | score = sum_logprobs / len(hyp) ** self.length_penalty 29 | # score = sum_logprobs / (pow((5+len(hyp)+1), self.length_penalty)/pow(5+1, self.length_penalty)) 30 | if len(self) < self.num_beams or score > self.worst_score: 31 | # 可添加 32 | self.beams.append((score, hyp)) 33 | if len(self) > self.num_beams: 34 | # 需要删掉一个 35 | sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) 36 | del self.beams[sorted_scores[0][1]] 37 | self.worst_score = sorted_scores[1][0] 38 | else: 39 | self.worst_score = min(score, self.worst_score) 40 | 41 | def add_scst(self, hyp, logprob, sum_logprobs): 42 | # 长度惩罚,可自定义 43 | score = sum_logprobs / len(hyp) ** self.length_penalty 44 | # score = sum_logprobs / (pow((5+len(hyp)+1), self.length_penalty)/pow(5+1, self.length_penalty)) 45 | if len(self) < self.num_beams or score > self.worst_score: 46 | # 可添加 47 | self.beams.append((score, hyp, logprob)) 48 | if len(self) > self.num_beams: 49 | # 需要删掉一个 50 | sorted_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)]) 51 | del self.beams[sorted_scores[0][1]] 52 | self.worst_score = sorted_scores[1][0] 53 | else: 54 | self.worst_score = min(score, self.worst_score) 55 | 56 | def is_done(self, best_sum_logprobs, cur_len=None): 57 | # 样本是否已经生成完成,关键:并非生成beam个完成的序列,而是新一时刻beam宽度个结果中的最高分不如之前保存的最低分 58 | # best_sum_logprobs是新的候选序列中的最高得分 59 | if len(self) < self.num_beams: 60 | return False 61 | else: 62 | if cur_len is None: 63 | cur_len = self.max_length 64 | cur_score = best_sum_logprobs / cur_len ** self.length_penalty 65 | # cur_score = best_sum_logprobs / (pow((5+cur_len+1), self.length_penalty)/pow(5+1, self.length_penalty)) 66 | # 如果最高分比保存的最低分还差,则结束 67 | ret = self.worst_score >= cur_score 68 | return ret 69 | 70 | 71 | def beam_search(mode, context, model, batch_size, max_length, num_beams, vocab_size, length_penalty, 72 | bos_token_id=1, pad_token_id=0, eos_token_id=2, prompt=None): 73 | # batch beamsearch 74 | # 记录每个样本的已生成序列,已生成序列得分和是否已生成完成 75 | generated_hyps = [BeamHypotheses(num_beams, max_length, length_penalty) for _ in range(batch_size)] 76 | beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float).to(device) 77 | beam_scores[:, 1:] = -1e9 # 否则t=1时刻取到的num_beams个最大的将都是同一个词,从而导致后面所有num_beams个结果均相同 78 | beam_scores = beam_scores.view(-1) 79 | done = [False for _ in range(batch_size)] 80 | 81 | # 初始input和当前长度 82 | if prompt == None: 83 | input_ids = torch.full((batch_size*num_beams, 1), bos_token_id, dtype=torch.long).to(device) 84 | else: 85 | input_ids = prompt.repeat([num_beams, 1]) 86 | cur_len = 1 87 | 88 | # 初始状态 hidden: (batch_size*num_beams, *) 89 | # 对于LSTM-based模型来说,hidden是解码器的隐藏层状态,需要在每个时刻更新;而对于Transformer-based模型来说,hidden是编码端的输出,解码所有时刻保持不变 90 | # hidden = context 91 | 92 | while cur_len < max_length: 93 | # 需要模型实现一个接口:根据hidden状态,以及当前已生成的序列,生成下一时刻的词表概率分布(以及LSTM-based模型需要更新后的hidden) 94 | outputs, hidden = model.decode_step(input_ids, context) 95 | next_token_logits = outputs[:, -1, :] 96 | 97 | scores = F.log_softmax(next_token_logits, dim=-1) 98 | next_scores = scores + beam_scores[:, None].expand_as(scores) 99 | next_scores = next_scores.view(batch_size, num_beams*vocab_size) # 便于用topk为batch内的每个样本选最大 100 | 101 | # next_scores/next_tokens: (batch_size, num_beams) 102 | # 关键:这里保留了2*num_beams个结果,目的是即使有beam生成了eos,依然能找到num_beams可以继续生成的选项 103 | next_scores, next_tokens = torch.topk(next_scores, 2*num_beams, dim=1, largest=True, sorted=True) 104 | 105 | next_batch_beam = [] # 为下一时刻准备 (分数, token_id, beam_id) 106 | for batch_idx in range(batch_size): 107 | if done[batch_idx]: # 如果当前batch已经完成,直接补pad 108 | next_batch_beam.extend([(0, pad_token_id, 0)]*num_beams) 109 | continue 110 | next_sent_beam = [] # 记录一个batch内beam_num个最好的(且没有生成完成的)结果 111 | for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( 112 | zip(next_tokens[batch_idx], next_scores[batch_idx]) 113 | ): 114 | beam_id = beam_token_id // vocab_size # beam_id:属于当前batch的第几个beam 115 | token_id = beam_token_id % vocab_size 116 | effective_beam_id = batch_idx * num_beams + beam_id # 在原始(batch_size*num_beams, *)中的位置 117 | if token_id.item() == eos_token_id: 118 | # 生成eos,将当前beam的句子存入 119 | is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams 120 | if is_beam_token_worse_than_top_num_beams: 121 | continue 122 | # 存入时不包含eos 123 | generated_hyps[batch_idx].add(input_ids[effective_beam_id].clone(), beam_token_score.item()) 124 | else: 125 | # 保存生成后的状态 126 | next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) 127 | 128 | if len(next_sent_beam) == num_beams: # 当前batch不管有没有、生成了几个eos,依然会保留num_beams个可扩展的序列 129 | break 130 | 131 | # 什么情况算生成完成?已经生成了num_beams个完整句子,且当前时刻生成的结果(可能是完整句子,也可能不是)没有新的更好的 132 | done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx]\ 133 | .is_done(next_scores[batch_idx].max().item(), cur_len) 134 | 135 | next_batch_beam.extend(next_sent_beam) 136 | 137 | if all(done): 138 | break 139 | 140 | # 准备下一时刻 141 | beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) 142 | beam_tokens = input_ids.new([x[1] for x in next_batch_beam]) 143 | beam_idx = input_ids.new([x[2] for x in next_batch_beam]) 144 | 145 | input_ids = input_ids[beam_idx, :] 146 | input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) 147 | 148 | if mode == 'LSTM': # LSTM需要更新隐藏层状态 149 | hidden = [item[beam_idx, :] for item in hidden] 150 | #h, c = hidden 151 | #h = h[beam_idx, :] 152 | #c = c[beam_idx, :] 153 | #hidden = (h, c) 154 | context[-1] = hidden 155 | 156 | cur_len += 1 157 | 158 | # 手动结束没有生成eos的样本 159 | for batch_idx in range(batch_size): 160 | if done[batch_idx]: 161 | continue 162 | for beam_id in range(num_beams): 163 | # 对于需要手动结束的样本,全部尝试加入 164 | effective_beam_id = batch_idx*num_beams+beam_id 165 | final_score = beam_scores[effective_beam_id].item() 166 | final_tokens = input_ids[effective_beam_id] 167 | generated_hyps[batch_idx].add(final_tokens, final_score) 168 | 169 | # 至此,generated_hyps中保存着每个样本的num_beams个最优序列 170 | best = [] 171 | for i, hypotheses in enumerate(generated_hyps): 172 | sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0]) 173 | best_hyp = sorted_hyps.pop()[1] 174 | best.append(best_hyp) 175 | 176 | return best 177 | 178 | 179 | def beam_search_scst(mode, context, model, batch_size, max_length, num_beams, vocab_size, length_penalty, 180 | bos_token_id=1, pad_token_id=0, eos_token_id=2): 181 | # batch beamsearch 182 | # 记录每个样本的已生成序列,已生成序列得分和是否已生成完成 183 | # 在beamseach的每个时刻,保存当前最优beam个从开始到当前所有时刻的logprob 184 | generated_hyps = [BeamHypotheses(num_beams, max_length, length_penalty) for _ in range(batch_size)] 185 | beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float).to(device) 186 | beam_scores[:, 1:] = -1e9 # 否则t=1时刻取到的num_beams个最大的将都是同一个词,从而导致后面所有num_beams个结果均相同 187 | beam_scores = beam_scores.view(-1) 188 | done = [False for _ in range(batch_size)] 189 | 190 | # 初始input和当前长度 191 | input_ids = torch.full((batch_size*num_beams, 1), bos_token_id, dtype=torch.long).to(device) 192 | ids_logprob = torch.full((batch_size*num_beams, 0), 0.0).to(device) 193 | cur_len = 1 194 | 195 | # 初始状态 hidden: (batch_size*num_beams, *) 196 | # 对于LSTM-based模型来说,hidden是解码器的隐藏层状态,需要在每个时刻更新;而对于Transformer-based模型来说,hidden是编码端的输出,解码所有时刻保持不变 197 | # hidden = context 198 | 199 | while cur_len < max_length: 200 | # 需要模型实现一个接口:根据hidden状态,以及当前已生成的序列,生成下一时刻的词表概率分布(以及LSTM-based模型需要更新后的hidden) 201 | outputs, hidden = model.decode_step(input_ids, context) 202 | next_token_logits = outputs[:, -1, :] 203 | scores = F.log_softmax(next_token_logits, dim=-1) 204 | next_scores = scores + beam_scores[:, None].expand_as(scores) 205 | next_scores = next_scores.view(batch_size, num_beams*vocab_size) # 便于用topk为batch内的每个样本选最大 206 | scores = scores.view(batch_size, num_beams*vocab_size) # 便于根据取出topk的id取出对应的概率 207 | 208 | # next_scores/next_tokens: (batch_size, num_beams) 209 | # 关键:这里保留了2*num_beams个结果,目的是即使有beam生成了eos,依然能找到num_beams可以继续生成的选项 210 | next_scores, next_tokens = torch.topk(next_scores, 2*num_beams, dim=1, largest=True, sorted=True) 211 | 212 | next_batch_beam = [] # 为下一时刻准备 (分数, token_id, beam_id) 213 | for batch_idx in range(batch_size): 214 | if done[batch_idx]: # 如果当前batch已经完成,直接补pad 215 | next_batch_beam.extend([(0, pad_token_id, 0, 0)]*num_beams) 216 | continue 217 | next_sent_beam = [] # 记录一个batch内beam_num个最好的(且没有生成完成的)结果 218 | for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( 219 | zip(next_tokens[batch_idx], next_scores[batch_idx]) 220 | ): 221 | beam_id = beam_token_id // vocab_size # beam_id:属于当前batch的第几个beam 222 | token_id = beam_token_id % vocab_size 223 | logprob = scores[batch_idx][beam_token_id] 224 | effective_beam_id = batch_idx * num_beams + beam_id # 在原始(batch_size*num_beams, *)中的位置 225 | if token_id.item() == eos_token_id: 226 | # 生成eos,将当前beam的句子存入 227 | is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams 228 | if is_beam_token_worse_than_top_num_beams: 229 | continue 230 | # 存入时不包含eos 231 | logprob_add = torch.cat([ids_logprob[effective_beam_id].clone(), logprob.unsqueeze(0)], dim=0) 232 | generated_hyps[batch_idx].add_scst(input_ids[effective_beam_id].clone(), logprob_add, beam_token_score.item()) 233 | 234 | else: 235 | # 保存生成后的状态 236 | next_sent_beam.append((beam_token_score, token_id, effective_beam_id, logprob)) 237 | 238 | if len(next_sent_beam) == num_beams: # 当前batch不管有没有、生成了几个eos,依然会保留num_beams个可扩展的序列 239 | break 240 | 241 | # 什么情况算生成完成?已经生成了num_beams个完整句子,且当前时刻生成的结果(可能是完整句子,也可能不是)没有新的更好的 242 | done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx]\ 243 | .is_done(next_scores[batch_idx].max().item(), cur_len) 244 | 245 | next_batch_beam.extend(next_sent_beam) 246 | 247 | if all(done): 248 | break 249 | 250 | # 准备下一时刻 251 | beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) 252 | beam_tokens = input_ids.new([x[1] for x in next_batch_beam]) 253 | beam_idx = input_ids.new([x[2] for x in next_batch_beam]) 254 | beam_logprob = ids_logprob.new([x[3] for x in next_batch_beam]) 255 | 256 | input_ids = input_ids[beam_idx, :] 257 | ids_logprob = ids_logprob[beam_idx, :] 258 | input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) 259 | ids_logprob = torch.cat([ids_logprob, beam_logprob.unsqueeze(1)], dim=-1) 260 | 261 | if mode == 'LSTM': # LSTM需要更新隐藏层状态 262 | hidden = [item[beam_idx, :] for item in hidden] 263 | #h, c = hidden 264 | #h = h[beam_idx, :] 265 | #c = c[beam_idx, :] 266 | #hidden = (h, c) 267 | context[-1] = hidden 268 | 269 | cur_len += 1 270 | 271 | # 手动结束没有生成eos的样本 272 | for batch_idx in range(batch_size): 273 | if done[batch_idx]: 274 | continue 275 | for beam_id in range(num_beams): 276 | # 对于需要手动结束的样本,全部尝试加入 277 | effective_beam_id = batch_idx*num_beams+beam_id 278 | final_score = beam_scores[effective_beam_id].item() 279 | final_tokens = input_ids[effective_beam_id] 280 | final_logprob = ids_logprob[effective_beam_id] 281 | generated_hyps[batch_idx].add_scst(final_tokens, final_logprob, final_score) 282 | 283 | # 至此,generated_hyps中保存着每个样本的num_beams个最优序列 284 | best = [] 285 | all_tokens = [] 286 | all_logprob = [] 287 | for i, hypotheses in enumerate(generated_hyps): 288 | sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0]) 289 | best_hyp = sorted_hyps[-1][1] 290 | best.append(best_hyp) 291 | all_tokens.extend([item[1] for item in sorted_hyps]) 292 | all_logprob.extend([item[2] for item in sorted_hyps]) 293 | 294 | return best, all_tokens, all_logprob 295 | 296 | -------------------------------------------------------------------------------- /utils/cc12m.py: -------------------------------------------------------------------------------- 1 | # Automatically filter some data by keywords from cc12m 2 | 3 | import csv 4 | from tqdm import tqdm 5 | import json 6 | import os 7 | import xlsxwriter 8 | from PIL import Image 9 | import random 10 | import re 11 | from torchvision import transforms 12 | import requests 13 | 14 | landmarks_replay = ['white house', 'grand canyon', 'statue of liberty', 'buckingham palace', 'forbidden city', 'colosseum', 'kremlin', 'alhambra', 'brooklyn bridge', 'red square', 'london eye', 'burj khalifa', 'parthenon', 'great wall of china', 'windsor castle', 'machu picchu', 'mount everest', 'westminster abbey', 'mount fuji', 'cn tower', 'sydney harbour bridge', 'stonehenge', 'palace of versailles', 'trevi fountain', 'pyramids of giza', 'edinburgh castle', 'palace of westminster', 'uluru', 'neuschwanstein castle', 'brandenburg gate', 'berlin wall', 'chichen itza', 'wailing wall', 'hoover dam', 'tokyo tower', 'vatican museums', 'mount kilimanjaro', 'mount rushmore', 'acropolis of athens', 'meiji shrine', 'mont saint michel', 'willis tower', 'captiol hill', 'victoria harbour', 'sensoji temple'] 15 | brands_replay = ['iphone', 'apple', 'shell', 'nike', 'samsung', 'chevrolet', 'porsche', 'dodge', 'chanel', 'facebook', 'microsoft', 'mercedes-benz', 'disneyland', 'burberry', 'cadillac', 'rolex', 'yamaha', 'fifa world cup', 'louis vuitton', 'coca cola', 'huawei', 'nokia', 'kawasaki', 'dell', 'rolls-royce', 'burger king', 'intel', 'philips', 'logitech', 'kfc', 'panasonic', 'bose', 'american express', "domino's", 'oppo', 'china southern airlines'] 16 | foods_replay = ['sushi', 'ramen', 'white wine', 'pho', 'kebab', 'kimchi', 'smoked salmon', 'pad thai', 'fish and chips', 'croissants', 'tempura', 'hot pot', 'tiramisu', 'fajitas', 'churros', 'escargot', 'kung pao chicken', 'peking duck'] 17 | charas_replay = ['batman', 'barbie', 'santa claus', 'iron man', 'cinderella', 'super mario', 'mickey mouse', 'the grinch', 'charlie brown', 'woody', 'rapunzel', 'the tramp', 'shrek', 'olaf', 'monkey king', 'mulan', 'merida', 'minnie mouse', 'bugs bunny', 'gandalf', 'big bird', 'buzz lightyear', 'winnie-the-pooh'] 18 | keywords = landmarks_replay+brands_replay+foods_replay+charas_replay 19 | print(keywords) 20 | print(len(keywords)) 21 | input() 22 | 23 | """ 24 | # cc12m 25 | cc12m_data = [] 26 | cc12m_path = '/Users/cckevin/Downloads/cc12m.tsv' 27 | with open(cc12m_path, 'r') as f: 28 | text = f.read() 29 | lines = text.split('\n') 30 | for line in lines: 31 | cc12m_data.append(line.split('\t')) 32 | print("Num: "+str(len(cc12m_data))) 33 | 34 | # random.shuffle(cc12m_data) 35 | # cc12m_data_tiny = cc12m_data[:50000] 36 | """ 37 | 38 | """ 39 | # filter in cc12m 40 | keywords = [item.lower() for item in keywords] 41 | keywords_num = {keyword: 0 for keyword in keywords} 42 | 43 | cc12m_select = [] 44 | for item in tqdm(cc12m_data): 45 | try: 46 | img_dir = item[0] 47 | caption = item[1] 48 | caption = caption.lower() 49 | for keyword in keywords: 50 | if re.search(keyword, caption) != None: 51 | if keywords_num[keyword] < 1000: 52 | keywords_num[keyword] += 1 53 | cc12m_select.append([img_dir, caption, keyword]) 54 | break 55 | except: 56 | continue 57 | 58 | print("Num of select: "+str(len(cc12m_select))) 59 | print(keywords_num) 60 | cc12m_data_path = '/Users/cckevin/Downloads/cc12m_select.json' 61 | with open(cc12m_data_path, 'w') as f: 62 | json.dump(cc12m_select, f) 63 | """ 64 | 65 | 66 | # download images 67 | cc12m_select = json.load(open('/home/data_ti4_c/chengkz/scripts/cc12m_select.json', 'r')) 68 | print(len(cc12m_select)) 69 | download_img_dir = '/home/chengkz/checkpoints/ofa/cc12m_select' 70 | cc12m_select = cc12m_select[:] 71 | 72 | for i, item in tqdm(enumerate(cc12m_select)): 73 | url = item[0] 74 | filename = str(i)+'.jpg' 75 | download_img_path = os.path.join(download_img_dir, filename) 76 | if os.path.exists(download_img_path) == False: 77 | try: 78 | download_file = requests.get(url, timeout=5) 79 | open(download_img_path, 'wb').write(download_file.content) 80 | except: 81 | continue 82 | 83 | 84 | 85 | # Filter out the images that can be used as replay data 86 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 87 | resolution = 480 88 | patch_resize_transform = transforms.Compose([ 89 | lambda image: image.convert("RGB"), 90 | transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC), 91 | transforms.ToTensor(), 92 | transforms.Normalize(mean=mean, std=std) 93 | ]) 94 | 95 | data_cc12m = [] 96 | rwconcept_num = {keyword.lower(): 0 for keyword in keywords} 97 | for i, item in tqdm(enumerate(cc12m_select)): 98 | filename = str(i)+'.jpg' 99 | img_path = os.path.join(download_img_dir, filename) 100 | if os.path.exists(img_path) == False: 101 | continue 102 | try: 103 | img = Image.open(img_path) 104 | patch_img = patch_resize_transform(img) 105 | except: 106 | continue 107 | else: 108 | caption = item[1] 109 | keyword = item[2] 110 | rwconcept_num[keyword] += 1 111 | caption = caption.lower() 112 | data_cc12m.append({"filename": img_path, "caption": caption, "keyword": keyword, 'data': 'cc12m'}) 113 | 114 | print(rwconcept_num) 115 | print("Num of select success: "+str(len(data_cc12m))) 116 | json.dump(data_cc12m, open('/home/chengkz/checkpoints/ofa/data_cc12m_SelectForReplay.json', 'w'), ensure_ascii=False) 117 | -------------------------------------------------------------------------------- /utils/convert_ofa.py: -------------------------------------------------------------------------------- 1 | # convert the official fairseq version ckpts to the transformers version ckpts 2 | # notice that our K-Replay train the OFA begin with a ckpts with fine-tuned encoder+pre-trained decoder 3 | # eg: 4 | # 1.download official transformers version ckpts in https://huggingface.co/OFA-Sys/ofa-large 5 | # 2.download official fairseq version ckpts in https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_large.pt 6 | # 3.using the following code to obtain the correct transformers version ckpts 7 | import torch 8 | """ 9 | model_t = torch.load('/home/chengkz/checkpoints/ofa/OFA-large/pytorch_model.bin') 10 | model_f = torch.load('/home/chengkz/checkpoints/ofa/OFA-large-fairseq/ofa_large.pt')['model'] 11 | 12 | key_t = set([k for k in model_t.keys()]) 13 | key_f = set([k for k in model_f.keys()]) 14 | print(len(key_t), len(key_f)) 15 | common_key = key_t.intersection(key_f) 16 | print(len(common_key)) 17 | 18 | for k in model_t.keys(): 19 | # if 'encoder' in k: 20 | if k in common_key: 21 | model_t[k] = model_f[k] 22 | del model_f[k] 23 | key_t.remove(k) 24 | key_f.remove(k) 25 | print(len(key_t), len(key_f)) 26 | 27 | for k in model_f.keys(): 28 | #if 'encoder' in k: 29 | k_pred = k.replace('ffn_layernorm', 'ffn_layer_norm') 30 | k_pred = k_pred.replace('self_attn_ln', 'self_attn_mid_layer_norm') 31 | k_pred = k_pred.replace('cross_attn_ln', 'cross_attn_mid_layer_norm') 32 | k_pred = k_pred.replace('encoder_attn', 'cross_attn') 33 | k_pred = k_pred.replace('attn_ln', 'self_attn_mid_layer_norm') 34 | if k_pred in key_t: 35 | model_t[k_pred] = model_f[k] 36 | key_t.remove(k_pred) 37 | key_f.remove(k) 38 | print(len(key_t), len(key_f)) 39 | print(key_f) 40 | 41 | torch.save(model_t, '/home/chengkz/checkpoints/ofa/OFA-large-caption-trainedenc/pytorch_model.bin') 42 | """ 43 | 44 | """ 45 | code for BLIP 46 | model_pretrain = torch.load('/home/chengkz/.cache/torch/hub/checkpoints/model_large.pth') 47 | model_ft = torch.load('/home/chengkz/.cache/torch/hub/checkpoints/model_large_caption.pth')['model'] 48 | key_ft = set([k for k in model_ft.keys()]) 49 | key_ft_vision = {item for item in key_ft if 'visual_encoder' in item} 50 | for k in key_ft_vision: 51 | model_pretrain['model'][k] = model_ft[k] 52 | 53 | torch.save(model_pretrain, '/home/chengkz/.cache/torch/hub/checkpoints/model_large_trainedenc.pth') 54 | """ -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | # 测试模型 2 | # 为验证、测试集生成句子并保存为可用pycoco直接计算指标的格式 3 | # 用保存的句子计算指标 4 | 5 | import os 6 | import torch 7 | import pickle 8 | import json 9 | import numpy as np 10 | 11 | from data_load import data_load 12 | from tqdm import tqdm 13 | from pycocoevalcap.eval import COCOEvalCap 14 | from evaluation import Cider 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | 19 | def generate_captions(config, model, step, mode, final_test=False): 20 | print("Generating captions...") 21 | 22 | log_path = config.log_dir.format(config.id) 23 | result_dir = os.path.join(log_path, 'generated') 24 | if not os.path.exists(result_dir): 25 | os.makedirs(result_dir) 26 | gen_pycoco_path = os.path.join(result_dir, mode+'_'+str(step)+'.json') 27 | 28 | data_dir = os.path.join(config.data_dir, mode+'.json') 29 | 30 | eval_loader = data_load(config, data_dir, mode) 31 | model.eval() 32 | gen_pycoco = {} 33 | 34 | for i, (image_id, image_feature) in tqdm(enumerate(eval_loader)): 35 | patch_image = image_feature['patch_image'] 36 | patch_image = patch_image.to(device) 37 | batch_size = len(image_id) 38 | if not final_test: 39 | captions, _ = model.greedy_search(patch_image) 40 | else: 41 | captions = model.generate_caption_batchbs(patch_image) 42 | for j, cap_id in enumerate(captions): 43 | if config.model == 'OFA': 44 | gen = cap_id.unsqueeze(0) 45 | caption = model.tokenizer.batch_decode(gen, skip_special_tokens=True)[0].strip() 46 | elif config.model == 'BLIP': 47 | caption = model.tokenizer.decode(cap_id, skip_special_tokens=True) 48 | caption = caption[len(model.prompt):] 49 | elif config.model == 'GIT': 50 | gen = cap_id.unsqueeze(0) 51 | caption = model.tokenizer.batch_decode(gen, skip_special_tokens=True)[0].strip() 52 | refs = [] 53 | ref = {'image_id': image_id[j], 'id': i * batch_size + j, 'caption': caption} 54 | refs.append(ref) 55 | gen_pycoco[i * batch_size + j] = refs 56 | if not final_test: 57 | if len(gen_pycoco) >= 200: 58 | break 59 | 60 | json.dump(gen_pycoco, open(gen_pycoco_path, 'w'), ensure_ascii=False) 61 | 62 | return gen_pycoco_path 63 | 64 | 65 | def eval_pycoco(config, gen_pycoco_path, mode): 66 | print("Calculating pycoco...") 67 | ref_pycoco_path = os.path.join(config.data_dir, mode+'_pycoco.json') 68 | ref_pycoco = json.load(open(ref_pycoco_path, 'r')) 69 | gen_pycoco = json.load(open(gen_pycoco_path, 'r')) 70 | num = len(gen_pycoco) 71 | ref_pycoco = {int(k): v for k, v in ref_pycoco.items() if int(k) < num} # json读取时key类型为str,在计算SPICE时会出现问题 72 | gen_pycoco = {int(k): v for k, v in gen_pycoco.items() if int(k) < num} 73 | """ 74 | ref_cider = {int(k): [item["caption"] for item in v] for k, v in ref_pycoco.items()} 75 | gen_cider = {int(k): [v[0]["caption"]] for k, v in gen_pycoco.items()} 76 | reward = cider_train.compute_score(ref_cider, gen_cider)[1].astype(np.float32) 77 | reward = torch.from_numpy(reward).to(device).view(-1) 78 | print("CIDEr: "+str(reward.mean())) 79 | """ 80 | cocoEval = COCOEvalCap('diy', 'diy') 81 | pycoco_results = cocoEval.evaluate_diy(ref_pycoco, gen_pycoco) 82 | 83 | return pycoco_results 84 | 85 | -------------------------------------------------------------------------------- /utils/import_models.py: -------------------------------------------------------------------------------- 1 | # 根据命令行构建模型 2 | 3 | import os 4 | import yaml 5 | from pathlib import Path 6 | from models.Transformer.transformer import Transformer_Cap 7 | from models.OFA.ofa import OFA 8 | from models.BLIP.blip import blip_decoder 9 | from models.GIT.git import GIT 10 | 11 | 12 | def construct_model(config): 13 | if config.model == 'Transformer': 14 | model = Transformer_Cap(config) 15 | elif config.model == 'OFA': 16 | model = OFA(config) 17 | elif config.model == 'BLIP': 18 | args = yaml.load(open(config.config_blip, 'r'), Loader=yaml.Loader) 19 | model = blip_decoder(pretrained='/home/chengkz/.cache/torch/hub/checkpoints/model_large_trainedenc.pth', config=config, image_size=args['image_size'], 20 | vit=args['vit'], 21 | vit_grad_ckpt=args['vit_grad_ckpt'], vit_ckpt_layer=args['vit_ckpt_layer'], 22 | prompt=args['prompt']) 23 | elif config.model == 'GIT': 24 | model = GIT(config) 25 | else: 26 | print("model "+str(config.model)+" not found") 27 | return None 28 | return model -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | # 训练日志 2 | # 写tensorboard 3 | # 保存模型 4 | 5 | import time 6 | import os 7 | import json 8 | import sys 9 | import shutil 10 | import torch 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | 14 | def train_print(loss, step, total_step, epoch, step_time, epoch_time): 15 | epoch_time = time.localtime(epoch_time) 16 | min = epoch_time.tm_min 17 | sec = epoch_time.tm_sec 18 | print(f"\rloss:{format(loss, '.8f')} |" 19 | f"step: {step}/{total_step} |" 20 | f"epoch: {epoch} |" 21 | f"step time:{format(step_time, '.2f')}secs |", 22 | f"epoch time: {min}min {sec}sec", end='') 23 | 24 | 25 | class Log_Writer(): 26 | 27 | def __init__(self, config): 28 | super(Log_Writer, self).__init__() 29 | 30 | print("Creating Log dir...") 31 | self.log_path = config.log_dir.format(config.id) 32 | if not os.path.exists(self.log_path): # 创建log路径 33 | os.makedirs(self.log_path) 34 | 35 | para_path = os.path.join(self.log_path, 'para.json') # 保存命令行参数 36 | with open(para_path, 'w') as f: 37 | json.dump(sys.argv, f) 38 | shutil.copy('./config.py', self.log_path) # 保存config参数 39 | 40 | self.writer = SummaryWriter(self.log_path) # tensorboard writer 41 | 42 | def write_tensorboard(self, scalar_name, scalar, step): 43 | self.writer.add_scalar(scalar_name, scalar, step) 44 | 45 | def write_metrics(self, pycoco_results, step): 46 | # metrics_list = ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4", "METEOR", "ROUGE_L", "CIDEr"] 47 | for metric in pycoco_results: 48 | self.write_tensorboard(metric, pycoco_results[metric], step) 49 | 50 | def save_model(self, model, global_step): 51 | model_path = os.path.join(self.log_path, 'model') 52 | if not os.path.exists(model_path): 53 | os.makedirs(model_path) 54 | save_path = os.path.join(model_path, f'model_{global_step}.pt') 55 | torch.save(model.state_dict(), save_path) 56 | 57 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pack_padded_sequence 4 | from evaluation import Cider 5 | import numpy as np 6 | import pickle 7 | import json 8 | from transformers.models.ofa.tokenization_ofa import OFATokenizer 9 | import torch.nn.functional as F 10 | 11 | from config import config 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | # 用于XEdistill 16 | class Loss_KD(nn.Module): 17 | 18 | def __init__(self, KD_T=8): 19 | super(Loss_KD, self).__init__() 20 | self.softmax = nn.Softmax(dim=-1) 21 | self.temperature = KD_T 22 | 23 | def forward(self, logit, logit_teacher, cap_len): 24 | prob = self.softmax(logit / self.temperature) 25 | prob_teacher = self.softmax(logit_teacher / self.temperature) 26 | 27 | pred = pack_padded_sequence(prob, cap_len.cpu(), batch_first=True, enforce_sorted=False)[0] 28 | target = pack_padded_sequence(prob_teacher, cap_len.cpu(), batch_first=True, enforce_sorted=False)[0] 29 | 30 | loss_kl = F.kl_div(pred.log(), target, reduction='sum') / logit.shape[0] 31 | return loss_kl 32 | 33 | 34 | # Label Smoothing 35 | class LabelSmoothingCrossEntropy(nn.Module): 36 | def __init__(self, epsilon: float = 0.1, reduction='mean'): 37 | super().__init__() 38 | self.epsilon = epsilon 39 | self.reduction = reduction 40 | 41 | def linear_combination(self, x, y, epsilon): 42 | return epsilon * x + (1 - epsilon) * y 43 | 44 | def reduce_loss(self, loss, reduction='mean'): 45 | return loss.mean() if reduction == 'mean' else loss.sum() if reduction == 'sum' else loss 46 | 47 | def forward(self, preds, target): 48 | n = preds.size()[-1] 49 | log_preds = F.log_softmax(preds, dim=-1) 50 | loss = self.reduce_loss(-log_preds.sum(dim=-1), self.reduction) 51 | nll = F.nll_loss(log_preds, target, reduction=self.reduction) 52 | return self.linear_combination(loss / n, nll, self.epsilon) 53 | 54 | 55 | class Cross_Entropy(nn.Module): 56 | # 序列形式的交叉熵 57 | def __init__(self, label_smoothing=0.0): 58 | super(Cross_Entropy, self).__init__() 59 | self.label_smoothing = label_smoothing 60 | self.ce = nn.CrossEntropyLoss().to(device) 61 | self.ce_ls = LabelSmoothingCrossEntropy(epsilon=label_smoothing).to(device) 62 | 63 | def forward(self, logit, cap, cap_len): 64 | target = cap[:, 1:] 65 | cap_len = cap_len - 1 66 | 67 | target = pack_padded_sequence(target, cap_len.cpu(), batch_first=True, enforce_sorted=False)[0] 68 | logit = pack_padded_sequence(logit, cap_len.cpu(), batch_first=True, enforce_sorted=False)[0] 69 | 70 | # cross_entropy 71 | if self.label_smoothing > 0: 72 | loss_ce = self.ce_ls(logit, target) 73 | else: 74 | loss_ce = self.ce(logit, target) 75 | 76 | return loss_ce 77 | 78 | 79 | # 只计算知识关键词的交叉熵,用于寻找和知识相关的参数 80 | class Cross_Entropy_Keyword(nn.Module): 81 | # 序列形式的交叉熵 82 | def __init__(self): 83 | super(Cross_Entropy_Keyword, self).__init__() 84 | self.ce = nn.CrossEntropyLoss().to(device) 85 | 86 | def forward(self, logit, cap, cap_len, if_keyword): 87 | target = cap[:, 1:] 88 | if_keyword = if_keyword[:, 1:] > 0 89 | logit = logit[:, :-1] 90 | cap_len = cap_len - 1 91 | 92 | target = target[if_keyword] 93 | logit = logit[if_keyword] 94 | 95 | # cross_entropy 96 | loss_ce = self.ce(logit, target) 97 | 98 | return loss_ce 99 | 100 | 101 | # K-Replay的核心损失函数,预测知识关键词 102 | class Sent_Level_Concept_Coverage(nn.Module): 103 | def __init__(self): 104 | super(Sent_Level_Concept_Coverage, self).__init__() 105 | self.softmax = nn.Softmax(dim=2) 106 | self.sigmoid = nn.Sigmoid() 107 | 108 | def forward(self, logit_rwc, cap_rwc_label, cap_len_rwc, model_type): 109 | softmax_rwc = self.softmax(logit_rwc) 110 | loss_cov = torch.zeros(cap_len_rwc.shape[0]).to(device) 111 | loss_rep = torch.zeros(cap_len_rwc.shape[0]).to(device) 112 | for i in range(cap_len_rwc.shape[0]): 113 | softmax_sen = softmax_rwc[i][:cap_len_rwc[i].item()] 114 | softmax_agg = softmax_sen.sum(dim=0) 115 | sigmoid_agg = self.sigmoid(softmax_agg) 116 | if model_type == 'OFA': 117 | label = cap_rwc_label[i][cap_rwc_label[i]>2] 118 | elif model_type == 'BLIP': 119 | label = cap_rwc_label[i][(cap_rwc_label[i]!=0) & (cap_rwc_label[i]!=102) & (cap_rwc_label[i]!=30522) 120 | & (cap_rwc_label[i]!=1037) & (cap_rwc_label[i]!=3861) & (cap_rwc_label[i]!=1997)] 121 | elif model_type == 'GIT': 122 | label = cap_rwc_label[i][(cap_rwc_label[i]!=0) & (cap_rwc_label[i]!=101) & (cap_rwc_label[i]!=102)] 123 | prob = sigmoid_agg[label] 124 | log_prob = -torch.log(prob).mean() 125 | loss_cov[i] = log_prob 126 | prob_softmax = softmax_agg[label] 127 | prob_pow = torch.pow(1-prob_softmax, 2).mean() 128 | loss_rep[i] = prob_pow 129 | loss_cov = loss_cov.mean() 130 | loss_rep = loss_rep.mean() 131 | loss_rwc = loss_cov+loss_rep 132 | return loss_rwc 133 | 134 | 135 | class Loss_Params_Regular(nn.Module): 136 | def __init__(self, params_init, params_fisher): 137 | super(Loss_Params_Regular, self).__init__() 138 | self.params_init = params_init 139 | self.params_fisher = params_fisher 140 | self.gamma = 50000 141 | 142 | def forward(self, model): 143 | loss = 0 144 | for name, params in model.named_parameters(): 145 | if params.requires_grad == True: 146 | loss_p = 0.5 * self.gamma * self.params_fisher[name] * torch.pow(params-self.params_init[name], 2) 147 | loss += loss_p.sum() 148 | return loss 149 | 150 | 151 | class Loss_SCST(nn.Module): 152 | 153 | def __init__(self, config): 154 | super(Loss_SCST, self).__init__() 155 | self.config = config 156 | self.batch_size = config.batch_size 157 | self.beam_num = config.beam_num 158 | self.vocab = pickle.load(open(config.vocab, 'rb')) 159 | self.train = json.load(open(config.train, 'r')) 160 | self.cider_texts = {i: [' '.join(item['caption'])] for i, item in enumerate(self.train)} 161 | self.cider_train = Cider(self.cider_texts) 162 | 163 | def vanilla_scst(self, all_tokens, all_tokens_greedy, all_logprob, refs): 164 | # vanilla scst: 多项式采样beam_num个,greedy作为baseline 165 | # 首先将greedy和ref复制beam_num倍 166 | gen_num = len(all_tokens) 167 | all_tokens_greedy_beam = [] 168 | for item in all_tokens_greedy: 169 | all_tokens_greedy_beam.extend([item for i in range(self.beam_num)]) 170 | refs_beam = [] 171 | for item in refs: 172 | refs_beam.extend([item for i in range(self.beam_num)]) 173 | 174 | # 整理采样、greedy和ref计算指标 175 | caps_gen = {i: [self.vocab.idList_to_sent(item)] for i, item in enumerate(all_tokens)} 176 | caps_gen_greedy = {i: [self.vocab.idList_to_sent(item)] for i, item in enumerate(all_tokens_greedy_beam)} 177 | caps_gt = {i: item for i, item in enumerate(refs_beam)} 178 | reward = self.cider_train.compute_score(caps_gt, caps_gen)[1].astype(np.float32) 179 | reward = torch.from_numpy(reward).to(device).view(gen_num) 180 | reward_baseline = self.cider_train.compute_score(caps_gt, caps_gen_greedy)[1].astype(np.float32) 181 | reward_baseline = torch.from_numpy(reward_baseline).to(device).view(gen_num) 182 | 183 | # 对采样结果的log_prob补齐 184 | all_logprob_pad = [] 185 | for logprob in all_logprob: 186 | logprob = torch.cat([logprob, logprob.new([0 for i in range(self.config.fixed_len - logprob.shape[0])])], dim=0) 187 | all_logprob_pad.append(logprob.unsqueeze(0)) 188 | all_logprob_pad = torch.cat(all_logprob_pad, dim=0) 189 | 190 | # 计算损失 191 | loss = -torch.mean(all_logprob_pad, -1) * (reward - reward_baseline) 192 | loss = loss.mean() 193 | 194 | # 计算训练reward 195 | reward_train = reward.mean() 196 | 197 | return loss, reward_train 198 | 199 | 200 | class Loss_SCST_OFA(nn.Module): 201 | 202 | def __init__(self, config): 203 | super(Loss_SCST_OFA, self).__init__() 204 | self.config = config 205 | self.batch_size = config.batch_size 206 | self.beam_num = config.beam_num 207 | self.tokenizer = OFATokenizer.from_pretrained(self.config.ofa_ckpts) 208 | self.train = json.load(open(config.train, 'r')) 209 | self.cider_texts = {i: [' '.join(item['caption'])] for i, item in enumerate(self.train)} 210 | self.cider_train = Cider(self.cider_texts) 211 | 212 | def vanilla_scst(self, all_tokens, all_tokens_greedy, all_logprob, refs): 213 | # vanilla scst: 多项式采样beam_num个,greedy作为baseline 214 | # 首先将greedy和ref复制beam_num倍 215 | gen_num = len(all_tokens) 216 | all_tokens_greedy_beam = [] 217 | for item in all_tokens_greedy: 218 | all_tokens_greedy_beam.extend([item for i in range(self.beam_num)]) 219 | refs_beam = [] 220 | for item in refs: 221 | refs_beam.extend([item for i in range(self.beam_num)]) 222 | 223 | # 整理采样、greedy和ref计算指标 224 | caps_gen = {i: [self.tokenizer.batch_decode(item.unsqueeze(0), skip_special_tokens=True)[0].strip()] for i, item in enumerate(all_tokens)} 225 | caps_gen_greedy = {i: [self.tokenizer.batch_decode(item.unsqueeze(0), skip_special_tokens=True)[0].strip()] for i, item in enumerate(all_tokens_greedy_beam)} 226 | caps_gt = {i: item for i, item in enumerate(refs_beam)} 227 | reward = self.cider_train.compute_score(caps_gt, caps_gen)[1].astype(np.float32) 228 | reward = torch.from_numpy(reward).to(device).view(gen_num) 229 | reward_baseline = self.cider_train.compute_score(caps_gt, caps_gen_greedy)[1].astype(np.float32) 230 | reward_baseline = torch.from_numpy(reward_baseline).to(device).view(gen_num) 231 | 232 | # 对采样结果的log_prob补齐 233 | all_logprob_pad = [] 234 | for logprob in all_logprob: 235 | logprob = torch.cat([logprob, logprob.new([0 for i in range(self.config.fixed_len - logprob.shape[0])])], dim=0) 236 | all_logprob_pad.append(logprob.unsqueeze(0)) 237 | all_logprob_pad = torch.cat(all_logprob_pad, dim=0) 238 | 239 | # 计算损失 240 | loss = -torch.mean(all_logprob_pad, -1) * (reward - reward_baseline) 241 | loss = loss.mean() 242 | 243 | # 计算训练reward 244 | reward_train = reward.mean() 245 | 246 | return loss, reward_train 247 | 248 | 249 | -------------------------------------------------------------------------------- /utils/prepro_data.py: -------------------------------------------------------------------------------- 1 | # construct the data in ./data 2 | # Steps 1-5 are used to construct the training, validation and test sets used for K-Replay, 3 | # and steps 6-7 are used to adjust the replay dataset 4 | 5 | import os 6 | import json 7 | import random 8 | from tqdm import tqdm 9 | import nltk 10 | 11 | """ 12 | # 1. Split COCO as train, val and test (follow KarpathSplit by dataset_coco.json) 13 | dataset_coco_karpath = json.load(open('../data/dataset_coco.json', 'r'))["images"] 14 | images_dir_train2014 = '/home/data_ti4_c/chengkz/data/coco_dataset/train2014' 15 | images_dir_val2014 = '/home/data_ti4_c/chengkz/data/coco_dataset/val2014' 16 | data_train = [] 17 | data_val = [] 18 | data_test = [] 19 | 20 | for item in tqdm(dataset_coco_karpath): 21 | if item['split'] == 'train' or item['split'] == 'restval': 22 | image_id = item['filename'][:-4] 23 | filename = os.path.join(images_dir_train2014 if item['filepath'] == 'train2014' else images_dir_val2014, item['filename']) 24 | refs = [] 25 | for sentence in item['sentences']: 26 | refs.append(' '.join(sentence["tokens"])) 27 | for sentence in item['sentences']: 28 | item_train = {'split': 'train', 'image_id': image_id, 'filename': filename, 'caption': sentence["tokens"], 'refs': refs} 29 | data_train.append(item_train) 30 | else: 31 | image_id = item['filename'][:-4] 32 | filename = os.path.join(images_dir_train2014 if item['filepath'] == 'train2014' else images_dir_val2014, item['filename']) 33 | captions = [sentence["tokens"] for sentence in item["sentences"]] 34 | item_eval = {'split': 'val', 'image_id': image_id, 'filename': filename, 'caption': captions} 35 | if item['split'] == 'val': 36 | data_val.append(item_eval) 37 | elif item['split'] == 'test': 38 | data_test.append(item_eval) 39 | 40 | random.shuffle(data_train) 41 | 42 | print("Num of train: " + str(len(data_train))) 43 | print("Num of val: " + str(len(data_val))) 44 | print("Num of test: " + str(len(data_test))) 45 | json.dump(data_train, open('../data/train.json', 'w'), ensure_ascii=False) 46 | json.dump(data_val, open('../data/val.json', 'w'), ensure_ascii=False) 47 | json.dump(data_test, open('../data/test.json', 'w'), ensure_ascii=False) 48 | """ 49 | 50 | """ 51 | # 2. Split KnowCap as 1000 test (all & Unseen) and 424 val 52 | knowcap_240 = json.load(open('../data/knowcap_240.json', 'r')) 53 | print("Num of KnowCap_240: "+str(len(knowcap_240))) 54 | knowcap_240_test = knowcap_240[:1000] 55 | knowcap_240_val = knowcap_240[1000:] 56 | print("Num of KnowCap_240 val: "+str(len(knowcap_240_val))) 57 | print("Num of KnowCap_240 test: "+str(len(knowcap_240_test))) 58 | # statistics the categories contained in val and test 59 | print("Categories of KnowCap_240 val: "+str(len(set([item["image"].split('/')[0] for item in knowcap_240_val])))) 60 | print("Categories of KnowCap_240 test: "+str(len(set([item["image"].split('/')[0] for item in knowcap_240_test])))) 61 | json.dump(knowcap_240_val, open('../data/knowcap_240_val.json', 'w')) 62 | json.dump(knowcap_240_test, open('../data/knowcap_240_test.json', 'w')) 63 | 64 | categories_replay = ['white house', 'grand canyon', 'statue of liberty', 'buckingham palace', 'forbidden city', 'colosseum', 'kremlin', 'alhambra', 'brooklyn bridge', 'red square', 'london eye', 'burj khalifa', 'parthenon', 'great wall of china', 'windsor castle', 'machu picchu', 'mount everest', 'westminster abbey', 'mount fuji', 'cn tower', 'sydney harbour bridge', 'stonehenge', 'palace of versailles', 'trevi fountain', 'pyramids of giza', 'edinburgh castle', 'palace of westminster', 'uluru', 'neuschwanstein castle', 'brandenburg gate', 'berlin wall', 'chichen itza', 'wailing wall', 'hoover dam', 'tokyo tower', 'vatican museums', 'mount kilimanjaro', 'mount rushmore', 'acropolis of athens', 'meiji shrine', 'mont saint michel', 'willis tower', 'captiol hill', 'victoria harbour', 'sensoji temple', 'iphone', 'apple', 'shell', 'nike', 'samsung', 'chevrolet', 'porsche', 'dodge', 'chanel', 'facebook', 'microsoft', 'mercedes-benz', 'disneyland', 'burberry', 'cadillac', 'rolex', 'yamaha', 'fifa world cup', 'louis vuitton', 'coca cola', 'huawei', 'nokia', 'kawasaki', 'dell', 'rolls-royce', 'burger king', 'intel', 'philips', 'logitech', 'kfc', 'panasonic', 'bose', 'american express', "domino's", 'oppo', 'china southern airlines', 'sushi', 'ramen', 'white wine', 'pho', 'kebab', 'kimchi', 'smoked salmon', 'pad thai', 'fish and chips', 'croissants', 'tempura', 'hot pot', 'tiramisu', 'fajitas', 'churros', 'escargot', 'kung pao chicken', 'peking duck', 'batman', 'barbie', 'santa claus', 'iron man', 'cinderella', 'super mario', 'mickey mouse', 'the grinch', 'charlie brown', 'woody', 'rapunzel', 'the tramp', 'shrek', 'olaf', 'monkey king', 'mulan', 'merida', 'minnie mouse', 'bugs bunny', 'gandalf', 'big bird', 'buzz lightyear', 'winnie-the-pooh'] 65 | knowcap_240_test_unseen = [] 66 | for item in knowcap_240_test: 67 | keyword = item["image"].split('/')[0] 68 | if keyword not in categories_replay: 69 | knowcap_240_test_unseen.append(item) 70 | print("Num of KnowCap_240 test unseen: "+str(len(knowcap_240_test_unseen))) 71 | print("Categories of KnowCap_240 test unseen: "+str(len(set([item["image"].split('/')[0] for item in knowcap_240_test_unseen])))) 72 | json.dump(knowcap_240_test_unseen, open('../data/knowcap_240_test_unseen.json', 'w')) 73 | """ 74 | 75 | """ 76 | 3. Adjust to the format of calculating metrics with pycoco 77 | for split in ['val', 'test']: 78 | ref_pycoco_path = os.path.join('../data', split+'_pycoco.json') 79 | data = json.load(open(os.path.join('../data', split+'.json'), 'r')) 80 | 81 | ref_pycoco = {} 82 | for i, item in tqdm(enumerate(data)): 83 | refs = [] 84 | for j, sentence in enumerate(item['caption']): 85 | ref = {} 86 | ref['image_id'] = item['image_id'] 87 | ref['id'] = j 88 | ref['caption'] = ' '.join(sentence) 89 | refs.append(ref) 90 | ref_pycoco[i] = refs 91 | 92 | print("Num: "+str(len(ref_pycoco))) 93 | json.dump(ref_pycoco, open(ref_pycoco_path, 'w'), ensure_ascii=False) 94 | 95 | 96 | ref_pycoco_path = os.path.join('../data/knowcap_240_val_pycoco.json') 97 | data = json.load(open(os.path.join('../data/knowcap_240_val.json'), 'r')) 98 | 99 | ref_pycoco = {} 100 | for i, item in tqdm(enumerate(data)): 101 | refs = [] 102 | for j, sentence in enumerate(item['captions']): 103 | ref = {} 104 | ref['image_id'] = item['image'] 105 | ref['id'] = j 106 | ref['caption'] = sentence 107 | refs.append(ref) 108 | ref_pycoco[i] = refs 109 | 110 | print("Num: "+str(len(ref_pycoco))) 111 | json.dump(ref_pycoco, open(ref_pycoco_path, 'w'), ensure_ascii=False) 112 | """ 113 | 114 | """ 115 | # 4. Convert the splitting results in train.json back to full sentences for use with our own tokenizer 116 | coco_train_all = json.load(open('../data/train.json', 'r')) 117 | print(len(coco_train_all)) 118 | random.shuffle(coco_train_all) 119 | coco_train_used = coco_train_all[:] 120 | print("coco: "+str(len(coco_train_used))) 121 | data_mix = [] 122 | for item in coco_train_used: 123 | item_coco = {'filename': item['filename'], 'caption': ' '.join(item['caption']), 'data': 'coco'} 124 | data_mix.append(item_coco) 125 | json.dump(data_mix, open('../data/train_all.json', 'w'), ensure_ascii=False) 126 | print("Num of coco used: "+str(len(data_mix))) 127 | """ 128 | 129 | """ 130 | # 5. Mix coco data and replay data as the hybrid dataset used for K-Replay training 131 | # data_cc12m_SelectForReplay.json contain 20000+ replay exemplars that randomly selected from the cc12m dataset based 132 | # on keyword matching, it contains 122 keywords as record in replay_keywords 133 | replay_keywords = ['white house', 'grand canyon', 'statue of liberty', 'buckingham palace', 'forbidden city', 'colosseum', 'kremlin', 'alhambra', 'brooklyn bridge', 'red square', 'london eye', 'burj khalifa', 'parthenon', 'great wall of china', 'windsor castle', 'machu picchu', 'mount everest', 'westminster abbey', 'mount fuji', 'cn tower', 'sydney harbour bridge', 'stonehenge', 'palace of versailles', 'trevi fountain', 'pyramids of giza', 'edinburgh castle', 'palace of westminster', 'uluru', 'neuschwanstein castle', 'brandenburg gate', 'berlin wall', 'chichen itza', 'wailing wall', 'hoover dam', 'tokyo tower', 'vatican museums', 'mount kilimanjaro', 'mount rushmore', 'acropolis of athens', 'meiji shrine', 'mont saint michel', 'willis tower', 'captiol hill', 'victoria harbour', 'sensoji temple', 'iphone', 'apple', 'shell', 'nike', 'samsung', 'chevrolet', 'porsche', 'dodge', 'chanel', 'facebook', 'microsoft', 'mercedes-benz', 'disneyland', 'burberry', 'cadillac', 'rolex', 'yamaha', 'fifa world cup', 'louis vuitton', 'coca cola', 'huawei', 'nokia', 'kawasaki', 'dell', 'rolls-royce', 'burger king', 'intel', 'philips', 'logitech', 'kfc', 'panasonic', 'bose', 'american express', "domino's", 'oppo', 'china southern airlines', 'sushi', 'ramen', 'white wine', 'pho', 'kebab', 'kimchi', 'smoked salmon', 'pad thai', 'fish and chips', 'croissants', 'tempura', 'hot pot', 'tiramisu', 'fajitas', 'churros', 'escargot', 'kung pao chicken', 'peking duck', 'batman', 'barbie', 'santa claus', 'iron man', 'cinderella', 'super mario', 'mickey mouse', 'the grinch', 'charlie brown', 'woody', 'rapunzel', 'the tramp', 'shrek', 'olaf', 'monkey king', 'mulan', 'merida', 'minnie mouse', 'bugs bunny', 'gandalf', 'big bird', 'buzz lightyear', 'winnie-the-pooh'] 134 | cc12m_select = json.load(open('../data/data_cc12m_SelectForReplay.json', 'r')) 135 | for item in cc12m_select: 136 | if item['keyword'] not in replay_keywords: 137 | print("replay item not in replay keywords!") 138 | train_all = json.load(open('../data/train_all.json', 'r')) 139 | random.shuffle(cc12m_select) 140 | cc12m_select = cc12m_select[:5000] 141 | random.shuffle(train_all) 142 | print(len(cc12m_select)) 143 | print(len(train_all)) 144 | data_mix = [] 145 | data_mix += train_all[:27000] # mix the coco and replay data 146 | ablation = False 147 | for item in cc12m_select[:]: 148 | item_cc12m = {'filename': item['filename'], 'caption': item['keyword'], 'data': 'coco'} 149 | if ablation: # for ablation study, we use the origin web-harvested text as reference 150 | item_cc12m = {'filename': item['filename'], 'caption': item['caption'], 'data': 'coco'} 151 | data_mix.append(item_cc12m) 152 | random.shuffle(data_mix) 153 | json.dump(data_mix, open('../data/train_mix_32000.json', 'w'), ensure_ascii=False) 154 | print("Num of data_mix: "+str(len(data_mix))) 155 | """ 156 | 157 | """ 158 | # 6. Adjust the number of replay exemplars in train_mix_32000.json 159 | ratio = 0.1 160 | data = json.load(open('../data/train_mix_32000.json', 'r')) 161 | data_cc12m = [item for item in data if item['data'] == 'cc12m'] 162 | data_coco = [item for item in data if item['data'] == 'coco'] 163 | random.shuffle(data_cc12m) 164 | random.shuffle(data_coco) 165 | data_ratio = data_coco[:int(len(data_coco)*ratio)]+data_cc12m[:int(len(data_cc12m)*ratio)] 166 | print(len(data_ratio)) 167 | random.shuffle(data_ratio) 168 | json.dump(data_ratio, open('../data/train_mix_32000_0.1.json', 'w'), ensure_ascii=False) 169 | 170 | # select only 120 exemplars in train_mix_32000.json 171 | data = json.load(open('../data/train_mix_32000.json', 'r')) 172 | data_cc12m = [item for item in data if item['data'] == 'cc12m'] 173 | data_coco = [item for item in data if item['data'] == 'coco'] 174 | random.shuffle(data_cc12m) 175 | random.shuffle(data_coco) 176 | print(len(data_cc12m)) 177 | print(len(data_coco)) 178 | data_120 = [] 179 | categories = [] 180 | for item in data_cc12m: 181 | if item['caption'] not in categories: 182 | categories.append(item['caption']) 183 | data_120.append(item) 184 | else: 185 | continue 186 | data_mix = [] 187 | data_mix += data_coco[:12960] 188 | for i in range(20): 189 | data_mix += data_120 190 | random.shuffle(data_mix) 191 | print(len(data_mix)) 192 | json.dump(data_mix, open('../data/train_mix_32000_120.json', 'w'), ensure_ascii=False) 193 | """ 194 | 195 | """ 196 | # 7. Adjust the categories of replay exemplars in train_mix_32000.json 197 | data = json.load(open('../data/train_mix_32000.json', 'r')) 198 | data_cc12m = [item for item in data if item['data'] == 'cc12m'] 199 | data_coco = [item for item in data if item['data'] == 'coco'] 200 | random.shuffle(data_cc12m) 201 | random.shuffle(data_coco) 202 | print(len(data_cc12m)) 203 | print(len(data_coco)) 204 | cc12m_select = json.load(open('../data/data_cc12m_select_122all.json', 'r')) 205 | random.shuffle(cc12m_select) 206 | categories = [] 207 | for item in data_cc12m: 208 | categories.append(item['caption']) 209 | categories = list(set(categories)) 210 | print(len(categories)) 211 | random.shuffle(categories) 212 | # categories_ratio = categories[:20] 213 | # select 10 replay categories 214 | categories_ratio = ['white house', 'grand canyon', 'statue of liberty', 'iphone', 'porsche', 'facebook', 'sushi', 'smoked salmon', 'batman', 'barbie'] 215 | 216 | print(len(categories_ratio)) 217 | data_cc12m_new = [item for item in data_cc12m if item['caption'] in categories_ratio] 218 | print(len(data_cc12m_new)) 219 | for item in cc12m_select: 220 | item_cc12m = {'filename': item['filename'], 'caption': item['keyword'], 'data': 'cc12m'} 221 | if item_cc12m['caption'] in categories_ratio: 222 | data_cc12m_new.append(item_cc12m) 223 | if len(data_cc12m_new) == 5000: 224 | break 225 | print(len(data_cc12m_new)) 226 | categories_new = [] 227 | for item in data_cc12m_new: 228 | categories_new.append(item['caption']) 229 | print(len(list(set(categories_new)))) 230 | data_mix = [] 231 | data_mix += data_coco 232 | data_mix += data_cc12m_new 233 | random.shuffle(data_mix) 234 | print(len(data_mix)) 235 | json.dump(data_mix, open('../data/train_mix_32000_10cate.json', 'w'), ensure_ascii=False) 236 | """ 237 | -------------------------------------------------------------------------------- /utils/prepro_ref_pycoco.py: -------------------------------------------------------------------------------- 1 | # 将val和test转化为可用pycoco直接计算指标的格式 2 | 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | """ 7 | for split in ['val', 'test']: 8 | ref_pycoco_path = os.path.join('../data', split+'_pycoco.json') 9 | data = json.load(open(os.path.join('../data', split+'.json'), 'r')) 10 | 11 | ref_pycoco = {} 12 | for i, item in tqdm(enumerate(data)): 13 | refs = [] 14 | for j, sentence in enumerate(item['caption']): 15 | ref = {} 16 | ref['image_id'] = item['image_id'] 17 | ref['id'] = j 18 | ref['caption'] = ' '.join(sentence) 19 | refs.append(ref) 20 | ref_pycoco[i] = refs 21 | 22 | print("Num: "+str(len(ref_pycoco))) 23 | json.dump(ref_pycoco, open(ref_pycoco_path, 'w'), ensure_ascii=False) 24 | """ 25 | ref_pycoco_path = os.path.join('../data/knowcap_240_val_pycoco.json') 26 | data = json.load(open(os.path.join('../data/knowcap_240_val.json'), 'r')) 27 | 28 | ref_pycoco = {} 29 | for i, item in tqdm(enumerate(data)): 30 | refs = [] 31 | for j, sentence in enumerate(item['captions']): 32 | ref = {} 33 | ref['image_id'] = item['image'] 34 | ref['id'] = j 35 | ref['caption'] = sentence 36 | refs.append(ref) 37 | ref_pycoco[i] = refs 38 | 39 | print("Num: "+str(len(ref_pycoco))) 40 | json.dump(ref_pycoco, open(ref_pycoco_path, 'w'), ensure_ascii=False) -------------------------------------------------------------------------------- /utils/prepro_rwcap.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | 4 | annot_excel = '/Users/cckevin/Desktop/RW_Label_100.xlsx' 5 | dataset_dir = '/Users/cckevin/Desktop/ofa/data/rwcap_100_keywords.json' 6 | 7 | invalid_list = ['a', 'on', 'of', 'the', 'in', 'with', 'and', 'is', 'to', 'an', 'two', 'at', 'are', 'that', 'it', 'by'] 8 | 9 | df = pd.read_excel(annot_excel) 10 | annot_list = df.to_dict(orient='record') 11 | dataset_rwcap = [] 12 | for item in annot_list: 13 | """ 14 | image_filename = item['filename'] 15 | image_filename = image_filename.strip() 16 | 17 | data_rwcap_item = {} 18 | refs = [] 19 | annot_name = ['SWP', 'CKZ', 'YHT'] 20 | for name in annot_name: 21 | ref = item[name].lower().strip() 22 | if ref[-1] == '.': 23 | ref = ref[:-1] 24 | refs.append(ref) 25 | data_rwcap_item['image'] = image_filename 26 | data_rwcap_item['captions'] = refs 27 | 28 | labels_list = [] 29 | """ 30 | keywords = item['Keywords'].strip().lower() 31 | keywords = keywords.split('#') 32 | dataset_rwcap += keywords 33 | """ 34 | for keyword in keywords: 35 | words = keyword.split(' ') 36 | for word in words: 37 | if word not in invalid_list and word not in labels_list: 38 | labels_list.append(word) 39 | data_rwcap_item['labels'] = labels_list 40 | 41 | dataset_rwcap.append(data_rwcap_item) 42 | """ 43 | 44 | dataset_rwcap = list(set(dataset_rwcap)) 45 | print("Num of dataset: "+str(len(dataset_rwcap))) 46 | json.dump(dataset_rwcap, open(dataset_dir, 'w')) -------------------------------------------------------------------------------- /utils/vocab.py: -------------------------------------------------------------------------------- 1 | # 构建单词表,用于token和id之间的相互转化 2 | # 出现次数小于5次的词用特殊符号代替 3 | 4 | import numpy as np 5 | import json 6 | import pickle 7 | from tqdm import tqdm 8 | 9 | class Vocabulary(): 10 | """单词表""" 11 | def __init__(self): 12 | self._word2id = {} 13 | self._id2word = {} 14 | self._idx = 0 15 | self._word = [] 16 | 17 | # 特殊符号 18 | self.pad = '' # 用于将长度补齐的标识符 19 | self.bos = '' # 开始符号 20 | self.eos = '' # 结束符号 21 | self.unk = '' # unknown符号 22 | self.add_spe_sign() 23 | 24 | def add_word(self, word): 25 | '''添加单词''' 26 | if word not in self._word: 27 | self._word2id.update({word: self._idx}) 28 | self._id2word.update({self._idx: word}) 29 | self._word.append(word) 30 | self._idx += 1 31 | 32 | def word_to_id(self, word): 33 | '''把word转换成id的形式''' 34 | if word in self._word: 35 | return self._word2id[word] 36 | else: 37 | return self._word2id[''] 38 | 39 | def id_to_word(self, id): 40 | '''把id的形式转换成word''' 41 | assert id <= self._idx, "输入的id大于最大的id" 42 | return self._id2word[id] 43 | 44 | def tokenList_to_idList(self, tokenList, fixed_len): 45 | '''把tokenList转换成id的形式,,同时添加上 46 | :param tokenList: 包含一个句子的token形式, 如 ["室内", "三个", "衣着", "各异", "的", "人", "坐在", "桌子", "旁", "交谈"] 47 | :param fixed_len: 句子的最大长度,包括 48 | :return: list 49 | ''' 50 | sent_len = len(tokenList) 51 | tok_id = [self.word_to_id(token) for token in tokenList] 52 | if sent_len < fixed_len: 53 | tok_id.insert(0, self._word2id[self.bos]) 54 | tok_id.append(self._word2id[self.eos]) 55 | pad_num = fixed_len - sent_len 56 | tok_id += [0] * pad_num 57 | else: 58 | tok_id = tok_id[:fixed_len] 59 | tok_id.insert(0, self._word2id[self.bos]) 60 | tok_id.append(self._word2id[self.eos]) 61 | sent_len = fixed_len 62 | sent_len += 2 # 加上开始结束符 63 | return tok_id, sent_len 64 | 65 | def idList_to_sent(self, id_List): 66 | '''把idList转换成sent的形式 67 | :param id_List: 包含一个句子的id形式,如: [1, 4, 5, 343, 4, 123, 2389 ,213, 233 ,678 ,2343 ,2, 0, 0, 0, 0, 0, 0] 68 | 支持格式,: list, tensor, numpy.array 69 | :return: 一个str句子,如: "室内三个衣着各异的人坐在桌子旁交谈" 70 | ''' 71 | id_List = np.array(list(map(int, id_List))) 72 | word_array = np.array(self._word) 73 | eos_id = self._word2id[self.eos] 74 | eos_pos = np.where(id_List == eos_id)[0] 75 | if len(eos_pos >= 0): 76 | sent = word_array[id_List[1:eos_pos[0]]] 77 | else: 78 | sent = word_array[id_List[1:]] 79 | return ' '.join(sent) 80 | 81 | def add_spe_sign(self): 82 | self.add_word(self.pad) 83 | self.add_word(self.bos) 84 | self.add_word(self.eos) 85 | self.add_word(self.unk) 86 | 87 | def get_size(self): 88 | return self._idx 89 | 90 | if __name__ == '__main__': 91 | vocab = Vocabulary() 92 | data_train = json.load(open('../data/train.json', 'r')) 93 | 94 | counter = {} 95 | for item in tqdm(data_train): 96 | sentence_token = item['caption'] 97 | for token in sentence_token: 98 | counter[token] = counter.get(token, 0) + 1 99 | # cand_word = [token for token, f in counter.items() if f >= 5] 100 | print(counter['tesla']) 101 | input() 102 | cand_word = sorted(counter.items(), key=lambda x: x[1], reverse=True) 103 | print("word (f>=5) num: "+str(len(cand_word))) 104 | 105 | for word in cand_word: 106 | vocab.add_word(word) 107 | print("vocab size: "+str(vocab.get_size())) 108 | 109 | # pickle.dump(vocab, open('../data/vocab.pkl', 'wb')) --------------------------------------------------------------------------------