├── .DS_Store
├── .idea
├── .gitignore
├── deployment.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── ofa.iml
├── sshConfigs.xml
├── vcs.xml
└── webServers.xml
├── __pycache__
├── config.cpython-37.pyc
└── data_load.cpython-37.pyc
├── config.py
├── data
└── .DS_Store
├── data_load.py
├── eval.py
├── evaluation
├── .DS_Store
├── __init__.py
├── __pycache__
│ └── __init__.cpython-37.pyc
├── bleu
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── bleu.cpython-37.pyc
│ │ └── bleu_scorer.cpython-37.pyc
│ ├── bleu.py
│ └── bleu_scorer.py
├── cider
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── cider.cpython-37.pyc
│ │ └── cider_scorer.cpython-37.pyc
│ ├── cider.py
│ └── cider_scorer.py
├── meteor
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── meteor.cpython-37.pyc
│ └── meteor.py
├── rouge
│ ├── __init__.py
│ └── rouge.py
├── stanford-corenlp-3.4.1.jar
└── tokenizer.py
├── knowcap.png
├── models
├── .DS_Store
├── BLIP
│ ├── __init__.py
│ ├── blip.py
│ ├── blip_itm.py
│ ├── blip_nlvr.py
│ ├── blip_pretrain.py
│ ├── blip_retrieval.py
│ ├── blip_vqa.py
│ ├── caption_coco.yaml
│ ├── caption_coco_teacher.yaml
│ ├── med.py
│ ├── med_config.json
│ ├── nlvr_encoder.py
│ └── vit.py
├── GIT
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── git_model.cpython-37.pyc
│ ├── git.py
│ └── git_model.py
├── OFA
│ ├── .DS_Store
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── ofa.cpython-37.pyc
│ │ └── ofa_model.cpython-37.pyc
│ ├── ofa.py
│ └── ofa_model.py
└── Transformer
│ ├── __init__.py
│ └── transformer.py
├── readme.md
├── requirements.txt
├── test.py
├── test_knowcap.py
├── train_multitask.py
└── utils
├── .DS_Store
├── __pycache__
├── beamsearch.cpython-37.pyc
├── eval.cpython-37.pyc
├── import_models.cpython-37.pyc
└── vocab.cpython-37.pyc
├── beamsearch.py
├── cc12m.py
├── convert_ofa.py
├── eval.py
├── import_models.py
├── knowcap.py
├── log.py
├── loss.py
├── optimizer_tools.py
├── prepro_data.py
├── prepro_ref_pycoco.py
├── prepro_rwcap.py
└── vocab.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/.DS_Store
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | # 基于编辑器的 HTTP 客户端请求
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/ofa.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/sshConfigs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
20 |
21 |
--------------------------------------------------------------------------------
/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/data_load.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/__pycache__/data_load.cpython-37.pyc
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser()
4 |
5 | parser.add_argument('--seed', type=int, default=222)
6 | parser.add_argument('--id', type=str, default='test')
7 | parser.add_argument('--mode', type=str, default='train')
8 | parser.add_argument('--model', type=str, default=None)
9 | parser.add_argument('--test', type=str, default='test')
10 | parser.add_argument('--epochs', type=int, default=100)
11 | parser.add_argument('--ft_epoch', type=int, default=3)
12 | parser.add_argument('--ckpts_id', type=str, default=None)
13 | parser.add_argument('--step', type=int, default=0)
14 |
15 | parser.add_argument('--local_rank', type=int, default=-1)
16 | parser.add_argument('--nproc_per_node', type=int, default=-1)
17 |
18 | parser.add_argument('--trained_ckpts', default='/home/chengkz/checkpoints/ofa/log/ofa_m1.0_t16_k1.0_222/model/model_300.pt')
19 | parser.add_argument('--ofa_ckpts', default='/home/data_ti4_c/chengkz/scripts/OFA-large')
20 | parser.add_argument('--ofa_ckpts_distill', default='/home/chengkz/checkpoints/ofa/OFA-large-caption-XEfinetuned')
21 | parser.add_argument('--git', default="microsoft/git-large")
22 | parser.add_argument('--git_distill', default="microsoft/git-large-coco")
23 | parser.add_argument('--config_blip', default='./models/BLIP/caption_coco.yaml')
24 | parser.add_argument('--config_blip_t', default='./models/BLIP/caption_coco_teacher.yaml')
25 | parser.add_argument('--data_dir', default='./data')
26 | parser.add_argument('--vocab', default='./data/vocab.pkl')
27 | parser.add_argument('--train', default='./data/train.json')
28 | parser.add_argument('--train_mix', default='./data/train_mix_cc12m_keyword_large.json')
29 | parser.add_argument('--knowcap240', default='/home/chengkz/checkpoints/KnowCap_240')
30 | parser.add_argument('--data_mode', default='mix')
31 | parser.add_argument('--samples_dir', default='./examples/example_images')
32 | parser.add_argument('--samples_out', default=None)
33 | parser.add_argument('--pretrain_model', default=None)
34 |
35 | parser.add_argument('--save_loss_freq', type=int, default=20)
36 | parser.add_argument('--save_model_freq', type=int, default=100)
37 | parser.add_argument('--log_dir', default='/home/chengkz/checkpoints/ofa/log/{}')
38 |
39 | parser.add_argument('--batch_size', type=int, default=60)
40 | parser.add_argument('--val_batch_size', type=int, default=25)
41 | parser.add_argument('--num_workers', type=int, default=1)
42 | parser.add_argument('--fixed_len', type=int, default=20)
43 | parser.add_argument('--lr_enc', type=float, default=2e-5)
44 | parser.add_argument('--learning_rate', type=float, default=7e-6)
45 | parser.add_argument('--grad_clip', type=float, default=0.1)
46 | parser.add_argument('--beam_num', type=int, default=5)
47 | parser.add_argument('--gen_num', type=int, default=5)
48 | parser.add_argument('--beam_alpha', type=float, default=1.0)
49 | parser.add_argument('--length_penalty', type=float, default=1.0)
50 | parser.add_argument('--multitask_weight', type=float, default=0.5)
51 | parser.add_argument('--knowdistill_weight', type=float, default=1.0)
52 | parser.add_argument('--data_ratio', type=float, default=1.0)
53 | parser.add_argument('--label_smoothing', type=float, default=0.0)
54 | parser.add_argument('--KD_temperature', type=float, default=8.0)
55 |
56 | parser.add_argument('--image_dim', type=int, default=2048)
57 | parser.add_argument('--embed_dim', type=int, default=512)
58 | parser.add_argument('--hidden_dim', type=int, default=512)
59 | parser.add_argument('--att_dim', type=int, default=1024)
60 |
61 | parser.add_argument('--method', type=str, default=None)
62 |
63 | config = parser.parse_args()
64 |
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/data/.DS_Store
--------------------------------------------------------------------------------
/data_load.py:
--------------------------------------------------------------------------------
1 | # 用于训练的dataloader
2 | # 不同的模型进行不同的预处理
3 |
4 | import torch
5 | import numpy as np
6 | import json
7 | import os
8 | import pickle
9 | from torch.utils.data import Dataset, DataLoader
10 |
11 | from PIL import Image
12 | from torchvision.transforms import InterpolationMode
13 | import torchvision.transforms as transforms
14 | from utils.vocab import Vocabulary
15 |
16 | from transformers.models.ofa.tokenization_ofa import OFATokenizer
17 | from transformers import AutoProcessor
18 | from transformers import BertTokenizer
19 |
20 | from models.BLIP.blip import init_tokenizer
21 |
22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23 |
24 |
25 | class IC_data(Dataset):
26 | """作为val和test时的dataset"""
27 | def __init__(self, config, dir, mode):
28 | super(IC_data, self).__init__()
29 | self.config = config
30 | self.data = json.load(open(dir, 'r'))
31 | self.model = config.model
32 | # 根据不同的model选择不同的transforms
33 | self.patch_resize_transform = self.get_transforms(self.model)
34 | if self.model == 'OFA':
35 | self.ofa_ckpt = config.ofa_ckpts
36 | self.tokenizer = OFATokenizer.from_pretrained(self.ofa_ckpt)
37 | elif self.model == 'BLIP':
38 | self.tokenizer = init_tokenizer()
39 | elif self.model == 'GIT':
40 | self.processor = AutoProcessor.from_pretrained(config.git_distill, local_files_only=True)
41 | self.tokenizer = self.processor.tokenizer
42 |
43 | self.mode = mode
44 |
45 | def get_transforms(self, model):
46 | if model == 'OFA':
47 | self.resolution = 480
48 | self.mean, self.std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
49 | patch_resize_transform = transforms.Compose([
50 | lambda image: image.convert("RGB"),
51 | transforms.Resize((self.resolution, self.resolution), interpolation=Image.BICUBIC),
52 | transforms.ToTensor(),
53 | transforms.Normalize(mean=self.mean, std=self.std)])
54 | elif model == 'BLIP':
55 | self.resolution = 384
56 | self.mean, self.std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
57 | patch_resize_transform = transforms.Compose([
58 | lambda image: image.convert("RGB"),
59 | transforms.Resize((self.resolution, self.resolution), interpolation=Image.BICUBIC),
60 | transforms.ToTensor(),
61 | transforms.Normalize(mean=self.mean, std=self.std)])
62 | elif model == 'GIT':
63 | patch_resize_transform = lambda img: self.processor(images=img, return_tensors='pt').pixel_values[0]
64 | return patch_resize_transform
65 |
66 | def __getitem__(self, item):
67 | if self.mode == 'train':
68 | """"""
69 | else:
70 | image_path = self.data[item]['filename']
71 | img = Image.open(image_path)
72 | patch_img = self.patch_resize_transform(img)
73 | image_id = self.data[item]['image_id']
74 | return image_id, patch_img
75 |
76 | def collate_fn_train(self, batch_data):
77 | """"""
78 |
79 | def collate_fn_eval(self, batch_data):
80 | image_id, image = zip(*batch_data)
81 | image = torch.stack(image, dim=0)
82 | image_feature = {'patch_image': image}
83 | return image_id, image_feature
84 |
85 | def __len__(self):
86 | return len(self.data)
87 |
88 |
89 | class RWConcept_data(Dataset):
90 |
91 | def __init__(self, config, dir, mode):
92 | super(RWConcept_data, self).__init__()
93 | self.config = config
94 | self.data = json.load(open(dir, 'r'))
95 | self.model = config.model
96 | # 根据不同的model选择不同的transforms
97 | self.patch_resize_transform = self.get_transforms(config.model)
98 | if self.model == 'OFA':
99 | self.ofa_ckpt = config.ofa_ckpts
100 | self.tokenizer = OFATokenizer.from_pretrained(self.ofa_ckpt)
101 | elif self.model == 'BLIP':
102 | self.tokenizer = init_tokenizer()
103 | elif self.model == 'GIT':
104 | self.processor = AutoProcessor.from_pretrained(config.git_distill, local_files_only=True)
105 | self.tokenizer = self.processor.tokenizer
106 | self.mode = mode
107 |
108 | def get_transforms(self, model):
109 | if model == 'OFA':
110 | self.resolution = 480
111 | self.mean, self.std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
112 | patch_resize_transform = transforms.Compose([
113 | lambda image: image.convert("RGB"),
114 | transforms.Resize((self.resolution, self.resolution), interpolation=Image.BICUBIC),
115 | transforms.ToTensor(),
116 | transforms.Normalize(mean=self.mean, std=self.std)])
117 | elif model == 'BLIP':
118 | self.resolution = 384
119 | self.mean, self.std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
120 | patch_resize_transform = transforms.Compose([
121 | lambda image: image.convert("RGB"),
122 | transforms.Resize((self.resolution, self.resolution), interpolation=Image.BICUBIC),
123 | transforms.ToTensor(),
124 | transforms.Normalize(mean=self.mean, std=self.std)])
125 | elif model == 'GIT':
126 | patch_resize_transform = lambda img: self.processor(images=img, return_tensors='pt').pixel_values[0]
127 | return patch_resize_transform
128 |
129 | def __getitem__(self, item):
130 | if self.mode == 'train':
131 | caption = self.data[item]['caption']
132 | # 不同的模型加载不同的前缀
133 | if self.model == 'OFA':
134 | caption = ' '+caption
135 | elif self.model == "BLIP":
136 | caption = ' a picture of ' + caption
137 | elif self.model == 'GIT':
138 | caption = ' '+caption
139 |
140 | # 不同的模型tokenize的方式不同
141 | if self.model == 'OFA':
142 | cap_id = self.tokenizer([caption], return_tensors="pt").input_ids[0]
143 | elif self.model == 'BLIP':
144 | text = self.tokenizer(caption, padding='longest', truncation=True, max_length=20, return_tensors="pt")
145 | cap_id = text.input_ids[0]
146 | cap_id[0] = self.tokenizer.bos_token_id
147 | elif self.model == 'GIT':
148 | cap_id = self.tokenizer([caption], return_tensors="pt").input_ids[0]
149 |
150 | cap_len = cap_id.shape[0]
151 | if cap_len < self.config.fixed_len:
152 | if self.model == 'OFA':
153 | cap_id = torch.cat([cap_id, torch.ones([self.config.fixed_len-cap_len])], dim=0)
154 | elif self.model == 'BLIP':
155 | cap_id = torch.cat([cap_id, torch.zeros([self.config.fixed_len-cap_len])], dim=0)
156 | elif self.model == 'GIT':
157 | cap_id = torch.cat([cap_id, torch.zeros([self.config.fixed_len-cap_len])], dim=0)
158 | att_mask = torch.cat([torch.ones([cap_len]), torch.zeros([self.config.fixed_len-cap_len])], dim=0)
159 | else:
160 | cap_id = cap_id[:self.config.fixed_len]
161 | cap_len = self.config.fixed_len
162 | att_mask = torch.ones(cap_id.shape)
163 |
164 | image_path = self.data[item]['filename']
165 | img = Image.open(image_path)
166 | patch_img = self.patch_resize_transform(img)
167 | label = 0 if self.data[item]['data'] == 'coco' else 1
168 | return patch_img, cap_id, att_mask, cap_len, label, self.data[item]
169 | else:
170 | image_path = self.data[item]['filename']
171 | img = Image.open(image_path)
172 | patch_img = self.patch_resize_transform(img)
173 | image_id = self.data[item]['image_id']
174 | return image_id, patch_img
175 |
176 | def collate_fn_train(self, batch_data):
177 | image, cap_id, att_mask, cap_len, label, data_item = zip(*batch_data)
178 | image = torch.stack(image, dim=0)
179 | image_feature = {'patch_image': image}
180 | cap_id = torch.stack(cap_id, dim=0)
181 | att_mask = torch.stack(att_mask, dim=0)
182 | cap_len = torch.Tensor(cap_len).int()
183 | label = torch.Tensor(label).int()
184 | return image_feature, cap_id.long(), att_mask.long(), cap_len, label, list(data_item)
185 |
186 | def collate_fn_eval(self, batch_data):
187 | image_id, image = zip(*batch_data)
188 | image = torch.stack(image, dim=0)
189 | image_feature = {'patch_image': image}
190 | return image_id, image_feature
191 |
192 | def __len__(self):
193 | return len(self.data)
194 |
195 |
196 | class RWConcept_data_EWC(Dataset):
197 |
198 | def __init__(self, config, dir, mode='train'):
199 | super(RWConcept_data_EWC, self).__init__()
200 | self.config = config
201 | self.data = json.load(open(dir, 'r'))
202 | self.mean, self.std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
203 | self.resolution = 480
204 | self.patch_resize_transform = transforms.Compose([
205 | lambda image: image.convert("RGB"),
206 | transforms.Resize((self.resolution, self.resolution), interpolation=Image.BICUBIC),
207 | transforms.ToTensor(),
208 | transforms.Normalize(mean=self.mean, std=self.std)])
209 | self.ofa_ckpt = config.ofa_ckpts
210 | self.tokenizer = OFATokenizer.from_pretrained(self.ofa_ckpt)
211 | self.mode = mode
212 |
213 | def __getitem__(self, item):
214 | if self.mode == 'train':
215 | caption = ' '+self.data[item]['caption']
216 | cap_id = self.tokenizer([caption], return_tensors="pt").input_ids[0]
217 | keyword = ' '+self.data[item]['keyword']
218 | keyword_id = self.tokenizer([keyword], return_tensors="pt").input_ids[0]
219 | keyword_id = keyword_id[keyword_id > 2]
220 | cap_len = cap_id.shape[0]
221 | if cap_len < self. config.fixed_len:
222 | cap_id = torch.cat([cap_id, torch.ones([self.config.fixed_len-cap_len])], dim=0)
223 | att_mask = torch.cat([torch.ones([cap_len]), torch.zeros([self.config.fixed_len-cap_len])], dim=0)
224 | else:
225 | cap_id = cap_id[:self.config.fixed_len]
226 | cap_len = self.config.fixed_len
227 | att_mask = torch.ones(cap_id.shape)
228 | if_keyword = torch.Tensor([True if (item in keyword_id) else False for item in cap_id])
229 | image_path = self.data[item]['filename']
230 | img = Image.open(image_path)
231 | patch_img = self.patch_resize_transform(img)
232 | label = 0 if self.data[item]['data'] == 'coco' else 1
233 | return patch_img, cap_id, att_mask, cap_len, label, if_keyword
234 |
235 | def collate_fn_train(self, batch_data):
236 | image, cap_id, att_mask, cap_len, label, if_keyword = zip(*batch_data)
237 | image = torch.stack(image, dim=0)
238 | image_feature = {'patch_image': image}
239 | cap_id = torch.stack(cap_id, dim=0)
240 | if_keyword = torch.stack(if_keyword, dim=0)
241 | att_mask = torch.stack(att_mask, dim=0)
242 | cap_len = torch.Tensor(cap_len).int()
243 | label = torch.Tensor(label).int()
244 | return image_feature, cap_id.long(), att_mask.long(), cap_len, label, if_keyword
245 |
246 | def __len__(self):
247 | return len(self.data)
248 |
249 |
250 | def data_load_rwc_EWC(config, dir, mode):
251 | dataset = RWConcept_data_EWC(config, dir, mode)
252 | data_loader = DataLoader(dataset=dataset,
253 | batch_size=config.batch_size,
254 | shuffle=True,
255 | collate_fn=dataset.collate_fn_train,
256 | num_workers=config.num_workers,
257 | pin_memory=True,
258 | )
259 | return data_loader
260 |
261 |
262 | def data_load(config, dir, mode):
263 | if mode == 'train':
264 | print("warning: the train_loader is not exist")
265 | dataset = IC_data(config, dir, mode)
266 | data_loader = DataLoader(dataset=dataset,
267 | batch_size=config.batch_size if mode == 'train' else config.val_batch_size,
268 | shuffle=True if mode == 'train' else False,
269 | collate_fn=dataset.collate_fn_train if mode == 'train' else dataset.collate_fn_eval,
270 | num_workers=config.num_workers,
271 | pin_memory=True,
272 | )
273 | return data_loader
274 |
275 | def data_load_rwc(config, dir, mode):
276 | dataset = RWConcept_data(config, dir, mode)
277 | data_loader = DataLoader(dataset=dataset,
278 | batch_size=config.batch_size if mode == 'train' else config.val_batch_size,
279 | shuffle=False,
280 | collate_fn=dataset.collate_fn_train if mode == 'train' else dataset.collate_fn_eval,
281 | num_workers=config.num_workers,
282 | pin_memory=True,
283 | )
284 | return data_loader
285 |
286 |
287 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 | from .tokenizer.ptbtokenizer import PTBTokenizer
3 | from .bleu.bleu import Bleu
4 | from .meteor.meteor import Meteor
5 | from .rouge.rouge import Rouge
6 | from .cider.cider import Cider
7 | from .spice.spice import Spice
8 |
9 |
10 | class COCOEvalCap:
11 | def __init__(self, coco, cocoRes):
12 | self.evalImgs = []
13 | self.eval = {}
14 | self.imgToEval = {}
15 | self.coco = coco
16 | self.cocoRes = cocoRes
17 | # self.params = {'image_id': coco.getImgIds()}
18 |
19 | def evaluate(self):
20 | imgIds = self.params['image_id']
21 | # imgIds = self.coco.getImgIds()
22 | gts = {}
23 | res = {}
24 | for imgId in imgIds:
25 | gts[imgId] = self.coco.imgToAnns[imgId]
26 | res[imgId] = self.cocoRes.imgToAnns[imgId]
27 |
28 | # =================================================
29 | # Set up scorers
30 | # =================================================
31 | print('tokenization...')
32 | tokenizer = PTBTokenizer()
33 | gts = tokenizer.tokenize(gts)
34 | res = tokenizer.tokenize(res)
35 |
36 | # =================================================
37 | # Set up scorers
38 | # =================================================
39 | print('setting up scorers...')
40 | scorers = [
41 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
42 | (Meteor(),"METEOR"),
43 | (Rouge(), "ROUGE_L"),
44 | (Cider(), "CIDEr"),
45 | (Spice(), "SPICE")
46 | ]
47 |
48 | # =================================================
49 | # Compute scores
50 | # =================================================
51 | for scorer, method in scorers:
52 | print('computing %s score...'%(scorer.method()))
53 | score, scores = scorer.compute_score(gts, res)
54 | if type(method) == list:
55 | for sc, scs, m in zip(score, scores, method):
56 | self.setEval(sc, m)
57 | self.setImgToEvalImgs(scs, gts.keys(), m)
58 | print("%s: %0.3f"%(m, sc))
59 | else:
60 | self.setEval(score, method)
61 | self.setImgToEvalImgs(scores, gts.keys(), method)
62 | print("%s: %0.3f"%(method, score))
63 | self.setEvalImgs()
64 |
65 | def evaluate_diy(self, gts, res):
66 | """
67 | imgIds = self.params['image_id']
68 | # imgIds = self.coco.getImgIds()
69 | gts = {}
70 | res = {}
71 | for imgId in imgIds:
72 | gts[imgId] = self.coco.imgToAnns[imgId]
73 | res[imgId] = self.cocoRes.imgToAnns[imgId]
74 | """
75 | # =================================================
76 | # Set up scorers
77 | # =================================================
78 | print('tokenization...')
79 | tokenizer = PTBTokenizer()
80 | gts = tokenizer.tokenize(gts)
81 | res = tokenizer.tokenize(res)
82 |
83 | # =================================================
84 | # Set up scorers
85 | # =================================================
86 | print('setting up scorers...')
87 | my_score = {}
88 | scorers = [
89 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
90 | (Meteor(),"METEOR"),
91 | (Rouge(), "ROUGE_L"),
92 | (Cider(), "CIDEr"),
93 | # (Spice(), "SPICE")
94 | ]
95 |
96 | # =================================================
97 | # Compute scores
98 | # =================================================
99 | for scorer, method in scorers:
100 | print('computing %s score...'%(scorer.method()))
101 | score, scores = scorer.compute_score(gts, res)
102 | if type(method) == list:
103 | for sc, scs, m in zip(score, scores, method):
104 | self.setEval(sc, m)
105 | my_score[m] = sc
106 | self.setImgToEvalImgs(scs, gts.keys(), m)
107 | print("%s: %0.3f"%(m, sc))
108 | else:
109 | self.setEval(score, method)
110 | my_score[method] = score
111 | self.setImgToEvalImgs(scores, gts.keys(), method)
112 | print("%s: %0.3f"%(method, score))
113 | self.setEvalImgs()
114 | return my_score
115 |
116 | def evaluate_diy_every(self, gts, res):
117 | """
118 | imgIds = self.params['image_id']
119 | # imgIds = self.coco.getImgIds()
120 | gts = {}
121 | res = {}
122 | for imgId in imgIds:
123 | gts[imgId] = self.coco.imgToAnns[imgId]
124 | res[imgId] = self.cocoRes.imgToAnns[imgId]
125 | """
126 | # =================================================
127 | # Set up scorers
128 | # =================================================
129 | print('tokenization...')
130 | tokenizer = PTBTokenizer()
131 | gts = tokenizer.tokenize(gts)
132 | res = tokenizer.tokenize(res)
133 |
134 | # =================================================
135 | # Set up scorers
136 | # =================================================
137 | print('setting up scorers...')
138 | my_score = {}
139 | my_score_every = {}
140 | scorers = [
141 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
142 | (Meteor(),"METEOR"),
143 | (Rouge(), "ROUGE_L"),
144 | (Cider(), "CIDEr"),
145 | # (Spice(), "SPICE")
146 | ]
147 |
148 | # =================================================
149 | # Compute scores
150 | # =================================================
151 | for scorer, method in scorers:
152 | print('computing %s score...'%(scorer.method()))
153 | score, scores = scorer.compute_score(gts, res)
154 | if type(method) == list:
155 | for sc, scs, m in zip(score, scores, method):
156 | self.setEval(sc, m)
157 | my_score[m] = sc
158 | my_score_every[m] = scs
159 | self.setImgToEvalImgs(scs, gts.keys(), m)
160 | print("%s: %0.3f"%(m, sc))
161 | else:
162 | self.setEval(score, method)
163 | my_score[method] = score
164 | my_score_every[method] = scores
165 | self.setImgToEvalImgs(scores, gts.keys(), method)
166 | print("%s: %0.3f"%(method, score))
167 | self.setEvalImgs()
168 | return my_score, my_score_every
169 |
170 | def evaluate_diy_test(self, gts, res):
171 | """
172 | imgIds = self.params['image_id']
173 | # imgIds = self.coco.getImgIds()
174 | gts = {}
175 | res = {}
176 | for imgId in imgIds:
177 | gts[imgId] = self.coco.imgToAnns[imgId]
178 | res[imgId] = self.cocoRes.imgToAnns[imgId]
179 | """
180 | # =================================================
181 | # Set up scorers
182 | # =================================================
183 | print('tokenization...')
184 | tokenizer = PTBTokenizer()
185 | gts = tokenizer.tokenize(gts)
186 | res = tokenizer.tokenize(res)
187 |
188 | # =================================================
189 | # Set up scorers
190 | # =================================================
191 | print('setting up scorers...')
192 | my_score = {}
193 | scorers = [
194 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
195 | (Meteor(),"METEOR"),
196 | (Rouge(), "ROUGE_L"),
197 | (Cider(), "CIDEr"),
198 | # (Spice(), "SPICE")
199 | ]
200 |
201 | # =================================================
202 | # Compute scores
203 | # =================================================
204 | for scorer, method in scorers:
205 | if not method == "CIDEr":
206 | continue
207 | print('computing %s score...'%(scorer.method()))
208 | score, scores = scorer.compute_score(gts, res)
209 | if type(method) == list:
210 | for sc, scs, m in zip(score, scores, method):
211 | self.setEval(sc, m)
212 | my_score[m] = sc
213 | self.setImgToEvalImgs(scs, gts.keys(), m)
214 | print("%s: %0.3f"%(m, sc))
215 | else:
216 | self.setEval(score, method)
217 | my_score[method] = score
218 | self.setImgToEvalImgs(scores, gts.keys(), method)
219 | print("%s: %0.3f"%(method, score))
220 | self.setEvalImgs()
221 | return scores
222 |
223 | def setEval(self, score, method):
224 | self.eval[method] = score
225 |
226 | def setImgToEvalImgs(self, scores, imgIds, method):
227 | for imgId, score in zip(imgIds, scores):
228 | if not imgId in self.imgToEval:
229 | self.imgToEval[imgId] = {}
230 | self.imgToEval[imgId]["image_id"] = imgId
231 | self.imgToEval[imgId][method] = score
232 |
233 | def setEvalImgs(self):
234 | self.evalImgs = [eval for imgId, eval in self.imgToEval.items()]
--------------------------------------------------------------------------------
/evaluation/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/.DS_Store
--------------------------------------------------------------------------------
/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | #from .bleu import Bleu
2 | #from .meteor import Meteor
3 | #from .rouge import Rouge
4 | from .cider import Cider
5 | #from .tokenizer import PTBTokenizer
6 | """
7 | def compute_scores(gts, gen):
8 | metrics = (Bleu(), Meteor(), Rouge(), Cider())
9 | all_score = {}
10 | all_scores = {}
11 | for metric in metrics:
12 | score, scores = metric.compute_score(gts, gen)
13 | all_score[str(metric)] = score
14 | all_scores[str(metric)] = scores
15 |
16 | return all_score, all_scores
17 | """
--------------------------------------------------------------------------------
/evaluation/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/bleu/__init__.py:
--------------------------------------------------------------------------------
1 | from .bleu import Bleu
--------------------------------------------------------------------------------
/evaluation/bleu/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/bleu/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/bleu/__pycache__/bleu.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/bleu/__pycache__/bleu.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/bleu/__pycache__/bleu_scorer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/bleu/__pycache__/bleu_scorer.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/bleu/bleu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : bleu.py
4 | #
5 | # Description : Wrapper for BLEU scorer.
6 | #
7 | # Creation Date : 06-01-2015
8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | from .bleu_scorer import BleuScorer
12 |
13 |
14 | class Bleu:
15 | def __init__(self, n=4):
16 | # default compute Blue score up to 4
17 | self._n = n
18 | self._hypo_for_image = {}
19 | self.ref_for_image = {}
20 |
21 | def compute_score(self, gts, res):
22 |
23 | assert(gts.keys() == res.keys())
24 | imgIds = gts.keys()
25 |
26 | bleu_scorer = BleuScorer(n=self._n)
27 | for id in imgIds:
28 | hypo = res[id]
29 | ref = gts[id]
30 |
31 | # Sanity check.
32 | assert(type(hypo) is list)
33 | assert(len(hypo) == 1)
34 | assert(type(ref) is list)
35 | assert(len(ref) >= 1)
36 |
37 | bleu_scorer += (hypo[0], ref)
38 |
39 | # score, scores = bleu_scorer.compute_score(option='shortest')
40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0)
41 | # score, scores = bleu_scorer.compute_score(option='average', verbose=1)
42 |
43 | return score, scores
44 |
45 | def __str__(self):
46 | return 'BLEU'
47 |
--------------------------------------------------------------------------------
/evaluation/bleu/bleu_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # bleu_scorer.py
4 | # David Chiang
5 |
6 | # Copyright (c) 2004-2006 University of Maryland. All rights
7 | # reserved. Do not redistribute without permission from the
8 | # author. Not for commercial use.
9 |
10 | # Modified by:
11 | # Hao Fang
12 | # Tsung-Yi Lin
13 |
14 | ''' Provides:
15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
17 | '''
18 |
19 | import copy
20 | import sys, math, re
21 | from collections import defaultdict
22 |
23 |
24 | def precook(s, n=4, out=False):
25 | """Takes a string as input and returns an object that can be given to
26 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
27 | can take string arguments as well."""
28 | words = s.split()
29 | counts = defaultdict(int)
30 | for k in range(1, n + 1):
31 | for i in range(len(words) - k + 1):
32 | ngram = tuple(words[i:i + k])
33 | counts[ngram] += 1
34 | return (len(words), counts)
35 |
36 |
37 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
38 | '''Takes a list of reference sentences for a single segment
39 | and returns an object that encapsulates everything that BLEU
40 | needs to know about them.'''
41 |
42 | reflen = []
43 | maxcounts = {}
44 | for ref in refs:
45 | rl, counts = precook(ref, n)
46 | reflen.append(rl)
47 | for (ngram, count) in counts.items():
48 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
49 |
50 | # Calculate effective reference sentence length.
51 | if eff == "shortest":
52 | reflen = min(reflen)
53 | elif eff == "average":
54 | reflen = float(sum(reflen)) / len(reflen)
55 |
56 | ## lhuang: N.B.: leave reflen computaiton to the very end!!
57 |
58 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
59 |
60 | return (reflen, maxcounts)
61 |
62 |
63 | def cook_test(test, ref_tuple, eff=None, n=4):
64 | '''Takes a test sentence and returns an object that
65 | encapsulates everything that BLEU needs to know about it.'''
66 |
67 | testlen, counts = precook(test, n, True)
68 | reflen, refmaxcounts = ref_tuple
69 |
70 | result = {}
71 |
72 | # Calculate effective reference sentence length.
73 |
74 | if eff == "closest":
75 | result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1]
76 | else: ## i.e., "average" or "shortest" or None
77 | result["reflen"] = reflen
78 |
79 | result["testlen"] = testlen
80 |
81 | result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)]
82 |
83 | result['correct'] = [0] * n
84 | for (ngram, count) in counts.items():
85 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
86 |
87 | return result
88 |
89 |
90 | class BleuScorer(object):
91 | """Bleu scorer.
92 | """
93 |
94 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
95 |
96 | # special_reflen is used in oracle (proportional effective ref len for a node).
97 |
98 | def copy(self):
99 | ''' copy the refs.'''
100 | new = BleuScorer(n=self.n)
101 | new.ctest = copy.copy(self.ctest)
102 | new.crefs = copy.copy(self.crefs)
103 | new._score = None
104 | return new
105 |
106 | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
107 | ''' singular instance '''
108 |
109 | self.n = n
110 | self.crefs = []
111 | self.ctest = []
112 | self.cook_append(test, refs)
113 | self.special_reflen = special_reflen
114 |
115 | def cook_append(self, test, refs):
116 | '''called by constructor and __iadd__ to avoid creating new instances.'''
117 |
118 | if refs is not None:
119 | self.crefs.append(cook_refs(refs))
120 | if test is not None:
121 | cooked_test = cook_test(test, self.crefs[-1])
122 | self.ctest.append(cooked_test) ## N.B.: -1
123 | else:
124 | self.ctest.append(None) # lens of crefs and ctest have to match
125 |
126 | self._score = None ## need to recompute
127 |
128 | def ratio(self, option=None):
129 | self.compute_score(option=option)
130 | return self._ratio
131 |
132 | def score_ratio(self, option=None):
133 | '''
134 | return (bleu, len_ratio) pair
135 | '''
136 |
137 | return self.fscore(option=option), self.ratio(option=option)
138 |
139 | def score_ratio_str(self, option=None):
140 | return "%.4f (%.2f)" % self.score_ratio(option)
141 |
142 | def reflen(self, option=None):
143 | self.compute_score(option=option)
144 | return self._reflen
145 |
146 | def testlen(self, option=None):
147 | self.compute_score(option=option)
148 | return self._testlen
149 |
150 | def retest(self, new_test):
151 | if type(new_test) is str:
152 | new_test = [new_test]
153 | assert len(new_test) == len(self.crefs), new_test
154 | self.ctest = []
155 | for t, rs in zip(new_test, self.crefs):
156 | self.ctest.append(cook_test(t, rs))
157 | self._score = None
158 |
159 | return self
160 |
161 | def rescore(self, new_test):
162 | ''' replace test(s) with new test(s), and returns the new score.'''
163 |
164 | return self.retest(new_test).compute_score()
165 |
166 | def size(self):
167 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
168 | return len(self.crefs)
169 |
170 | def __iadd__(self, other):
171 | '''add an instance (e.g., from another sentence).'''
172 |
173 | if type(other) is tuple:
174 | ## avoid creating new BleuScorer instances
175 | self.cook_append(other[0], other[1])
176 | else:
177 | assert self.compatible(other), "incompatible BLEUs."
178 | self.ctest.extend(other.ctest)
179 | self.crefs.extend(other.crefs)
180 | self._score = None ## need to recompute
181 |
182 | return self
183 |
184 | def compatible(self, other):
185 | return isinstance(other, BleuScorer) and self.n == other.n
186 |
187 | def single_reflen(self, option="average"):
188 | return self._single_reflen(self.crefs[0][0], option)
189 |
190 | def _single_reflen(self, reflens, option=None, testlen=None):
191 |
192 | if option == "shortest":
193 | reflen = min(reflens)
194 | elif option == "average":
195 | reflen = float(sum(reflens)) / len(reflens)
196 | elif option == "closest":
197 | reflen = min((abs(l - testlen), l) for l in reflens)[1]
198 | else:
199 | assert False, "unsupported reflen option %s" % option
200 |
201 | return reflen
202 |
203 | def recompute_score(self, option=None, verbose=0):
204 | self._score = None
205 | return self.compute_score(option, verbose)
206 |
207 | def compute_score(self, option=None, verbose=0):
208 | n = self.n
209 | small = 1e-9
210 | tiny = 1e-15 ## so that if guess is 0 still return 0
211 | bleu_list = [[] for _ in range(n)]
212 |
213 | if self._score is not None:
214 | return self._score
215 |
216 | if option is None:
217 | option = "average" if len(self.crefs) == 1 else "closest"
218 |
219 | self._testlen = 0
220 | self._reflen = 0
221 | totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n}
222 |
223 | # for each sentence
224 | for comps in self.ctest:
225 | testlen = comps['testlen']
226 | self._testlen += testlen
227 |
228 | if self.special_reflen is None: ## need computation
229 | reflen = self._single_reflen(comps['reflen'], option, testlen)
230 | else:
231 | reflen = self.special_reflen
232 |
233 | self._reflen += reflen
234 |
235 | for key in ['guess', 'correct']:
236 | for k in range(n):
237 | totalcomps[key][k] += comps[key][k]
238 |
239 | # append per image bleu score
240 | bleu = 1.
241 | for k in range(n):
242 | bleu *= (float(comps['correct'][k]) + tiny) \
243 | / (float(comps['guess'][k]) + small)
244 | bleu_list[k].append(bleu ** (1. / (k + 1)))
245 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
246 | if ratio < 1:
247 | for k in range(n):
248 | bleu_list[k][-1] *= math.exp(1 - 1 / ratio)
249 |
250 | if verbose > 1:
251 | print(comps, reflen)
252 |
253 | totalcomps['reflen'] = self._reflen
254 | totalcomps['testlen'] = self._testlen
255 |
256 | bleus = []
257 | bleu = 1.
258 | for k in range(n):
259 | bleu *= float(totalcomps['correct'][k] + tiny) \
260 | / (totalcomps['guess'][k] + small)
261 | bleus.append(bleu ** (1. / (k + 1)))
262 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
263 | if ratio < 1:
264 | for k in range(n):
265 | bleus[k] *= math.exp(1 - 1 / ratio)
266 |
267 | if verbose > 0:
268 | print(totalcomps)
269 | print("ratio:", ratio)
270 |
271 | self._score = bleus
272 | return self._score, bleu_list
273 |
--------------------------------------------------------------------------------
/evaluation/cider/__init__.py:
--------------------------------------------------------------------------------
1 | from .cider import Cider
--------------------------------------------------------------------------------
/evaluation/cider/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/cider/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/cider/__pycache__/cider.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/cider/__pycache__/cider.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/cider/__pycache__/cider_scorer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/cider/__pycache__/cider_scorer.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/cider/cider.py:
--------------------------------------------------------------------------------
1 | # Filename: cider.py
2 | #
3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5 | #
6 | # Creation Date: Sun Feb 8 14:16:54 2015
7 | #
8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin
9 |
10 | from .cider_scorer import CiderScorer
11 |
12 | class Cider:
13 | """
14 | Main Class to compute the CIDEr metric
15 |
16 | """
17 | def __init__(self, gts=None, n=4, sigma=6.0):
18 | # set cider to sum over 1 to 4-grams
19 | self._n = n
20 | # set the standard deviation parameter for gaussian penalty
21 | self._sigma = sigma
22 | self.doc_frequency = None
23 | self.ref_len = None
24 | if gts is not None:
25 | tmp_cider = CiderScorer(gts, n=self._n, sigma=self._sigma)
26 | self.doc_frequency = tmp_cider.doc_frequency
27 | self.ref_len = tmp_cider.ref_len
28 |
29 | def compute_score(self, gts, res):
30 | """
31 | Main function to compute CIDEr score
32 | :param gts (dict) : dictionary with key and value
33 | res (dict) : dictionary with key and value
34 | :return: cider (float) : computed CIDEr score for the corpus
35 | """
36 | assert(gts.keys() == res.keys())
37 | cider_scorer = CiderScorer(gts, test=res, n=self._n, sigma=self._sigma, doc_frequency=self.doc_frequency,
38 | ref_len=self.ref_len)
39 | return cider_scorer.compute_score()
40 |
41 | def __str__(self):
42 | return 'CIDEr'
43 |
--------------------------------------------------------------------------------
/evaluation/cider/cider_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Tsung-Yi Lin
3 | # Ramakrishna Vedantam
4 |
5 | import copy
6 | from collections import defaultdict
7 | import numpy as np
8 | import math
9 |
10 | def precook(s, n=4):
11 | """
12 | Takes a string as input and returns an object that can be given to
13 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
14 | can take string arguments as well.
15 | :param s: string : sentence to be converted into ngrams
16 | :param n: int : number of ngrams for which representation is calculated
17 | :return: term frequency vector for occuring ngrams
18 | """
19 | words = s.split()
20 | counts = defaultdict(int)
21 | for k in range(1,n+1):
22 | for i in range(len(words)-k+1):
23 | ngram = tuple(words[i:i+k])
24 | counts[ngram] += 1
25 | return counts
26 |
27 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
28 | '''Takes a list of reference sentences for a single segment
29 | and returns an object that encapsulates everything that BLEU
30 | needs to know about them.
31 | :param refs: list of string : reference sentences for some image
32 | :param n: int : number of ngrams for which (ngram) representation is calculated
33 | :return: result (list of dict)
34 | '''
35 | return [precook(ref, n) for ref in refs]
36 |
37 | def cook_test(test, n=4):
38 | '''Takes a test sentence and returns an object that
39 | encapsulates everything that BLEU needs to know about it.
40 | :param test: list of string : hypothesis sentence for some image
41 | :param n: int : number of ngrams for which (ngram) representation is calculated
42 | :return: result (dict)
43 | '''
44 | return precook(test, n)
45 |
46 | class CiderScorer(object):
47 | """CIDEr scorer.
48 | """
49 |
50 | def __init__(self, refs, test=None, n=4, sigma=6.0, doc_frequency=None, ref_len=None):
51 | ''' singular instance '''
52 | self.n = n
53 | self.sigma = sigma
54 | self.crefs = []
55 | self.ctest = []
56 | self.doc_frequency = defaultdict(float)
57 | self.ref_len = None
58 |
59 | for k in refs.keys():
60 | self.crefs.append(cook_refs(refs[k]))
61 | if test is not None:
62 | self.ctest.append(cook_test(test[k][0])) ## N.B.: -1
63 | else:
64 | self.ctest.append(None) # lens of crefs and ctest have to match
65 |
66 | if doc_frequency is None and ref_len is None:
67 | # compute idf
68 | self.compute_doc_freq()
69 | # compute log reference length
70 | self.ref_len = np.log(float(len(self.crefs)))
71 | else:
72 | self.doc_frequency = doc_frequency
73 | self.ref_len = ref_len
74 |
75 | def compute_doc_freq(self):
76 | '''
77 | Compute term frequency for reference data.
78 | This will be used to compute idf (inverse document frequency later)
79 | The term frequency is stored in the object
80 | :return: None
81 | '''
82 | for refs in self.crefs:
83 | # refs, k ref captions of one image
84 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
85 | self.doc_frequency[ngram] += 1
86 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
87 |
88 | def compute_cider(self):
89 | def counts2vec(cnts):
90 | """
91 | Function maps counts of ngram to vector of tfidf weights.
92 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
93 | The n-th entry of array denotes length of n-grams.
94 | :param cnts:
95 | :return: vec (array of dict), norm (array of float), length (int)
96 | """
97 | vec = [defaultdict(float) for _ in range(self.n)]
98 | length = 0
99 | norm = [0.0 for _ in range(self.n)]
100 | for (ngram,term_freq) in cnts.items():
101 | # give word count 1 if it doesn't appear in reference corpus
102 | df = np.log(max(1.0, self.doc_frequency[ngram]))
103 | # ngram index
104 | n = len(ngram)-1
105 | # tf (term_freq) * idf (precomputed idf) for n-grams
106 | vec[n][ngram] = float(term_freq)*(self.ref_len - df)
107 | # compute norm for the vector. the norm will be used for computing similarity
108 | norm[n] += pow(vec[n][ngram], 2)
109 |
110 | if n == 1:
111 | length += term_freq
112 | norm = [np.sqrt(n) for n in norm]
113 | return vec, norm, length
114 |
115 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
116 | '''
117 | Compute the cosine similarity of two vectors.
118 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis
119 | :param vec_ref: array of dictionary for vector corresponding to reference
120 | :param norm_hyp: array of float for vector corresponding to hypothesis
121 | :param norm_ref: array of float for vector corresponding to reference
122 | :param length_hyp: int containing length of hypothesis
123 | :param length_ref: int containing length of reference
124 | :return: array of score for each n-grams cosine similarity
125 | '''
126 | delta = float(length_hyp - length_ref)
127 | # measure consine similarity
128 | val = np.array([0.0 for _ in range(self.n)])
129 | for n in range(self.n):
130 | # ngram
131 | for (ngram,count) in vec_hyp[n].items():
132 | # vrama91 : added clipping
133 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
134 |
135 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
136 | val[n] /= (norm_hyp[n]*norm_ref[n])
137 |
138 | assert(not math.isnan(val[n]))
139 | # vrama91: added a length based gaussian penalty
140 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
141 | return val
142 |
143 | scores = []
144 | for test, refs in zip(self.ctest, self.crefs):
145 | # compute vector for test captions
146 | vec, norm, length = counts2vec(test)
147 | # compute vector for ref captions
148 | score = np.array([0.0 for _ in range(self.n)])
149 | for ref in refs:
150 | vec_ref, norm_ref, length_ref = counts2vec(ref)
151 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
152 | # change by vrama91 - mean of ngram scores, instead of sum
153 | score_avg = np.mean(score)
154 | # divide by number of references
155 | score_avg /= len(refs)
156 | # multiply score by 10
157 | score_avg *= 10.0
158 | # append score of an image to the score list
159 | scores.append(score_avg)
160 | return scores
161 |
162 | def compute_score(self):
163 | # compute cider score
164 | score = self.compute_cider()
165 | # debug
166 | # print score
167 | return np.mean(np.array(score)), np.array(score)
--------------------------------------------------------------------------------
/evaluation/meteor/__init__.py:
--------------------------------------------------------------------------------
1 | from .meteor import Meteor
--------------------------------------------------------------------------------
/evaluation/meteor/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/meteor/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/meteor/__pycache__/meteor.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/meteor/__pycache__/meteor.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/meteor/meteor.py:
--------------------------------------------------------------------------------
1 | # Python wrapper for METEOR implementation, by Xinlei Chen
2 | # Acknowledge Michael Denkowski for the generous discussion and help
3 |
4 | import os
5 | import subprocess
6 | import threading
7 | import tarfile
8 | from utils import download_from_url
9 |
10 | METEOR_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/meteor.tgz'
11 | METEOR_JAR = 'meteor-1.5.jar'
12 |
13 | class Meteor:
14 | def __init__(self):
15 | base_path = os.path.dirname(os.path.abspath(__file__))
16 | jar_path = os.path.join(base_path, METEOR_JAR)
17 | gz_path = os.path.join(base_path, os.path.basename(METEOR_GZ_URL))
18 | if not os.path.isfile(jar_path):
19 | if not os.path.isfile(gz_path):
20 | download_from_url(METEOR_GZ_URL, gz_path)
21 | tar = tarfile.open(gz_path, "r")
22 | tar.extractall(path=os.path.dirname(os.path.abspath(__file__)))
23 | tar.close()
24 | os.remove(gz_path)
25 |
26 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
27 | '-', '-', '-stdio', '-l', 'en', '-norm']
28 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \
29 | cwd=os.path.dirname(os.path.abspath(__file__)), \
30 | stdin=subprocess.PIPE, \
31 | stdout=subprocess.PIPE, \
32 | stderr=subprocess.PIPE)
33 | # Used to guarantee thread safety
34 | self.lock = threading.Lock()
35 |
36 | def compute_score(self, gts, res):
37 | assert(gts.keys() == res.keys())
38 | imgIds = gts.keys()
39 | scores = []
40 |
41 | eval_line = 'EVAL'
42 | self.lock.acquire()
43 | for i in imgIds:
44 | assert(len(res[i]) == 1)
45 | stat = self._stat(res[i][0], gts[i])
46 | eval_line += ' ||| {}'.format(stat)
47 |
48 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode())
49 | self.meteor_p.stdin.flush()
50 | for i in range(0,len(imgIds)):
51 | scores.append(float(self.meteor_p.stdout.readline().strip()))
52 | score = float(self.meteor_p.stdout.readline().strip())
53 | self.lock.release()
54 |
55 | return score, scores
56 |
57 | def _stat(self, hypothesis_str, reference_list):
58 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
59 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
60 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
61 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode())
62 | self.meteor_p.stdin.flush()
63 | raw = self.meteor_p.stdout.readline().decode().strip()
64 | numbers = [str(int(float(n))) for n in raw.split()]
65 | return ' '.join(numbers)
66 |
67 | def __del__(self):
68 | self.lock.acquire()
69 | self.meteor_p.stdin.close()
70 | self.meteor_p.kill()
71 | self.meteor_p.wait()
72 | self.lock.release()
73 |
74 | def __str__(self):
75 | return 'METEOR'
76 |
--------------------------------------------------------------------------------
/evaluation/rouge/__init__.py:
--------------------------------------------------------------------------------
1 | from .rouge import Rouge
--------------------------------------------------------------------------------
/evaluation/rouge/rouge.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : rouge.py
4 | #
5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6 | #
7 | # Creation Date : 2015-01-07 06:03
8 | # Author : Ramakrishna Vedantam
9 |
10 | import numpy as np
11 | import pdb
12 |
13 |
14 | def my_lcs(string, sub):
15 | """
16 | Calculates longest common subsequence for a pair of tokenized strings
17 | :param string : list of str : tokens from a string split using whitespace
18 | :param sub : list of str : shorter string, also split using whitespace
19 | :returns: length (list of int): length of the longest common subsequence between the two strings
20 |
21 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
22 | """
23 | if (len(string) < len(sub)):
24 | sub, string = string, sub
25 |
26 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)]
27 |
28 | for j in range(1, len(sub) + 1):
29 | for i in range(1, len(string) + 1):
30 | if (string[i - 1] == sub[j - 1]):
31 | lengths[i][j] = lengths[i - 1][j - 1] + 1
32 | else:
33 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1])
34 |
35 | return lengths[len(string)][len(sub)]
36 |
37 |
38 | class Rouge():
39 | '''
40 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
41 |
42 | '''
43 |
44 | def __init__(self):
45 | # vrama91: updated the value below based on discussion with Hovey
46 | self.beta = 1.2
47 |
48 | def calc_score(self, candidate, refs):
49 | """
50 | Compute ROUGE-L score given one candidate and references for an image
51 | :param candidate: str : candidate sentence to be evaluated
52 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
53 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
54 | """
55 | assert (len(candidate) == 1)
56 | assert (len(refs) > 0)
57 | prec = []
58 | rec = []
59 |
60 | # split into tokens
61 | token_c = candidate[0].split(" ")
62 |
63 | for reference in refs:
64 | # split into tokens
65 | token_r = reference.split(" ")
66 | # compute the longest common subsequence
67 | lcs = my_lcs(token_r, token_c)
68 | prec.append(lcs / float(len(token_c)))
69 | rec.append(lcs / float(len(token_r)))
70 |
71 | prec_max = max(prec)
72 | rec_max = max(rec)
73 |
74 | if (prec_max != 0 and rec_max != 0):
75 | score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max)
76 | else:
77 | score = 0.0
78 | return score
79 |
80 | def compute_score(self, gts, res):
81 | """
82 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
83 | Invoked by evaluate_captions.py
84 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
85 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
86 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
87 | """
88 | assert (gts.keys() == res.keys())
89 | imgIds = gts.keys()
90 |
91 | score = []
92 | for id in imgIds:
93 | hypo = res[id]
94 | ref = gts[id]
95 |
96 | score.append(self.calc_score(hypo, ref))
97 |
98 | # Sanity check.
99 | assert (type(hypo) is list)
100 | assert (len(hypo) == 1)
101 | assert (type(ref) is list)
102 | assert (len(ref) > 0)
103 |
104 | average_score = np.mean(np.array(score))
105 | return average_score, np.array(score)
106 |
107 | def __str__(self):
108 | return 'ROUGE'
109 |
--------------------------------------------------------------------------------
/evaluation/stanford-corenlp-3.4.1.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/evaluation/stanford-corenlp-3.4.1.jar
--------------------------------------------------------------------------------
/evaluation/tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : ptbtokenizer.py
4 | #
5 | # Description : Do the PTB Tokenization and remove punctuations.
6 | #
7 | # Creation Date : 29-12-2014
8 | # Last Modified : Thu Mar 19 09:53:35 2015
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | import os
12 | import subprocess
13 | import tempfile
14 |
15 | class PTBTokenizer(object):
16 | """Python wrapper of Stanford PTBTokenizer"""
17 |
18 | corenlp_jar = 'stanford-corenlp-3.4.1.jar'
19 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
20 | ".", "?", "!", ",", ":", "-", "--", "...", ";"]
21 |
22 | @classmethod
23 | def tokenize(cls, corpus):
24 | cmd = ['java', '-cp', cls.corenlp_jar, \
25 | 'edu.stanford.nlp.process.PTBTokenizer', \
26 | '-preserveLines', '-lowerCase']
27 |
28 | if isinstance(corpus, list) or isinstance(corpus, tuple):
29 | if isinstance(corpus[0], list) or isinstance(corpus[0], tuple):
30 | corpus = {i:c for i, c in enumerate(corpus)}
31 | else:
32 | corpus = {i: [c, ] for i, c in enumerate(corpus)}
33 |
34 | # prepare data for PTB Tokenizer
35 | tokenized_corpus = {}
36 | image_id = [k for k, v in list(corpus.items()) for _ in range(len(v))]
37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in corpus.items() for c in v])
38 |
39 | # save sentences to temporary file
40 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
41 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
42 | tmp_file.write(sentences.encode())
43 | tmp_file.close()
44 |
45 | # tokenize sentence
46 | cmd.append(os.path.basename(tmp_file.name))
47 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
48 | stdout=subprocess.PIPE, stderr=open(os.devnull, 'w'))
49 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
50 | token_lines = token_lines.decode()
51 | lines = token_lines.split('\n')
52 | # remove temp file
53 | os.remove(tmp_file.name)
54 |
55 | # create dictionary for tokenized captions
56 | for k, line in zip(image_id, lines):
57 | if not k in tokenized_corpus:
58 | tokenized_corpus[k] = []
59 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
60 | if w not in cls.punctuations])
61 | tokenized_corpus[k].append(tokenized_caption)
62 |
63 | return tokenized_corpus
--------------------------------------------------------------------------------
/knowcap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/knowcap.png
--------------------------------------------------------------------------------
/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/.DS_Store
--------------------------------------------------------------------------------
/models/BLIP/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/BLIP/__init__.py
--------------------------------------------------------------------------------
/models/BLIP/blip_itm.py:
--------------------------------------------------------------------------------
1 | from models.med import BertConfig, BertModel
2 | from transformers import BertTokenizer
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 |
8 | from models.blip import create_vit, init_tokenizer, load_checkpoint
9 |
10 | class BLIP_ITM(nn.Module):
11 | def __init__(self,
12 | med_config = 'configs/med_config.json',
13 | image_size = 384,
14 | vit = 'base',
15 | vit_grad_ckpt = False,
16 | vit_ckpt_layer = 0,
17 | embed_dim = 256,
18 | ):
19 | """
20 | Args:
21 | med_config (str): path for the mixture of encoder-decoder model's configuration file
22 | image_size (int): input image size
23 | vit (str): model size of vision transformer
24 | """
25 | super().__init__()
26 |
27 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
28 | self.tokenizer = init_tokenizer()
29 | med_config = BertConfig.from_json_file(med_config)
30 | med_config.encoder_width = vision_width
31 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
32 |
33 | text_width = self.text_encoder.config.hidden_size
34 |
35 | self.vision_proj = nn.Linear(vision_width, embed_dim)
36 | self.text_proj = nn.Linear(text_width, embed_dim)
37 |
38 | self.itm_head = nn.Linear(text_width, 2)
39 |
40 |
41 | def forward(self, image, caption, match_head='itm'):
42 |
43 | image_embeds = self.visual_encoder(image)
44 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
45 |
46 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
47 | return_tensors="pt").to(image.device)
48 |
49 |
50 | if match_head=='itm':
51 | output = self.text_encoder(text.input_ids,
52 | attention_mask = text.attention_mask,
53 | encoder_hidden_states = image_embeds,
54 | encoder_attention_mask = image_atts,
55 | return_dict = True,
56 | )
57 | itm_output = self.itm_head(output.last_hidden_state[:,0,:])
58 | return itm_output
59 |
60 | elif match_head=='itc':
61 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
62 | return_dict = True, mode = 'text')
63 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
64 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
65 |
66 | sim = image_feat @ text_feat.t()
67 | return sim
68 |
69 |
70 | def blip_itm(pretrained='',**kwargs):
71 | model = BLIP_ITM(**kwargs)
72 | if pretrained:
73 | model,msg = load_checkpoint(model,pretrained)
74 | assert(len(msg.missing_keys)==0)
75 | return model
76 |
--------------------------------------------------------------------------------
/models/BLIP/blip_nlvr.py:
--------------------------------------------------------------------------------
1 | from models.med import BertConfig
2 | from models.nlvr_encoder import BertModel
3 | from models.vit import interpolate_pos_embed
4 | from models.blip import create_vit, init_tokenizer, is_url
5 |
6 | from timm.models.hub import download_cached_file
7 |
8 | import torch
9 | from torch import nn
10 | import torch.nn.functional as F
11 | from transformers import BertTokenizer
12 | import numpy as np
13 |
14 | class BLIP_NLVR(nn.Module):
15 | def __init__(self,
16 | med_config = 'configs/med_config.json',
17 | image_size = 480,
18 | vit = 'base',
19 | vit_grad_ckpt = False,
20 | vit_ckpt_layer = 0,
21 | ):
22 | """
23 | Args:
24 | med_config (str): path for the mixture of encoder-decoder model's configuration file
25 | image_size (int): input image size
26 | vit (str): model size of vision transformer
27 | """
28 | super().__init__()
29 |
30 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
31 | self.tokenizer = init_tokenizer()
32 | med_config = BertConfig.from_json_file(med_config)
33 | med_config.encoder_width = vision_width
34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
35 |
36 | self.cls_head = nn.Sequential(
37 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
38 | nn.ReLU(),
39 | nn.Linear(self.text_encoder.config.hidden_size, 2)
40 | )
41 |
42 | def forward(self, image, text, targets, train=True):
43 |
44 | image_embeds = self.visual_encoder(image)
45 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
46 | image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
47 |
48 | text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
49 | text.input_ids[:,0] = self.tokenizer.enc_token_id
50 |
51 | output = self.text_encoder(text.input_ids,
52 | attention_mask = text.attention_mask,
53 | encoder_hidden_states = [image0_embeds,image1_embeds],
54 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
55 | image_atts[image0_embeds.size(0):]],
56 | return_dict = True,
57 | )
58 | hidden_state = output.last_hidden_state[:,0,:]
59 | prediction = self.cls_head(hidden_state)
60 |
61 | if train:
62 | loss = F.cross_entropy(prediction, targets)
63 | return loss
64 | else:
65 | return prediction
66 |
67 | def blip_nlvr(pretrained='',**kwargs):
68 | model = BLIP_NLVR(**kwargs)
69 | if pretrained:
70 | model,msg = load_checkpoint(model,pretrained)
71 | print("missing keys:")
72 | print(msg.missing_keys)
73 | return model
74 |
75 |
76 | def load_checkpoint(model,url_or_filename):
77 | if is_url(url_or_filename):
78 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
79 | checkpoint = torch.load(cached_file, map_location='cpu')
80 | elif os.path.isfile(url_or_filename):
81 | checkpoint = torch.load(url_or_filename, map_location='cpu')
82 | else:
83 | raise RuntimeError('checkpoint url or path is invalid')
84 | state_dict = checkpoint['model']
85 |
86 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
87 |
88 | for key in list(state_dict.keys()):
89 | if 'crossattention.self.' in key:
90 | new_key0 = key.replace('self','self0')
91 | new_key1 = key.replace('self','self1')
92 | state_dict[new_key0] = state_dict[key]
93 | state_dict[new_key1] = state_dict[key]
94 | elif 'crossattention.output.dense.' in key:
95 | new_key0 = key.replace('dense','dense0')
96 | new_key1 = key.replace('dense','dense1')
97 | state_dict[new_key0] = state_dict[key]
98 | state_dict[new_key1] = state_dict[key]
99 |
100 | msg = model.load_state_dict(state_dict,strict=False)
101 | print('load checkpoint from %s'%url_or_filename)
102 | return model,msg
103 |
--------------------------------------------------------------------------------
/models/BLIP/blip_retrieval.py:
--------------------------------------------------------------------------------
1 | from .med import BertConfig, BertModel
2 | from transformers import BertTokenizer
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 |
8 | from .blip import create_vit, init_tokenizer, load_checkpoint
9 |
10 | class BLIP_Retrieval(nn.Module):
11 | def __init__(self,
12 | med_config = 'configs/med_config.json',
13 | image_size = 384,
14 | vit = 'base',
15 | vit_grad_ckpt = False,
16 | vit_ckpt_layer = 0,
17 | embed_dim = 256,
18 | queue_size = 57600,
19 | momentum = 0.995,
20 | negative_all_rank = False,
21 | ):
22 | """
23 | Args:
24 | med_config (str): path for the mixture of encoder-decoder model's configuration file
25 | image_size (int): input image size
26 | vit (str): model size of vision transformer
27 | """
28 | super().__init__()
29 |
30 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
31 | self.tokenizer = init_tokenizer()
32 | med_config = BertConfig.from_json_file(med_config)
33 | med_config.encoder_width = vision_width
34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
35 |
36 | text_width = self.text_encoder.config.hidden_size
37 |
38 | self.vision_proj = nn.Linear(vision_width, embed_dim)
39 | self.text_proj = nn.Linear(text_width, embed_dim)
40 |
41 | self.itm_head = nn.Linear(text_width, 2)
42 |
43 | # create momentum encoders
44 | self.visual_encoder_m, vision_width = create_vit(vit,image_size)
45 | self.vision_proj_m = nn.Linear(vision_width, embed_dim)
46 | self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
47 | self.text_proj_m = nn.Linear(text_width, embed_dim)
48 |
49 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
50 | [self.vision_proj,self.vision_proj_m],
51 | [self.text_encoder,self.text_encoder_m],
52 | [self.text_proj,self.text_proj_m],
53 | ]
54 | self.copy_params()
55 |
56 | # create the queue
57 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
58 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
59 | self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
60 | self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
61 |
62 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
63 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
64 |
65 | self.queue_size = queue_size
66 | self.momentum = momentum
67 | self.temp = nn.Parameter(0.07*torch.ones([]))
68 |
69 | self.negative_all_rank = negative_all_rank
70 |
71 |
72 | def forward(self, image, caption, alpha, idx):
73 | with torch.no_grad():
74 | self.temp.clamp_(0.001,0.5)
75 |
76 | image_embeds = self.visual_encoder(image)
77 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
78 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
79 |
80 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
81 | return_tensors="pt").to(image.device)
82 |
83 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
84 | return_dict = True, mode = 'text')
85 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
86 |
87 | ###============== Image-text Contrastive Learning ===================###
88 | idx = idx.view(-1,1)
89 | idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
90 | pos_idx = torch.eq(idx, idx_all).float()
91 | sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
92 |
93 | # get momentum features
94 | with torch.no_grad():
95 | self._momentum_update()
96 | image_embeds_m = self.visual_encoder_m(image)
97 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
98 | image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
99 |
100 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
101 | return_dict = True, mode = 'text')
102 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
103 | text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
104 |
105 | sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
106 | sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
107 |
108 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
109 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
110 |
111 | sim_i2t = image_feat @ text_feat_m_all / self.temp
112 | sim_t2i = text_feat @ image_feat_m_all / self.temp
113 |
114 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
115 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
116 |
117 | loss_ita = (loss_i2t+loss_t2i)/2
118 |
119 | idxs = concat_all_gather(idx)
120 | self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
121 |
122 | ###============== Image-text Matching ===================###
123 | encoder_input_ids = text.input_ids.clone()
124 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id
125 |
126 | # forward the positve image-text pair
127 | bs = image.size(0)
128 | output_pos = self.text_encoder(encoder_input_ids,
129 | attention_mask = text.attention_mask,
130 | encoder_hidden_states = image_embeds,
131 | encoder_attention_mask = image_atts,
132 | return_dict = True,
133 | )
134 |
135 |
136 | if self.negative_all_rank:
137 | # compute sample similarity
138 | with torch.no_grad():
139 | mask = torch.eq(idx, idxs.t())
140 |
141 | image_feat_world = concat_all_gather(image_feat)
142 | text_feat_world = concat_all_gather(text_feat)
143 |
144 | sim_i2t = image_feat @ text_feat_world.t() / self.temp
145 | sim_t2i = text_feat @ image_feat_world.t() / self.temp
146 |
147 | weights_i2t = F.softmax(sim_i2t,dim=1)
148 | weights_i2t.masked_fill_(mask, 0)
149 |
150 | weights_t2i = F.softmax(sim_t2i,dim=1)
151 | weights_t2i.masked_fill_(mask, 0)
152 |
153 | image_embeds_world = all_gather_with_grad(image_embeds)
154 |
155 | # select a negative image (from all ranks) for each text
156 | image_embeds_neg = []
157 | for b in range(bs):
158 | neg_idx = torch.multinomial(weights_t2i[b], 1).item()
159 | image_embeds_neg.append(image_embeds_world[neg_idx])
160 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
161 |
162 | # select a negative text (from all ranks) for each image
163 | input_ids_world = concat_all_gather(encoder_input_ids)
164 | att_mask_world = concat_all_gather(text.attention_mask)
165 |
166 | text_ids_neg = []
167 | text_atts_neg = []
168 | for b in range(bs):
169 | neg_idx = torch.multinomial(weights_i2t[b], 1).item()
170 | text_ids_neg.append(input_ids_world[neg_idx])
171 | text_atts_neg.append(att_mask_world[neg_idx])
172 |
173 | else:
174 | with torch.no_grad():
175 | mask = torch.eq(idx, idx.t())
176 |
177 | sim_i2t = image_feat @ text_feat.t() / self.temp
178 | sim_t2i = text_feat @ image_feat.t() / self.temp
179 |
180 | weights_i2t = F.softmax(sim_i2t,dim=1)
181 | weights_i2t.masked_fill_(mask, 0)
182 |
183 | weights_t2i = F.softmax(sim_t2i,dim=1)
184 | weights_t2i.masked_fill_(mask, 0)
185 |
186 | # select a negative image (from same rank) for each text
187 | image_embeds_neg = []
188 | for b in range(bs):
189 | neg_idx = torch.multinomial(weights_t2i[b], 1).item()
190 | image_embeds_neg.append(image_embeds[neg_idx])
191 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
192 |
193 | # select a negative text (from same rank) for each image
194 | text_ids_neg = []
195 | text_atts_neg = []
196 | for b in range(bs):
197 | neg_idx = torch.multinomial(weights_i2t[b], 1).item()
198 | text_ids_neg.append(encoder_input_ids[neg_idx])
199 | text_atts_neg.append(text.attention_mask[neg_idx])
200 |
201 | text_ids_neg = torch.stack(text_ids_neg,dim=0)
202 | text_atts_neg = torch.stack(text_atts_neg,dim=0)
203 |
204 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
205 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
206 |
207 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
208 | image_atts_all = torch.cat([image_atts,image_atts],dim=0)
209 |
210 | output_neg = self.text_encoder(text_ids_all,
211 | attention_mask = text_atts_all,
212 | encoder_hidden_states = image_embeds_all,
213 | encoder_attention_mask = image_atts_all,
214 | return_dict = True,
215 | )
216 |
217 |
218 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
219 | vl_output = self.itm_head(vl_embeddings)
220 |
221 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
222 | dim=0).to(image.device)
223 | loss_itm = F.cross_entropy(vl_output, itm_labels)
224 |
225 | return loss_ita, loss_itm
226 |
227 |
228 | @torch.no_grad()
229 | def copy_params(self):
230 | for model_pair in self.model_pairs:
231 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
232 | param_m.data.copy_(param.data) # initialize
233 | param_m.requires_grad = False # not update by gradient
234 |
235 |
236 | @torch.no_grad()
237 | def _momentum_update(self):
238 | for model_pair in self.model_pairs:
239 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
240 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
241 |
242 |
243 | @torch.no_grad()
244 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
245 | # gather keys before updating queue
246 | image_feats = concat_all_gather(image_feat)
247 | text_feats = concat_all_gather(text_feat)
248 |
249 |
250 | batch_size = image_feats.shape[0]
251 |
252 | ptr = int(self.ptr_queue)
253 | assert self.queue_size % batch_size == 0 # for simplicity
254 |
255 | # replace the keys at ptr (dequeue and enqueue)
256 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
257 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
258 | self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
259 | ptr = (ptr + batch_size) % self.queue_size # move pointer
260 |
261 | self.ptr_queue[0] = ptr
262 |
263 |
264 | def blip_retrieval(pretrained='',**kwargs):
265 | model = BLIP_Retrieval(**kwargs)
266 | if pretrained:
267 | model,msg = load_checkpoint(model,pretrained)
268 | print("missing keys:")
269 | print(msg.missing_keys)
270 | return model
271 |
272 |
273 | @torch.no_grad()
274 | def concat_all_gather(tensor):
275 | """
276 | Performs all_gather operation on the provided tensors.
277 | *** Warning ***: torch.distributed.all_gather has no gradient.
278 | """
279 | tensors_gather = [torch.ones_like(tensor)
280 | for _ in range(torch.distributed.get_world_size())]
281 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
282 |
283 | output = torch.cat(tensors_gather, dim=0)
284 | return output
285 |
286 |
287 | class GatherLayer(torch.autograd.Function):
288 | """
289 | Gather tensors from all workers with support for backward propagation:
290 | This implementation does not cut the gradients as torch.distributed.all_gather does.
291 | """
292 |
293 | @staticmethod
294 | def forward(ctx, x):
295 | output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
296 | torch.distributed.all_gather(output, x)
297 | return tuple(output)
298 |
299 | @staticmethod
300 | def backward(ctx, *grads):
301 | all_gradients = torch.stack(grads)
302 | torch.distributed.all_reduce(all_gradients)
303 | return all_gradients[torch.distributed.get_rank()]
304 |
305 |
306 | def all_gather_with_grad(tensors):
307 | """
308 | Performs all_gather operation on the provided tensors.
309 | Graph remains connected for backward grad computation.
310 | """
311 | # Queue the gathered tensors
312 | world_size = torch.distributed.get_world_size()
313 | # There is no need for reduction in the single-proc case
314 | if world_size == 1:
315 | return tensors
316 |
317 | tensor_all = GatherLayer.apply(tensors)
318 |
319 | return torch.cat(tensor_all, dim=0)
320 |
--------------------------------------------------------------------------------
/models/BLIP/blip_vqa.py:
--------------------------------------------------------------------------------
1 | from models.med import BertConfig, BertModel, BertLMHeadModel
2 | from models.blip import create_vit, init_tokenizer, load_checkpoint
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from transformers import BertTokenizer
8 | import numpy as np
9 |
10 | class BLIP_VQA(nn.Module):
11 | def __init__(self,
12 | med_config = 'configs/med_config.json',
13 | image_size = 480,
14 | vit = 'base',
15 | vit_grad_ckpt = False,
16 | vit_ckpt_layer = 0,
17 | ):
18 | """
19 | Args:
20 | med_config (str): path for the mixture of encoder-decoder model's configuration file
21 | image_size (int): input image size
22 | vit (str): model size of vision transformer
23 | """
24 | super().__init__()
25 |
26 | self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
27 | self.tokenizer = init_tokenizer()
28 |
29 | encoder_config = BertConfig.from_json_file(med_config)
30 | encoder_config.encoder_width = vision_width
31 | self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
32 |
33 | decoder_config = BertConfig.from_json_file(med_config)
34 | self.text_decoder = BertLMHeadModel(config=decoder_config)
35 |
36 |
37 | def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
38 |
39 | image_embeds = self.visual_encoder(image)
40 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
41 |
42 | question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
43 | return_tensors="pt").to(image.device)
44 | question.input_ids[:,0] = self.tokenizer.enc_token_id
45 |
46 | if train:
47 | '''
48 | n: number of answers for each question
49 | weights: weight for each answer
50 | '''
51 | answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
52 | answer.input_ids[:,0] = self.tokenizer.bos_token_id
53 | answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
54 |
55 | question_output = self.text_encoder(question.input_ids,
56 | attention_mask = question.attention_mask,
57 | encoder_hidden_states = image_embeds,
58 | encoder_attention_mask = image_atts,
59 | return_dict = True)
60 |
61 | question_states = []
62 | question_atts = []
63 | for b, n in enumerate(n):
64 | question_states += [question_output.last_hidden_state[b]]*n
65 | question_atts += [question.attention_mask[b]]*n
66 | question_states = torch.stack(question_states,0)
67 | question_atts = torch.stack(question_atts,0)
68 |
69 | answer_output = self.text_decoder(answer.input_ids,
70 | attention_mask = answer.attention_mask,
71 | encoder_hidden_states = question_states,
72 | encoder_attention_mask = question_atts,
73 | labels = answer_targets,
74 | return_dict = True,
75 | reduction = 'none',
76 | )
77 |
78 | loss = weights * answer_output.loss
79 | loss = loss.sum()/image.size(0)
80 |
81 | return loss
82 |
83 |
84 | else:
85 | question_output = self.text_encoder(question.input_ids,
86 | attention_mask = question.attention_mask,
87 | encoder_hidden_states = image_embeds,
88 | encoder_attention_mask = image_atts,
89 | return_dict = True)
90 |
91 | if inference=='generate':
92 | num_beams = 3
93 | question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
94 | question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
95 | model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
96 |
97 | bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
98 |
99 | outputs = self.text_decoder.generate(input_ids=bos_ids,
100 | max_length=10,
101 | min_length=1,
102 | num_beams=num_beams,
103 | eos_token_id=self.tokenizer.sep_token_id,
104 | pad_token_id=self.tokenizer.pad_token_id,
105 | **model_kwargs)
106 |
107 | answers = []
108 | for output in outputs:
109 | answer = self.tokenizer.decode(output, skip_special_tokens=True)
110 | answers.append(answer)
111 | return answers
112 |
113 | elif inference=='rank':
114 | max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
115 | answer.input_ids, answer.attention_mask, k_test)
116 | return max_ids
117 |
118 |
119 |
120 | def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
121 |
122 | num_ques = question_states.size(0)
123 | start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
124 |
125 | start_output = self.text_decoder(start_ids,
126 | encoder_hidden_states = question_states,
127 | encoder_attention_mask = question_atts,
128 | return_dict = True,
129 | reduction = 'none')
130 | logits = start_output.logits[:,0,:] # first token's logit
131 |
132 | # topk_probs: top-k probability
133 | # topk_ids: [num_question, k]
134 | answer_first_token = answer_ids[:,1]
135 | prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
136 | topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
137 |
138 | # answer input: [num_question*k, answer_len]
139 | input_ids = []
140 | input_atts = []
141 | for b, topk_id in enumerate(topk_ids):
142 | input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
143 | input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
144 | input_ids = torch.cat(input_ids,dim=0)
145 | input_atts = torch.cat(input_atts,dim=0)
146 |
147 | targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
148 |
149 | # repeat encoder's output for top-k answers
150 | question_states = tile(question_states, 0, k)
151 | question_atts = tile(question_atts, 0, k)
152 |
153 | output = self.text_decoder(input_ids,
154 | attention_mask = input_atts,
155 | encoder_hidden_states = question_states,
156 | encoder_attention_mask = question_atts,
157 | labels = targets_ids,
158 | return_dict = True,
159 | reduction = 'none')
160 |
161 | log_probs_sum = -output.loss
162 | log_probs_sum = log_probs_sum.view(num_ques,k)
163 |
164 | max_topk_ids = log_probs_sum.argmax(dim=1)
165 | max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
166 |
167 | return max_ids
168 |
169 |
170 | def blip_vqa(pretrained='',**kwargs):
171 | model = BLIP_VQA(**kwargs)
172 | if pretrained:
173 | model,msg = load_checkpoint(model,pretrained)
174 | # assert(len(msg.missing_keys)==0)
175 | return model
176 |
177 |
178 | def tile(x, dim, n_tile):
179 | init_dim = x.size(dim)
180 | repeat_idx = [1] * x.dim()
181 | repeat_idx[dim] = n_tile
182 | x = x.repeat(*(repeat_idx))
183 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
184 | return torch.index_select(x, dim, order_index.to(x.device))
185 |
186 |
--------------------------------------------------------------------------------
/models/BLIP/caption_coco.yaml:
--------------------------------------------------------------------------------
1 | # image_root: './export/share/datasets/vision/coco/images/'
2 | image_root: './'
3 | ann_root: 'annotation'
4 | coco_gt_root: 'annotation/coco_gt'
5 |
6 | # set pretrained as a file path or an url
7 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth'
8 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
9 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth'
10 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth'
11 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth'
12 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
13 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
14 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
15 |
16 | # size of vit model; base or large
17 | vit: 'large'
18 | vit_grad_ckpt: False
19 | vit_ckpt_layer: 0
20 | batch_size: 32
21 | init_lr: 1e-5
22 |
23 | # vit: 'large'
24 | # vit_grad_ckpt: True
25 | # vit_ckpt_layer: 5
26 | # batch_size: 16
27 | # init_lr: 2e-6
28 |
29 | image_size: 384
30 |
31 | # generation configs
32 | max_length: 20
33 | min_length: 5
34 | num_beams: 3
35 | prompt: 'a picture of '
36 |
37 | # optimizer
38 | weight_decay: 0.05
39 | min_lr: 0
40 | max_epoch: 5
41 |
42 |
43 |
--------------------------------------------------------------------------------
/models/BLIP/caption_coco_teacher.yaml:
--------------------------------------------------------------------------------
1 | # image_root: './export/share/datasets/vision/coco/images/'
2 | image_root: './'
3 | ann_root: 'annotation'
4 | coco_gt_root: 'annotation/coco_gt'
5 |
6 | # set pretrained as a file path or an url
7 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth'
8 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
9 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth'
10 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth'
11 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth'
12 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
13 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
14 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
15 |
16 | # size of vit model; base or large
17 | vit: 'large'
18 | vit_grad_ckpt: False
19 | vit_ckpt_layer: 0
20 | batch_size: 32
21 | init_lr: 1e-5
22 |
23 | # vit: 'large'
24 | # vit_grad_ckpt: True
25 | # vit_ckpt_layer: 5
26 | # batch_size: 16
27 | # init_lr: 2e-6
28 |
29 | image_size: 384
30 |
31 | # generation configs
32 | max_length: 20
33 | min_length: 5
34 | num_beams: 3
35 | prompt: 'a picture of '
36 |
37 | # optimizer
38 | weight_decay: 0.05
39 | min_lr: 0
40 | max_epoch: 5
41 |
42 |
43 |
--------------------------------------------------------------------------------
/models/BLIP/med_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertModel"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "layer_norm_eps": 1e-12,
12 | "max_position_embeddings": 512,
13 | "model_type": "bert",
14 | "num_attention_heads": 12,
15 | "num_hidden_layers": 12,
16 | "pad_token_id": 0,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30524,
19 | "encoder_width": 768,
20 | "add_cross_attention": true
21 | }
22 |
--------------------------------------------------------------------------------
/models/BLIP/vit.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2022, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Junnan Li
7 | * Based on timm code base
8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | '''
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from functools import partial
15 |
16 | from timm.models.vision_transformer import _cfg, PatchEmbed
17 | from timm.models.registry import register_model
18 | from timm.models.layers import trunc_normal_, DropPath
19 | from timm.models.helpers import named_apply, adapt_input_conv
20 |
21 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22 |
23 | class Mlp(nn.Module):
24 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25 | """
26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x):
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
43 |
44 | class Attention(nn.Module):
45 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46 | super().__init__()
47 | self.num_heads = num_heads
48 | head_dim = dim // num_heads
49 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50 | self.scale = qk_scale or head_dim ** -0.5
51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52 | self.attn_drop = nn.Dropout(attn_drop)
53 | self.proj = nn.Linear(dim, dim)
54 | self.proj_drop = nn.Dropout(proj_drop)
55 | self.attn_gradients = None
56 | self.attention_map = None
57 |
58 | def save_attn_gradients(self, attn_gradients):
59 | self.attn_gradients = attn_gradients
60 |
61 | def get_attn_gradients(self):
62 | return self.attn_gradients
63 |
64 | def save_attention_map(self, attention_map):
65 | self.attention_map = attention_map
66 |
67 | def get_attention_map(self):
68 | return self.attention_map
69 |
70 | def forward(self, x, register_hook=False):
71 | B, N, C = x.shape
72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74 |
75 | attn = (q @ k.transpose(-2, -1)) * self.scale
76 | attn = attn.softmax(dim=-1)
77 | attn = self.attn_drop(attn)
78 |
79 | if register_hook:
80 | self.save_attention_map(attn)
81 | attn.register_hook(self.save_attn_gradients)
82 |
83 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84 | x = self.proj(x)
85 | x = self.proj_drop(x)
86 | return x
87 |
88 |
89 | class Block(nn.Module):
90 |
91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93 | super().__init__()
94 | self.norm1 = norm_layer(dim)
95 | self.attn = Attention(
96 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99 | self.norm2 = norm_layer(dim)
100 | mlp_hidden_dim = int(dim * mlp_ratio)
101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102 |
103 | if use_grad_checkpointing:
104 | self.attn = checkpoint_wrapper(self.attn)
105 | self.mlp = checkpoint_wrapper(self.mlp)
106 |
107 | def forward(self, x, register_hook=False):
108 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109 | x = x + self.drop_path(self.mlp(self.norm2(x)))
110 | return x
111 |
112 |
113 | class VisionTransformer(nn.Module):
114 | """ Vision Transformer
115 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116 | https://arxiv.org/abs/2010.11929
117 | """
118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121 | use_grad_checkpointing=False, ckpt_layer=0):
122 | """
123 | Args:
124 | img_size (int, tuple): input image size
125 | patch_size (int, tuple): patch size
126 | in_chans (int): number of input channels
127 | num_classes (int): number of classes for classification head
128 | embed_dim (int): embedding dimension
129 | depth (int): depth of transformer
130 | num_heads (int): number of attention heads
131 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132 | qkv_bias (bool): enable bias for qkv if True
133 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135 | drop_rate (float): dropout rate
136 | attn_drop_rate (float): attention dropout rate
137 | drop_path_rate (float): stochastic depth rate
138 | norm_layer: (nn.Module): normalization layer
139 | """
140 | super().__init__()
141 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143 |
144 | self.patch_embed = PatchEmbed(
145 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146 |
147 | num_patches = self.patch_embed.num_patches
148 |
149 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151 | self.pos_drop = nn.Dropout(p=drop_rate)
152 |
153 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154 | self.blocks = nn.ModuleList([
155 | Block(
156 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159 | )
160 | for i in range(depth)])
161 | self.norm = norm_layer(embed_dim)
162 |
163 | trunc_normal_(self.pos_embed, std=.02)
164 | trunc_normal_(self.cls_token, std=.02)
165 | self.apply(self._init_weights)
166 |
167 | def _init_weights(self, m):
168 | if isinstance(m, nn.Linear):
169 | trunc_normal_(m.weight, std=.02)
170 | if isinstance(m, nn.Linear) and m.bias is not None:
171 | nn.init.constant_(m.bias, 0)
172 | elif isinstance(m, nn.LayerNorm):
173 | nn.init.constant_(m.bias, 0)
174 | nn.init.constant_(m.weight, 1.0)
175 |
176 | @torch.jit.ignore
177 | def no_weight_decay(self):
178 | return {'pos_embed', 'cls_token'}
179 |
180 | def forward(self, x, register_blk=-1):
181 | B = x.shape[0]
182 | x = self.patch_embed(x)
183 |
184 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185 | x = torch.cat((cls_tokens, x), dim=1)
186 |
187 | x = x + self.pos_embed[:,:x.size(1),:]
188 | x = self.pos_drop(x)
189 |
190 | for i,blk in enumerate(self.blocks):
191 | x = blk(x, register_blk==i)
192 | x = self.norm(x)
193 |
194 | return x
195 |
196 | @torch.jit.ignore()
197 | def load_pretrained(self, checkpoint_path, prefix=''):
198 | _load_weights(self, checkpoint_path, prefix)
199 |
200 |
201 | @torch.no_grad()
202 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204 | """
205 | import numpy as np
206 |
207 | def _n2p(w, t=True):
208 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209 | w = w.flatten()
210 | if t:
211 | if w.ndim == 4:
212 | w = w.transpose([3, 2, 0, 1])
213 | elif w.ndim == 3:
214 | w = w.transpose([2, 0, 1])
215 | elif w.ndim == 2:
216 | w = w.transpose([1, 0])
217 | return torch.from_numpy(w)
218 |
219 | w = np.load(checkpoint_path)
220 | if not prefix and 'opt/target/embedding/kernel' in w:
221 | prefix = 'opt/target/'
222 |
223 | if hasattr(model.patch_embed, 'backbone'):
224 | # hybrid
225 | backbone = model.patch_embed.backbone
226 | stem_only = not hasattr(backbone, 'stem')
227 | stem = backbone if stem_only else backbone.stem
228 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231 | if not stem_only:
232 | for i, stage in enumerate(backbone.stages):
233 | for j, block in enumerate(stage.blocks):
234 | bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235 | for r in range(3):
236 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239 | if block.downsample is not None:
240 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244 | else:
245 | embed_conv_w = adapt_input_conv(
246 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247 | model.patch_embed.proj.weight.copy_(embed_conv_w)
248 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251 | if pos_embed_w.shape != model.pos_embed.shape:
252 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254 | model.pos_embed.copy_(pos_embed_w)
255 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263 | for i, block in enumerate(model.blocks.children()):
264 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268 | block.attn.qkv.weight.copy_(torch.cat([
269 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270 | block.attn.qkv.bias.copy_(torch.cat([
271 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274 | for r in range(2):
275 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279 |
280 |
281 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282 | # interpolate position embedding
283 | embedding_size = pos_embed_checkpoint.shape[-1]
284 | num_patches = visual_encoder.patch_embed.num_patches
285 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286 | # height (== width) for the checkpoint position embedding
287 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288 | # height (== width) for the new position embedding
289 | new_size = int(num_patches ** 0.5)
290 |
291 | if orig_size!=new_size:
292 | # class_token and dist_token are kept unchanged
293 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294 | # only the position tokens are interpolated
295 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297 | pos_tokens = torch.nn.functional.interpolate(
298 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302 |
303 | return new_pos_embed
304 | else:
305 | return pos_embed_checkpoint
--------------------------------------------------------------------------------
/models/GIT/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/GIT/__init__.py
--------------------------------------------------------------------------------
/models/GIT/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/GIT/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/GIT/__pycache__/git_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/GIT/__pycache__/git_model.cpython-37.pyc
--------------------------------------------------------------------------------
/models/GIT/git.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('/home/data_ti4_c/chengkz/ofa/models/GIT')
3 | import torch
4 | import torch.nn as nn
5 | from git_model import GitForCausalLM
6 | from transformers import AutoProcessor
7 | import torch.nn.functional as F
8 | from utils.beamsearch import beam_search, beam_search_scst
9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10 |
11 |
12 | class GIT(nn.Module):
13 |
14 | def __init__(self, config, distill_model=False):
15 | super(GIT, self).__init__()
16 | self.config = config
17 | self.processor = AutoProcessor.from_pretrained("microsoft/git-large-coco", local_files_only=False)
18 | self.tokenizer = self.processor.tokenizer
19 | if distill_model:
20 | self.git_model = GitForCausalLM.from_pretrained(config.git_distill)
21 | else:
22 | self.git_model = GitForCausalLM.from_pretrained(config.git)
23 |
24 | def get_enc_output(self, patch_img):
25 | # 为了实现decode_step函数,需要将视觉编码单独分割出来,避免decode_step中每一步都forward视觉编码器
26 | # 同时也改写了GitModel的forward接口,允许提供视觉编码结果以避免重复计算
27 | projected_visual_features = None
28 | if patch_img is not None:
29 | if patch_img.ndim == 4:
30 | # here we assume patch_img is of shape (batch_size, num_channels, height, width)
31 | visual_features = self.git_model.git.image_encoder(patch_img).last_hidden_state
32 | elif patch_img.ndim == 5:
33 | # here we assume patch_img is of shape (batch_size, num_frames, num_channels, height, width)
34 | visual_features = []
35 | for frame_idx in range(patch_img.shape[1]):
36 | visual_features_frame = self.git_model.git.image_encoder(patch_img[:, frame_idx, :, :]).last_hidden_state
37 | visual_features_frame += self.git_model.git.img_temperal_embedding[frame_idx]
38 | visual_features.append(visual_features_frame)
39 |
40 | # finally, concatenate all features along sequence dimension
41 | visual_features = torch.cat(visual_features, dim=1)
42 | else:
43 | raise ValueError("patch_img must be of rank 4 or 5")
44 | projected_visual_features = self.git_model.git.visual_projection(visual_features)
45 | return projected_visual_features
46 |
47 | def forward(self, patch_img, cap, att_mask, cap_len):
48 | batch_size = patch_img.shape[0]
49 | with torch.no_grad():
50 | visual_features = self.get_enc_output(patch_img)
51 | logits = self.git_model(input_ids=cap, attention_mask=att_mask, visual_features=visual_features, pixel_values=patch_img).logits
52 | logits = logits[:, -20:, :]
53 | return logits
54 |
55 | def decode_step(self, input_ids, context):
56 | visual_features = context[0]
57 | patch_img = context[1]
58 | att_mask = torch.ones(input_ids.shape).long().to(device)
59 | logits = self.git_model(input_ids=input_ids, attention_mask=att_mask, visual_features=visual_features, pixel_values=patch_img).logits
60 | return logits, None
61 |
62 | def greedy_search(self, patch_img, mode='max'):
63 | """
64 | patch_img: [batch_size, *img_patch_size]
65 | """
66 | # 贪心搜索,返回的tokens应该是带有开始符和结束符的,以便用作pseudo-caption
67 | fixed_len = self.config.fixed_len
68 | gen_num = self.config.beam_num if mode == 'prob' else 1
69 | batch_size = patch_img.shape[0]*gen_num
70 | # GIT模型的bos符是bert的cls符,101
71 | sentences = torch.full((batch_size, 1), self.tokenizer.cls_token_id).long().to(device)
72 | log_probs_sen = torch.full((batch_size, 0), 0.0).to(device)
73 | cap_len = torch.LongTensor([fixed_len for i in range(batch_size)]).to(device)
74 |
75 | with torch.no_grad():
76 | visual_features = self.get_enc_output(patch_img)
77 |
78 | for i in range(fixed_len):
79 | attention_mask = torch.ones(sentences.shape).long().to(device)
80 | logits_all = self.git_model(input_ids=sentences, attention_mask=attention_mask, visual_features=visual_features,
81 | pixel_values=patch_img).logits
82 | logits = logits_all[:, -1, :]
83 | probs = F.softmax(logits, dim=-1)
84 | if mode == 'prob':
85 | token_id = torch.multinomial(probs, 1)[:, 0]
86 | else:
87 | score, token_id = torch.max(probs, dim=-1)
88 | for j in range(batch_size): # 生成过程中记录生成句子长度
89 | if token_id[j].item() == self.tokenizer.sep_token_id and cap_len[j].item() == fixed_len:
90 | cap_len[j] = i + 1
91 | sentences = torch.cat([sentences, token_id.unsqueeze(1)], dim=1)
92 | token_id = token_id.unsqueeze(1)
93 | log_probs_sen = torch.cat([log_probs_sen, torch.log(torch.gather(probs, 1, token_id))], dim=-1)
94 |
95 | all_tokens = [sentences[i][:(cap_len[i] + 1)] for i in range(batch_size)]
96 | all_logprob = [log_probs_sen[i][:cap_len[i]] for i in range(batch_size)]
97 | return all_tokens, all_logprob
98 |
99 | def generate_caption_batchbs(self, patch_img):
100 | batch_size = patch_img.shape[0]
101 | with torch.no_grad():
102 | visual_features = self.get_enc_output(patch_img)
103 | visual_features = visual_features.repeat_interleave(self.config.beam_num, dim=0)
104 |
105 | vocab_size = 30522
106 | captions = beam_search('Transformer', [visual_features, patch_img], self, batch_size, self.config.fixed_len, self.config.beam_num,
107 | vocab_size, self.config.length_penalty, bos_token_id=self.tokenizer.cls_token_id, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.sep_token_id)
108 | return captions
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
--------------------------------------------------------------------------------
/models/OFA/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/OFA/.DS_Store
--------------------------------------------------------------------------------
/models/OFA/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/OFA/__init__.py
--------------------------------------------------------------------------------
/models/OFA/__pycache__/ofa.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/OFA/__pycache__/ofa.cpython-37.pyc
--------------------------------------------------------------------------------
/models/OFA/__pycache__/ofa_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/OFA/__pycache__/ofa_model.cpython-37.pyc
--------------------------------------------------------------------------------
/models/OFA/ofa.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('/home/data_ti4_c/chengkz/ofa/models/OFA')
3 | import torch
4 | import torch.nn as nn
5 | from ofa_model import OFAModel
6 | from transformers.models.ofa.tokenization_ofa import OFATokenizer
7 | import torch.nn.functional as F
8 | from utils.beamsearch import beam_search, beam_search_scst
9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10 |
11 |
12 | class OFA(nn.Module):
13 |
14 | def __init__(self, config, distill_model=False):
15 | super(OFA, self).__init__()
16 | self.config = config
17 | self.ofa_ckpts = config.ofa_ckpts
18 | self.tokenizer = OFATokenizer.from_pretrained(self.ofa_ckpts)
19 | if distill_model:
20 | self.ofa_model = OFAModel.from_pretrained(self.config.ofa_ckpts_distill, use_cache=False).to(device)
21 | else:
22 | self.ofa_model = OFAModel.from_pretrained(self.config.ofa_ckpts, use_cache=False).to(device)
23 | #self.ofa_encoder = self.ofa_model.encoder
24 | self.prompt = " what does the image describe?"
25 | self.prompt_input = self.tokenizer([self.prompt], return_tensors="pt").input_ids.to(device)
26 | self.frozen()
27 | # self.re_init()
28 |
29 | def frozen(self):
30 | for name, params in self.named_parameters():
31 | if 'encoder' in name:
32 | params.requires_grad = False
33 |
34 | def un_frozen(self):
35 | for name, params in self.named_parameters():
36 | if 'encoder' in name:
37 | params.requires_grad = True
38 |
39 | def re_init(self):
40 | print("reinit decoder")
41 | self.ofa_model.decoder.init_weights()
42 |
43 | def gen_enc_output(self, patch_img):
44 | """
45 | patch_img: [batch_size, *img_patch_size]
46 | return: [batch_size, 908, 1024]
47 | """
48 | batch_size = patch_img.shape[0]
49 | prompt_input = self.prompt_input.expand([batch_size, self.prompt_input.shape[1]])
50 | encoder_outputs = self.ofa_model.encoder(input_ids=prompt_input, patch_images=patch_img)
51 | return encoder_outputs
52 |
53 | def forward(self, patch_img, cap, att_mask, cap_len):
54 | batch_size = patch_img.shape[0]
55 | # with torch.no_grad():
56 | enc_output = self.gen_enc_output(patch_img)
57 | sentences = cap
58 | attention_mask = att_mask
59 | logits = self.ofa_model(decoder_input_ids=sentences, # [batch_size, cap_len, vocab_size]
60 | attention_mask=attention_mask, encoder_outputs=enc_output).logits
61 | return logits
62 |
63 | def decode_step(self, input_ids, context):
64 | enc_output = context[0]
65 | sentences = input_ids
66 | attention_mask = torch.ones(sentences.shape).long().to(device)
67 | logits = self.ofa_model(decoder_input_ids=sentences, # [batch_size, cap_len, vocab_size]
68 | attention_mask=attention_mask, encoder_outputs=enc_output).logits
69 | return logits, None
70 |
71 |
72 | def greedy_search(self, patch_img, mode='max'):
73 | """
74 | patch_img: [batch_size, *img_patch_size]
75 | """
76 | # 贪心搜索,返回的tokens应该是带有开始符和结束符的,以便用作pseudo-caption
77 | fixed_len = self.config.fixed_len
78 | gen_num = self.config.beam_num if mode == 'prob' else 1
79 | batch_size = patch_img.shape[0]*gen_num
80 | # OFA模型的bos符是0
81 | sentences = torch.zeros([batch_size, 1]).long().to(device)
82 | log_probs_sen = torch.full((batch_size, 0), 0.0).to(device)
83 | cap_len = torch.LongTensor([fixed_len for i in range(batch_size)]).to(device)
84 |
85 | with torch.no_grad():
86 | enc_output = self.gen_enc_output(patch_img) # [batch_size, 908, 1024]
87 | if mode == 'prob':
88 | enc_output.last_hidden_state = enc_output.last_hidden_state.repeat(1, gen_num, 1). \
89 | view(enc_output.last_hidden_state.shape[0] * gen_num,
90 | enc_output.last_hidden_state.shape[1], enc_output.last_hidden_state.shape[2])
91 | enc_output.position_embedding = enc_output.position_embedding.repeat(1, gen_num, 1). \
92 | view(enc_output.position_embedding.shape[0] * gen_num,
93 | enc_output.position_embedding.shape[1], enc_output.position_embedding.shape[2])
94 | enc_output.padding_mask = enc_output.padding_mask.repeat(1, gen_num). \
95 | view(enc_output.padding_mask.shape[0] * gen_num, enc_output.padding_mask.shape[1])
96 |
97 | for i in range(fixed_len):
98 | attention_mask = torch.ones(sentences.shape).long().to(device)
99 | logits_all = self.ofa_model(decoder_input_ids=sentences, # [batch_size, 1, vocab_size]
100 | attention_mask=attention_mask, encoder_outputs=enc_output).logits
101 | logits = logits_all[:, -1, :]
102 | probs = F.softmax(logits, dim=-1)
103 | if mode == 'prob':
104 | token_id = torch.multinomial(probs, 1)[:, 0]
105 | else:
106 | score, token_id = torch.max(probs, dim=-1)
107 | for j in range(batch_size): # 生成过程中记录生成句子长度
108 | if token_id[j].item() == 2 and cap_len[j].item() == fixed_len:
109 | cap_len[j] = i + 1
110 | sentences = torch.cat([sentences, token_id.unsqueeze(1)], dim=1)
111 | token_id = token_id.unsqueeze(1)
112 | log_probs_sen = torch.cat([log_probs_sen, torch.log(torch.gather(probs, 1, token_id))], dim=-1)
113 |
114 | all_tokens = [sentences[i][:(cap_len[i] + 1)] for i in range(batch_size)]
115 | all_logprob = [log_probs_sen[i][:cap_len[i]] for i in range(batch_size)]
116 | return all_tokens, all_logprob
117 |
118 | def generate_caption_batchbs(self, patch_img):
119 | batch_size = patch_img.shape[0]
120 | with torch.no_grad():
121 | enc_output = self.gen_enc_output(patch_img)
122 | enc_output.last_hidden_state = enc_output.last_hidden_state.repeat(1, self.config.beam_num, 1).\
123 | view(enc_output.last_hidden_state.shape[0]*self.config.beam_num, enc_output.last_hidden_state.shape[1], enc_output.last_hidden_state.shape[2])
124 | enc_output.position_embedding = enc_output.position_embedding.repeat(1, self.config.beam_num, 1).\
125 | view(enc_output.position_embedding.shape[0]*self.config.beam_num, enc_output.position_embedding.shape[1], enc_output.position_embedding.shape[2])
126 | enc_output.padding_mask = enc_output.padding_mask.repeat(1, self.config.beam_num).\
127 | view(enc_output.padding_mask.shape[0]*self.config.beam_num, enc_output.padding_mask.shape[1])
128 | vocab_size = 59457
129 | captions = beam_search('Transformer', [enc_output], self, batch_size, self.config.fixed_len, self.config.beam_num,
130 | vocab_size, self.config.length_penalty, bos_token_id=0, pad_token_id=1, eos_token_id=2)
131 | return captions
--------------------------------------------------------------------------------
/models/Transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/models/Transformer/__init__.py
--------------------------------------------------------------------------------
/models/Transformer/transformer.py:
--------------------------------------------------------------------------------
1 | # 基于Transformer架构的图像描述模型
2 | # 包含使用faster-rcnn特征作为输入和cnn特征作为输入
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.utils.weight_norm import weight_norm
7 | import torch.nn.functional as F
8 | import pickle
9 | import math
10 | from utils.beamsearch import beam_search, beam_search_scst
11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12 |
13 |
14 | class PositionalEncoding(nn.Module):
15 |
16 | def __init__(self, d_model, dropout=0.1, max_len=30):
17 | super(PositionalEncoding, self).__init__()
18 | self.dropout = nn.Dropout(p=dropout)
19 |
20 | pe = torch.zeros(max_len, d_model)
21 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
22 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
23 | pe[:, 0::2] = torch.sin(position * div_term)
24 | pe[:, 1::2] = torch.cos(position * div_term)
25 | pe = pe.unsqueeze(0).transpose(0, 1)
26 | self.register_buffer('pe', pe)
27 |
28 | def forward(self, x):
29 | x = x + self.pe[:x.size(0), :]
30 | return self.dropout(x)
31 |
32 |
33 | class Transformer_Encoder(nn.Module):
34 |
35 | def __init__(self, config):
36 | super(Transformer_Encoder, self).__init__()
37 | self.config = config
38 | self.image_dim = config.image_dim
39 | self.embed_dim = config.embed_dim
40 | self.fea2embed = nn.Linear(self.image_dim, self.embed_dim)
41 | encoder_layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=8)
42 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
43 |
44 | def forward(self, fea_maps):
45 | fea_maps = self.fea2embed(fea_maps)
46 | fea_maps_seq = fea_maps.permute(1, 0, 2)
47 | memory = self.transformer_encoder(src=fea_maps_seq)
48 | return memory
49 |
50 |
51 | class Transformer_Decoder(nn.Module):
52 |
53 | def __init__(self, config):
54 | super(Transformer_Decoder, self).__init__()
55 | self.config = config
56 | self.vocab = pickle.load(open(self.config.vocab, 'rb'))
57 | self.vocab_size = self.vocab.get_size()
58 | self.embed_dim = config.embed_dim
59 |
60 | self.embed = nn.Embedding(self.vocab_size, self.embed_dim)
61 |
62 | self.pos_encoder = PositionalEncoding(self.embed_dim)
63 | decoder_layer = nn.TransformerDecoderLayer(d_model=self.embed_dim, nhead=8)
64 | self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
65 |
66 | self.fc = weight_norm(nn.Linear(self.embed_dim, self.vocab_size))
67 | self.dropout = nn.Dropout(0.5)
68 |
69 | def gen_tgt_mask(self, length):
70 | mask = torch.triu(torch.ones(length, length)).permute(1, 0).to(device)
71 | mask = mask.float().masked_fill(mask==0, float('-inf')).masked_fill(mask==1, float(0.0))
72 | return mask
73 |
74 | def forward(self, memory, cap, cap_len):
75 | cap = cap.permute(1, 0)
76 | tgt_pos_embedding = self.pos_encoder(self.embed(cap)*math.sqrt(self.embed_dim))
77 | tgt_mask = self.gen_tgt_mask(tgt_pos_embedding.shape[0])
78 | out = self.transformer_decoder(tgt=tgt_pos_embedding, memory=memory, tgt_mask=tgt_mask)
79 |
80 | pred = self.fc(self.dropout(out))
81 | pred = pred.permute(1, 0, 2)
82 |
83 | return pred
84 |
85 | def decode_step(self, input_ids, context):
86 | memory = context[0]
87 | cap = input_ids.permute(1, 0)
88 | tgt_pos_embedding = self.pos_encoder(self.embed(cap) * math.sqrt(self.embed_dim))
89 | tgt_mask = self.gen_tgt_mask(tgt_pos_embedding.shape[0])
90 | out = self.transformer_decoder(tgt=tgt_pos_embedding, memory=memory, tgt_mask=tgt_mask)
91 |
92 | pred = self.fc(self.dropout(out))
93 | pred = pred.permute(1, 0, 2)
94 | return pred, None
95 |
96 |
97 | class Transformer_Cap(nn.Module):
98 |
99 | def __init__(self, config):
100 | super(Transformer_Cap, self).__init__()
101 | self.config = config
102 | self.transformer_encoder = Transformer_Encoder(self.config)
103 | self.transformer_decoder = Transformer_Decoder(self.config)
104 |
105 | def forward(self, image_feature, cap, cap_len, mode='xe'):
106 | if mode == 'xe':
107 | fea_maps = image_feature['feature_map']
108 | memory = self.transformer_encoder(fea_maps)
109 | logit = self.transformer_decoder(memory, cap, cap_len)
110 | return logit
111 | elif mode == 'vanilla_scst':
112 | return self.greedy_search(image_feature, 'prob')
113 |
114 | def beam_search(self, image_feature):
115 | fea_maps = image_feature['feature_map']
116 | batch_size = fea_maps.shape[0]
117 | memory = self.transformer_encoder(fea_maps)
118 | memory = memory.repeat(1, 1, self.config.beam_num).view(memory.shape[0], memory.shape[1]*self.config.beam_num, memory.shape[2])
119 | captions, all_tokens, all_logprob = beam_search_scst('Transformer', [memory], self.transformer_decoder, batch_size, self.config.fixed_len, self.config.beam_num,
120 | self.transformer_decoder.vocab_size, self.config.length_penalty)
121 | return captions, all_tokens, all_logprob
122 |
123 | def greedy_search(self, image_feature, mode='max'):
124 | # greedy search或多项式采样search
125 | fea_maps = image_feature['feature_map']
126 | # 对一个样本采样beam_num个结果
127 | gen_num = self.config.beam_num if mode == 'prob' else 1
128 | fea_maps = fea_maps.unsqueeze(dim=1)
129 | fea_maps = fea_maps.expand([fea_maps.shape[0], gen_num, fea_maps.shape[2], fea_maps.shape[3]])
130 | fea_maps = fea_maps.reshape(fea_maps.shape[0] * fea_maps.shape[1], fea_maps.shape[2], fea_maps.shape[3])
131 | batch_size = fea_maps.shape[0]
132 |
133 | sentences = torch.ones([batch_size, 1]).to(device).long()
134 | log_probs_sen = torch.full((batch_size, 0), 0.0).to(device)
135 | cap_len = torch.LongTensor([20 for i in range(batch_size)]).to(device)
136 |
137 | memory = self.transformer_encoder(fea_maps)
138 | context = [memory]
139 | for i in range(self.config.fixed_len):
140 | outputs, _ = self.transformer_decoder.decode_step(sentences, context)
141 | logits = outputs[:, -1, :]
142 | probs = F.softmax(logits, dim=-1)
143 | if mode == 'prob':
144 | token_id = torch.multinomial(probs, 1)[:, 0]
145 | else:
146 | score, token_id = torch.max(probs, dim=-1)
147 | for j in range(batch_size): # 生成过程中记录生成句子长度
148 | if token_id[j].item() == 2 and cap_len[j].item() == 20:
149 | cap_len[j] = i + 1
150 | sentences = torch.cat([sentences, token_id.unsqueeze(1)], dim=1)
151 | token_id = token_id.unsqueeze(1)
152 | log_probs_sen = torch.cat([log_probs_sen, torch.log(torch.gather(probs, 1, token_id))], dim=-1)
153 |
154 | # 利用生成句子长度mask
155 | all_tokens = [sentences[i][:(cap_len[i] + 1)] for i in range(batch_size)]
156 | all_logprob = [log_probs_sen[i][:cap_len[i]] for i in range(batch_size)]
157 |
158 | return all_tokens, all_logprob
159 |
160 | def generate_caption_batchbs(self, image_feature):
161 | fea_maps = image_feature['feature_map']
162 | batch_size = fea_maps.shape[0]
163 | memory = self.transformer_encoder(fea_maps)
164 | memory = memory.repeat(1, 1, self.config.beam_num).view(memory.shape[0], memory.shape[1]*self.config.beam_num, memory.shape[2])
165 | caption = beam_search('Transformer', [memory], self.transformer_decoder, batch_size, self.config.fixed_len, self.config.beam_num,
166 | self.transformer_decoder.vocab_size, self.config.length_penalty)
167 | return caption
168 |
169 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | ## Beyond Generic: Enhancing Image Captioning with Real-World Knowledge using Vision-Language Pre-Training Model
2 | This repo provides the source code & data of our paper: [Beyond Generic: Enhancing Image Captioning with Real-World Knowledge using Vision-Language Pre-Training Model. (ACMMM 23)](https://arxiv.org/abs/2308.01126)
3 | ```
4 | @misc{cheng2023generic,
5 | title={Beyond Generic: Enhancing Image Captioning with Real-World Knowledge using Vision-Language Pre-Training Model},
6 | author={Kanzhi Cheng and Wenpo Song and Zheng Ma and Wenhao Zhu and Zixuan Zhu and Jianbing Zhang},
7 | year={2023},
8 | eprint={2308.01126},
9 | archivePrefix={arXiv},
10 | primaryClass={cs.CV}
11 | }
12 | ```
13 | ### Code Structure
14 | ***
15 | ````
16 | ├── config.py # config
17 | ├── data # coco data & knowcap data
18 | │ ├── data_cc12m_SelectForRreplay.json
19 | │ ├── dataset_coco.json
20 | │ ├── test.json
21 | │ ├── train.json
22 | │ ├── val.json
23 | │ ├── knowcap_240.json
24 | │ ├── knowcap_240_test.json
25 | │ ├── knowcap_240_test_unseen.json
26 | │ ├── knowcap_240_val.json
27 | │ ├── train_mix_32000.json
28 | │ └── ...
29 | ├── data_load.py # dataloader
30 | ├── test.py # evaluation on coco
31 | ├── test_knowcap.py # evaluation on knowcap
32 | ├── models # models (OFA,BLIP,GIT)
33 | │ ├── OFA
34 | │ ├── BLIP
35 | │ └── GIT
36 | ├── train_multitask.py # K-Replay training
37 | └── utils # support codes & tools
38 | ├── beamsearch.py # beamsearch
39 | ├── cc12m.py # filter relay data from cc12m
40 | ├── convert_ofa.py # ckpts convert
41 | ├── eval.py # generate captions & calculate metrics
42 | ├── import_models.py
43 | ├── log.py
44 | ├── loss.py # loss function of K-Replay
45 | ├── optimizer_tools.py
46 | └── prepro_data.py # construct the data in ./data
47 | ````
48 | ### KnowCap Dataset
49 | ***
50 | KnowCap is a new dataset for the evaluation of knowledge-enhanced image captioning, containing 1424 images and 4156 reference descriptions
51 | carefully written by human annotators.
52 |
53 | 
54 |
55 | Download the images and annotations of [KnowCap](https://drive.google.com/file/d/1DOk5WZZgHyO6tKT8A135hMgePid-akFq/view?usp=drive_link).
56 | ### Preparing Data&Model
57 | ***
58 | #### Step1:
59 | Download the images of:
60 | * [COCO2014](https://github.com/ruotianluo/ImageCaptioning.pytorch/blob/master/data/README.md)
61 | * [KnowCap](https://drive.google.com/file/d/1DOk5WZZgHyO6tKT8A135hMgePid-akFq/view?usp=drive_link)
62 | * [Replay images selected from cc12m](https://drive.google.com/file/d/1tdVZ1rUpr5va-NwInMwBglRpSGOzUoMu/view?usp=drive_link)
63 | #### Step2:
64 | `prepro_data.py`, Collate and split coco and knowcap datasets in ./data.
65 |
66 | Alternatively, we provide the processed [data](https://drive.google.com/file/d/1DBdnqcH_lOm--t5pZOlac1j1my4kVgrP/view?usp=drive_link) that can be put into . /data directory. Note that the file_path in each dataset needs to be modified according to the path of the downloaded image in step1. Similarly, some of the parameters in config need to be modified depending on your own.
67 |
68 | #### Step3:
69 | Prepare the ckpts of VLP models (take OFA as an example) for training and testing.
70 | 1. Download the transformers version ckpts of [OFA](https://huggingface.co/OFA-Sys/ofa-large)
71 | 2. However, since there are some [problems](https://github.com/OFA-Sys/OFA/issues/296) with the official ckpts in transformers, we manually replaced the original parameters with the official ckpts in fairseq using `convert_ofa.py`
72 |
73 | Alternatively, we provide the converted [ckpts](https://drive.google.com/file/d/1QQZ9eyO63JBBtyK5YIKA4CJ3jjAPuhQM/view?usp=drive_link).
74 | ### Reproduce the main results
75 | ***
76 | The baseline result of *OFA* in knowcap: `CUDA_VISIBLE_DEVICES=0 python test_knowcap.py --model OFA --ofa_ckpts xxx --length_penalty 1.0`, the `ofa_ckpts` is obtained in step3.
77 |
78 | The *OFA+K-Replay* result in knowcap: `CUDA_VISIBLE_DEVICES=0 python test_knowcap.py --model OFA --trained_ckpts xxx --length_penalty 1.0`, the `trained_ckpts` can be downloaded in [here](https://drive.google.com/file/d/1z2InwjGOcmTOFGr25nIFI_tGCPNBnc1H/view?usp=drive_link).
79 |
80 | To evaluate on coco, use `test.py` instead of `test_knowcap.py`.
81 |
82 | > #### Tips:
83 | To eliminate the need for `coco_id` in the evaluation, we customized the COCOEval function in `eval.py`.
84 | Therefore the `xxx/site-packages/pycocoevalcap/eval.py` needs to be replaced or modified with the `eval.py` to use the current evaluation code.
85 | ### Training with K-Replay
86 | ***
87 | #### Step4:
88 | Start Training with K-Replay:
89 | `CUDA_VISIBLE_DEVICES=0 python train_multitask.py --mode train --model OFA --id ofa_kreplay --batch_size 60 --learning_rate 7e-6 --label_smoothing 0.1 --multitask_weight 1.0 --KD_temperature 16.0 --knowdistill_weight 1.0 --save_model_freq 100 --ofa_ckpts /home/chengkz/checkpoints/ofa/OFA-large-caption-trainedenc --ofa_ckpts_distill /home/chengkz/checkpoints/ofa/OFA-large-caption-XEfinetuned --train_mix ./data/train_mix_32000.json --method XEdistill`.
90 |
91 | The `ofa_ckpts` and `ofa_ckpts_distill` are obtained in step3, `train_mix_32000.json` is obtained in step2.
92 | #### Step5:
93 | Evaluation on COCO:
94 | `CUDA_VISIBLE_DEVICES=0 python test.py --model OFA --id ofa_kreplay --step 300 --length_penalty 1.0`.
95 |
96 | Evaluation on KnowCap: `CUDA_VISIBLE_DEVICES=0 python test_knowcap.py --model OFA --id ofa_kreplay --step 300 --length_penalty 1.0`.
97 | > #### Tips:
98 | OFA uses `resnet` as the backbone of its visual encoder. In our experiments, we found that the `batchnorm` layers in the resnet backbone do not give good estimates of the `mean` and `std` due to the small batchsize we used, which leads to a degradation of the model performance. Therefore, we fixed the `mean` and `std` of these layers during training, by setting `momentum=0.0` in `./transformers/models/ofa/resnet.py`.
99 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | aiohttp
3 | astunparse==1.6.3
4 | async-timeout==3.0.1
5 | attrs
6 | blinker==1.4
7 | brotlipy==0.7.0
8 | cached-property==1.5.2
9 | cachetools
10 | certifi==2021.5.30
11 | cffi==1.14.0
12 | chardet
13 | click
14 | clip
15 | coverage
16 | cryptography
17 | cycler==0.10.0
18 | Cython==0.29.24
19 | decorator==4.4.2
20 | dnspython==2.2.1
21 | echo1-coco-split==0.1.5
22 | et-xmlfile==1.1.0
23 | eventlet==0.33.2
24 | fairscale==0.4.6
25 | filelock==3.3.2
26 | flatbuffers==23.5.26
27 | ftfy==6.1.1
28 | funcy==1.17
29 | future==0.18.2
30 | gast==0.4.0
31 | google-auth
32 | google-auth-oauthlib==0.4.1
33 | google-pasta==0.2.0
34 | greenlet==2.0.1
35 | grpcio
36 | h5py==3.3.0
37 | huggingface-hub==0.13.3
38 | idna
39 | imageio==2.9.0
40 | importlib-metadata==5.1.0
41 | jieba==0.42.1
42 | joblib==1.1.0
43 | jsonlines==3.0.0
44 | keras==2.11.0
45 | kiwisolver
46 | libclang==16.0.6
47 | llvmlite==0.39.1
48 | Markdown
49 | matplotlib==3.4.2
50 | mkl-service==2.3.0
51 | multidict
52 | networkx==2.5.1
53 | nltk==3.4.5
54 | numba==0.56.4
55 | numpy==1.21.0
56 | oauthlib==3.1.0
57 | olefile==0.46
58 | opencv-python==4.5.5.64
59 | openpyxl==3.0.10
60 | opt-einsum==3.3.0
61 | packaging==21.2
62 | pandas==1.3.5
63 | Pillow==8.3.0
64 | protobuf==3.14.0
65 | pyasn1==0.4.8
66 | pyasn1-modules==0.2.8
67 | pycocoevalcap==1.2
68 | pycocotools==2.0.2
69 | pycparser
70 | PyJWT==1.7.1
71 | pynndescent==0.5.8
72 | pyOpenSSL
73 | pyparsing
74 | PySocks
75 | python-dateutil
76 | pytz
77 | PyWavelets==1.1.1
78 | PyYAML==6.0
79 | regex==2021.11.1
80 | requests
81 | requests-oauthlib==1.3.0
82 | rouge==1.0.1
83 | rsa
84 | sacremoses==0.0.46
85 | scikit-image==0.18.2
86 | scikit-learn==1.0.2
87 | scipy==1.7.0
88 | seaborn==0.12.1
89 | sentencepiece==0.1.99
90 | six
91 | sklearn==0.0
92 | tensorboard==2.11.2
93 | tensorboard-data-server==0.6.1
94 | tensorboard-plugin-wit==1.6.0
95 | tensorflow==2.11.0
96 | tensorflow-estimator==2.11.0
97 | tensorflow-io-gcs-filesystem==0.33.0
98 | termcolor==2.3.0
99 | threadpoolctl==3.1.0
100 | tifffile==2021.7.2
101 | timeout-decorator==0.5.0
102 | timm==0.6.13
103 | tokenizers==0.12.1
104 | torch==1.8.1+cu101
105 | torchaudio==0.8.1
106 | torchvision==0.9.1+cu101
107 | tornado
108 | tqdm==4.65.0
109 | transformers==4.28.0.dev0
110 | typing-extensions
111 | umap==0.1.1
112 | umap-learn==0.5.3
113 | urllib3==1.25.8
114 | wcwidth==0.2.5
115 | Werkzeug
116 | wget==3.2
117 | wrapt==1.15.0
118 | XlsxWriter==3.0.3
119 | yarl
120 | zipp
121 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # 测试,通过test参数指定在testa还是testb上进行测试
2 |
3 | import torch
4 | import random
5 | import numpy as np
6 | import json
7 |
8 | from config import config
9 |
10 | from utils.import_models import construct_model
11 | from utils.eval import generate_captions, eval_pycoco
12 | from utils.vocab import Vocabulary
13 |
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 |
17 | # 随机种子
18 | seed = config.seed
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 | np.random.seed(seed)
22 | random.seed(seed)
23 | torch.backends.cudnn.deterministic = True
24 |
25 | # model
26 | model = construct_model(config).to(device)
27 | log_path = config.log_dir.format(config.id)
28 | trained_model_path = log_path + '/model/model_' + str(config.step) + '.pt'
29 | model.load_state_dict(torch.load(trained_model_path))
30 | model.eval()
31 | with torch.no_grad():
32 | gen_pycoco_path = generate_captions(config, model, config.step, config.test, final_test=True)
33 |
34 | if False: # 一些官方的ckpts会出现多余字符需要后处理
35 | pycoco = json.load(open(gen_pycoco_path, 'r'))
36 | for k, v in pycoco.items():
37 | caption_origin = v[0]["caption"]
38 | caption_new = caption_origin.replace(')', '').replace('\\', '').replace('}', '').replace(']', '').strip()
39 | v[0]["caption"] = caption_new
40 | json.dump(pycoco, open(gen_pycoco_path, 'w'))
41 |
42 | pycoco_results = eval_pycoco(config, gen_pycoco_path, config.test)
43 | print(pycoco_results)
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/train_multitask.py:
--------------------------------------------------------------------------------
1 | # training with K-Replay
2 | import yaml
3 | import json
4 | import torch
5 | import random
6 | import numpy as np
7 |
8 | import time
9 | import torch.nn as nn
10 |
11 | from config import config
12 | from data_load import data_load_rwc
13 |
14 | from utils.import_models import construct_model
15 | from utils.loss import Cross_Entropy, Loss_SCST_OFA, Sent_Level_Concept_Coverage, Loss_Params_Regular, Loss_KD
16 | from utils.log import Log_Writer, train_print
17 | from utils.eval import generate_captions, eval_pycoco
18 | from utils.optimizer_tools import adjust_weight, adjust_lr, cal_fisher_coco, cal_fisher_downtask_mask, adjust_mask, model_grad_mask, RecAdam, cal_fisher_downtask, ratio_dataset
19 | from utils.vocab import Vocabulary
20 | from test_knowcap import cal_knowcap
21 | from models.OFA.ofa import OFA
22 | from models.BLIP.blip import blip_decoder
23 | from models.GIT.git import GIT
24 |
25 |
26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27 |
28 | # 随机种子
29 | seed = config.seed
30 | torch.manual_seed(seed)
31 | torch.cuda.manual_seed_all(seed)
32 | np.random.seed(seed)
33 | random.seed(seed)
34 | torch.backends.cudnn.deterministic = True
35 |
36 | # log
37 | writer = Log_Writer(config)
38 | global_step = 0
39 | loss_avg = 0
40 | loss_ce_avg = 0
41 | loss_rwc_avg = 0
42 | mt_weight = config.multitask_weight
43 | kd_weight = config.knowdistill_weight
44 |
45 | train_mix = config.train_mix # 用于训练的数据集,既可以是混合的,也可以是coco单独的
46 | if config.data_ratio != 1.0: # 可调整coco和其他的比例
47 | train_mix_data_new = ratio_dataset(train_mix, config.data_ratio)
48 | train_mix = './data/train_mix_cc12m_keyword_'+str(config.data_ratio)+'.json'
49 | json.dump(train_mix_data_new, open(train_mix, 'w'))
50 |
51 | data_mode = config.data_mode # 和train_mix配合使用,决定训练数据和模式,mix|single
52 | method = config.method # 比较的各种方法
53 | model_type = config.model
54 | # data_loader
55 | train_loader = data_load_rwc(config, train_mix, 'train')
56 |
57 | # model
58 | model = construct_model(config).to(device)
59 | if method == 'XEdistill':
60 | if model_type == 'OFA':
61 | model_t = OFA(config, distill_model=True)
62 | elif model_type == 'BLIP':
63 | argst = yaml.load(open(config.config_blip_t, 'r'), Loader=yaml.Loader)
64 | model_t = blip_decoder(pretrained=argst['pretrained'], config=config, image_size=argst['image_size'],
65 | vit=argst['vit'],
66 | vit_grad_ckpt=argst['vit_grad_ckpt'], vit_ckpt_layer=argst['vit_ckpt_layer'],
67 | prompt=argst['prompt'])
68 | elif model_type == 'GIT':
69 | model_t = GIT(config, distill_model=True)
70 | model_t = model_t.to(device)
71 | loss_distill = Loss_KD(config.KD_temperature)
72 | if method == 'Adapter': # Adapter使得只有小部分模型参数参与训练
73 | for name, p in model.named_parameters():
74 | if p.requires_grad == True:
75 | if 'adapter_ln1' in name or 'adapter_ln2' in name:
76 | p.requires_grad = True
77 | print(name)
78 | else:
79 | p.requires_grad = False
80 |
81 | # optimizer
82 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
83 | if method == 'RecAdam': # Recall and Learn利用优化器做正则化
84 | pretrain_params = []
85 | for name, p in model.named_parameters():
86 | pretrain_params.append(p)
87 | optimizer = RecAdam(model.parameters(), lr=config.learning_rate, betas=(0.9, 0.98), eps=1e-9, anneal_k=1.0, anneal_t0=100, pretrain_params=pretrain_params)
88 |
89 | # loss
90 | loss_cov = Sent_Level_Concept_Coverage()
91 | loss_fn = Cross_Entropy(label_smoothing=config.label_smoothing)
92 |
93 | if method == 'child-tuning': # Child-tuning和EWC与参数在对应任务梯度有关,因此先计算相关梯度
94 | grads_mask_coco = cal_fisher_downtask_mask(config, model)
95 | # grads_mask_coco = adjust_mask(grads_mask_coco)
96 | elif method == 'k-tuning':
97 | grads_mask_knowledge = cal_fisher_downtask_mask(config, model) # 找到和知识相关的参数
98 | elif method == 'EWC':
99 | params_fisher = cal_fisher_downtask(config, model)
100 | params_init = dict()
101 | for name, params in model.named_parameters():
102 | if params.requires_grad == True:
103 | params_init[name] = params
104 | loss_params_regular = Loss_Params_Regular(params_init, params_fisher)
105 |
106 | if config.step != 0:
107 | log_path = config.log_dir.format(config.ckpts_id)
108 | trained_model_path = log_path + '/model/model_' + str(config.step) + '.pt'
109 | model.load_state_dict(torch.load(trained_model_path))
110 | global_step = config.step
111 |
112 | for epoch in range(config.epochs):
113 | if global_step == 800:
114 | break
115 | model.train()
116 | totel_step = len(train_loader)
117 | epoch_time = time.time()
118 | step_time = time.time()
119 |
120 | optimizer = adjust_lr(optimizer, epoch)
121 | for step, (image_feature, cap, att_mask, cap_len, labels, data_item) in enumerate(train_loader):
122 |
123 | data_mode = config.data_mode
124 | global_step += 1
125 | optimizer.zero_grad()
126 |
127 | patch_image = image_feature['patch_image']
128 | patch_image = patch_image.to(device)
129 | cap = cap.to(device)
130 | cap_len = cap_len.to(device)
131 | labels = labels.to(device)
132 | att_mask = att_mask.to(device)
133 |
134 | if labels.sum().item() == 0:
135 | data_mode = 'single'
136 |
137 | if data_mode == 'mix': # 找到其中rwconcept的样本,构建伪 pair进行训练
138 | index_rwc = torch.nonzero(labels==1).squeeze().long()
139 | if index_rwc.shape == torch.Size([]):
140 | index_rwc = index_rwc.unsqueeze(0)
141 | index_coco = torch.nonzero(labels==0).squeeze(dim=1).long()
142 | # 保存原caption以作为label
143 | cap_rwc_label = cap[index_rwc]
144 | # 为这些样本用当前模型生成伪caption
145 | if index_rwc.shape != torch.Size([0]):
146 | with torch.no_grad():
147 | patch_image_rwc = patch_image[index_rwc]
148 | all_tokens, all_logprob = model.greedy_search(patch_image_rwc, 'max')
149 | cap_new = []
150 | att_mask_new = []
151 | cap_len_new = []
152 | for cap_id in all_tokens:
153 | cap_len_g = cap_id.shape[0]
154 | if cap_len_g < config.fixed_len:
155 | if model_type == 'OFA':
156 | cap_id = torch.cat([cap_id, torch.ones([config.fixed_len - cap_len_g]).to(device)], dim=0)
157 | elif model_type == 'BLIP':
158 | cap_id = torch.cat([cap_id, torch.zeros([config.fixed_len - cap_len_g]).to(device)], dim=0)
159 | elif model_type == 'GIT':
160 | cap_id = torch.cat([cap_id, torch.zeros([config.fixed_len - cap_len_g]).to(device)], dim=0)
161 | att_mask_g = torch.cat([torch.ones([cap_len_g]).to(device), torch.zeros([config.fixed_len - cap_len_g]).to(device)], dim=0)
162 | else:
163 | cap_id = cap_id[:config.fixed_len]
164 | cap_len_g = config.fixed_len
165 | att_mask_g = torch.ones(cap_id.shape).to(device)
166 | cap_new.append(cap_id)
167 | att_mask_new.append(att_mask_g)
168 | cap_len_new.append(cap_len_g)
169 | cap_new = torch.stack(cap_new, dim=0).long()
170 | att_mask_new = torch.stack(att_mask_new, dim=0).long()
171 | cap_len_new = torch.Tensor(cap_len_new).int()
172 | # 将伪caption放回原数据中一起进行forward
173 | cap[index_rwc] = cap_new.to(device)
174 | att_mask[index_rwc] = att_mask_new.to(device)
175 | cap_len[index_rwc] = cap_len_new.to(device)
176 | # 知识蒸馏,用teacher进行一次前向传播获得logit
177 | if method == 'XEdistill':
178 | with torch.no_grad():
179 | logit_t = model_t(patch_image[index_rwc], cap[index_rwc], att_mask[index_rwc], cap_len[index_rwc])
180 |
181 | logit = model(patch_image, cap, att_mask, cap_len)
182 | if data_mode == 'single':
183 | loss = loss_fn(logit, cap, cap_len)
184 | loss_avg += loss.item()
185 | elif data_mode == 'mix':
186 | loss_ce = loss_fn(logit[index_coco], cap[index_coco], cap_len[index_coco])
187 | loss_rwc = loss_cov(logit[index_rwc], cap_rwc_label, cap_len[index_rwc], model_type)
188 | loss = loss_ce + mt_weight * loss_rwc
189 | if method == 'XEdistill':
190 | loss_kd = loss_distill(logit[index_rwc], logit_t, cap_len[index_rwc])
191 | loss += kd_weight*loss_kd
192 | loss_ce_avg += loss_ce.item()
193 | loss_rwc_avg += loss_rwc.item()
194 | loss_avg += loss.item()
195 |
196 | if method == 'EWC':
197 | loss = loss + loss_params_regular(model)
198 |
199 | loss.backward()
200 | nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
201 | if method == 'child-tuning':
202 | model_grad_mask(model, grads_mask_coco)
203 | optimizer.step()
204 |
205 | if global_step % config.save_loss_freq == 0:
206 | writer.write_tensorboard('loss', loss_avg/config.save_loss_freq, global_step)
207 | loss_avg = 0
208 | if data_mode == 'mix':
209 | writer.write_tensorboard('loss_ce', loss_ce_avg/config.save_loss_freq, global_step)
210 | writer.write_tensorboard('loss_rwc', loss_rwc_avg/config.save_loss_freq, global_step)
211 | loss_ce_avg = 0
212 | loss_rwc_avg = 0
213 |
214 | train_print(loss.item(), step, totel_step, epoch, time.time() - step_time, time.time() - epoch_time)
215 | step_time = time.time()
216 |
217 | if global_step % config.save_model_freq == 0:
218 | print("Evaluating...")
219 |
220 | # 保存模型
221 | if global_step % 100 == 0:
222 | writer.save_model(model, global_step)
223 |
224 | # validation
225 | model.eval()
226 | with torch.no_grad():
227 | gen_pycoco_path = generate_captions(config, model, global_step, 'val')
228 | pycoco_results = eval_pycoco(config, gen_pycoco_path, 'val')
229 | pycoco_results_knowcap, acc = cal_knowcap(model, global_step)
230 | writer.write_metrics(pycoco_results, global_step)
231 | writer.write_metrics(pycoco_results_knowcap, global_step)
232 | writer.write_metrics(acc, global_step)
233 |
234 | model.train()
235 |
236 | if global_step == 800:
237 | break
--------------------------------------------------------------------------------
/utils/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/utils/.DS_Store
--------------------------------------------------------------------------------
/utils/__pycache__/beamsearch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/utils/__pycache__/beamsearch.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/eval.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/utils/__pycache__/eval.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/import_models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/utils/__pycache__/import_models.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/vocab.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njucckevin/KnowCap/85965a45bd5cbfbcd580050b8de3c842dbf27542/utils/__pycache__/vocab.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/beamsearch.py:
--------------------------------------------------------------------------------
1 | # batch beamsearch
2 | # 参照huggingface的实现 https://zhuanlan.zhihu.com/p/167072494 http://www.wuyuanhao.com/2020/03/20/解读beam-search-1-2/
3 | # 除了支持以batch形式一次为多个样本进行beamsearch,与传统beamsearch的最大不同在于:
4 | # 对于beam中的序列,即使生成了end标识符,beam的宽度也不会减小;而是将生成完成的序列存入BeamHypotheses,并向beam中补充一个新的未生成完成序列,
5 | # 并继续宽度为beam的搜索过程,期间不断用新生成完成的序列更新BeamHypotheses
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11 |
12 |
13 | class BeamHypotheses(object):
14 | # 每个样本绑定一个,其中维护num_beams个当前最优的序列;可向其中添加新序列并自动踢掉分数最低的
15 | def __init__(self, num_beams, max_length, length_penalty):
16 | # 初始化
17 | self.max_length = max_length - 1
18 | self.num_beams = num_beams
19 | self.length_penalty = length_penalty
20 | self.beams = []
21 | self.worst_score = 1e9
22 |
23 | def __len__(self):
24 | return len(self.beams)
25 |
26 | def add(self, hyp, sum_logprobs):
27 | # 长度惩罚,可自定义
28 | score = sum_logprobs / len(hyp) ** self.length_penalty
29 | # score = sum_logprobs / (pow((5+len(hyp)+1), self.length_penalty)/pow(5+1, self.length_penalty))
30 | if len(self) < self.num_beams or score > self.worst_score:
31 | # 可添加
32 | self.beams.append((score, hyp))
33 | if len(self) > self.num_beams:
34 | # 需要删掉一个
35 | sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
36 | del self.beams[sorted_scores[0][1]]
37 | self.worst_score = sorted_scores[1][0]
38 | else:
39 | self.worst_score = min(score, self.worst_score)
40 |
41 | def add_scst(self, hyp, logprob, sum_logprobs):
42 | # 长度惩罚,可自定义
43 | score = sum_logprobs / len(hyp) ** self.length_penalty
44 | # score = sum_logprobs / (pow((5+len(hyp)+1), self.length_penalty)/pow(5+1, self.length_penalty))
45 | if len(self) < self.num_beams or score > self.worst_score:
46 | # 可添加
47 | self.beams.append((score, hyp, logprob))
48 | if len(self) > self.num_beams:
49 | # 需要删掉一个
50 | sorted_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
51 | del self.beams[sorted_scores[0][1]]
52 | self.worst_score = sorted_scores[1][0]
53 | else:
54 | self.worst_score = min(score, self.worst_score)
55 |
56 | def is_done(self, best_sum_logprobs, cur_len=None):
57 | # 样本是否已经生成完成,关键:并非生成beam个完成的序列,而是新一时刻beam宽度个结果中的最高分不如之前保存的最低分
58 | # best_sum_logprobs是新的候选序列中的最高得分
59 | if len(self) < self.num_beams:
60 | return False
61 | else:
62 | if cur_len is None:
63 | cur_len = self.max_length
64 | cur_score = best_sum_logprobs / cur_len ** self.length_penalty
65 | # cur_score = best_sum_logprobs / (pow((5+cur_len+1), self.length_penalty)/pow(5+1, self.length_penalty))
66 | # 如果最高分比保存的最低分还差,则结束
67 | ret = self.worst_score >= cur_score
68 | return ret
69 |
70 |
71 | def beam_search(mode, context, model, batch_size, max_length, num_beams, vocab_size, length_penalty,
72 | bos_token_id=1, pad_token_id=0, eos_token_id=2, prompt=None):
73 | # batch beamsearch
74 | # 记录每个样本的已生成序列,已生成序列得分和是否已生成完成
75 | generated_hyps = [BeamHypotheses(num_beams, max_length, length_penalty) for _ in range(batch_size)]
76 | beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float).to(device)
77 | beam_scores[:, 1:] = -1e9 # 否则t=1时刻取到的num_beams个最大的将都是同一个词,从而导致后面所有num_beams个结果均相同
78 | beam_scores = beam_scores.view(-1)
79 | done = [False for _ in range(batch_size)]
80 |
81 | # 初始input和当前长度
82 | if prompt == None:
83 | input_ids = torch.full((batch_size*num_beams, 1), bos_token_id, dtype=torch.long).to(device)
84 | else:
85 | input_ids = prompt.repeat([num_beams, 1])
86 | cur_len = 1
87 |
88 | # 初始状态 hidden: (batch_size*num_beams, *)
89 | # 对于LSTM-based模型来说,hidden是解码器的隐藏层状态,需要在每个时刻更新;而对于Transformer-based模型来说,hidden是编码端的输出,解码所有时刻保持不变
90 | # hidden = context
91 |
92 | while cur_len < max_length:
93 | # 需要模型实现一个接口:根据hidden状态,以及当前已生成的序列,生成下一时刻的词表概率分布(以及LSTM-based模型需要更新后的hidden)
94 | outputs, hidden = model.decode_step(input_ids, context)
95 | next_token_logits = outputs[:, -1, :]
96 |
97 | scores = F.log_softmax(next_token_logits, dim=-1)
98 | next_scores = scores + beam_scores[:, None].expand_as(scores)
99 | next_scores = next_scores.view(batch_size, num_beams*vocab_size) # 便于用topk为batch内的每个样本选最大
100 |
101 | # next_scores/next_tokens: (batch_size, num_beams)
102 | # 关键:这里保留了2*num_beams个结果,目的是即使有beam生成了eos,依然能找到num_beams可以继续生成的选项
103 | next_scores, next_tokens = torch.topk(next_scores, 2*num_beams, dim=1, largest=True, sorted=True)
104 |
105 | next_batch_beam = [] # 为下一时刻准备 (分数, token_id, beam_id)
106 | for batch_idx in range(batch_size):
107 | if done[batch_idx]: # 如果当前batch已经完成,直接补pad
108 | next_batch_beam.extend([(0, pad_token_id, 0)]*num_beams)
109 | continue
110 | next_sent_beam = [] # 记录一个batch内beam_num个最好的(且没有生成完成的)结果
111 | for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
112 | zip(next_tokens[batch_idx], next_scores[batch_idx])
113 | ):
114 | beam_id = beam_token_id // vocab_size # beam_id:属于当前batch的第几个beam
115 | token_id = beam_token_id % vocab_size
116 | effective_beam_id = batch_idx * num_beams + beam_id # 在原始(batch_size*num_beams, *)中的位置
117 | if token_id.item() == eos_token_id:
118 | # 生成eos,将当前beam的句子存入
119 | is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
120 | if is_beam_token_worse_than_top_num_beams:
121 | continue
122 | # 存入时不包含eos
123 | generated_hyps[batch_idx].add(input_ids[effective_beam_id].clone(), beam_token_score.item())
124 | else:
125 | # 保存生成后的状态
126 | next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
127 |
128 | if len(next_sent_beam) == num_beams: # 当前batch不管有没有、生成了几个eos,依然会保留num_beams个可扩展的序列
129 | break
130 |
131 | # 什么情况算生成完成?已经生成了num_beams个完整句子,且当前时刻生成的结果(可能是完整句子,也可能不是)没有新的更好的
132 | done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx]\
133 | .is_done(next_scores[batch_idx].max().item(), cur_len)
134 |
135 | next_batch_beam.extend(next_sent_beam)
136 |
137 | if all(done):
138 | break
139 |
140 | # 准备下一时刻
141 | beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
142 | beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
143 | beam_idx = input_ids.new([x[2] for x in next_batch_beam])
144 |
145 | input_ids = input_ids[beam_idx, :]
146 | input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
147 |
148 | if mode == 'LSTM': # LSTM需要更新隐藏层状态
149 | hidden = [item[beam_idx, :] for item in hidden]
150 | #h, c = hidden
151 | #h = h[beam_idx, :]
152 | #c = c[beam_idx, :]
153 | #hidden = (h, c)
154 | context[-1] = hidden
155 |
156 | cur_len += 1
157 |
158 | # 手动结束没有生成eos的样本
159 | for batch_idx in range(batch_size):
160 | if done[batch_idx]:
161 | continue
162 | for beam_id in range(num_beams):
163 | # 对于需要手动结束的样本,全部尝试加入
164 | effective_beam_id = batch_idx*num_beams+beam_id
165 | final_score = beam_scores[effective_beam_id].item()
166 | final_tokens = input_ids[effective_beam_id]
167 | generated_hyps[batch_idx].add(final_tokens, final_score)
168 |
169 | # 至此,generated_hyps中保存着每个样本的num_beams个最优序列
170 | best = []
171 | for i, hypotheses in enumerate(generated_hyps):
172 | sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
173 | best_hyp = sorted_hyps.pop()[1]
174 | best.append(best_hyp)
175 |
176 | return best
177 |
178 |
179 | def beam_search_scst(mode, context, model, batch_size, max_length, num_beams, vocab_size, length_penalty,
180 | bos_token_id=1, pad_token_id=0, eos_token_id=2):
181 | # batch beamsearch
182 | # 记录每个样本的已生成序列,已生成序列得分和是否已生成完成
183 | # 在beamseach的每个时刻,保存当前最优beam个从开始到当前所有时刻的logprob
184 | generated_hyps = [BeamHypotheses(num_beams, max_length, length_penalty) for _ in range(batch_size)]
185 | beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float).to(device)
186 | beam_scores[:, 1:] = -1e9 # 否则t=1时刻取到的num_beams个最大的将都是同一个词,从而导致后面所有num_beams个结果均相同
187 | beam_scores = beam_scores.view(-1)
188 | done = [False for _ in range(batch_size)]
189 |
190 | # 初始input和当前长度
191 | input_ids = torch.full((batch_size*num_beams, 1), bos_token_id, dtype=torch.long).to(device)
192 | ids_logprob = torch.full((batch_size*num_beams, 0), 0.0).to(device)
193 | cur_len = 1
194 |
195 | # 初始状态 hidden: (batch_size*num_beams, *)
196 | # 对于LSTM-based模型来说,hidden是解码器的隐藏层状态,需要在每个时刻更新;而对于Transformer-based模型来说,hidden是编码端的输出,解码所有时刻保持不变
197 | # hidden = context
198 |
199 | while cur_len < max_length:
200 | # 需要模型实现一个接口:根据hidden状态,以及当前已生成的序列,生成下一时刻的词表概率分布(以及LSTM-based模型需要更新后的hidden)
201 | outputs, hidden = model.decode_step(input_ids, context)
202 | next_token_logits = outputs[:, -1, :]
203 | scores = F.log_softmax(next_token_logits, dim=-1)
204 | next_scores = scores + beam_scores[:, None].expand_as(scores)
205 | next_scores = next_scores.view(batch_size, num_beams*vocab_size) # 便于用topk为batch内的每个样本选最大
206 | scores = scores.view(batch_size, num_beams*vocab_size) # 便于根据取出topk的id取出对应的概率
207 |
208 | # next_scores/next_tokens: (batch_size, num_beams)
209 | # 关键:这里保留了2*num_beams个结果,目的是即使有beam生成了eos,依然能找到num_beams可以继续生成的选项
210 | next_scores, next_tokens = torch.topk(next_scores, 2*num_beams, dim=1, largest=True, sorted=True)
211 |
212 | next_batch_beam = [] # 为下一时刻准备 (分数, token_id, beam_id)
213 | for batch_idx in range(batch_size):
214 | if done[batch_idx]: # 如果当前batch已经完成,直接补pad
215 | next_batch_beam.extend([(0, pad_token_id, 0, 0)]*num_beams)
216 | continue
217 | next_sent_beam = [] # 记录一个batch内beam_num个最好的(且没有生成完成的)结果
218 | for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
219 | zip(next_tokens[batch_idx], next_scores[batch_idx])
220 | ):
221 | beam_id = beam_token_id // vocab_size # beam_id:属于当前batch的第几个beam
222 | token_id = beam_token_id % vocab_size
223 | logprob = scores[batch_idx][beam_token_id]
224 | effective_beam_id = batch_idx * num_beams + beam_id # 在原始(batch_size*num_beams, *)中的位置
225 | if token_id.item() == eos_token_id:
226 | # 生成eos,将当前beam的句子存入
227 | is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
228 | if is_beam_token_worse_than_top_num_beams:
229 | continue
230 | # 存入时不包含eos
231 | logprob_add = torch.cat([ids_logprob[effective_beam_id].clone(), logprob.unsqueeze(0)], dim=0)
232 | generated_hyps[batch_idx].add_scst(input_ids[effective_beam_id].clone(), logprob_add, beam_token_score.item())
233 |
234 | else:
235 | # 保存生成后的状态
236 | next_sent_beam.append((beam_token_score, token_id, effective_beam_id, logprob))
237 |
238 | if len(next_sent_beam) == num_beams: # 当前batch不管有没有、生成了几个eos,依然会保留num_beams个可扩展的序列
239 | break
240 |
241 | # 什么情况算生成完成?已经生成了num_beams个完整句子,且当前时刻生成的结果(可能是完整句子,也可能不是)没有新的更好的
242 | done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx]\
243 | .is_done(next_scores[batch_idx].max().item(), cur_len)
244 |
245 | next_batch_beam.extend(next_sent_beam)
246 |
247 | if all(done):
248 | break
249 |
250 | # 准备下一时刻
251 | beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
252 | beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
253 | beam_idx = input_ids.new([x[2] for x in next_batch_beam])
254 | beam_logprob = ids_logprob.new([x[3] for x in next_batch_beam])
255 |
256 | input_ids = input_ids[beam_idx, :]
257 | ids_logprob = ids_logprob[beam_idx, :]
258 | input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
259 | ids_logprob = torch.cat([ids_logprob, beam_logprob.unsqueeze(1)], dim=-1)
260 |
261 | if mode == 'LSTM': # LSTM需要更新隐藏层状态
262 | hidden = [item[beam_idx, :] for item in hidden]
263 | #h, c = hidden
264 | #h = h[beam_idx, :]
265 | #c = c[beam_idx, :]
266 | #hidden = (h, c)
267 | context[-1] = hidden
268 |
269 | cur_len += 1
270 |
271 | # 手动结束没有生成eos的样本
272 | for batch_idx in range(batch_size):
273 | if done[batch_idx]:
274 | continue
275 | for beam_id in range(num_beams):
276 | # 对于需要手动结束的样本,全部尝试加入
277 | effective_beam_id = batch_idx*num_beams+beam_id
278 | final_score = beam_scores[effective_beam_id].item()
279 | final_tokens = input_ids[effective_beam_id]
280 | final_logprob = ids_logprob[effective_beam_id]
281 | generated_hyps[batch_idx].add_scst(final_tokens, final_logprob, final_score)
282 |
283 | # 至此,generated_hyps中保存着每个样本的num_beams个最优序列
284 | best = []
285 | all_tokens = []
286 | all_logprob = []
287 | for i, hypotheses in enumerate(generated_hyps):
288 | sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
289 | best_hyp = sorted_hyps[-1][1]
290 | best.append(best_hyp)
291 | all_tokens.extend([item[1] for item in sorted_hyps])
292 | all_logprob.extend([item[2] for item in sorted_hyps])
293 |
294 | return best, all_tokens, all_logprob
295 |
296 |
--------------------------------------------------------------------------------
/utils/cc12m.py:
--------------------------------------------------------------------------------
1 | # Automatically filter some data by keywords from cc12m
2 |
3 | import csv
4 | from tqdm import tqdm
5 | import json
6 | import os
7 | import xlsxwriter
8 | from PIL import Image
9 | import random
10 | import re
11 | from torchvision import transforms
12 | import requests
13 |
14 | landmarks_replay = ['white house', 'grand canyon', 'statue of liberty', 'buckingham palace', 'forbidden city', 'colosseum', 'kremlin', 'alhambra', 'brooklyn bridge', 'red square', 'london eye', 'burj khalifa', 'parthenon', 'great wall of china', 'windsor castle', 'machu picchu', 'mount everest', 'westminster abbey', 'mount fuji', 'cn tower', 'sydney harbour bridge', 'stonehenge', 'palace of versailles', 'trevi fountain', 'pyramids of giza', 'edinburgh castle', 'palace of westminster', 'uluru', 'neuschwanstein castle', 'brandenburg gate', 'berlin wall', 'chichen itza', 'wailing wall', 'hoover dam', 'tokyo tower', 'vatican museums', 'mount kilimanjaro', 'mount rushmore', 'acropolis of athens', 'meiji shrine', 'mont saint michel', 'willis tower', 'captiol hill', 'victoria harbour', 'sensoji temple']
15 | brands_replay = ['iphone', 'apple', 'shell', 'nike', 'samsung', 'chevrolet', 'porsche', 'dodge', 'chanel', 'facebook', 'microsoft', 'mercedes-benz', 'disneyland', 'burberry', 'cadillac', 'rolex', 'yamaha', 'fifa world cup', 'louis vuitton', 'coca cola', 'huawei', 'nokia', 'kawasaki', 'dell', 'rolls-royce', 'burger king', 'intel', 'philips', 'logitech', 'kfc', 'panasonic', 'bose', 'american express', "domino's", 'oppo', 'china southern airlines']
16 | foods_replay = ['sushi', 'ramen', 'white wine', 'pho', 'kebab', 'kimchi', 'smoked salmon', 'pad thai', 'fish and chips', 'croissants', 'tempura', 'hot pot', 'tiramisu', 'fajitas', 'churros', 'escargot', 'kung pao chicken', 'peking duck']
17 | charas_replay = ['batman', 'barbie', 'santa claus', 'iron man', 'cinderella', 'super mario', 'mickey mouse', 'the grinch', 'charlie brown', 'woody', 'rapunzel', 'the tramp', 'shrek', 'olaf', 'monkey king', 'mulan', 'merida', 'minnie mouse', 'bugs bunny', 'gandalf', 'big bird', 'buzz lightyear', 'winnie-the-pooh']
18 | keywords = landmarks_replay+brands_replay+foods_replay+charas_replay
19 | print(keywords)
20 | print(len(keywords))
21 | input()
22 |
23 | """
24 | # cc12m
25 | cc12m_data = []
26 | cc12m_path = '/Users/cckevin/Downloads/cc12m.tsv'
27 | with open(cc12m_path, 'r') as f:
28 | text = f.read()
29 | lines = text.split('\n')
30 | for line in lines:
31 | cc12m_data.append(line.split('\t'))
32 | print("Num: "+str(len(cc12m_data)))
33 |
34 | # random.shuffle(cc12m_data)
35 | # cc12m_data_tiny = cc12m_data[:50000]
36 | """
37 |
38 | """
39 | # filter in cc12m
40 | keywords = [item.lower() for item in keywords]
41 | keywords_num = {keyword: 0 for keyword in keywords}
42 |
43 | cc12m_select = []
44 | for item in tqdm(cc12m_data):
45 | try:
46 | img_dir = item[0]
47 | caption = item[1]
48 | caption = caption.lower()
49 | for keyword in keywords:
50 | if re.search(keyword, caption) != None:
51 | if keywords_num[keyword] < 1000:
52 | keywords_num[keyword] += 1
53 | cc12m_select.append([img_dir, caption, keyword])
54 | break
55 | except:
56 | continue
57 |
58 | print("Num of select: "+str(len(cc12m_select)))
59 | print(keywords_num)
60 | cc12m_data_path = '/Users/cckevin/Downloads/cc12m_select.json'
61 | with open(cc12m_data_path, 'w') as f:
62 | json.dump(cc12m_select, f)
63 | """
64 |
65 |
66 | # download images
67 | cc12m_select = json.load(open('/home/data_ti4_c/chengkz/scripts/cc12m_select.json', 'r'))
68 | print(len(cc12m_select))
69 | download_img_dir = '/home/chengkz/checkpoints/ofa/cc12m_select'
70 | cc12m_select = cc12m_select[:]
71 |
72 | for i, item in tqdm(enumerate(cc12m_select)):
73 | url = item[0]
74 | filename = str(i)+'.jpg'
75 | download_img_path = os.path.join(download_img_dir, filename)
76 | if os.path.exists(download_img_path) == False:
77 | try:
78 | download_file = requests.get(url, timeout=5)
79 | open(download_img_path, 'wb').write(download_file.content)
80 | except:
81 | continue
82 |
83 |
84 |
85 | # Filter out the images that can be used as replay data
86 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
87 | resolution = 480
88 | patch_resize_transform = transforms.Compose([
89 | lambda image: image.convert("RGB"),
90 | transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
91 | transforms.ToTensor(),
92 | transforms.Normalize(mean=mean, std=std)
93 | ])
94 |
95 | data_cc12m = []
96 | rwconcept_num = {keyword.lower(): 0 for keyword in keywords}
97 | for i, item in tqdm(enumerate(cc12m_select)):
98 | filename = str(i)+'.jpg'
99 | img_path = os.path.join(download_img_dir, filename)
100 | if os.path.exists(img_path) == False:
101 | continue
102 | try:
103 | img = Image.open(img_path)
104 | patch_img = patch_resize_transform(img)
105 | except:
106 | continue
107 | else:
108 | caption = item[1]
109 | keyword = item[2]
110 | rwconcept_num[keyword] += 1
111 | caption = caption.lower()
112 | data_cc12m.append({"filename": img_path, "caption": caption, "keyword": keyword, 'data': 'cc12m'})
113 |
114 | print(rwconcept_num)
115 | print("Num of select success: "+str(len(data_cc12m)))
116 | json.dump(data_cc12m, open('/home/chengkz/checkpoints/ofa/data_cc12m_SelectForReplay.json', 'w'), ensure_ascii=False)
117 |
--------------------------------------------------------------------------------
/utils/convert_ofa.py:
--------------------------------------------------------------------------------
1 | # convert the official fairseq version ckpts to the transformers version ckpts
2 | # notice that our K-Replay train the OFA begin with a ckpts with fine-tuned encoder+pre-trained decoder
3 | # eg:
4 | # 1.download official transformers version ckpts in https://huggingface.co/OFA-Sys/ofa-large
5 | # 2.download official fairseq version ckpts in https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_large.pt
6 | # 3.using the following code to obtain the correct transformers version ckpts
7 | import torch
8 | """
9 | model_t = torch.load('/home/chengkz/checkpoints/ofa/OFA-large/pytorch_model.bin')
10 | model_f = torch.load('/home/chengkz/checkpoints/ofa/OFA-large-fairseq/ofa_large.pt')['model']
11 |
12 | key_t = set([k for k in model_t.keys()])
13 | key_f = set([k for k in model_f.keys()])
14 | print(len(key_t), len(key_f))
15 | common_key = key_t.intersection(key_f)
16 | print(len(common_key))
17 |
18 | for k in model_t.keys():
19 | # if 'encoder' in k:
20 | if k in common_key:
21 | model_t[k] = model_f[k]
22 | del model_f[k]
23 | key_t.remove(k)
24 | key_f.remove(k)
25 | print(len(key_t), len(key_f))
26 |
27 | for k in model_f.keys():
28 | #if 'encoder' in k:
29 | k_pred = k.replace('ffn_layernorm', 'ffn_layer_norm')
30 | k_pred = k_pred.replace('self_attn_ln', 'self_attn_mid_layer_norm')
31 | k_pred = k_pred.replace('cross_attn_ln', 'cross_attn_mid_layer_norm')
32 | k_pred = k_pred.replace('encoder_attn', 'cross_attn')
33 | k_pred = k_pred.replace('attn_ln', 'self_attn_mid_layer_norm')
34 | if k_pred in key_t:
35 | model_t[k_pred] = model_f[k]
36 | key_t.remove(k_pred)
37 | key_f.remove(k)
38 | print(len(key_t), len(key_f))
39 | print(key_f)
40 |
41 | torch.save(model_t, '/home/chengkz/checkpoints/ofa/OFA-large-caption-trainedenc/pytorch_model.bin')
42 | """
43 |
44 | """
45 | code for BLIP
46 | model_pretrain = torch.load('/home/chengkz/.cache/torch/hub/checkpoints/model_large.pth')
47 | model_ft = torch.load('/home/chengkz/.cache/torch/hub/checkpoints/model_large_caption.pth')['model']
48 | key_ft = set([k for k in model_ft.keys()])
49 | key_ft_vision = {item for item in key_ft if 'visual_encoder' in item}
50 | for k in key_ft_vision:
51 | model_pretrain['model'][k] = model_ft[k]
52 |
53 | torch.save(model_pretrain, '/home/chengkz/.cache/torch/hub/checkpoints/model_large_trainedenc.pth')
54 | """
--------------------------------------------------------------------------------
/utils/eval.py:
--------------------------------------------------------------------------------
1 | # 测试模型
2 | # 为验证、测试集生成句子并保存为可用pycoco直接计算指标的格式
3 | # 用保存的句子计算指标
4 |
5 | import os
6 | import torch
7 | import pickle
8 | import json
9 | import numpy as np
10 |
11 | from data_load import data_load
12 | from tqdm import tqdm
13 | from pycocoevalcap.eval import COCOEvalCap
14 | from evaluation import Cider
15 |
16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 |
18 |
19 | def generate_captions(config, model, step, mode, final_test=False):
20 | print("Generating captions...")
21 |
22 | log_path = config.log_dir.format(config.id)
23 | result_dir = os.path.join(log_path, 'generated')
24 | if not os.path.exists(result_dir):
25 | os.makedirs(result_dir)
26 | gen_pycoco_path = os.path.join(result_dir, mode+'_'+str(step)+'.json')
27 |
28 | data_dir = os.path.join(config.data_dir, mode+'.json')
29 |
30 | eval_loader = data_load(config, data_dir, mode)
31 | model.eval()
32 | gen_pycoco = {}
33 |
34 | for i, (image_id, image_feature) in tqdm(enumerate(eval_loader)):
35 | patch_image = image_feature['patch_image']
36 | patch_image = patch_image.to(device)
37 | batch_size = len(image_id)
38 | if not final_test:
39 | captions, _ = model.greedy_search(patch_image)
40 | else:
41 | captions = model.generate_caption_batchbs(patch_image)
42 | for j, cap_id in enumerate(captions):
43 | if config.model == 'OFA':
44 | gen = cap_id.unsqueeze(0)
45 | caption = model.tokenizer.batch_decode(gen, skip_special_tokens=True)[0].strip()
46 | elif config.model == 'BLIP':
47 | caption = model.tokenizer.decode(cap_id, skip_special_tokens=True)
48 | caption = caption[len(model.prompt):]
49 | elif config.model == 'GIT':
50 | gen = cap_id.unsqueeze(0)
51 | caption = model.tokenizer.batch_decode(gen, skip_special_tokens=True)[0].strip()
52 | refs = []
53 | ref = {'image_id': image_id[j], 'id': i * batch_size + j, 'caption': caption}
54 | refs.append(ref)
55 | gen_pycoco[i * batch_size + j] = refs
56 | if not final_test:
57 | if len(gen_pycoco) >= 200:
58 | break
59 |
60 | json.dump(gen_pycoco, open(gen_pycoco_path, 'w'), ensure_ascii=False)
61 |
62 | return gen_pycoco_path
63 |
64 |
65 | def eval_pycoco(config, gen_pycoco_path, mode):
66 | print("Calculating pycoco...")
67 | ref_pycoco_path = os.path.join(config.data_dir, mode+'_pycoco.json')
68 | ref_pycoco = json.load(open(ref_pycoco_path, 'r'))
69 | gen_pycoco = json.load(open(gen_pycoco_path, 'r'))
70 | num = len(gen_pycoco)
71 | ref_pycoco = {int(k): v for k, v in ref_pycoco.items() if int(k) < num} # json读取时key类型为str,在计算SPICE时会出现问题
72 | gen_pycoco = {int(k): v for k, v in gen_pycoco.items() if int(k) < num}
73 | """
74 | ref_cider = {int(k): [item["caption"] for item in v] for k, v in ref_pycoco.items()}
75 | gen_cider = {int(k): [v[0]["caption"]] for k, v in gen_pycoco.items()}
76 | reward = cider_train.compute_score(ref_cider, gen_cider)[1].astype(np.float32)
77 | reward = torch.from_numpy(reward).to(device).view(-1)
78 | print("CIDEr: "+str(reward.mean()))
79 | """
80 | cocoEval = COCOEvalCap('diy', 'diy')
81 | pycoco_results = cocoEval.evaluate_diy(ref_pycoco, gen_pycoco)
82 |
83 | return pycoco_results
84 |
85 |
--------------------------------------------------------------------------------
/utils/import_models.py:
--------------------------------------------------------------------------------
1 | # 根据命令行构建模型
2 |
3 | import os
4 | import yaml
5 | from pathlib import Path
6 | from models.Transformer.transformer import Transformer_Cap
7 | from models.OFA.ofa import OFA
8 | from models.BLIP.blip import blip_decoder
9 | from models.GIT.git import GIT
10 |
11 |
12 | def construct_model(config):
13 | if config.model == 'Transformer':
14 | model = Transformer_Cap(config)
15 | elif config.model == 'OFA':
16 | model = OFA(config)
17 | elif config.model == 'BLIP':
18 | args = yaml.load(open(config.config_blip, 'r'), Loader=yaml.Loader)
19 | model = blip_decoder(pretrained='/home/chengkz/.cache/torch/hub/checkpoints/model_large_trainedenc.pth', config=config, image_size=args['image_size'],
20 | vit=args['vit'],
21 | vit_grad_ckpt=args['vit_grad_ckpt'], vit_ckpt_layer=args['vit_ckpt_layer'],
22 | prompt=args['prompt'])
23 | elif config.model == 'GIT':
24 | model = GIT(config)
25 | else:
26 | print("model "+str(config.model)+" not found")
27 | return None
28 | return model
--------------------------------------------------------------------------------
/utils/log.py:
--------------------------------------------------------------------------------
1 | # 训练日志
2 | # 写tensorboard
3 | # 保存模型
4 |
5 | import time
6 | import os
7 | import json
8 | import sys
9 | import shutil
10 | import torch
11 | from torch.utils.tensorboard import SummaryWriter
12 |
13 |
14 | def train_print(loss, step, total_step, epoch, step_time, epoch_time):
15 | epoch_time = time.localtime(epoch_time)
16 | min = epoch_time.tm_min
17 | sec = epoch_time.tm_sec
18 | print(f"\rloss:{format(loss, '.8f')} |"
19 | f"step: {step}/{total_step} |"
20 | f"epoch: {epoch} |"
21 | f"step time:{format(step_time, '.2f')}secs |",
22 | f"epoch time: {min}min {sec}sec", end='')
23 |
24 |
25 | class Log_Writer():
26 |
27 | def __init__(self, config):
28 | super(Log_Writer, self).__init__()
29 |
30 | print("Creating Log dir...")
31 | self.log_path = config.log_dir.format(config.id)
32 | if not os.path.exists(self.log_path): # 创建log路径
33 | os.makedirs(self.log_path)
34 |
35 | para_path = os.path.join(self.log_path, 'para.json') # 保存命令行参数
36 | with open(para_path, 'w') as f:
37 | json.dump(sys.argv, f)
38 | shutil.copy('./config.py', self.log_path) # 保存config参数
39 |
40 | self.writer = SummaryWriter(self.log_path) # tensorboard writer
41 |
42 | def write_tensorboard(self, scalar_name, scalar, step):
43 | self.writer.add_scalar(scalar_name, scalar, step)
44 |
45 | def write_metrics(self, pycoco_results, step):
46 | # metrics_list = ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4", "METEOR", "ROUGE_L", "CIDEr"]
47 | for metric in pycoco_results:
48 | self.write_tensorboard(metric, pycoco_results[metric], step)
49 |
50 | def save_model(self, model, global_step):
51 | model_path = os.path.join(self.log_path, 'model')
52 | if not os.path.exists(model_path):
53 | os.makedirs(model_path)
54 | save_path = os.path.join(model_path, f'model_{global_step}.pt')
55 | torch.save(model.state_dict(), save_path)
56 |
57 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils.rnn import pack_padded_sequence
4 | from evaluation import Cider
5 | import numpy as np
6 | import pickle
7 | import json
8 | from transformers.models.ofa.tokenization_ofa import OFATokenizer
9 | import torch.nn.functional as F
10 |
11 | from config import config
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 |
15 | # 用于XEdistill
16 | class Loss_KD(nn.Module):
17 |
18 | def __init__(self, KD_T=8):
19 | super(Loss_KD, self).__init__()
20 | self.softmax = nn.Softmax(dim=-1)
21 | self.temperature = KD_T
22 |
23 | def forward(self, logit, logit_teacher, cap_len):
24 | prob = self.softmax(logit / self.temperature)
25 | prob_teacher = self.softmax(logit_teacher / self.temperature)
26 |
27 | pred = pack_padded_sequence(prob, cap_len.cpu(), batch_first=True, enforce_sorted=False)[0]
28 | target = pack_padded_sequence(prob_teacher, cap_len.cpu(), batch_first=True, enforce_sorted=False)[0]
29 |
30 | loss_kl = F.kl_div(pred.log(), target, reduction='sum') / logit.shape[0]
31 | return loss_kl
32 |
33 |
34 | # Label Smoothing
35 | class LabelSmoothingCrossEntropy(nn.Module):
36 | def __init__(self, epsilon: float = 0.1, reduction='mean'):
37 | super().__init__()
38 | self.epsilon = epsilon
39 | self.reduction = reduction
40 |
41 | def linear_combination(self, x, y, epsilon):
42 | return epsilon * x + (1 - epsilon) * y
43 |
44 | def reduce_loss(self, loss, reduction='mean'):
45 | return loss.mean() if reduction == 'mean' else loss.sum() if reduction == 'sum' else loss
46 |
47 | def forward(self, preds, target):
48 | n = preds.size()[-1]
49 | log_preds = F.log_softmax(preds, dim=-1)
50 | loss = self.reduce_loss(-log_preds.sum(dim=-1), self.reduction)
51 | nll = F.nll_loss(log_preds, target, reduction=self.reduction)
52 | return self.linear_combination(loss / n, nll, self.epsilon)
53 |
54 |
55 | class Cross_Entropy(nn.Module):
56 | # 序列形式的交叉熵
57 | def __init__(self, label_smoothing=0.0):
58 | super(Cross_Entropy, self).__init__()
59 | self.label_smoothing = label_smoothing
60 | self.ce = nn.CrossEntropyLoss().to(device)
61 | self.ce_ls = LabelSmoothingCrossEntropy(epsilon=label_smoothing).to(device)
62 |
63 | def forward(self, logit, cap, cap_len):
64 | target = cap[:, 1:]
65 | cap_len = cap_len - 1
66 |
67 | target = pack_padded_sequence(target, cap_len.cpu(), batch_first=True, enforce_sorted=False)[0]
68 | logit = pack_padded_sequence(logit, cap_len.cpu(), batch_first=True, enforce_sorted=False)[0]
69 |
70 | # cross_entropy
71 | if self.label_smoothing > 0:
72 | loss_ce = self.ce_ls(logit, target)
73 | else:
74 | loss_ce = self.ce(logit, target)
75 |
76 | return loss_ce
77 |
78 |
79 | # 只计算知识关键词的交叉熵,用于寻找和知识相关的参数
80 | class Cross_Entropy_Keyword(nn.Module):
81 | # 序列形式的交叉熵
82 | def __init__(self):
83 | super(Cross_Entropy_Keyword, self).__init__()
84 | self.ce = nn.CrossEntropyLoss().to(device)
85 |
86 | def forward(self, logit, cap, cap_len, if_keyword):
87 | target = cap[:, 1:]
88 | if_keyword = if_keyword[:, 1:] > 0
89 | logit = logit[:, :-1]
90 | cap_len = cap_len - 1
91 |
92 | target = target[if_keyword]
93 | logit = logit[if_keyword]
94 |
95 | # cross_entropy
96 | loss_ce = self.ce(logit, target)
97 |
98 | return loss_ce
99 |
100 |
101 | # K-Replay的核心损失函数,预测知识关键词
102 | class Sent_Level_Concept_Coverage(nn.Module):
103 | def __init__(self):
104 | super(Sent_Level_Concept_Coverage, self).__init__()
105 | self.softmax = nn.Softmax(dim=2)
106 | self.sigmoid = nn.Sigmoid()
107 |
108 | def forward(self, logit_rwc, cap_rwc_label, cap_len_rwc, model_type):
109 | softmax_rwc = self.softmax(logit_rwc)
110 | loss_cov = torch.zeros(cap_len_rwc.shape[0]).to(device)
111 | loss_rep = torch.zeros(cap_len_rwc.shape[0]).to(device)
112 | for i in range(cap_len_rwc.shape[0]):
113 | softmax_sen = softmax_rwc[i][:cap_len_rwc[i].item()]
114 | softmax_agg = softmax_sen.sum(dim=0)
115 | sigmoid_agg = self.sigmoid(softmax_agg)
116 | if model_type == 'OFA':
117 | label = cap_rwc_label[i][cap_rwc_label[i]>2]
118 | elif model_type == 'BLIP':
119 | label = cap_rwc_label[i][(cap_rwc_label[i]!=0) & (cap_rwc_label[i]!=102) & (cap_rwc_label[i]!=30522)
120 | & (cap_rwc_label[i]!=1037) & (cap_rwc_label[i]!=3861) & (cap_rwc_label[i]!=1997)]
121 | elif model_type == 'GIT':
122 | label = cap_rwc_label[i][(cap_rwc_label[i]!=0) & (cap_rwc_label[i]!=101) & (cap_rwc_label[i]!=102)]
123 | prob = sigmoid_agg[label]
124 | log_prob = -torch.log(prob).mean()
125 | loss_cov[i] = log_prob
126 | prob_softmax = softmax_agg[label]
127 | prob_pow = torch.pow(1-prob_softmax, 2).mean()
128 | loss_rep[i] = prob_pow
129 | loss_cov = loss_cov.mean()
130 | loss_rep = loss_rep.mean()
131 | loss_rwc = loss_cov+loss_rep
132 | return loss_rwc
133 |
134 |
135 | class Loss_Params_Regular(nn.Module):
136 | def __init__(self, params_init, params_fisher):
137 | super(Loss_Params_Regular, self).__init__()
138 | self.params_init = params_init
139 | self.params_fisher = params_fisher
140 | self.gamma = 50000
141 |
142 | def forward(self, model):
143 | loss = 0
144 | for name, params in model.named_parameters():
145 | if params.requires_grad == True:
146 | loss_p = 0.5 * self.gamma * self.params_fisher[name] * torch.pow(params-self.params_init[name], 2)
147 | loss += loss_p.sum()
148 | return loss
149 |
150 |
151 | class Loss_SCST(nn.Module):
152 |
153 | def __init__(self, config):
154 | super(Loss_SCST, self).__init__()
155 | self.config = config
156 | self.batch_size = config.batch_size
157 | self.beam_num = config.beam_num
158 | self.vocab = pickle.load(open(config.vocab, 'rb'))
159 | self.train = json.load(open(config.train, 'r'))
160 | self.cider_texts = {i: [' '.join(item['caption'])] for i, item in enumerate(self.train)}
161 | self.cider_train = Cider(self.cider_texts)
162 |
163 | def vanilla_scst(self, all_tokens, all_tokens_greedy, all_logprob, refs):
164 | # vanilla scst: 多项式采样beam_num个,greedy作为baseline
165 | # 首先将greedy和ref复制beam_num倍
166 | gen_num = len(all_tokens)
167 | all_tokens_greedy_beam = []
168 | for item in all_tokens_greedy:
169 | all_tokens_greedy_beam.extend([item for i in range(self.beam_num)])
170 | refs_beam = []
171 | for item in refs:
172 | refs_beam.extend([item for i in range(self.beam_num)])
173 |
174 | # 整理采样、greedy和ref计算指标
175 | caps_gen = {i: [self.vocab.idList_to_sent(item)] for i, item in enumerate(all_tokens)}
176 | caps_gen_greedy = {i: [self.vocab.idList_to_sent(item)] for i, item in enumerate(all_tokens_greedy_beam)}
177 | caps_gt = {i: item for i, item in enumerate(refs_beam)}
178 | reward = self.cider_train.compute_score(caps_gt, caps_gen)[1].astype(np.float32)
179 | reward = torch.from_numpy(reward).to(device).view(gen_num)
180 | reward_baseline = self.cider_train.compute_score(caps_gt, caps_gen_greedy)[1].astype(np.float32)
181 | reward_baseline = torch.from_numpy(reward_baseline).to(device).view(gen_num)
182 |
183 | # 对采样结果的log_prob补齐
184 | all_logprob_pad = []
185 | for logprob in all_logprob:
186 | logprob = torch.cat([logprob, logprob.new([0 for i in range(self.config.fixed_len - logprob.shape[0])])], dim=0)
187 | all_logprob_pad.append(logprob.unsqueeze(0))
188 | all_logprob_pad = torch.cat(all_logprob_pad, dim=0)
189 |
190 | # 计算损失
191 | loss = -torch.mean(all_logprob_pad, -1) * (reward - reward_baseline)
192 | loss = loss.mean()
193 |
194 | # 计算训练reward
195 | reward_train = reward.mean()
196 |
197 | return loss, reward_train
198 |
199 |
200 | class Loss_SCST_OFA(nn.Module):
201 |
202 | def __init__(self, config):
203 | super(Loss_SCST_OFA, self).__init__()
204 | self.config = config
205 | self.batch_size = config.batch_size
206 | self.beam_num = config.beam_num
207 | self.tokenizer = OFATokenizer.from_pretrained(self.config.ofa_ckpts)
208 | self.train = json.load(open(config.train, 'r'))
209 | self.cider_texts = {i: [' '.join(item['caption'])] for i, item in enumerate(self.train)}
210 | self.cider_train = Cider(self.cider_texts)
211 |
212 | def vanilla_scst(self, all_tokens, all_tokens_greedy, all_logprob, refs):
213 | # vanilla scst: 多项式采样beam_num个,greedy作为baseline
214 | # 首先将greedy和ref复制beam_num倍
215 | gen_num = len(all_tokens)
216 | all_tokens_greedy_beam = []
217 | for item in all_tokens_greedy:
218 | all_tokens_greedy_beam.extend([item for i in range(self.beam_num)])
219 | refs_beam = []
220 | for item in refs:
221 | refs_beam.extend([item for i in range(self.beam_num)])
222 |
223 | # 整理采样、greedy和ref计算指标
224 | caps_gen = {i: [self.tokenizer.batch_decode(item.unsqueeze(0), skip_special_tokens=True)[0].strip()] for i, item in enumerate(all_tokens)}
225 | caps_gen_greedy = {i: [self.tokenizer.batch_decode(item.unsqueeze(0), skip_special_tokens=True)[0].strip()] for i, item in enumerate(all_tokens_greedy_beam)}
226 | caps_gt = {i: item for i, item in enumerate(refs_beam)}
227 | reward = self.cider_train.compute_score(caps_gt, caps_gen)[1].astype(np.float32)
228 | reward = torch.from_numpy(reward).to(device).view(gen_num)
229 | reward_baseline = self.cider_train.compute_score(caps_gt, caps_gen_greedy)[1].astype(np.float32)
230 | reward_baseline = torch.from_numpy(reward_baseline).to(device).view(gen_num)
231 |
232 | # 对采样结果的log_prob补齐
233 | all_logprob_pad = []
234 | for logprob in all_logprob:
235 | logprob = torch.cat([logprob, logprob.new([0 for i in range(self.config.fixed_len - logprob.shape[0])])], dim=0)
236 | all_logprob_pad.append(logprob.unsqueeze(0))
237 | all_logprob_pad = torch.cat(all_logprob_pad, dim=0)
238 |
239 | # 计算损失
240 | loss = -torch.mean(all_logprob_pad, -1) * (reward - reward_baseline)
241 | loss = loss.mean()
242 |
243 | # 计算训练reward
244 | reward_train = reward.mean()
245 |
246 | return loss, reward_train
247 |
248 |
249 |
--------------------------------------------------------------------------------
/utils/prepro_data.py:
--------------------------------------------------------------------------------
1 | # construct the data in ./data
2 | # Steps 1-5 are used to construct the training, validation and test sets used for K-Replay,
3 | # and steps 6-7 are used to adjust the replay dataset
4 |
5 | import os
6 | import json
7 | import random
8 | from tqdm import tqdm
9 | import nltk
10 |
11 | """
12 | # 1. Split COCO as train, val and test (follow KarpathSplit by dataset_coco.json)
13 | dataset_coco_karpath = json.load(open('../data/dataset_coco.json', 'r'))["images"]
14 | images_dir_train2014 = '/home/data_ti4_c/chengkz/data/coco_dataset/train2014'
15 | images_dir_val2014 = '/home/data_ti4_c/chengkz/data/coco_dataset/val2014'
16 | data_train = []
17 | data_val = []
18 | data_test = []
19 |
20 | for item in tqdm(dataset_coco_karpath):
21 | if item['split'] == 'train' or item['split'] == 'restval':
22 | image_id = item['filename'][:-4]
23 | filename = os.path.join(images_dir_train2014 if item['filepath'] == 'train2014' else images_dir_val2014, item['filename'])
24 | refs = []
25 | for sentence in item['sentences']:
26 | refs.append(' '.join(sentence["tokens"]))
27 | for sentence in item['sentences']:
28 | item_train = {'split': 'train', 'image_id': image_id, 'filename': filename, 'caption': sentence["tokens"], 'refs': refs}
29 | data_train.append(item_train)
30 | else:
31 | image_id = item['filename'][:-4]
32 | filename = os.path.join(images_dir_train2014 if item['filepath'] == 'train2014' else images_dir_val2014, item['filename'])
33 | captions = [sentence["tokens"] for sentence in item["sentences"]]
34 | item_eval = {'split': 'val', 'image_id': image_id, 'filename': filename, 'caption': captions}
35 | if item['split'] == 'val':
36 | data_val.append(item_eval)
37 | elif item['split'] == 'test':
38 | data_test.append(item_eval)
39 |
40 | random.shuffle(data_train)
41 |
42 | print("Num of train: " + str(len(data_train)))
43 | print("Num of val: " + str(len(data_val)))
44 | print("Num of test: " + str(len(data_test)))
45 | json.dump(data_train, open('../data/train.json', 'w'), ensure_ascii=False)
46 | json.dump(data_val, open('../data/val.json', 'w'), ensure_ascii=False)
47 | json.dump(data_test, open('../data/test.json', 'w'), ensure_ascii=False)
48 | """
49 |
50 | """
51 | # 2. Split KnowCap as 1000 test (all & Unseen) and 424 val
52 | knowcap_240 = json.load(open('../data/knowcap_240.json', 'r'))
53 | print("Num of KnowCap_240: "+str(len(knowcap_240)))
54 | knowcap_240_test = knowcap_240[:1000]
55 | knowcap_240_val = knowcap_240[1000:]
56 | print("Num of KnowCap_240 val: "+str(len(knowcap_240_val)))
57 | print("Num of KnowCap_240 test: "+str(len(knowcap_240_test)))
58 | # statistics the categories contained in val and test
59 | print("Categories of KnowCap_240 val: "+str(len(set([item["image"].split('/')[0] for item in knowcap_240_val]))))
60 | print("Categories of KnowCap_240 test: "+str(len(set([item["image"].split('/')[0] for item in knowcap_240_test]))))
61 | json.dump(knowcap_240_val, open('../data/knowcap_240_val.json', 'w'))
62 | json.dump(knowcap_240_test, open('../data/knowcap_240_test.json', 'w'))
63 |
64 | categories_replay = ['white house', 'grand canyon', 'statue of liberty', 'buckingham palace', 'forbidden city', 'colosseum', 'kremlin', 'alhambra', 'brooklyn bridge', 'red square', 'london eye', 'burj khalifa', 'parthenon', 'great wall of china', 'windsor castle', 'machu picchu', 'mount everest', 'westminster abbey', 'mount fuji', 'cn tower', 'sydney harbour bridge', 'stonehenge', 'palace of versailles', 'trevi fountain', 'pyramids of giza', 'edinburgh castle', 'palace of westminster', 'uluru', 'neuschwanstein castle', 'brandenburg gate', 'berlin wall', 'chichen itza', 'wailing wall', 'hoover dam', 'tokyo tower', 'vatican museums', 'mount kilimanjaro', 'mount rushmore', 'acropolis of athens', 'meiji shrine', 'mont saint michel', 'willis tower', 'captiol hill', 'victoria harbour', 'sensoji temple', 'iphone', 'apple', 'shell', 'nike', 'samsung', 'chevrolet', 'porsche', 'dodge', 'chanel', 'facebook', 'microsoft', 'mercedes-benz', 'disneyland', 'burberry', 'cadillac', 'rolex', 'yamaha', 'fifa world cup', 'louis vuitton', 'coca cola', 'huawei', 'nokia', 'kawasaki', 'dell', 'rolls-royce', 'burger king', 'intel', 'philips', 'logitech', 'kfc', 'panasonic', 'bose', 'american express', "domino's", 'oppo', 'china southern airlines', 'sushi', 'ramen', 'white wine', 'pho', 'kebab', 'kimchi', 'smoked salmon', 'pad thai', 'fish and chips', 'croissants', 'tempura', 'hot pot', 'tiramisu', 'fajitas', 'churros', 'escargot', 'kung pao chicken', 'peking duck', 'batman', 'barbie', 'santa claus', 'iron man', 'cinderella', 'super mario', 'mickey mouse', 'the grinch', 'charlie brown', 'woody', 'rapunzel', 'the tramp', 'shrek', 'olaf', 'monkey king', 'mulan', 'merida', 'minnie mouse', 'bugs bunny', 'gandalf', 'big bird', 'buzz lightyear', 'winnie-the-pooh']
65 | knowcap_240_test_unseen = []
66 | for item in knowcap_240_test:
67 | keyword = item["image"].split('/')[0]
68 | if keyword not in categories_replay:
69 | knowcap_240_test_unseen.append(item)
70 | print("Num of KnowCap_240 test unseen: "+str(len(knowcap_240_test_unseen)))
71 | print("Categories of KnowCap_240 test unseen: "+str(len(set([item["image"].split('/')[0] for item in knowcap_240_test_unseen]))))
72 | json.dump(knowcap_240_test_unseen, open('../data/knowcap_240_test_unseen.json', 'w'))
73 | """
74 |
75 | """
76 | 3. Adjust to the format of calculating metrics with pycoco
77 | for split in ['val', 'test']:
78 | ref_pycoco_path = os.path.join('../data', split+'_pycoco.json')
79 | data = json.load(open(os.path.join('../data', split+'.json'), 'r'))
80 |
81 | ref_pycoco = {}
82 | for i, item in tqdm(enumerate(data)):
83 | refs = []
84 | for j, sentence in enumerate(item['caption']):
85 | ref = {}
86 | ref['image_id'] = item['image_id']
87 | ref['id'] = j
88 | ref['caption'] = ' '.join(sentence)
89 | refs.append(ref)
90 | ref_pycoco[i] = refs
91 |
92 | print("Num: "+str(len(ref_pycoco)))
93 | json.dump(ref_pycoco, open(ref_pycoco_path, 'w'), ensure_ascii=False)
94 |
95 |
96 | ref_pycoco_path = os.path.join('../data/knowcap_240_val_pycoco.json')
97 | data = json.load(open(os.path.join('../data/knowcap_240_val.json'), 'r'))
98 |
99 | ref_pycoco = {}
100 | for i, item in tqdm(enumerate(data)):
101 | refs = []
102 | for j, sentence in enumerate(item['captions']):
103 | ref = {}
104 | ref['image_id'] = item['image']
105 | ref['id'] = j
106 | ref['caption'] = sentence
107 | refs.append(ref)
108 | ref_pycoco[i] = refs
109 |
110 | print("Num: "+str(len(ref_pycoco)))
111 | json.dump(ref_pycoco, open(ref_pycoco_path, 'w'), ensure_ascii=False)
112 | """
113 |
114 | """
115 | # 4. Convert the splitting results in train.json back to full sentences for use with our own tokenizer
116 | coco_train_all = json.load(open('../data/train.json', 'r'))
117 | print(len(coco_train_all))
118 | random.shuffle(coco_train_all)
119 | coco_train_used = coco_train_all[:]
120 | print("coco: "+str(len(coco_train_used)))
121 | data_mix = []
122 | for item in coco_train_used:
123 | item_coco = {'filename': item['filename'], 'caption': ' '.join(item['caption']), 'data': 'coco'}
124 | data_mix.append(item_coco)
125 | json.dump(data_mix, open('../data/train_all.json', 'w'), ensure_ascii=False)
126 | print("Num of coco used: "+str(len(data_mix)))
127 | """
128 |
129 | """
130 | # 5. Mix coco data and replay data as the hybrid dataset used for K-Replay training
131 | # data_cc12m_SelectForReplay.json contain 20000+ replay exemplars that randomly selected from the cc12m dataset based
132 | # on keyword matching, it contains 122 keywords as record in replay_keywords
133 | replay_keywords = ['white house', 'grand canyon', 'statue of liberty', 'buckingham palace', 'forbidden city', 'colosseum', 'kremlin', 'alhambra', 'brooklyn bridge', 'red square', 'london eye', 'burj khalifa', 'parthenon', 'great wall of china', 'windsor castle', 'machu picchu', 'mount everest', 'westminster abbey', 'mount fuji', 'cn tower', 'sydney harbour bridge', 'stonehenge', 'palace of versailles', 'trevi fountain', 'pyramids of giza', 'edinburgh castle', 'palace of westminster', 'uluru', 'neuschwanstein castle', 'brandenburg gate', 'berlin wall', 'chichen itza', 'wailing wall', 'hoover dam', 'tokyo tower', 'vatican museums', 'mount kilimanjaro', 'mount rushmore', 'acropolis of athens', 'meiji shrine', 'mont saint michel', 'willis tower', 'captiol hill', 'victoria harbour', 'sensoji temple', 'iphone', 'apple', 'shell', 'nike', 'samsung', 'chevrolet', 'porsche', 'dodge', 'chanel', 'facebook', 'microsoft', 'mercedes-benz', 'disneyland', 'burberry', 'cadillac', 'rolex', 'yamaha', 'fifa world cup', 'louis vuitton', 'coca cola', 'huawei', 'nokia', 'kawasaki', 'dell', 'rolls-royce', 'burger king', 'intel', 'philips', 'logitech', 'kfc', 'panasonic', 'bose', 'american express', "domino's", 'oppo', 'china southern airlines', 'sushi', 'ramen', 'white wine', 'pho', 'kebab', 'kimchi', 'smoked salmon', 'pad thai', 'fish and chips', 'croissants', 'tempura', 'hot pot', 'tiramisu', 'fajitas', 'churros', 'escargot', 'kung pao chicken', 'peking duck', 'batman', 'barbie', 'santa claus', 'iron man', 'cinderella', 'super mario', 'mickey mouse', 'the grinch', 'charlie brown', 'woody', 'rapunzel', 'the tramp', 'shrek', 'olaf', 'monkey king', 'mulan', 'merida', 'minnie mouse', 'bugs bunny', 'gandalf', 'big bird', 'buzz lightyear', 'winnie-the-pooh']
134 | cc12m_select = json.load(open('../data/data_cc12m_SelectForReplay.json', 'r'))
135 | for item in cc12m_select:
136 | if item['keyword'] not in replay_keywords:
137 | print("replay item not in replay keywords!")
138 | train_all = json.load(open('../data/train_all.json', 'r'))
139 | random.shuffle(cc12m_select)
140 | cc12m_select = cc12m_select[:5000]
141 | random.shuffle(train_all)
142 | print(len(cc12m_select))
143 | print(len(train_all))
144 | data_mix = []
145 | data_mix += train_all[:27000] # mix the coco and replay data
146 | ablation = False
147 | for item in cc12m_select[:]:
148 | item_cc12m = {'filename': item['filename'], 'caption': item['keyword'], 'data': 'coco'}
149 | if ablation: # for ablation study, we use the origin web-harvested text as reference
150 | item_cc12m = {'filename': item['filename'], 'caption': item['caption'], 'data': 'coco'}
151 | data_mix.append(item_cc12m)
152 | random.shuffle(data_mix)
153 | json.dump(data_mix, open('../data/train_mix_32000.json', 'w'), ensure_ascii=False)
154 | print("Num of data_mix: "+str(len(data_mix)))
155 | """
156 |
157 | """
158 | # 6. Adjust the number of replay exemplars in train_mix_32000.json
159 | ratio = 0.1
160 | data = json.load(open('../data/train_mix_32000.json', 'r'))
161 | data_cc12m = [item for item in data if item['data'] == 'cc12m']
162 | data_coco = [item for item in data if item['data'] == 'coco']
163 | random.shuffle(data_cc12m)
164 | random.shuffle(data_coco)
165 | data_ratio = data_coco[:int(len(data_coco)*ratio)]+data_cc12m[:int(len(data_cc12m)*ratio)]
166 | print(len(data_ratio))
167 | random.shuffle(data_ratio)
168 | json.dump(data_ratio, open('../data/train_mix_32000_0.1.json', 'w'), ensure_ascii=False)
169 |
170 | # select only 120 exemplars in train_mix_32000.json
171 | data = json.load(open('../data/train_mix_32000.json', 'r'))
172 | data_cc12m = [item for item in data if item['data'] == 'cc12m']
173 | data_coco = [item for item in data if item['data'] == 'coco']
174 | random.shuffle(data_cc12m)
175 | random.shuffle(data_coco)
176 | print(len(data_cc12m))
177 | print(len(data_coco))
178 | data_120 = []
179 | categories = []
180 | for item in data_cc12m:
181 | if item['caption'] not in categories:
182 | categories.append(item['caption'])
183 | data_120.append(item)
184 | else:
185 | continue
186 | data_mix = []
187 | data_mix += data_coco[:12960]
188 | for i in range(20):
189 | data_mix += data_120
190 | random.shuffle(data_mix)
191 | print(len(data_mix))
192 | json.dump(data_mix, open('../data/train_mix_32000_120.json', 'w'), ensure_ascii=False)
193 | """
194 |
195 | """
196 | # 7. Adjust the categories of replay exemplars in train_mix_32000.json
197 | data = json.load(open('../data/train_mix_32000.json', 'r'))
198 | data_cc12m = [item for item in data if item['data'] == 'cc12m']
199 | data_coco = [item for item in data if item['data'] == 'coco']
200 | random.shuffle(data_cc12m)
201 | random.shuffle(data_coco)
202 | print(len(data_cc12m))
203 | print(len(data_coco))
204 | cc12m_select = json.load(open('../data/data_cc12m_select_122all.json', 'r'))
205 | random.shuffle(cc12m_select)
206 | categories = []
207 | for item in data_cc12m:
208 | categories.append(item['caption'])
209 | categories = list(set(categories))
210 | print(len(categories))
211 | random.shuffle(categories)
212 | # categories_ratio = categories[:20]
213 | # select 10 replay categories
214 | categories_ratio = ['white house', 'grand canyon', 'statue of liberty', 'iphone', 'porsche', 'facebook', 'sushi', 'smoked salmon', 'batman', 'barbie']
215 |
216 | print(len(categories_ratio))
217 | data_cc12m_new = [item for item in data_cc12m if item['caption'] in categories_ratio]
218 | print(len(data_cc12m_new))
219 | for item in cc12m_select:
220 | item_cc12m = {'filename': item['filename'], 'caption': item['keyword'], 'data': 'cc12m'}
221 | if item_cc12m['caption'] in categories_ratio:
222 | data_cc12m_new.append(item_cc12m)
223 | if len(data_cc12m_new) == 5000:
224 | break
225 | print(len(data_cc12m_new))
226 | categories_new = []
227 | for item in data_cc12m_new:
228 | categories_new.append(item['caption'])
229 | print(len(list(set(categories_new))))
230 | data_mix = []
231 | data_mix += data_coco
232 | data_mix += data_cc12m_new
233 | random.shuffle(data_mix)
234 | print(len(data_mix))
235 | json.dump(data_mix, open('../data/train_mix_32000_10cate.json', 'w'), ensure_ascii=False)
236 | """
237 |
--------------------------------------------------------------------------------
/utils/prepro_ref_pycoco.py:
--------------------------------------------------------------------------------
1 | # 将val和test转化为可用pycoco直接计算指标的格式
2 |
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | """
7 | for split in ['val', 'test']:
8 | ref_pycoco_path = os.path.join('../data', split+'_pycoco.json')
9 | data = json.load(open(os.path.join('../data', split+'.json'), 'r'))
10 |
11 | ref_pycoco = {}
12 | for i, item in tqdm(enumerate(data)):
13 | refs = []
14 | for j, sentence in enumerate(item['caption']):
15 | ref = {}
16 | ref['image_id'] = item['image_id']
17 | ref['id'] = j
18 | ref['caption'] = ' '.join(sentence)
19 | refs.append(ref)
20 | ref_pycoco[i] = refs
21 |
22 | print("Num: "+str(len(ref_pycoco)))
23 | json.dump(ref_pycoco, open(ref_pycoco_path, 'w'), ensure_ascii=False)
24 | """
25 | ref_pycoco_path = os.path.join('../data/knowcap_240_val_pycoco.json')
26 | data = json.load(open(os.path.join('../data/knowcap_240_val.json'), 'r'))
27 |
28 | ref_pycoco = {}
29 | for i, item in tqdm(enumerate(data)):
30 | refs = []
31 | for j, sentence in enumerate(item['captions']):
32 | ref = {}
33 | ref['image_id'] = item['image']
34 | ref['id'] = j
35 | ref['caption'] = sentence
36 | refs.append(ref)
37 | ref_pycoco[i] = refs
38 |
39 | print("Num: "+str(len(ref_pycoco)))
40 | json.dump(ref_pycoco, open(ref_pycoco_path, 'w'), ensure_ascii=False)
--------------------------------------------------------------------------------
/utils/prepro_rwcap.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 |
4 | annot_excel = '/Users/cckevin/Desktop/RW_Label_100.xlsx'
5 | dataset_dir = '/Users/cckevin/Desktop/ofa/data/rwcap_100_keywords.json'
6 |
7 | invalid_list = ['a', 'on', 'of', 'the', 'in', 'with', 'and', 'is', 'to', 'an', 'two', 'at', 'are', 'that', 'it', 'by']
8 |
9 | df = pd.read_excel(annot_excel)
10 | annot_list = df.to_dict(orient='record')
11 | dataset_rwcap = []
12 | for item in annot_list:
13 | """
14 | image_filename = item['filename']
15 | image_filename = image_filename.strip()
16 |
17 | data_rwcap_item = {}
18 | refs = []
19 | annot_name = ['SWP', 'CKZ', 'YHT']
20 | for name in annot_name:
21 | ref = item[name].lower().strip()
22 | if ref[-1] == '.':
23 | ref = ref[:-1]
24 | refs.append(ref)
25 | data_rwcap_item['image'] = image_filename
26 | data_rwcap_item['captions'] = refs
27 |
28 | labels_list = []
29 | """
30 | keywords = item['Keywords'].strip().lower()
31 | keywords = keywords.split('#')
32 | dataset_rwcap += keywords
33 | """
34 | for keyword in keywords:
35 | words = keyword.split(' ')
36 | for word in words:
37 | if word not in invalid_list and word not in labels_list:
38 | labels_list.append(word)
39 | data_rwcap_item['labels'] = labels_list
40 |
41 | dataset_rwcap.append(data_rwcap_item)
42 | """
43 |
44 | dataset_rwcap = list(set(dataset_rwcap))
45 | print("Num of dataset: "+str(len(dataset_rwcap)))
46 | json.dump(dataset_rwcap, open(dataset_dir, 'w'))
--------------------------------------------------------------------------------
/utils/vocab.py:
--------------------------------------------------------------------------------
1 | # 构建单词表,用于token和id之间的相互转化
2 | # 出现次数小于5次的词用特殊符号代替
3 |
4 | import numpy as np
5 | import json
6 | import pickle
7 | from tqdm import tqdm
8 |
9 | class Vocabulary():
10 | """单词表"""
11 | def __init__(self):
12 | self._word2id = {}
13 | self._id2word = {}
14 | self._idx = 0
15 | self._word = []
16 |
17 | # 特殊符号
18 | self.pad = '' # 用于将长度补齐的标识符
19 | self.bos = '' # 开始符号
20 | self.eos = '' # 结束符号
21 | self.unk = '' # unknown符号
22 | self.add_spe_sign()
23 |
24 | def add_word(self, word):
25 | '''添加单词'''
26 | if word not in self._word:
27 | self._word2id.update({word: self._idx})
28 | self._id2word.update({self._idx: word})
29 | self._word.append(word)
30 | self._idx += 1
31 |
32 | def word_to_id(self, word):
33 | '''把word转换成id的形式'''
34 | if word in self._word:
35 | return self._word2id[word]
36 | else:
37 | return self._word2id['']
38 |
39 | def id_to_word(self, id):
40 | '''把id的形式转换成word'''
41 | assert id <= self._idx, "输入的id大于最大的id"
42 | return self._id2word[id]
43 |
44 | def tokenList_to_idList(self, tokenList, fixed_len):
45 | '''把tokenList转换成id的形式,,同时添加上,和
46 | :param tokenList: 包含一个句子的token形式, 如 ["室内", "三个", "衣着", "各异", "的", "人", "坐在", "桌子", "旁", "交谈"]
47 | :param fixed_len: 句子的最大长度,包括和
48 | :return: list
49 | '''
50 | sent_len = len(tokenList)
51 | tok_id = [self.word_to_id(token) for token in tokenList]
52 | if sent_len < fixed_len:
53 | tok_id.insert(0, self._word2id[self.bos])
54 | tok_id.append(self._word2id[self.eos])
55 | pad_num = fixed_len - sent_len
56 | tok_id += [0] * pad_num
57 | else:
58 | tok_id = tok_id[:fixed_len]
59 | tok_id.insert(0, self._word2id[self.bos])
60 | tok_id.append(self._word2id[self.eos])
61 | sent_len = fixed_len
62 | sent_len += 2 # 加上开始结束符
63 | return tok_id, sent_len
64 |
65 | def idList_to_sent(self, id_List):
66 | '''把idList转换成sent的形式
67 | :param id_List: 包含一个句子的id形式,如: [1, 4, 5, 343, 4, 123, 2389 ,213, 233 ,678 ,2343 ,2, 0, 0, 0, 0, 0, 0]
68 | 支持格式,: list, tensor, numpy.array
69 | :return: 一个str句子,如: "室内三个衣着各异的人坐在桌子旁交谈"
70 | '''
71 | id_List = np.array(list(map(int, id_List)))
72 | word_array = np.array(self._word)
73 | eos_id = self._word2id[self.eos]
74 | eos_pos = np.where(id_List == eos_id)[0]
75 | if len(eos_pos >= 0):
76 | sent = word_array[id_List[1:eos_pos[0]]]
77 | else:
78 | sent = word_array[id_List[1:]]
79 | return ' '.join(sent)
80 |
81 | def add_spe_sign(self):
82 | self.add_word(self.pad)
83 | self.add_word(self.bos)
84 | self.add_word(self.eos)
85 | self.add_word(self.unk)
86 |
87 | def get_size(self):
88 | return self._idx
89 |
90 | if __name__ == '__main__':
91 | vocab = Vocabulary()
92 | data_train = json.load(open('../data/train.json', 'r'))
93 |
94 | counter = {}
95 | for item in tqdm(data_train):
96 | sentence_token = item['caption']
97 | for token in sentence_token:
98 | counter[token] = counter.get(token, 0) + 1
99 | # cand_word = [token for token, f in counter.items() if f >= 5]
100 | print(counter['tesla'])
101 | input()
102 | cand_word = sorted(counter.items(), key=lambda x: x[1], reverse=True)
103 | print("word (f>=5) num: "+str(len(cand_word)))
104 |
105 | for word in cand_word:
106 | vocab.add_word(word)
107 | print("vocab size: "+str(vocab.get_size()))
108 |
109 | # pickle.dump(vocab, open('../data/vocab.pkl', 'wb'))
--------------------------------------------------------------------------------