├── Inference_example.ipynb ├── LICENSE ├── README.md ├── assets ├── COCO_val2014_000000462565.jpeg └── teaser.png ├── captioning ├── __init__.py ├── data │ ├── __init__.py │ ├── dataloader.py │ ├── pth_loader.py │ └── pth_loader_FineCapEval.py ├── models │ ├── AoAModel.py │ ├── AttEnsemble.py │ ├── AttModel.py │ ├── BertCapModel.py │ ├── CaptionModel.py │ ├── FCModel.py │ ├── M2Transformer.py │ ├── ShowTellModel.py │ ├── TransformerModel.py │ ├── __init__.py │ ├── cachedTransformer.py │ └── utils.py ├── modules │ ├── loss_wrapper.py │ └── losses.py └── utils │ ├── __init__.py │ ├── clipscore.py │ ├── config.py │ ├── dist_utils.py │ ├── div_utils.py │ ├── eval_multi.py │ ├── eval_utils.py │ ├── misc.py │ ├── opts.py │ ├── resnet.py │ ├── resnet_utils.py │ ├── rewards.py │ └── utils.py ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── cog.yaml ├── configs ├── phase1 │ ├── FineCapEval_clipRN50_mle.yml │ ├── clipRN50_mle.yml │ └── transformer.yml └── phase2 │ ├── FineCapEval_clipRN50_cider.yml │ ├── FineCapEval_clipRN50_cider_clips.yml │ ├── FineCapEval_clipRN50_clips.yml │ ├── FineCapEval_clipRN50_clips_grammar.yml │ ├── clipRN50_cider.yml │ ├── clipRN50_cider_clips.yml │ ├── clipRN50_clips.yml │ ├── clipRN50_clips_grammar.yml │ └── transformer.yml ├── data └── README.md ├── predict.py ├── requirements.txt ├── retrieval ├── README.md ├── caption_data.py ├── clip_model.py ├── configs │ └── clip_negative_text.yaml ├── param.py ├── pth_loader.py ├── text_utils.py └── train_pl.py ├── save └── README.md ├── scripts ├── build_bpe_subword_nmt.py ├── clip_prepro_feats.py ├── clipscore_prepro_feats.py ├── copy_model.sh ├── dump_to_h5df.py ├── dump_to_lmdb.py ├── make_bu_data.py ├── prepro_feats.py ├── prepro_labels.py ├── prepro_ngrams.py └── prepro_reference_json.py ├── scripts_FineCapEval ├── clip_prepro_feats.py ├── clipscore_prepro_feats.py └── prepro_labels.py ├── setup.py └── tools ├── eval.py ├── eval_clip_retrieval.py ├── eval_finecapeval.py ├── finecapeval_inference.py └── train_pl.py /assets/COCO_val2014_000000462565.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/j-min/CLIP-Caption-Reward/ca5fe53b848d7e7b1fb9808984f7a3187b2a32d6/assets/COCO_val2014_000000462565.jpeg -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/j-min/CLIP-Caption-Reward/ca5fe53b848d7e7b1fb9808984f7a3187b2a32d6/assets/teaser.png -------------------------------------------------------------------------------- /captioning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/j-min/CLIP-Caption-Reward/ca5fe53b848d7e7b1fb9808984f7a3187b2a32d6/captioning/__init__.py -------------------------------------------------------------------------------- /captioning/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/j-min/CLIP-Caption-Reward/ca5fe53b848d7e7b1fb9808984f7a3187b2a32d6/captioning/data/__init__.py -------------------------------------------------------------------------------- /captioning/models/AttEnsemble.py: -------------------------------------------------------------------------------- 1 | # This file is the implementation for ensemble evaluation. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import * 12 | 13 | from .CaptionModel import CaptionModel 14 | from .AttModel import pack_wrapper, AttModel 15 | 16 | class AttEnsemble(AttModel): 17 | def __init__(self, models, weights=None): 18 | CaptionModel.__init__(self) 19 | # super(AttEnsemble, self).__init__() 20 | 21 | self.models = nn.ModuleList(models) 22 | self.vocab_size = models[0].vocab_size 23 | self.seq_length = models[0].seq_length 24 | self.bad_endings_ix = models[0].bad_endings_ix 25 | self.ss_prob = 0 26 | weights = weights or [1.0] * len(self.models) 27 | self.register_buffer('weights', torch.tensor(weights)) 28 | 29 | def init_hidden(self, batch_size): 30 | state = [m.init_hidden(batch_size) for m in self.models] 31 | return self.pack_state(state) 32 | 33 | def pack_state(self, state): 34 | self.state_lengths = [len(_) for _ in state] 35 | return sum([list(_) for _ in state], []) 36 | 37 | def unpack_state(self, state): 38 | out = [] 39 | for l in self.state_lengths: 40 | out.append(state[:l]) 41 | state = state[l:] 42 | return out 43 | 44 | def embed(self, it): 45 | return [m.embed(it) for m in self.models] 46 | 47 | def core(self, *args): 48 | return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))]) 49 | 50 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state, output_logsoftmax=1): 51 | # 'it' contains a word index 52 | xt = self.embed(it) 53 | 54 | state = self.unpack_state(state) 55 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks) 56 | logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log() 57 | 58 | return logprobs, self.pack_state(state) 59 | 60 | def _prepare_feature(self, *args): 61 | return tuple(zip(*[m._prepare_feature(*args) for m in self.models])) 62 | 63 | def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 64 | beam_size = opt.get('beam_size', 10) 65 | batch_size = fc_feats.size(0) 66 | 67 | fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) 68 | 69 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 70 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 71 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1) 72 | # lets process every image independently for now, for simplicity 73 | 74 | self.done_beams = [[] for _ in range(batch_size)] 75 | for k in range(batch_size): 76 | state = self.init_hidden(beam_size) 77 | tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)] 78 | tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] 79 | tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] 80 | tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)] 81 | 82 | it = fc_feats[0].data.new(beam_size).long().zero_() 83 | logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) 84 | 85 | self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) 86 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 87 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 88 | # return the samples and their log likelihoods 89 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 90 | # return the samples and their log likelihoods 91 | -------------------------------------------------------------------------------- /captioning/models/BertCapModel.py: -------------------------------------------------------------------------------- 1 | """ 2 | BertCapModel is using huggingface transformer bert model as seq2seq model. 3 | 4 | The result is not as goog as original transformer. 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | import copy 16 | import math 17 | import numpy as np 18 | 19 | from .CaptionModel import CaptionModel 20 | from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel 21 | try: 22 | from transformers import BertModel, BertConfig 23 | except: 24 | print('Hugginface transformers not installed; please visit https://github.com/huggingface/transformers') 25 | from .TransformerModel import subsequent_mask, TransformerModel, Generator 26 | 27 | class EncoderDecoder(nn.Module): 28 | """ 29 | A standard Encoder-Decoder architecture. Base for this and many 30 | other models. 31 | """ 32 | def __init__(self, encoder, decoder, generator): 33 | super(EncoderDecoder, self).__init__() 34 | self.encoder = encoder 35 | self.decoder = decoder 36 | self.generator = generator 37 | 38 | def forward(self, src, tgt, src_mask, tgt_mask): 39 | "Take in and process masked src and target sequences." 40 | return self.decode(self.encode(src, src_mask), src_mask, 41 | tgt, tgt_mask) 42 | 43 | def encode(self, src, src_mask): 44 | return self.encoder(inputs_embeds=src, 45 | attention_mask=src_mask)[0] 46 | 47 | def decode(self, memory, src_mask, tgt, tgt_mask): 48 | return self.decoder(input_ids=tgt, 49 | attention_mask=tgt_mask, 50 | encoder_hidden_states=memory, 51 | encoder_attention_mask=src_mask)[0] 52 | 53 | 54 | class BertCapModel(TransformerModel): 55 | 56 | def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, 57 | d_model=512, d_ff=2048, h=8, dropout=0.1): 58 | "Helper: Construct a model from hyperparameters." 59 | enc_config = BertConfig(vocab_size=1, 60 | hidden_size=d_model, 61 | num_hidden_layers=N_enc, 62 | num_attention_heads=h, 63 | intermediate_size=d_ff, 64 | hidden_dropout_prob=dropout, 65 | attention_probs_dropout_prob=dropout, 66 | max_position_embeddings=1, 67 | type_vocab_size=1) 68 | dec_config = BertConfig(vocab_size=tgt_vocab, 69 | hidden_size=d_model, 70 | num_hidden_layers=N_dec, 71 | num_attention_heads=h, 72 | intermediate_size=d_ff, 73 | hidden_dropout_prob=dropout, 74 | attention_probs_dropout_prob=dropout, 75 | max_position_embeddings=17, 76 | type_vocab_size=1, 77 | is_decoder=True) 78 | encoder = BertModel(enc_config) 79 | def return_embeds(*args, **kwargs): 80 | return kwargs['inputs_embeds'] 81 | del encoder.embeddings; encoder.embeddings = return_embeds 82 | decoder = BertModel(dec_config) 83 | model = EncoderDecoder( 84 | encoder, 85 | decoder, 86 | Generator(d_model, tgt_vocab)) 87 | return model 88 | 89 | def __init__(self, opt): 90 | super(BertCapModel, self).__init__(opt) 91 | 92 | def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): 93 | """ 94 | state = [ys.unsqueeze(0)] 95 | """ 96 | if len(state) == 0: 97 | ys = it.unsqueeze(1) 98 | else: 99 | ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) 100 | out = self.model.decode(memory, mask, 101 | ys, 102 | subsequent_mask(ys.size(1)) 103 | .to(memory.device)) 104 | return out[:, -1], [ys.unsqueeze(0)] 105 | -------------------------------------------------------------------------------- /captioning/models/M2Transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instruction to use meshed_memory_transformer (https://arxiv.org/abs/1912.08226) 3 | 4 | pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git 5 | 6 | Note: 7 | Currently m2transformer is not performing as well as original transformer. Not sure why? Still investigating. 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | import copy 19 | import math 20 | import numpy as np 21 | 22 | from .CaptionModel import CaptionModel 23 | from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel 24 | 25 | try: 26 | from m2transformer.models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory 27 | except: 28 | print('meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`') 29 | from .TransformerModel import subsequent_mask, TransformerModel 30 | 31 | 32 | class M2TransformerModel(TransformerModel): 33 | 34 | def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, 35 | d_model=512, d_ff=2048, h=8, dropout=0.1): 36 | "Helper: Construct a model from hyperparameters." 37 | encoder = MemoryAugmentedEncoder(N_enc, 0, attention_module=ScaledDotProductAttentionMemory, 38 | attention_module_kwargs={'m': 40}) 39 | # Another implementation is to use MultiLevelEncoder + att_embed 40 | decoder = MeshedDecoder(tgt_vocab, 54, N_dec, -1) # -1 is padding; 41 | model = Transformer(0, encoder, decoder) # 0 is bos 42 | return model 43 | 44 | def __init__(self, opt): 45 | super(M2TransformerModel, self).__init__(opt) 46 | delattr(self, 'att_embed') 47 | self.att_embed = lambda x: x # The visual embed is in the MAEncoder 48 | # Notes: The dropout in MAEncoder is different from my att_embed, mine is 0.5? 49 | # Also the attention mask seems wrong in MAEncoder too...intersting 50 | 51 | def logit(self, x): # unsafe way 52 | return x # M2transformer always output logsoftmax 53 | 54 | def _prepare_feature(self, fc_feats, att_feats, att_masks): 55 | 56 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) 57 | memory, att_masks = self.model.encoder(att_feats) 58 | 59 | return fc_feats[...,:0], att_feats[...,:0], memory, att_masks 60 | 61 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 62 | if seq.ndim == 3: # B * seq_per_img * seq_len 63 | seq = seq.reshape(-1, seq.shape[2]) 64 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) 65 | 66 | seq = seq.clone() 67 | seq[~seq_mask.any(-2)] = -1 # Make padding to be -1 (my dataloader uses 0 as padding) 68 | outputs = self.model(att_feats, seq) 69 | 70 | return outputs 71 | 72 | def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): 73 | """ 74 | state = [ys.unsqueeze(0)] 75 | """ 76 | if len(state) == 0: 77 | ys = it.unsqueeze(1) 78 | else: 79 | ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) 80 | out = self.model.decoder(ys, memory, mask) 81 | return out[:, -1], [ys.unsqueeze(0)] 82 | 83 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 84 | beam_size = opt.get('beam_size', 10) 85 | group_size = opt.get('group_size', 1) 86 | sample_n = opt.get('sample_n', 10) 87 | assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' 88 | 89 | att_feats, _, __, ___ = self._prepare_feature_forward(att_feats, att_masks) 90 | seq, logprobs, seqLogprobs = self.model.beam_search(att_feats, self.seq_length, 0, 91 | beam_size, return_probs=True, out_size=beam_size) 92 | seq = seq.reshape(-1, *seq.shape[2:]) 93 | seqLogprobs = seqLogprobs.reshape(-1, *seqLogprobs.shape[2:]) 94 | 95 | # if not (seqLogprobs.gather(-1, seq.unsqueeze(-1)).squeeze(-1) == logprobs.reshape(-1, logprobs.shape[-1])).all(): 96 | # import pudb;pu.db 97 | # seqLogprobs = logprobs.reshape(-1, logprobs.shape[-1]).unsqueeze(-1).expand(-1,-1,seqLogprobs.shape[-1]) 98 | return seq, seqLogprobs -------------------------------------------------------------------------------- /captioning/models/ShowTellModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import * 9 | from . import utils 10 | 11 | from .CaptionModel import CaptionModel 12 | 13 | class ShowTellModel(CaptionModel): 14 | def __init__(self, opt): 15 | super(ShowTellModel, self).__init__() 16 | self.vocab_size = opt.vocab_size 17 | self.input_encoding_size = opt.input_encoding_size 18 | self.rnn_type = opt.rnn_type 19 | self.rnn_size = opt.rnn_size 20 | self.num_layers = opt.num_layers 21 | self.drop_prob_lm = opt.drop_prob_lm 22 | self.seq_length = opt.seq_length 23 | self.fc_feat_size = opt.fc_feat_size 24 | 25 | self.ss_prob = 0.0 # Schedule sampling probability 26 | 27 | self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) 28 | self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) 29 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 30 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 31 | self.dropout = nn.Dropout(self.drop_prob_lm) 32 | 33 | self.init_weights() 34 | 35 | def init_weights(self): 36 | initrange = 0.1 37 | self.embed.weight.data.uniform_(-initrange, initrange) 38 | self.logit.bias.data.fill_(0) 39 | self.logit.weight.data.uniform_(-initrange, initrange) 40 | 41 | def init_hidden(self, bsz): 42 | weight = self.logit.weight 43 | if self.rnn_type == 'lstm': 44 | return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), 45 | weight.new_zeros(self.num_layers, bsz, self.rnn_size)) 46 | else: 47 | return weight.new_zeros(self.num_layers, bsz, self.rnn_size) 48 | 49 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 50 | batch_size = fc_feats.size(0) 51 | seq_per_img = seq.shape[0] // batch_size 52 | state = self.init_hidden(batch_size*seq_per_img) 53 | outputs = [] 54 | 55 | if seq_per_img > 1: 56 | fc_feats = utils.repeat_tensors(seq_per_img, fc_feats) 57 | 58 | for i in range(seq.size(1) + 1): 59 | if i == 0: 60 | xt = self.img_embed(fc_feats) 61 | else: 62 | if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample 63 | sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1) 64 | sample_mask = sample_prob < self.ss_prob 65 | if sample_mask.sum() == 0: 66 | it = seq[:, i-1].clone() 67 | else: 68 | sample_ind = sample_mask.nonzero().view(-1) 69 | it = seq[:, i-1].data.clone() 70 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 71 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 72 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 73 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 74 | else: 75 | it = seq[:, i-1].clone() 76 | # break if all the sequences end 77 | if i >= 2 and seq[:, i-1].data.sum() == 0: 78 | break 79 | xt = self.embed(it) 80 | 81 | output, state = self.core(xt.unsqueeze(0), state) 82 | output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 83 | outputs.append(output) 84 | 85 | return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() 86 | 87 | def get_logprobs_state(self, it, state): 88 | # 'it' contains a word index 89 | xt = self.embed(it) 90 | 91 | output, state = self.core(xt.unsqueeze(0), state) 92 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 93 | 94 | return logprobs, state 95 | 96 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 97 | beam_size = opt.get('beam_size', 10) 98 | batch_size = fc_feats.size(0) 99 | 100 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 101 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 102 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 103 | # lets process every image independently for now, for simplicity 104 | 105 | self.done_beams = [[] for _ in range(batch_size)] 106 | for k in range(batch_size): 107 | state = self.init_hidden(beam_size) 108 | for t in range(2): 109 | if t == 0: 110 | xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) 111 | elif t == 1: # input 112 | it = fc_feats.data.new(beam_size).long().zero_() 113 | xt = self.embed(it) 114 | 115 | output, state = self.core(xt.unsqueeze(0), state) 116 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 117 | 118 | self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) 119 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 120 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 121 | # return the samples and their log likelihoods 122 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 123 | 124 | def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): 125 | sample_method = opt.get('sample_method', 'greedy') 126 | beam_size = opt.get('beam_size', 1) 127 | temperature = opt.get('temperature', 1.0) 128 | if beam_size > 1 and sample_method in ['greedy', 'beam_search']: 129 | return self.sample_beam(fc_feats, att_feats, opt) 130 | 131 | batch_size = fc_feats.size(0) 132 | state = self.init_hidden(batch_size) 133 | seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) 134 | seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) 135 | for t in range(self.seq_length + 2): 136 | if t == 0: 137 | xt = self.img_embed(fc_feats) 138 | else: 139 | if t == 1: # input 140 | it = fc_feats.data.new(batch_size).long().zero_() 141 | xt = self.embed(it) 142 | 143 | output, state = self.core(xt.unsqueeze(0), state) 144 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 145 | 146 | # sample the next word 147 | if t == self.seq_length + 1: # skip if we achieve maximum length 148 | break 149 | if sample_method == 'greedy': 150 | sampleLogprobs, it = torch.max(logprobs.data, 1) 151 | it = it.view(-1).long() 152 | else: 153 | if temperature == 1.0: 154 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 155 | else: 156 | # scale logprobs by temperature 157 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 158 | it = torch.multinomial(prob_prev, 1).to(logprobs.device) 159 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions 160 | it = it.view(-1).long() # and flatten indices for downstream processing 161 | 162 | if t >= 1: 163 | # stop when all finished 164 | if t == 1: 165 | unfinished = it > 0 166 | else: 167 | unfinished = unfinished & (it > 0) 168 | it = it * unfinished.type_as(it) 169 | seq[:,t-1] = it #seq[t] the input of t+2 time step 170 | seqLogprobs[:,t-1] = sampleLogprobs.view(-1) 171 | if unfinished.sum() == 0: 172 | break 173 | 174 | return seq, seqLogprobs -------------------------------------------------------------------------------- /captioning/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import copy 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from .ShowTellModel import ShowTellModel 12 | from .FCModel import FCModel 13 | from .AttModel import * 14 | from .TransformerModel import TransformerModel 15 | from .cachedTransformer import TransformerModel as cachedTransformer 16 | from .BertCapModel import BertCapModel 17 | from .M2Transformer import M2TransformerModel 18 | from .AoAModel import AoAModel 19 | 20 | def setup(opt): 21 | if opt.caption_model in ['fc', 'show_tell']: 22 | print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model) 23 | if opt.caption_model == 'fc': 24 | print('Use newfc instead of fc') 25 | if opt.caption_model == 'fc': 26 | model = FCModel(opt) 27 | elif opt.caption_model == 'language_model': 28 | model = LMModel(opt) 29 | elif opt.caption_model == 'newfc': 30 | model = NewFCModel(opt) 31 | elif opt.caption_model == 'show_tell': 32 | model = ShowTellModel(opt) 33 | # Att2in model in self-critical 34 | elif opt.caption_model == 'att2in': 35 | model = Att2inModel(opt) 36 | # Att2in model with two-layer MLP img embedding and word embedding 37 | elif opt.caption_model == 'att2in2': 38 | model = Att2in2Model(opt) 39 | elif opt.caption_model == 'att2all2': 40 | print('Warning: this is not a correct implementation of the att2all model in the original paper.') 41 | model = Att2all2Model(opt) 42 | # Adaptive Attention model from Knowing when to look 43 | elif opt.caption_model == 'adaatt': 44 | model = AdaAttModel(opt) 45 | # Adaptive Attention with maxout lstm 46 | elif opt.caption_model == 'adaattmo': 47 | model = AdaAttMOModel(opt) 48 | # Top-down attention model 49 | elif opt.caption_model in ['topdown', 'updown']: 50 | model = UpDownModel(opt) 51 | # StackAtt 52 | elif opt.caption_model == 'stackatt': 53 | model = StackAttModel(opt) 54 | # DenseAtt 55 | elif opt.caption_model == 'denseatt': 56 | model = DenseAttModel(opt) 57 | # Transformer 58 | elif opt.caption_model == 'transformer': 59 | if getattr(opt, 'cached_transformer', False): 60 | model = cachedTransformer(opt) 61 | else: 62 | model = TransformerModel(opt) 63 | # AoANet 64 | elif opt.caption_model == 'aoa': 65 | model = AoAModel(opt) 66 | elif opt.caption_model == 'bert': 67 | model = BertCapModel(opt) 68 | elif opt.caption_model == 'm2transformer': 69 | model = M2TransformerModel(opt) 70 | else: 71 | raise Exception("Caption model not supported: {}".format(opt.caption_model)) 72 | 73 | return model 74 | -------------------------------------------------------------------------------- /captioning/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def repeat_tensors(n, x): 4 | """ 5 | For a tensor of size Bx..., we repeat it n times, and make it Bnx... 6 | For collections, do nested repeat 7 | """ 8 | if torch.is_tensor(x): 9 | x = x.unsqueeze(1) # Bx1x... 10 | x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx... 11 | x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx... 12 | elif type(x) is list or type(x) is tuple: 13 | x = [repeat_tensors(n, _) for _ in x] 14 | return x 15 | 16 | 17 | def split_tensors(n, x): 18 | if torch.is_tensor(x): 19 | assert x.shape[0] % n == 0 20 | x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1) 21 | elif type(x) is list or type(x) is tuple: 22 | x = [split_tensors(n, _) for _ in x] 23 | elif x is None: 24 | x = [None] * n 25 | return x -------------------------------------------------------------------------------- /captioning/modules/loss_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import losses 3 | from ..utils.rewards import init_scorer, get_self_critical_reward, get_self_critical_clipscore_reward 4 | from ..utils.clipscore import CLIPScore 5 | import numpy as np 6 | 7 | class LossWrapper(torch.nn.Module): 8 | def __init__(self, model, opt): 9 | super(LossWrapper, self).__init__() 10 | self.opt = opt 11 | self.model = model 12 | if opt.label_smoothing > 0: 13 | self.crit = losses.LabelSmoothing(smoothing=opt.label_smoothing) 14 | else: 15 | self.crit = losses.LanguageModelCriterion() 16 | self.rl_crit = losses.RewardCriterion() 17 | self.struc_crit = losses.StructureLosses(opt) 18 | 19 | self.clipscore_model = None 20 | if self.opt.use_clipscore: 21 | use_grammar = getattr(self.opt, 'use_grammar', False) 22 | joint_out = getattr(self.opt, 'joint_out', False) 23 | self.clipscore_model = CLIPScore( 24 | mode=opt.clipscore_mode, 25 | use_grammar=use_grammar, 26 | joint_out=joint_out, 27 | ) 28 | for p in self.clipscore_model.parameters(): 29 | p.requires_grad = False 30 | 31 | if use_grammar: 32 | state_dict = torch.load(self.opt.clip_load_path, map_location='cpu') 33 | self.clipscore_model.load_state_dict(state_dict['state_dict']) 34 | 35 | def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, 36 | sc_flag, struc_flag, clip_vis_feats=None): 37 | opt = self.opt 38 | 39 | out = {} 40 | if struc_flag: 41 | if opt.structure_loss_weight < 1: 42 | lm_loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) 43 | else: 44 | lm_loss = torch.tensor(0).type_as(fc_feats) 45 | if opt.structure_loss_weight > 0: 46 | gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, 47 | opt={'sample_method':opt.train_sample_method, 48 | 'beam_size':opt.train_beam_size, 49 | 'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\ 50 | or not 'margin' in opt.structure_loss_type, 51 | 'sample_n': opt.train_sample_n}, 52 | mode='sample') 53 | gts = [gts[_] for _ in gt_indices.tolist()] 54 | struc_loss = self.struc_crit(sample_logprobs, gen_result, gts) 55 | else: 56 | struc_loss = {'loss': torch.tensor(0).type_as(fc_feats), 57 | 'reward': torch.tensor(0).type_as(fc_feats)} 58 | loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss'] 59 | out['lm_loss'] = lm_loss 60 | out['struc_loss'] = struc_loss['loss'] 61 | out['reward'] = struc_loss['reward'] 62 | elif not sc_flag: 63 | loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) 64 | else: 65 | self.model.eval() 66 | with torch.no_grad(): 67 | greedy_res, _ = self.model(fc_feats, att_feats, att_masks, 68 | mode='sample', 69 | opt={'sample_method': opt.sc_sample_method, 70 | 'beam_size': opt.sc_beam_size}) 71 | self.model.train() 72 | gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, 73 | opt={'sample_method':opt.train_sample_method, 74 | 'beam_size':opt.train_beam_size, 75 | 'sample_n': opt.train_sample_n}, 76 | mode='sample') 77 | gts = [gts[_] for _ in gt_indices.tolist()] 78 | 79 | if getattr(self.opt, 'use_multi_rewards', False): 80 | assert self.opt.use_clipscore 81 | clipscore_reward_normalized, clipscore_unnormalized_mean, grammar_rewards = get_self_critical_clipscore_reward( 82 | greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab) 83 | 84 | if self.opt.clipscore_mode == 'clip_s': 85 | out['CLIP-S'] = clipscore_unnormalized_mean 86 | elif self.opt.clipscore_mode == 'refclip_s': 87 | out['RefCLIP-S'] = clipscore_unnormalized_mean 88 | 89 | if getattr(self.opt, 'use_grammar', False): 90 | out['grammar_reward'] = grammar_rewards.mean() 91 | 92 | reward = clipscore_reward_normalized + grammar_rewards 93 | 94 | 95 | else: 96 | assert grammar_rewards is None 97 | 98 | cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward( 99 | greedy_res, gts, gen_result, self.opt) 100 | out['CIDEr'] = cider_unnormalized_mean 101 | if isinstance(cider_reward_normalized, np.ndarray): 102 | cider_reward_normalized = torch.from_numpy(cider_reward_normalized).to(clipscore_reward_normalized.device) 103 | 104 | reward = clipscore_reward_normalized + cider_reward_normalized 105 | else: 106 | if self.opt.use_clipscore: 107 | clipscore_reward_normalized, clipscore_unnormalized_mean, _ = get_self_critical_clipscore_reward( 108 | greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab) 109 | if self.opt.clipscore_mode == 'clip_s': 110 | out['CLIP-S'] = clipscore_unnormalized_mean 111 | elif self.opt.clipscore_mode == 'refclip_s': 112 | out['RefCLIP-S'] = clipscore_unnormalized_mean 113 | reward = clipscore_reward_normalized 114 | else: 115 | cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward( 116 | greedy_res, gts, gen_result, self.opt) 117 | out['CIDEr'] = cider_unnormalized_mean 118 | reward = cider_reward_normalized 119 | 120 | if isinstance(reward, np.ndarray): 121 | reward = torch.from_numpy(reward) 122 | reward = reward.to(sample_logprobs) 123 | loss = self.rl_crit(sample_logprobs, gen_result.data, reward) 124 | out['reward'] = reward[:,0].mean() 125 | out['loss'] = loss 126 | return out 127 | 128 | -------------------------------------------------------------------------------- /captioning/modules/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..utils.rewards import get_scores, get_self_cider_scores 4 | 5 | class RewardCriterion(nn.Module): 6 | def __init__(self): 7 | super(RewardCriterion, self).__init__() 8 | 9 | def forward(self, input, seq, reward): 10 | input = input.gather(2, seq.unsqueeze(2)).squeeze(2) 11 | 12 | input = input.reshape(-1) 13 | reward = reward.reshape(-1) 14 | mask = (seq>0).to(input) 15 | mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1).reshape(-1) 16 | output = - input * reward * mask 17 | output = torch.sum(output) / torch.sum(mask) 18 | 19 | return output 20 | 21 | class StructureLosses(nn.Module): 22 | """ 23 | This loss is inspired by Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018). 24 | """ 25 | def __init__(self, opt): 26 | super(StructureLosses, self).__init__() 27 | self.opt = opt 28 | self.loss_type = opt.structure_loss_type 29 | 30 | def forward(self, input, seq, data_gts): 31 | """ 32 | Input is either logits or log softmax 33 | """ 34 | out = {} 35 | 36 | batch_size = input.size(0)# batch_size = sample_size * seq_per_img 37 | seq_per_img = batch_size // len(data_gts) 38 | 39 | assert seq_per_img == self.opt.train_sample_n, seq_per_img 40 | 41 | mask = (seq>0).to(input) 42 | mask = torch.cat([mask.new_full((mask.size(0), 1), 1), mask[:, :-1]], 1) 43 | 44 | scores = get_scores(data_gts, seq, self.opt) 45 | scores = torch.from_numpy(scores).type_as(input).view(-1, seq_per_img) 46 | out['reward'] = scores #.mean() 47 | if self.opt.entropy_reward_weight > 0: 48 | entropy = - (F.softmax(input, dim=2) * F.log_softmax(input, dim=2)).sum(2).data 49 | entropy = (entropy * mask).sum(1) / mask.sum(1) 50 | print('entropy', entropy.mean().item()) 51 | scores = scores + self.opt.entropy_reward_weight * entropy.view(-1, seq_per_img) 52 | # rescale cost to [0,1] 53 | costs = - scores 54 | if self.loss_type == 'risk' or self.loss_type == 'softmax_margin': 55 | costs = costs - costs.min(1, keepdim=True)[0] 56 | costs = costs / costs.max(1, keepdim=True)[0] 57 | # in principle 58 | # Only risk need such rescale 59 | # margin should be alright; Let's try. 60 | 61 | # Gather input: BxTxD -> BxT 62 | input = input.gather(2, seq.unsqueeze(2)).squeeze(2) 63 | 64 | if self.loss_type == 'seqnll': 65 | # input is logsoftmax 66 | input = input * mask 67 | input = input.sum(1) / mask.sum(1) 68 | input = input.view(-1, seq_per_img) 69 | 70 | target = costs.min(1)[1] 71 | output = F.cross_entropy(input, target) 72 | elif self.loss_type == 'risk': 73 | # input is logsoftmax 74 | input = input * mask 75 | input = input.sum(1) 76 | input = input.view(-1, seq_per_img) 77 | 78 | output = (F.softmax(input.exp()) * costs).sum(1).mean() 79 | 80 | # test 81 | # avg_scores = input 82 | # probs = F.softmax(avg_scores.exp_()) 83 | # loss = (probs * costs.type_as(probs)).sum() / input.size(0) 84 | # print(output.item(), loss.item()) 85 | 86 | elif self.loss_type == 'max_margin': 87 | # input is logits 88 | input = input * mask 89 | input = input.sum(1) / mask.sum(1) 90 | input = input.view(-1, seq_per_img) 91 | _, __ = costs.min(1, keepdim=True) 92 | costs_star = _ 93 | input_star = input.gather(1, __) 94 | output = F.relu(costs - costs_star - input_star + input).max(1)[0] / 2 95 | output = output.mean() 96 | 97 | # sanity test 98 | # avg_scores = input + costs 99 | # scores_with_high_target = avg_scores.clone() 100 | # scores_with_high_target.scatter_(1, costs.min(1)[1].view(-1, 1), 1e10) 101 | 102 | # target_and_offender_index = scores_with_high_target.sort(1, True)[1][:, 0:2] 103 | # avg_scores = avg_scores.gather(1, target_and_offender_index) 104 | # target_index = avg_scores.new_zeros(avg_scores.size(0), dtype=torch.long) 105 | # loss = F.multi_margin_loss(avg_scores, target_index, size_average=True, margin=0) 106 | # print(loss.item() * 2, output.item()) 107 | 108 | elif self.loss_type == 'multi_margin': 109 | # input is logits 110 | input = input * mask 111 | input = input.sum(1) / mask.sum(1) 112 | input = input.view(-1, seq_per_img) 113 | _, __ = costs.min(1, keepdim=True) 114 | costs_star = _ 115 | input_star = input.gather(1, __) 116 | output = F.relu(costs - costs_star - input_star + input) 117 | output = output.mean() 118 | 119 | # sanity test 120 | # avg_scores = input + costs 121 | # loss = F.multi_margin_loss(avg_scores, costs.min(1)[1], margin=0) 122 | # print(output, loss) 123 | 124 | elif self.loss_type == 'softmax_margin': 125 | # input is logsoftmax 126 | input = input * mask 127 | input = input.sum(1) / mask.sum(1) 128 | input = input.view(-1, seq_per_img) 129 | 130 | input = input + costs 131 | target = costs.min(1)[1] 132 | output = F.cross_entropy(input, target) 133 | 134 | elif self.loss_type == 'real_softmax_margin': 135 | # input is logits 136 | # This is what originally defined in Kevin's paper 137 | # The result should be equivalent to softmax_margin 138 | input = input * mask 139 | input = input.sum(1) / mask.sum(1) 140 | input = input.view(-1, seq_per_img) 141 | 142 | input = input + costs 143 | target = costs.min(1)[1] 144 | output = F.cross_entropy(input, target) 145 | 146 | elif self.loss_type == 'new_self_critical': 147 | """ 148 | A different self critical 149 | Self critical uses greedy decoding score as baseline; 150 | This setting uses the average score of the rest samples as baseline 151 | (suppose c1...cn n samples, reward1 = score1 - 1/(n-1)(score2+..+scoren) ) 152 | """ 153 | baseline = (scores.sum(1, keepdim=True) - scores) / (scores.shape[1] - 1) 154 | scores = scores - baseline 155 | # self cider used as reward to promote diversity (not working that much in this way) 156 | if getattr(self.opt, 'self_cider_reward_weight', 0) > 0: 157 | _scores = get_self_cider_scores(data_gts, seq, self.opt) 158 | _scores = torch.from_numpy(_scores).type_as(scores).view(-1, 1) 159 | _scores = _scores.expand_as(scores - 1) 160 | scores += self.opt.self_cider_reward_weight * _scores 161 | output = - input * mask * scores.view(-1, 1) 162 | output = torch.sum(output) / torch.sum(mask) 163 | 164 | out['loss'] = output 165 | return out 166 | 167 | class LanguageModelCriterion(nn.Module): 168 | def __init__(self): 169 | super(LanguageModelCriterion, self).__init__() 170 | 171 | def forward(self, input, target, mask): 172 | if target.ndim == 3: 173 | target = target.reshape(-1, target.shape[2]) 174 | mask = mask.reshape(-1, mask.shape[2]) 175 | # truncate to the same size 176 | target = target[:, :input.size(1)] 177 | mask = mask[:, :input.size(1)].to(input) 178 | 179 | output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask 180 | # Average over each token 181 | output = torch.sum(output) / torch.sum(mask) 182 | 183 | return output 184 | 185 | class LabelSmoothing(nn.Module): 186 | "Implement label smoothing." 187 | def __init__(self, size=0, padding_idx=0, smoothing=0.0): 188 | super(LabelSmoothing, self).__init__() 189 | self.criterion = nn.KLDivLoss(size_average=False, reduce=False) 190 | # self.padding_idx = padding_idx 191 | self.confidence = 1.0 - smoothing 192 | self.smoothing = smoothing 193 | # self.size = size 194 | self.true_dist = None 195 | 196 | def forward(self, input, target, mask): 197 | if target.ndim == 3: 198 | target = target.reshape(-1, target.shape[2]) 199 | mask = mask.reshape(-1, mask.shape[2]) 200 | # truncate to the same size 201 | target = target[:, :input.size(1)] 202 | mask = mask[:, :input.size(1)] 203 | 204 | input = input.reshape(-1, input.size(-1)) 205 | target = target.reshape(-1) 206 | mask = mask.reshape(-1).to(input) 207 | 208 | # assert x.size(1) == self.size 209 | self.size = input.size(1) 210 | # true_dist = x.data.clone() 211 | true_dist = input.data.clone() 212 | # true_dist.fill_(self.smoothing / (self.size - 2)) 213 | true_dist.fill_(self.smoothing / (self.size - 1)) 214 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 215 | # true_dist[:, self.padding_idx] = 0 216 | # mask = torch.nonzero(target.data == self.padding_idx) 217 | # self.true_dist = true_dist 218 | return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum() -------------------------------------------------------------------------------- /captioning/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/j-min/CLIP-Caption-Reward/ca5fe53b848d7e7b1fb9808984f7a3187b2a32d6/captioning/utils/__init__.py -------------------------------------------------------------------------------- /captioning/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Copy from fvcore 3 | 4 | import logging 5 | import os 6 | from typing import Any 7 | import yaml 8 | from yacs.config import CfgNode as _CfgNode 9 | 10 | import io as PathManager 11 | 12 | BASE_KEY = "_BASE_" 13 | 14 | 15 | class CfgNode(_CfgNode): 16 | """ 17 | Our own extended version of :class:`yacs.config.CfgNode`. 18 | It contains the following extra features: 19 | 20 | 1. The :meth:`merge_from_file` method supports the "_BASE_" key, 21 | which allows the new CfgNode to inherit all the attributes from the 22 | base configuration file. 23 | 2. Keys that start with "COMPUTED_" are treated as insertion-only 24 | "computed" attributes. They can be inserted regardless of whether 25 | the CfgNode is frozen or not. 26 | 3. With "allow_unsafe=True", it supports pyyaml tags that evaluate 27 | expressions in config. See examples in 28 | https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types 29 | Note that this may lead to arbitrary code execution: you must not 30 | load a config file from untrusted sources before manually inspecting 31 | the content of the file. 32 | """ 33 | 34 | @staticmethod 35 | def load_yaml_with_base(filename, allow_unsafe = False): 36 | """ 37 | Just like `yaml.load(open(filename))`, but inherit attributes from its 38 | `_BASE_`. 39 | 40 | Args: 41 | filename (str): the file name of the current config. Will be used to 42 | find the base config file. 43 | allow_unsafe (bool): whether to allow loading the config file with 44 | `yaml.unsafe_load`. 45 | 46 | Returns: 47 | (dict): the loaded yaml 48 | """ 49 | with PathManager.open(filename, "r") as f: 50 | try: 51 | cfg = yaml.safe_load(f) 52 | except yaml.constructor.ConstructorError: 53 | if not allow_unsafe: 54 | raise 55 | logger = logging.getLogger(__name__) 56 | logger.warning( 57 | "Loading config {} with yaml.unsafe_load. Your machine may " 58 | "be at risk if the file contains malicious content.".format( 59 | filename 60 | ) 61 | ) 62 | f.close() 63 | with open(filename, "r") as f: 64 | cfg = yaml.unsafe_load(f) 65 | 66 | def merge_a_into_b(a, b): 67 | # merge dict a into dict b. values in a will overwrite b. 68 | for k, v in a.items(): 69 | if isinstance(v, dict) and k in b: 70 | assert isinstance( 71 | b[k], dict 72 | ), "Cannot inherit key '{}' from base!".format(k) 73 | merge_a_into_b(v, b[k]) 74 | else: 75 | b[k] = v 76 | 77 | if BASE_KEY in cfg: 78 | base_cfg_file = cfg[BASE_KEY] 79 | if base_cfg_file.startswith("~"): 80 | base_cfg_file = os.path.expanduser(base_cfg_file) 81 | if not any( 82 | map(base_cfg_file.startswith, ["/", "https://", "http://"]) 83 | ): 84 | # the path to base cfg is relative to the config file itself. 85 | base_cfg_file = os.path.join( 86 | os.path.dirname(filename), base_cfg_file 87 | ) 88 | base_cfg = CfgNode.load_yaml_with_base( 89 | base_cfg_file, allow_unsafe=allow_unsafe 90 | ) 91 | del cfg[BASE_KEY] 92 | 93 | merge_a_into_b(cfg, base_cfg) 94 | return base_cfg 95 | return cfg 96 | 97 | def merge_from_file(self, cfg_filename, allow_unsafe = False): 98 | """ 99 | Merge configs from a given yaml file. 100 | 101 | Args: 102 | cfg_filename: the file name of the yaml config. 103 | allow_unsafe: whether to allow loading the config file with 104 | `yaml.unsafe_load`. 105 | """ 106 | loaded_cfg = CfgNode.load_yaml_with_base( 107 | cfg_filename, allow_unsafe=allow_unsafe 108 | ) 109 | loaded_cfg = type(self)(loaded_cfg) 110 | self.merge_from_other_cfg(loaded_cfg) 111 | 112 | # Forward the following calls to base, but with a check on the BASE_KEY. 113 | def merge_from_other_cfg(self, cfg_other): 114 | """ 115 | Args: 116 | cfg_other (CfgNode): configs to merge from. 117 | """ 118 | assert ( 119 | BASE_KEY not in cfg_other 120 | ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) 121 | return super().merge_from_other_cfg(cfg_other) 122 | 123 | def merge_from_list(self, cfg_list): 124 | """ 125 | Args: 126 | cfg_list (list): list of configs to merge from. 127 | """ 128 | keys = set(cfg_list[0::2]) 129 | assert ( 130 | BASE_KEY not in keys 131 | ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) 132 | return super().merge_from_list(cfg_list) 133 | 134 | def __setattr__(self, name, val): 135 | if name.startswith("COMPUTED_"): 136 | if name in self: 137 | old_val = self[name] 138 | if old_val == val: 139 | return 140 | raise KeyError( 141 | "Computed attributed '{}' already exists " 142 | "with a different value! old={}, new={}.".format( 143 | name, old_val, val 144 | ) 145 | ) 146 | self[name] = val 147 | else: 148 | super().__setattr__(name, val) 149 | 150 | 151 | if __name__ == '__main__': 152 | cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml') 153 | print(cfg) -------------------------------------------------------------------------------- /captioning/utils/div_utils.py: -------------------------------------------------------------------------------- 1 | from random import uniform 2 | import numpy as np 3 | from collections import OrderedDict, defaultdict 4 | from itertools import tee 5 | import time 6 | 7 | # ----------------------------------------------- 8 | def find_ngrams(input_list, n): 9 | return zip(*[input_list[i:] for i in range(n)]) 10 | 11 | def compute_div_n(caps,n=1): 12 | aggr_div = [] 13 | for k in caps: 14 | all_ngrams = set() 15 | lenT = 0. 16 | for c in caps[k]: 17 | tkns = c.split() 18 | lenT += len(tkns) 19 | ng = find_ngrams(tkns, n) 20 | all_ngrams.update(ng) 21 | aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT))) 22 | return np.array(aggr_div).mean(), np.array(aggr_div) 23 | 24 | def compute_global_div_n(caps,n=1): 25 | aggr_div = [] 26 | all_ngrams = set() 27 | lenT = 0. 28 | for k in caps: 29 | for c in caps[k]: 30 | tkns = c.split() 31 | lenT += len(tkns) 32 | ng = find_ngrams(tkns, n) 33 | all_ngrams.update(ng) 34 | if n == 1: 35 | aggr_div.append(float(len(all_ngrams))) 36 | else: 37 | aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT))) 38 | return aggr_div[0], np.repeat(np.array(aggr_div),len(caps)) -------------------------------------------------------------------------------- /captioning/utils/eval_multi.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | import numpy as np 9 | import json 10 | from json import encoder 11 | import random 12 | import string 13 | import time 14 | import os 15 | import sys 16 | from . import misc as utils 17 | from eval_utils import getCOCO 18 | 19 | from .div_utils import compute_div_n, compute_global_div_n 20 | 21 | import sys 22 | try: 23 | sys.path.append("coco-caption") 24 | annFile = 'coco-caption/annotations/captions_val2014.json' 25 | from pycocotools.coco import COCO 26 | from pycocoevalcap.eval import COCOEvalCap 27 | from pycocoevalcap.eval_spice import COCOEvalCapSpice 28 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 29 | from pycocoevalcap.bleu.bleu import Bleu 30 | sys.path.append("cider") 31 | from pyciderevalcap.cider.cider import Cider 32 | except: 33 | print('Warning: requirements for eval_multi not satisfied') 34 | 35 | 36 | def eval_allspice(dataset, preds_n, model_id, split): 37 | coco = getCOCO(dataset) 38 | valids = coco.getImgIds() 39 | 40 | capsById = {} 41 | for d in preds_n: 42 | capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] 43 | 44 | # filter results to only those in MSCOCO validation set (will be about a third) 45 | preds_filt_n = [p for p in preds_n if p['image_id'] in valids] 46 | print('using %d/%d predictions_n' % (len(preds_filt_n), len(preds_n))) 47 | cache_path_n = os.path.join('eval_results/', model_id + '_' + split + '_n.json') 48 | json.dump(preds_filt_n, open(cache_path_n, 'w')) # serialize to temporary json file. Sigh, COCO API... 49 | 50 | # Eval AllSPICE 51 | cocoRes_n = coco.loadRes(cache_path_n) 52 | cocoEvalAllSPICE = COCOEvalCapSpice(coco, cocoRes_n) 53 | cocoEvalAllSPICE.params['image_id'] = cocoRes_n.getImgIds() 54 | cocoEvalAllSPICE.evaluate() 55 | 56 | out = {} 57 | for metric, score in cocoEvalAllSPICE.eval.items(): 58 | out['All'+metric] = score 59 | 60 | imgToEvalAllSPICE = cocoEvalAllSPICE.imgToEval 61 | # collect SPICE_sub_score 62 | for k in list(imgToEvalAllSPICE.values())[0]['SPICE'].keys(): 63 | if k != 'All': 64 | out['AllSPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEvalAllSPICE.values()]) 65 | out['AllSPICE_'+k] = (out['AllSPICE_'+k][out['AllSPICE_'+k]==out['AllSPICE_'+k]]).mean() 66 | for p in preds_filt_n: 67 | image_id, caption = p['image_id'], p['caption'] 68 | imgToEvalAllSPICE[image_id]['caption'] = capsById[image_id] 69 | return {'overall': out, 'imgToEvalAllSPICE': imgToEvalAllSPICE} 70 | 71 | def eval_oracle(dataset, preds_n, model_id, split): 72 | cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json') 73 | 74 | coco = getCOCO(dataset) 75 | valids = coco.getImgIds() 76 | 77 | capsById = {} 78 | for d in preds_n: 79 | capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] 80 | 81 | sample_n = capsById[list(capsById.keys())[0]] 82 | for i in range(len(capsById[list(capsById.keys())[0]])): 83 | preds = [_[i] for _ in capsById.values()] 84 | 85 | json.dump(preds, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... 86 | 87 | cocoRes = coco.loadRes(cache_path) 88 | cocoEval = COCOEvalCap(coco, cocoRes) 89 | cocoEval.params['image_id'] = cocoRes.getImgIds() 90 | cocoEval.evaluate() 91 | 92 | imgToEval = cocoEval.imgToEval 93 | for img_id in capsById.keys(): 94 | tmp = imgToEval[img_id] 95 | for k in tmp['SPICE'].keys(): 96 | if k != 'All': 97 | tmp['SPICE_'+k] = tmp['SPICE'][k]['f'] 98 | if tmp['SPICE_'+k] != tmp['SPICE_'+k]: # nan 99 | tmp['SPICE_'+k] = -100 100 | tmp['SPICE'] = tmp['SPICE']['All']['f'] 101 | if tmp['SPICE'] != tmp['SPICE']: tmp['SPICE'] = -100 102 | capsById[img_id][i]['scores'] = imgToEval[img_id] 103 | 104 | out = {'overall': {}, 'ImgToEval': {}} 105 | for img_id in capsById.keys(): 106 | out['ImgToEval'][img_id] = {} 107 | for metric in capsById[img_id][0]['scores'].keys(): 108 | if metric == 'image_id': continue 109 | out['ImgToEval'][img_id]['oracle_'+metric] = max([_['scores'][metric] for _ in capsById[img_id]]) 110 | out['ImgToEval'][img_id]['avg_'+metric] = sum([_['scores'][metric] for _ in capsById[img_id]]) / len(capsById[img_id]) 111 | out['ImgToEval'][img_id]['captions'] = capsById[img_id] 112 | for metric in list(out['ImgToEval'].values())[0].keys(): 113 | if metric == 'captions': 114 | continue 115 | tmp = np.array([_[metric] for _ in out['ImgToEval'].values()]) 116 | tmp = tmp[tmp!=-100] 117 | out['overall'][metric] = tmp.mean() 118 | 119 | return out 120 | 121 | def eval_div_stats(dataset, preds_n, model_id, split): 122 | tokenizer = PTBTokenizer() 123 | 124 | capsById = {} 125 | for i, d in enumerate(preds_n): 126 | d['id'] = i 127 | capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] 128 | 129 | n_caps_perimg = len(capsById[list(capsById.keys())[0]]) 130 | print(n_caps_perimg) 131 | _capsById = capsById # save the untokenized version 132 | capsById = tokenizer.tokenize(capsById) 133 | 134 | div_1, adiv_1 = compute_div_n(capsById,1) 135 | div_2, adiv_2 = compute_div_n(capsById,2) 136 | 137 | globdiv_1, _= compute_global_div_n(capsById,1) 138 | 139 | print('Diversity Statistics are as follows: \n Div1: %.2f, Div2: %.2f, gDiv1: %d\n'%(div_1,div_2, globdiv_1)) 140 | 141 | # compute mbleu 142 | scorer = Bleu(4) 143 | all_scrs = [] 144 | scrperimg = np.zeros((n_caps_perimg, len(capsById))) 145 | 146 | for i in range(n_caps_perimg): 147 | tempRefsById = {} 148 | candsById = {} 149 | for k in capsById: 150 | tempRefsById[k] = capsById[k][:i] + capsById[k][i+1:] 151 | candsById[k] = [capsById[k][i]] 152 | 153 | score, scores = scorer.compute_score(tempRefsById, candsById) 154 | all_scrs.append(score) 155 | scrperimg[i,:] = scores[1] 156 | 157 | all_scrs = np.array(all_scrs) 158 | 159 | out = {} 160 | out['overall'] = {'Div1': div_1, 'Div2': div_2, 'gDiv1': globdiv_1} 161 | for k, score in zip(range(4), all_scrs.mean(axis=0).tolist()): 162 | out['overall'].update({'mBLeu_%d'%(k+1): score}) 163 | imgToEval = {} 164 | for i,imgid in enumerate(capsById.keys()): 165 | imgToEval[imgid] = {'mBleu_2' : scrperimg[:,i].mean()} 166 | imgToEval[imgid]['individuals'] = [] 167 | for j, d in enumerate(_capsById[imgid]): 168 | imgToEval[imgid]['individuals'].append(preds_n[d['id']]) 169 | imgToEval[imgid]['individuals'][-1]['mBleu_2'] = scrperimg[j,i] 170 | out['ImgToEval'] = imgToEval 171 | 172 | print('Mean mutual Bleu scores on this set is:\nmBLeu_1, mBLeu_2, mBLeu_3, mBLeu_4') 173 | print(all_scrs.mean(axis=0)) 174 | 175 | return out 176 | 177 | def eval_self_cider(dataset, preds_n, model_id, split): 178 | cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json') 179 | 180 | coco = getCOCO(dataset) 181 | valids = coco.getImgIds() 182 | 183 | # Get Cider_scorer 184 | Cider_scorer = Cider(df='corpus') 185 | 186 | tokenizer = PTBTokenizer() 187 | gts = {} 188 | for imgId in valids: 189 | gts[imgId] = coco.imgToAnns[imgId] 190 | gts = tokenizer.tokenize(gts) 191 | 192 | for imgId in valids: 193 | Cider_scorer.cider_scorer += (None, gts[imgId]) 194 | Cider_scorer.cider_scorer.compute_doc_freq() 195 | Cider_scorer.cider_scorer.ref_len = np.log(float(len(Cider_scorer.cider_scorer.crefs))) 196 | 197 | # Prepare captions 198 | capsById = {} 199 | for d in preds_n: 200 | capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] 201 | 202 | capsById = tokenizer.tokenize(capsById) 203 | imgIds = list(capsById.keys()) 204 | scores = Cider_scorer.my_self_cider([capsById[_] for _ in imgIds]) 205 | 206 | def get_div(eigvals): 207 | eigvals = np.clip(eigvals, 0, None) 208 | return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals)) 209 | sc_scores = [get_div(np.linalg.eigvalsh(_/10)) for _ in scores] 210 | score = np.mean(np.array(sc_scores)) 211 | 212 | imgToEval = {} 213 | for i, image_id in enumerate(imgIds): 214 | imgToEval[image_id] = {'self_cider': sc_scores[i], 'self_cider_mat': scores[i].tolist()} 215 | return {'overall': {'self_cider': score}, 'imgToEval': imgToEval} 216 | 217 | 218 | return score 219 | -------------------------------------------------------------------------------- /captioning/utils/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models.resnet 4 | from torchvision.models.resnet import BasicBlock, Bottleneck 5 | 6 | class ResNet(torchvision.models.resnet.ResNet): 7 | def __init__(self, block, layers, num_classes=1000): 8 | super(ResNet, self).__init__(block, layers, num_classes) 9 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 10 | for i in range(2, 5): 11 | getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2) 12 | getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1) 13 | 14 | def resnet18(pretrained=False): 15 | """Constructs a ResNet-18 model. 16 | 17 | Args: 18 | pretrained (bool): If True, returns a model pre-trained on ImageNet 19 | """ 20 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 21 | if pretrained: 22 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 23 | return model 24 | 25 | 26 | def resnet34(pretrained=False): 27 | """Constructs a ResNet-34 model. 28 | 29 | Args: 30 | pretrained (bool): If True, returns a model pre-trained on ImageNet 31 | """ 32 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 33 | if pretrained: 34 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 35 | return model 36 | 37 | 38 | def resnet50(pretrained=False): 39 | """Constructs a ResNet-50 model. 40 | 41 | Args: 42 | pretrained (bool): If True, returns a model pre-trained on ImageNet 43 | """ 44 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 45 | if pretrained: 46 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 47 | return model 48 | 49 | 50 | def resnet101(pretrained=False): 51 | """Constructs a ResNet-101 model. 52 | 53 | Args: 54 | pretrained (bool): If True, returns a model pre-trained on ImageNet 55 | """ 56 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 57 | if pretrained: 58 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 59 | return model 60 | 61 | 62 | def resnet152(pretrained=False): 63 | """Constructs a ResNet-152 model. 64 | 65 | Args: 66 | pretrained (bool): If True, returns a model pre-trained on ImageNet 67 | """ 68 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 69 | if pretrained: 70 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 71 | return model -------------------------------------------------------------------------------- /captioning/utils/resnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class myResnet(nn.Module): 6 | def __init__(self, resnet): 7 | super(myResnet, self).__init__() 8 | self.resnet = resnet 9 | 10 | def forward(self, img, att_size=14): 11 | x = img.unsqueeze(0) 12 | 13 | x = self.resnet.conv1(x) 14 | x = self.resnet.bn1(x) 15 | x = self.resnet.relu(x) 16 | x = self.resnet.maxpool(x) 17 | 18 | x = self.resnet.layer1(x) 19 | x = self.resnet.layer2(x) 20 | x = self.resnet.layer3(x) 21 | x = self.resnet.layer4(x) 22 | 23 | fc = x.mean(3).mean(2).squeeze() 24 | att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0) 25 | 26 | return fc, att 27 | 28 | -------------------------------------------------------------------------------- /captioning/utils/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import torch 4 | import torch.distributed as dist 5 | import collections 6 | import logging 7 | 8 | def get_area(pos): 9 | """ 10 | Args 11 | pos: [B, N, 4] 12 | (x1, x2, y1, y2) 13 | 14 | Return 15 | area : [B, N] 16 | """ 17 | # [B, N] 18 | height = pos[:, :, 3] - pos[:, :, 2] 19 | width = pos[:, :, 1] - pos[:, :, 0] 20 | area = height * width 21 | return area 22 | 23 | def get_relative_distance(pos): 24 | """ 25 | Args 26 | pos: [B, N, 4] 27 | (x1, x2, y1, y2) 28 | 29 | Return 30 | out : [B, N, N, 4] 31 | """ 32 | # B, N = pos.size()[:-1] 33 | 34 | # [B, N, N, 4] 35 | relative_distance = pos.unsqueeze(1) - pos.unsqueeze(2) 36 | 37 | return relative_distance 38 | 39 | 40 | class LossMeter(object): 41 | def __init__(self, maxlen=100): 42 | """Computes and stores the running average""" 43 | self.vals = collections.deque([], maxlen=maxlen) 44 | 45 | def __len__(self): 46 | return len(self.vals) 47 | 48 | def update(self, new_val): 49 | self.vals.append(new_val) 50 | 51 | @property 52 | def val(self): 53 | return sum(self.vals) / len(self.vals) 54 | 55 | def __repr__(self): 56 | return str(self.val) 57 | 58 | 59 | def count_parameters(model): 60 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 61 | 62 | 63 | def load_state_dict(state_dict_path, loc='cpu'): 64 | state_dict = torch.load(state_dict_path, map_location=loc) 65 | # Change Multi GPU to single GPU 66 | original_keys = list(state_dict.keys()) 67 | for key in original_keys: 68 | if key.startswith("module."): 69 | new_key = key[len("module."):] 70 | state_dict[new_key] = state_dict.pop(key) 71 | return state_dict 72 | 73 | 74 | def set_global_logging_level(level=logging.ERROR, prefices=[""]): 75 | """ 76 | Override logging levels of different modules based on their name as a prefix. 77 | It needs to be invoked after the modules have been loaded so that their loggers have been initialized. 78 | 79 | Args: 80 | - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR 81 | - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional. 82 | Default is `[""]` to match all active loggers. 83 | The match is a case-sensitive `module_name.startswith(prefix)` 84 | """ 85 | prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })') 86 | for name in logging.root.manager.loggerDict: 87 | if re.match(prefix_re, name): 88 | logging.getLogger(name).setLevel(level) 89 | 90 | 91 | def get_iou(anchors, gt_boxes): 92 | """ 93 | anchors: (N, 4) torch floattensor 94 | gt_boxes: (K, 4) torch floattensor 95 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 96 | """ 97 | N = anchors.size(0) 98 | 99 | if gt_boxes.size() == (4,): 100 | gt_boxes = gt_boxes.view(1, 4) 101 | K = gt_boxes.size(0) 102 | 103 | gt_boxes_area = ( 104 | (gt_boxes[:, 2] - gt_boxes[:, 0] + 1) * 105 | (gt_boxes[:, 3] - gt_boxes[:, 1] + 1) 106 | ).view(1, K) 107 | 108 | anchors_area = ( 109 | (anchors[:, 2] - anchors[:, 0] + 1) * 110 | (anchors[:, 3] - anchors[:, 1] + 1) 111 | ).view(N, 1) 112 | 113 | boxes = anchors.view(N, 1, 4).expand(N, K, 4) 114 | query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4) 115 | 116 | iw = ( 117 | torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) 118 | - torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) 119 | + 1 120 | ) 121 | iw[iw < 0] = 0 122 | 123 | ih = ( 124 | torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) 125 | - torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) 126 | + 1 127 | ) 128 | ih[ih < 0] = 0 129 | 130 | ua = anchors_area + gt_boxes_area - (iw * ih) 131 | overlaps = iw * ih / ua 132 | 133 | return overlaps 134 | 135 | 136 | def xywh_to_xyxy(boxes): 137 | """Convert [x y w h] box format to [x1 y1 x2 y2] format.""" 138 | return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1)) 139 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/j-min/CLIP-Caption-Reward/ca5fe53b848d7e7b1fb9808984f7a3187b2a32d6/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | __all__ = ["available_models", "load", "tokenize"] 16 | _tokenizer = _Tokenizer() 17 | 18 | _MODELS = { 19 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 20 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 21 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 22 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 23 | } 24 | 25 | 26 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 27 | os.makedirs(root, exist_ok=True) 28 | filename = os.path.basename(url) 29 | 30 | expected_sha256 = url.split("/")[-2] 31 | download_target = os.path.join(root, filename) 32 | 33 | if os.path.exists(download_target) and not os.path.isfile(download_target): 34 | raise RuntimeError(f"{download_target} exists and is not a regular file") 35 | 36 | if os.path.isfile(download_target): 37 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 38 | return download_target 39 | else: 40 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 41 | 42 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 43 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 44 | while True: 45 | buffer = source.read(8192) 46 | if not buffer: 47 | break 48 | 49 | output.write(buffer) 50 | loop.update(len(buffer)) 51 | 52 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 53 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 54 | 55 | return download_target 56 | 57 | 58 | def _transform(n_px): 59 | return Compose([ 60 | Resize(n_px, interpolation=Image.BICUBIC), 61 | CenterCrop(n_px), 62 | lambda image: image.convert("RGB"), 63 | ToTensor(), 64 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 65 | ]) 66 | 67 | 68 | def available_models() -> List[str]: 69 | """Returns the names of available CLIP models""" 70 | return list(_MODELS.keys()) 71 | 72 | 73 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): 74 | """Load a CLIP model 75 | 76 | Parameters 77 | ---------- 78 | name : str 79 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 80 | 81 | device : Union[str, torch.device] 82 | The device to put the loaded model 83 | 84 | jit : bool 85 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 86 | 87 | Returns 88 | ------- 89 | model : torch.nn.Module 90 | The CLIP model 91 | 92 | preprocess : Callable[[PIL.Image], torch.Tensor] 93 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 94 | """ 95 | if name in _MODELS: 96 | model_path = _download(_MODELS[name]) 97 | elif os.path.isfile(name): 98 | model_path = name 99 | else: 100 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 101 | 102 | try: 103 | # loading JIT archive 104 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 105 | state_dict = None 106 | except RuntimeError: 107 | # loading saved state dict 108 | if jit: 109 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 110 | jit = False 111 | state_dict = torch.load(model_path, map_location="cpu") 112 | 113 | if not jit: 114 | model = build_model(state_dict or model.state_dict()).to(device) 115 | if str(device) == "cpu": 116 | model.float() 117 | return model, _transform(model.visual.input_resolution) 118 | 119 | # patch the device names 120 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 121 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 122 | 123 | def patch_device(module): 124 | graphs = [module.graph] if hasattr(module, "graph") else [] 125 | if hasattr(module, "forward1"): 126 | graphs.append(module.forward1.graph) 127 | 128 | for graph in graphs: 129 | for node in graph.findAllNodes("prim::Constant"): 130 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 131 | node.copyAttributes(device_node) 132 | 133 | model.apply(patch_device) 134 | patch_device(model.encode_image) 135 | patch_device(model.encode_text) 136 | 137 | # patch dtype to float32 on CPU 138 | if str(device) == "cpu": 139 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 140 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 141 | float_node = float_input.node() 142 | 143 | def patch_float(module): 144 | graphs = [module.graph] if hasattr(module, "graph") else [] 145 | if hasattr(module, "forward1"): 146 | graphs.append(module.forward1.graph) 147 | 148 | for graph in graphs: 149 | for node in graph.findAllNodes("aten::to"): 150 | inputs = list(node.inputs()) 151 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 152 | if inputs[i].node()["value"] == 5: 153 | inputs[i].node().copyAttributes(float_node) 154 | 155 | model.apply(patch_float) 156 | patch_float(model.encode_image) 157 | patch_float(model.encode_text) 158 | 159 | model.float() 160 | 161 | return model, _transform(model.input_resolution.item()) 162 | 163 | 164 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 165 | """ 166 | Returns the tokenized representation of given input string(s) 167 | 168 | Parameters 169 | ---------- 170 | texts : Union[str, List[str]] 171 | An input string or a list of input strings to tokenize 172 | 173 | context_length : int 174 | The context length to use; all CLIP models use 77 as the context length 175 | 176 | Returns 177 | ------- 178 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 179 | """ 180 | if isinstance(texts, str): 181 | texts = [texts] 182 | 183 | sot_token = _tokenizer.encoder["<|startoftext|>"] 184 | eot_token = _tokenizer.encoder["<|endoftext|>"] 185 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 186 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 187 | 188 | for i, tokens in enumerate(all_tokens): 189 | if len(tokens) > context_length: 190 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 191 | result[i, :len(tokens)] = torch.tensor(tokens) 192 | 193 | return result 194 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | cuda: "11.0" 3 | gpu: true 4 | python_version: "3.7" 5 | system_packages: 6 | - "libgl1-mesa-glx" 7 | - "libglib2.0-0" 8 | python_packages: 9 | - "ipython==7.21.0" 10 | - "transformers==4.19.2" 11 | - "h5py==3.7.0" 12 | - "numpy==1.20.3" 13 | - "pandas==1.3.3" 14 | - "scikit-image==0.18.3" 15 | - "ipywidgets==7.7.0" 16 | - "wandb==0.12.17" 17 | - "bert-score==0.3.11" 18 | - "ftfy==6.1.1" 19 | - "timm==0.5.4" 20 | - "lmdbdict==0.2.2" 21 | - "yacs==0.1.8" 22 | - "pyemd==0.5.1" 23 | - "gensim==4.2.0" 24 | - "pytorch-lightning==1.6.3" 25 | 26 | predict: "predict.py:Predictor" 27 | -------------------------------------------------------------------------------- /configs/phase1/FineCapEval_clipRN50_mle.yml: -------------------------------------------------------------------------------- 1 | caption_model: transformer 2 | noamopt: true 3 | noamopt_warmup: 20000 4 | label_smoothing: 0.0 5 | input_json: data/FineCapEval.json 6 | input_label_h5: none 7 | input_fc_dir: data/FineCapEval_clip_RN50_fc 8 | input_att_dir: data/FineCapEval_clip_RN50_att 9 | input_clipscore_vis_dir: data/FineCapEval_clipscore_vis 10 | 11 | seq_per_img: 5 12 | batch_size: 200 13 | learning_rate: 0.0005 14 | 15 | checkpoint_path: ./save/clipRN50_mle/clipRN50_mle 16 | 17 | # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' 18 | 19 | # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 20 | # N=num_layers 21 | # d_model=input_encoding_size 22 | # d_ff=rnn_size 23 | 24 | # will be ignored 25 | num_layers: 6 26 | input_encoding_size: 512 27 | rnn_size: 2048 28 | 29 | # Transformer config 30 | N_enc: 6 31 | N_dec: 6 32 | d_model: 512 33 | d_ff: 2048 34 | num_att_heads: 8 35 | dropout: 0.1 36 | 37 | 38 | learning_rate_decay_start: 0 39 | scheduled_sampling_start: -1 40 | save_checkpoint_every: 3000 41 | language_eval: 1 42 | val_images_use: 5000 43 | max_epochs: 15 44 | train_sample_n: 5 45 | 46 | REFORWARD: false 47 | 48 | # _BASE_: transformer.yml 49 | reduce_on_plateau: false 50 | noamopt: false 51 | learning_rate: 0.000005 52 | learning_rate_decay_start: -1 53 | 54 | self_critical_after: 15 55 | max_epochs: 50 56 | 57 | verbose: false 58 | precision: 32 59 | 60 | use_clipscore: false -------------------------------------------------------------------------------- /configs/phase1/clipRN50_mle.yml: -------------------------------------------------------------------------------- 1 | caption_model: transformer 2 | noamopt: true 3 | # noamopt: false 4 | noamopt_warmup: 20000 5 | label_smoothing: 0.0 6 | input_json: data/cocotalk.json 7 | input_label_h5: data/cocotalk_label.h5 8 | input_fc_dir: data/cocotalk_clip_RN50_fc 9 | input_att_dir: data/cocotalk_clip_RN50_att 10 | input_clipscore_vis_dir: data/cocotalk_clipscore_vis 11 | seq_per_img: 5 12 | # batch_size: 600 13 | batch_size: 200 14 | 15 | learning_rate: 0.0005 16 | 17 | # checkpoint_path: ./save/trans_clip_rn50_sc_pl 18 | checkpoint_path: save/clipRN50_mle/clipRN50_mle 19 | 20 | # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 21 | # N=num_layers 22 | # d_model=input_encoding_size 23 | # d_ff=rnn_size 24 | 25 | # will be ignored 26 | num_layers: 6 27 | input_encoding_size: 512 28 | rnn_size: 2048 29 | 30 | # Transformer config 31 | N_enc: 6 32 | N_dec: 6 33 | d_model: 512 34 | d_ff: 2048 35 | num_att_heads: 8 36 | dropout: 0.1 37 | 38 | 39 | learning_rate_decay_start: 0 40 | scheduled_sampling_start: -1 41 | save_checkpoint_every: 3000 42 | language_eval: 1 43 | val_images_use: 5000 44 | # max_epochs: 15 45 | max_epochs: 25 46 | train_sample_n: 5 47 | 48 | REFORWARD: false 49 | 50 | 51 | verbose: false 52 | precision: 16 -------------------------------------------------------------------------------- /configs/phase1/transformer.yml: -------------------------------------------------------------------------------- 1 | caption_model: transformer 2 | noamopt: true 3 | noamopt_warmup: 20000 4 | label_smoothing: 0.0 5 | input_json: data/cocotalk.json 6 | input_label_h5: data/cocotalk_label.h5 7 | input_att_dir: data/cocotalk_att 8 | seq_per_img: 5 9 | batch_size: 10 10 | learning_rate: 0.0005 11 | 12 | checkpoint_path: ./save/trans_rn50_sc 13 | 14 | # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 15 | # N=num_layers 16 | # d_model=input_encoding_size 17 | # d_ff=rnn_size 18 | 19 | # will be ignored 20 | num_layers: 6 21 | input_encoding_size: 512 22 | rnn_size: 2048 23 | 24 | # Transformer config 25 | N_enc: 6 26 | N_dec: 6 27 | d_model: 512 28 | d_ff: 2048 29 | num_att_heads: 8 30 | dropout: 0.1 31 | 32 | 33 | learning_rate_decay_start: 0 34 | scheduled_sampling_start: -1 35 | save_checkpoint_every: 3000 36 | language_eval: 1 37 | val_images_use: 5000 38 | max_epochs: 15 39 | train_sample_n: 5 40 | 41 | REFORWARD: false -------------------------------------------------------------------------------- /configs/phase2/FineCapEval_clipRN50_cider.yml: -------------------------------------------------------------------------------- 1 | caption_model: transformer 2 | noamopt: true 3 | noamopt_warmup: 20000 4 | label_smoothing: 0.0 5 | input_json: data/FineCapEval.json 6 | input_label_h5: none 7 | input_fc_dir: data/FineCapEval_clip_RN50_fc 8 | input_att_dir: data/FineCapEval_clip_RN50_att 9 | input_clipscore_vis_dir: data/FineCapEval_clipscore_vis 10 | 11 | seq_per_img: 5 12 | batch_size: 200 13 | learning_rate: 0.0005 14 | 15 | checkpoint_path: ./save/clipRN50_cider/clipRN50_cider 16 | 17 | # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' 18 | 19 | # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 20 | # N=num_layers 21 | # d_model=input_encoding_size 22 | # d_ff=rnn_size 23 | 24 | # will be ignored 25 | num_layers: 6 26 | input_encoding_size: 512 27 | rnn_size: 2048 28 | 29 | # Transformer config 30 | N_enc: 6 31 | N_dec: 6 32 | d_model: 512 33 | d_ff: 2048 34 | num_att_heads: 8 35 | dropout: 0.1 36 | 37 | 38 | learning_rate_decay_start: 0 39 | scheduled_sampling_start: -1 40 | save_checkpoint_every: 3000 41 | language_eval: 1 42 | val_images_use: 5000 43 | max_epochs: 15 44 | train_sample_n: 5 45 | 46 | REFORWARD: false 47 | 48 | # _BASE_: transformer.yml 49 | reduce_on_plateau: false 50 | noamopt: false 51 | learning_rate: 0.000005 52 | learning_rate_decay_start: -1 53 | 54 | self_critical_after: 15 55 | max_epochs: 50 56 | 57 | verbose: false 58 | precision: 32 59 | 60 | # use_clipscore: true 61 | use_clipscore: false -------------------------------------------------------------------------------- /configs/phase2/FineCapEval_clipRN50_cider_clips.yml: -------------------------------------------------------------------------------- 1 | caption_model: transformer 2 | noamopt: true 3 | noamopt_warmup: 20000 4 | label_smoothing: 0.0 5 | input_json: data/FineCapEval.json 6 | input_label_h5: none 7 | input_fc_dir: data/FineCapEval_clip_RN50_fc 8 | input_att_dir: data/FineCapEval_clip_RN50_att 9 | input_clipscore_vis_dir: data/FineCapEval_clipscore_vis 10 | 11 | seq_per_img: 5 12 | batch_size: 200 13 | learning_rate: 0.0005 14 | 15 | checkpoint_path: ./save/clipRN50_cider_clips/clipRN50_cider_clips 16 | 17 | # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' 18 | 19 | # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 20 | # N=num_layers 21 | # d_model=input_encoding_size 22 | # d_ff=rnn_size 23 | 24 | # will be ignored 25 | num_layers: 6 26 | input_encoding_size: 512 27 | rnn_size: 2048 28 | 29 | # Transformer config 30 | N_enc: 6 31 | N_dec: 6 32 | d_model: 512 33 | d_ff: 2048 34 | num_att_heads: 8 35 | dropout: 0.1 36 | 37 | 38 | learning_rate_decay_start: 0 39 | scheduled_sampling_start: -1 40 | save_checkpoint_every: 3000 41 | language_eval: 1 42 | val_images_use: 5000 43 | max_epochs: 15 44 | train_sample_n: 5 45 | 46 | REFORWARD: false 47 | 48 | # _BASE_: transformer.yml 49 | reduce_on_plateau: false 50 | noamopt: false 51 | learning_rate: 0.000005 52 | learning_rate_decay_start: -1 53 | 54 | self_critical_after: 15 55 | max_epochs: 50 56 | 57 | verbose: false 58 | precision: 32 59 | 60 | # use_clipscore: true 61 | use_clipscore: false 62 | clipscore_reward_weight: 2.0 63 | clipscore_mode: clip_s 64 | 65 | use_multi_rewards: true -------------------------------------------------------------------------------- /configs/phase2/FineCapEval_clipRN50_clips.yml: -------------------------------------------------------------------------------- 1 | caption_model: transformer 2 | noamopt: true 3 | noamopt_warmup: 20000 4 | label_smoothing: 0.0 5 | input_json: data/FineCapEval.json 6 | input_label_h5: none 7 | input_fc_dir: data/FineCapEval_clip_RN50_fc 8 | input_att_dir: data/FineCapEval_clip_RN50_att 9 | input_clipscore_vis_dir: data/FineCapEval_clipscore_vis 10 | seq_per_img: 5 11 | batch_size: 160 12 | learning_rate: 0.0005 13 | 14 | checkpoint_path: ./save/clipRN50_clips/clipRN50_clips 15 | 16 | use_multi_rewards: false 17 | use_grammar: false 18 | use_grammar_baseline: false 19 | # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' 20 | 21 | # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 22 | # N=num_layers 23 | # d_model=input_encoding_size 24 | # d_ff=rnn_size 25 | 26 | # will be ignored 27 | num_layers: 6 28 | input_encoding_size: 512 29 | rnn_size: 2048 30 | 31 | # Transformer config 32 | N_enc: 6 33 | N_dec: 6 34 | d_model: 512 35 | d_ff: 2048 36 | num_att_heads: 8 37 | dropout: 0.1 38 | 39 | 40 | learning_rate_decay_start: 0 41 | scheduled_sampling_start: -1 42 | save_checkpoint_every: 3000 43 | language_eval: 0 44 | val_images_use: 5000 45 | max_epochs: 15 46 | train_sample_n: 5 47 | 48 | REFORWARD: false 49 | 50 | # _BASE_: transformer.yml 51 | reduce_on_plateau: false 52 | noamopt: false 53 | learning_rate: 0.000005 54 | learning_rate_decay_start: -1 55 | 56 | self_critical_after: 15 57 | max_epochs: 50 58 | 59 | verbose: false 60 | precision: 32 61 | 62 | # use_clipscore: true 63 | use_clipscore: false 64 | clipscore_reward_weight: 2.0 -------------------------------------------------------------------------------- /configs/phase2/FineCapEval_clipRN50_clips_grammar.yml: -------------------------------------------------------------------------------- 1 | caption_model: transformer 2 | noamopt: true 3 | noamopt_warmup: 20000 4 | label_smoothing: 0.0 5 | input_json: data/FineCapEval.json 6 | input_label_h5: none 7 | input_fc_dir: data/FineCapEval_clip_RN50_fc 8 | input_att_dir: data/FineCapEval_clip_RN50_att 9 | input_clipscore_vis_dir: data/FineCapEval_clipscore_vis 10 | seq_per_img: 5 11 | batch_size: 160 12 | learning_rate: 0.0005 13 | 14 | checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar 15 | 16 | use_multi_rewards: true 17 | use_grammar: true 18 | use_grammar_baseline: true 19 | # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' 20 | 21 | # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 22 | # N=num_layers 23 | # d_model=input_encoding_size 24 | # d_ff=rnn_size 25 | 26 | # will be ignored 27 | num_layers: 6 28 | input_encoding_size: 512 29 | rnn_size: 2048 30 | 31 | # Transformer config 32 | N_enc: 6 33 | N_dec: 6 34 | d_model: 512 35 | d_ff: 2048 36 | num_att_heads: 8 37 | dropout: 0.1 38 | 39 | 40 | learning_rate_decay_start: 0 41 | scheduled_sampling_start: -1 42 | save_checkpoint_every: 3000 43 | language_eval: 0 44 | val_images_use: 5000 45 | max_epochs: 15 46 | train_sample_n: 5 47 | 48 | REFORWARD: false 49 | 50 | # _BASE_: transformer.yml 51 | reduce_on_plateau: false 52 | noamopt: false 53 | learning_rate: 0.000005 54 | learning_rate_decay_start: -1 55 | 56 | self_critical_after: 15 57 | max_epochs: 50 58 | 59 | verbose: false 60 | precision: 32 61 | 62 | # use_clipscore: true 63 | use_clipscore: false 64 | clipscore_reward_weight: 2.0 -------------------------------------------------------------------------------- /configs/phase2/clipRN50_cider.yml: -------------------------------------------------------------------------------- 1 | caption_model: transformer 2 | noamopt: true 3 | noamopt_warmup: 20000 4 | label_smoothing: 0.0 5 | input_json: data/cocotalk.json 6 | input_label_h5: data/cocotalk_label.h5 7 | input_fc_dir: data/cocotalk_clip_RN50_fc 8 | input_att_dir: data/cocotalk_clip_RN50_att 9 | # used only for evaluation 10 | input_clipscore_vis_dir: data/cocotalk_clipscore_vis 11 | 12 | seq_per_img: 5 13 | batch_size: 200 14 | learning_rate: 0.0005 15 | 16 | # checkpoint_path: ./save/trans_clip_rn50_sc_pl_scst_cider 17 | checkpoint_path: save/clipRN50_cider/clipRN50_cider 18 | 19 | # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 20 | # N=num_layers 21 | # d_model=input_encoding_size 22 | # d_ff=rnn_size 23 | 24 | # will be ignored 25 | num_layers: 6 26 | input_encoding_size: 512 27 | rnn_size: 2048 28 | 29 | # Transformer config 30 | N_enc: 6 31 | N_dec: 6 32 | d_model: 512 33 | d_ff: 2048 34 | num_att_heads: 8 35 | dropout: 0.1 36 | 37 | 38 | learning_rate_decay_start: 0 39 | scheduled_sampling_start: -1 40 | save_checkpoint_every: 3000 41 | language_eval: 1 42 | val_images_use: 5000 43 | max_epochs: 15 44 | train_sample_n: 5 45 | 46 | REFORWARD: false 47 | 48 | # _BASE_: transformer.yml 49 | reduce_on_plateau: false 50 | noamopt: false 51 | learning_rate: 0.000005 52 | learning_rate_decay_start: -1 53 | 54 | self_critical_after: 15 55 | max_epochs: 40 56 | 57 | verbose: false 58 | precision: 32 -------------------------------------------------------------------------------- /configs/phase2/clipRN50_cider_clips.yml: -------------------------------------------------------------------------------- 1 | caption_model: transformer 2 | noamopt: true 3 | noamopt_warmup: 20000 4 | label_smoothing: 0.0 5 | input_json: data/cocotalk.json 6 | input_label_h5: data/cocotalk_label.h5 7 | input_fc_dir: data/cocotalk_clip_RN50_fc 8 | input_att_dir: data/cocotalk_clip_RN50_att 9 | input_clipscore_vis_dir: data/cocotalk_clipscore_vis 10 | seq_per_img: 5 11 | batch_size: 160 12 | learning_rate: 0.0005 13 | 14 | checkpoint_path: save/clipRN50_cider_clips/clipRN50_cider_clips 15 | 16 | # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 17 | # N=num_layers 18 | # d_model=input_encoding_size 19 | # d_ff=rnn_size 20 | 21 | # will be ignored 22 | num_layers: 6 23 | input_encoding_size: 512 24 | rnn_size: 2048 25 | 26 | # Transformer config 27 | N_enc: 6 28 | N_dec: 6 29 | d_model: 512 30 | d_ff: 2048 31 | num_att_heads: 8 32 | dropout: 0.1 33 | 34 | 35 | learning_rate_decay_start: 0 36 | scheduled_sampling_start: -1 37 | save_checkpoint_every: 3000 38 | language_eval: 1 39 | val_images_use: 5000 40 | max_epochs: 15 41 | train_sample_n: 5 42 | 43 | REFORWARD: false 44 | 45 | # _BASE_: transformer.yml 46 | reduce_on_plateau: false 47 | noamopt: false 48 | learning_rate: 0.000005 49 | learning_rate_decay_start: -1 50 | 51 | self_critical_after: 15 52 | max_epochs: 40 53 | 54 | verbose: false 55 | precision: 32 56 | 57 | use_clipscore: true 58 | clipscore_reward_weight: 2.0 59 | clipscore_mode: clip_s 60 | 61 | use_multi_rewards: true -------------------------------------------------------------------------------- /configs/phase2/clipRN50_clips.yml: -------------------------------------------------------------------------------- 1 | caption_model: transformer 2 | noamopt: true 3 | noamopt_warmup: 20000 4 | label_smoothing: 0.0 5 | input_json: data/cocotalk.json 6 | input_label_h5: data/cocotalk_label.h5 7 | input_fc_dir: data/cocotalk_clip_RN50_fc 8 | input_att_dir: data/cocotalk_clip_RN50_att 9 | input_clipscore_vis_dir: data/cocotalk_clipscore_vis 10 | seq_per_img: 5 11 | batch_size: 160 12 | learning_rate: 0.0005 13 | 14 | checkpoint_path: save/clipRN50_clips/clipRN50_clips 15 | 16 | # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 17 | # N=num_layers 18 | # d_model=input_encoding_size 19 | # d_ff=rnn_size 20 | 21 | # will be ignored 22 | num_layers: 6 23 | input_encoding_size: 512 24 | rnn_size: 2048 25 | 26 | # Transformer config 27 | N_enc: 6 28 | N_dec: 6 29 | d_model: 512 30 | d_ff: 2048 31 | num_att_heads: 8 32 | dropout: 0.1 33 | 34 | 35 | learning_rate_decay_start: 0 36 | scheduled_sampling_start: -1 37 | save_checkpoint_every: 3000 38 | language_eval: 1 39 | val_images_use: 5000 40 | max_epochs: 15 41 | train_sample_n: 5 42 | 43 | REFORWARD: false 44 | 45 | # _BASE_: transformer.yml 46 | reduce_on_plateau: false 47 | noamopt: false 48 | learning_rate: 0.000005 49 | learning_rate_decay_start: -1 50 | 51 | self_critical_after: 15 52 | max_epochs: 40 53 | 54 | verbose: false 55 | precision: 32 56 | 57 | use_clipscore: true 58 | clipscore_reward_weight: 2.0 -------------------------------------------------------------------------------- /configs/phase2/clipRN50_clips_grammar.yml: -------------------------------------------------------------------------------- 1 | caption_model: transformer 2 | noamopt: true 3 | noamopt_warmup: 20000 4 | label_smoothing: 0.0 5 | input_json: data/cocotalk.json 6 | input_label_h5: data/cocotalk_label.h5 7 | input_fc_dir: data/cocotalk_clip_RN50_fc 8 | input_att_dir: data/cocotalk_clip_RN50_att 9 | input_clipscore_vis_dir: data/cocotalk_clipscore_vis 10 | seq_per_img: 5 11 | batch_size: 160 12 | learning_rate: 0.0005 13 | 14 | checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar 15 | 16 | use_multi_rewards: true 17 | use_grammar: true 18 | use_grammar_baseline: true 19 | # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' 20 | clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt' 21 | 22 | # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 23 | # N=num_layers 24 | # d_model=input_encoding_size 25 | # d_ff=rnn_size 26 | 27 | # will be ignored 28 | num_layers: 6 29 | input_encoding_size: 512 30 | rnn_size: 2048 31 | 32 | # Transformer config 33 | N_enc: 6 34 | N_dec: 6 35 | d_model: 512 36 | d_ff: 2048 37 | num_att_heads: 8 38 | dropout: 0.1 39 | 40 | 41 | learning_rate_decay_start: 0 42 | scheduled_sampling_start: -1 43 | save_checkpoint_every: 3000 44 | language_eval: 1 45 | val_images_use: 5000 46 | max_epochs: 15 47 | train_sample_n: 5 48 | 49 | REFORWARD: false 50 | 51 | # _BASE_: transformer.yml 52 | reduce_on_plateau: false 53 | noamopt: false 54 | learning_rate: 0.000005 55 | learning_rate_decay_start: -1 56 | 57 | self_critical_after: 15 58 | max_epochs: 40 59 | 60 | verbose: false 61 | precision: 32 62 | 63 | use_clipscore: true 64 | clipscore_reward_weight: 2.0 -------------------------------------------------------------------------------- /configs/phase2/transformer.yml: -------------------------------------------------------------------------------- 1 | caption_model: transformer 2 | noamopt: true 3 | noamopt_warmup: 20000 4 | label_smoothing: 0.0 5 | input_json: data/cocotalk.json 6 | input_label_h5: data/cocotalk_label.h5 7 | input_att_dir: data/cocotalk_att 8 | seq_per_img: 5 9 | batch_size: 10 10 | learning_rate: 0.0005 11 | 12 | checkpoint_path: ./save/trans_rn50_sc 13 | 14 | # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 15 | # N=num_layers 16 | # d_model=input_encoding_size 17 | # d_ff=rnn_size 18 | 19 | # will be ignored 20 | num_layers: 6 21 | input_encoding_size: 512 22 | rnn_size: 2048 23 | 24 | # Transformer config 25 | N_enc: 6 26 | N_dec: 6 27 | d_model: 512 28 | d_ff: 2048 29 | num_att_heads: 8 30 | dropout: 0.1 31 | 32 | 33 | learning_rate_decay_start: 0 34 | scheduled_sampling_start: -1 35 | save_checkpoint_every: 3000 36 | language_eval: 1 37 | val_images_use: 5000 38 | max_epochs: 15 39 | train_sample_n: 5 40 | 41 | REFORWARD: false -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | directory to store preprocessed files -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import torch 5 | import torch.nn as nn 6 | import clip 7 | import pytorch_lightning as pl 8 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 9 | from PIL import Image 10 | from timm.models.vision_transformer import resize_pos_embed 11 | from cog import BasePredictor, Path, Input 12 | 13 | import captioning.utils.opts as opts 14 | import captioning.models as models 15 | import captioning.utils.misc as utils 16 | 17 | 18 | class Predictor(BasePredictor): 19 | def setup(self): 20 | import __main__ 21 | __main__.ModelCheckpoint = pl.callbacks.ModelCheckpoint 22 | 23 | self.device = torch.device("cuda:0") 24 | self.dict_json = json.load(open("./data/cocotalk.json")) 25 | self.ix_to_word = self.dict_json["ix_to_word"] 26 | self.vocab_size = len(self.ix_to_word) 27 | self.clip_model, self.clip_transform = clip.load( 28 | "RN50", jit=False, device=self.device 29 | ) 30 | 31 | self.preprocess = Compose( 32 | [ 33 | Resize((448, 448), interpolation=Image.BICUBIC), 34 | CenterCrop((448, 448)), 35 | ToTensor(), 36 | ] 37 | ) 38 | 39 | def predict( 40 | self, 41 | image: Path = Input( 42 | description="Input image.", 43 | ), 44 | reward: str = Input( 45 | choices=["mle", "cider", "clips", "cider_clips", "clips_grammar"], 46 | default="clips_grammar", 47 | description="Choose a reward criterion.", 48 | ), 49 | ) -> str: 50 | 51 | self.device = torch.device("cuda:0") 52 | self.dict_json = json.load(open("./data/cocotalk.json")) 53 | self.ix_to_word = self.dict_json["ix_to_word"] 54 | self.vocab_size = len(self.ix_to_word) 55 | self.clip_model, self.clip_transform = clip.load( 56 | "RN50", jit=False, device=self.device 57 | ) 58 | 59 | self.preprocess = Compose( 60 | [ 61 | Resize((448, 448), interpolation=Image.BICUBIC), 62 | CenterCrop((448, 448)), 63 | ToTensor(), 64 | ] 65 | ) 66 | 67 | cfg = ( 68 | f"configs/phase1/clipRN50_{reward}.yml" 69 | if reward == "mle" 70 | else f"configs/phase2/clipRN50_{reward}.yml" 71 | ) 72 | print("Loading cfg from", cfg) 73 | 74 | opt = opts.parse_opt(parse=False, cfg=cfg) 75 | print("vocab size:", self.vocab_size) 76 | 77 | seq_length = 1 78 | opt.vocab_size = self.vocab_size 79 | opt.seq_length = seq_length 80 | 81 | opt.batch_size = 1 82 | opt.vocab = self.ix_to_word 83 | print(opt.caption_model) 84 | 85 | model = models.setup(opt) 86 | del opt.vocab 87 | 88 | ckpt_path = opt.checkpoint_path + "-last.ckpt" 89 | print("Loading checkpoint from", ckpt_path) 90 | raw_state_dict = torch.load(ckpt_path, map_location=self.device) 91 | 92 | strict = True 93 | state_dict = raw_state_dict["state_dict"] 94 | 95 | if "_vocab" in state_dict: 96 | model.vocab = utils.deserialize(state_dict["_vocab"]) 97 | del state_dict["_vocab"] 98 | elif strict: 99 | raise KeyError 100 | if "_opt" in state_dict: 101 | saved_model_opt = utils.deserialize(state_dict["_opt"]) 102 | del state_dict["_opt"] 103 | # Make sure the saved opt is compatible with the curren topt 104 | need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"] 105 | for checkme in need_be_same: 106 | if ( 107 | getattr(saved_model_opt, checkme) 108 | in [ 109 | "updown", 110 | "topdown", 111 | ] 112 | and getattr(opt, checkme) in ["updown", "topdown"] 113 | ): 114 | continue 115 | assert getattr(saved_model_opt, checkme) == getattr(opt, checkme), ( 116 | "Command line argument and saved model disagree on '%s' " % checkme 117 | ) 118 | elif strict: 119 | raise KeyError 120 | res = model.load_state_dict(state_dict, strict) 121 | print(res) 122 | 123 | model = model.to(self.device) 124 | model.eval() 125 | 126 | image_mean = ( 127 | torch.Tensor([0.48145466, 0.4578275, 0.40821073]) 128 | .to(self.device) 129 | .reshape(3, 1, 1) 130 | ) 131 | image_std = ( 132 | torch.Tensor([0.26862954, 0.26130258, 0.27577711]) 133 | .to(self.device) 134 | .reshape(3, 1, 1) 135 | ) 136 | 137 | num_patches = 196 # 600 * 1000 // 32 // 32 138 | pos_embed = nn.Parameter( 139 | torch.zeros( 140 | 1, 141 | num_patches + 1, 142 | self.clip_model.visual.attnpool.positional_embedding.shape[-1], 143 | device=self.device, 144 | ), 145 | ) 146 | pos_embed.weight = resize_pos_embed( 147 | self.clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed 148 | ) 149 | self.clip_model.visual.attnpool.positional_embedding = pos_embed 150 | 151 | with torch.no_grad(): 152 | image = self.preprocess(Image.open(str(image)).convert("RGB")) 153 | image = torch.tensor(np.stack([image])).to(self.device) 154 | image -= image_mean 155 | image /= image_std 156 | 157 | tmp_att, tmp_fc = self.clip_model.encode_image(image) 158 | tmp_att = tmp_att[0].permute(1, 2, 0) 159 | 160 | att_feat = tmp_att 161 | 162 | # Inference configurations 163 | eval_kwargs = {} 164 | eval_kwargs.update(vars(opt)) 165 | 166 | with torch.no_grad(): 167 | fc_feats = torch.zeros((1, 0)).to(self.device) 168 | att_feats = att_feat.view(1, 196, 2048).float().to(self.device) 169 | att_masks = None 170 | 171 | # forward the model to also get generated samples for each image 172 | # Only leave one feature for each image, in case duplicate sample 173 | tmp_eval_kwargs = eval_kwargs.copy() 174 | tmp_eval_kwargs.update({"sample_n": 1}) 175 | seq, seq_logprobs = model( 176 | fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode="sample" 177 | ) 178 | seq = seq.data 179 | 180 | sents = utils.decode_sequence(model.vocab, seq) 181 | 182 | return sents[0] 183 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | h5py 3 | pandas 4 | tqdm 5 | scikit-image 6 | ipywidgets 7 | wandb 8 | bert-score 9 | ftfy 10 | timm 11 | lmdbdict 12 | yacs 13 | pyemd 14 | gensim 15 | pytorch-lightning==1.0.0 16 | -------------------------------------------------------------------------------- /retrieval/README.md: -------------------------------------------------------------------------------- 1 | # Finetuning CLIP reward model 2 | 3 | ```bash 4 | python train_pl.py --cfg clip_negative_text --id clip_negative_text 5 | ``` -------------------------------------------------------------------------------- /retrieval/configs/clip_negative_text.yaml: -------------------------------------------------------------------------------- 1 | checkpoint_dir: ./save/clip_negative_text/ 2 | 3 | losses_log_every: 25 4 | precision: 32 5 | load_feat: true 6 | data_in_memory: false 7 | 8 | batch_size: 1600 9 | valid_batch_size: 200 10 | clip_grad_norm: 0 11 | 12 | epochs: 30 13 | use_grammar: true 14 | joint_out: false -------------------------------------------------------------------------------- /retrieval/param.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import pprint 8 | import yaml 9 | 10 | 11 | def str2bool(v): 12 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 13 | return True 14 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 15 | return False 16 | else: 17 | raise argparse.ArgumentTypeError('Boolean value expected.') 18 | 19 | 20 | def is_interactive(): 21 | import __main__ as main 22 | return not hasattr(main, '__file__') 23 | 24 | 25 | def get_optimizer(optim, verbose=False): 26 | # Bind the optimizer 27 | if optim == 'rms': 28 | if verbose: 29 | print("Optimizer: Using RMSProp") 30 | optimizer = torch.optim.RMSprop 31 | elif optim == 'adam': 32 | if verbose: 33 | print("Optimizer: Using Adam") 34 | optimizer = torch.optim.Adam 35 | elif optim == 'adamw': 36 | if verbose: 37 | print("Optimizer: Using AdamW") 38 | # optimizer = torch.optim.AdamW 39 | optimizer = 'adamw' 40 | elif optim == 'adamax': 41 | if verbose: 42 | print("Optimizer: Using Adamax") 43 | optimizer = torch.optim.Adamax 44 | elif optim == 'sgd': 45 | if verbose: 46 | print("Optimizer: SGD") 47 | optimizer = torch.optim.SGD 48 | else: 49 | assert False, "Please add your optimizer %s in the list." % optim 50 | 51 | return optimizer 52 | 53 | 54 | def parse_args(parse=True, **optional_kwargs): 55 | parser = argparse.ArgumentParser() 56 | 57 | parser.add_argument('--seed', type=int, default=9595, help='random seed') 58 | 59 | # Data Splits 60 | parser.add_argument("--train", default='karpathy_train') 61 | parser.add_argument("--valid", default='karpathy_val') 62 | parser.add_argument("--test", default='karpathy_test') 63 | # parser.add_argument('--test_only', action='store_true') 64 | 65 | # Quick experiments 66 | parser.add_argument('--train_topk', type=int, default=-1) 67 | parser.add_argument('--valid_topk', type=int, default=-1) 68 | 69 | # Checkpoint 70 | parser.add_argument('--output', type=str, default='snap/test') 71 | parser.add_argument('--load', type=str, default=None, help='Load the model (usually the fine-tuned model).') 72 | parser.add_argument('--from_scratch', action='store_true') 73 | 74 | # CPU/GPU 75 | parser.add_argument("--multiGPU", action='store_const', default=False, const=True) 76 | parser.add_argument('--fp16', action='store_true') 77 | parser.add_argument("--distributed", action='store_true') 78 | parser.add_argument("--num_workers", default=0, type=int) 79 | parser.add_argument('--local_rank', type=int, default=-1) 80 | # parser.add_argument('--rank', type=int, default=-1) 81 | 82 | # Model Config 83 | # parser.add_argument('--encoder_backbone', type=str, default='openai/clip-vit-base-patch32') 84 | # parser.add_argument('--decoder_backbone', type=str, default='bert-base-uncased') 85 | parser.add_argument('--tokenizer', type=str, default='openai/clip-vit-base-patch32') 86 | 87 | # parser.add_argument('--position_embedding_type', type=str, default='absolute') 88 | 89 | # parser.add_argument('--encoder_transform', action='store_true') 90 | 91 | parser.add_argument('--max_text_length', type=int, default=40) 92 | 93 | # parser.add_argument('--image_size', type=int, default=224) 94 | # parser.add_argument('--patch_size', type=int, default=32) 95 | 96 | # parser.add_argument('--decoder_num_layers', type=int, default=12) 97 | 98 | # Training 99 | parser.add_argument('--batch_size', type=int, default=256) 100 | parser.add_argument('--valid_batch_size', type=int, default=None) 101 | 102 | parser.add_argument('--optim', default='adamw') 103 | 104 | parser.add_argument('--warmup_ratio', type=float, default=0.05) 105 | parser.add_argument('--weight_decay', type=float, default=0.01) 106 | parser.add_argument('--clip_grad_norm', type=float, default=-1.0) 107 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1) 108 | parser.add_argument('--lr', type=float, default=1e-4) 109 | parser.add_argument('--adam_eps', type=float, default=1e-6) 110 | parser.add_argument('--adam_beta1', type=float, default=0.9) 111 | parser.add_argument('--adam_beta2', type=float, default=0.999) 112 | 113 | parser.add_argument('--epochs', type=int, default=20) 114 | # parser.add_argument('--dropout', type=float, default=0.1) 115 | 116 | 117 | # Inference 118 | # parser.add_argument('--num_beams', type=int, default=1) 119 | # parser.add_argument('--gen_max_length', type=int, default=20) 120 | 121 | parser.add_argument('--start_from', type=str, default=None) 122 | 123 | # Data 124 | # parser.add_argument('--do_lower_case', type=str2bool, default=None) 125 | 126 | # parser.add_argument('--prefix', type=str, default=None) 127 | 128 | 129 | # COCO Caption 130 | # parser.add_argument('--no_prefix', action='store_true') 131 | 132 | parser.add_argument('--no_cls', action='store_true') 133 | 134 | parser.add_argument('--cfg', type=str, default=None) 135 | parser.add_argument('--id', type=str, default=None) 136 | 137 | # Etc. 138 | parser.add_argument('--comment', type=str, default='') 139 | parser.add_argument("--dry", action='store_true') 140 | 141 | # Parse the arguments. 142 | if parse: 143 | args = parser.parse_args() 144 | # For interative engironmnet (ex. jupyter) 145 | else: 146 | args = parser.parse_known_args()[0] 147 | 148 | loaded_kwargs = {} 149 | if args.cfg is not None: 150 | cfg_path = f'configs/{args.cfg}.yaml' 151 | with open(cfg_path, 'r') as f: 152 | loaded_kwargs = yaml.safe_load(f) 153 | 154 | # Namespace => Dictionary 155 | parsed_kwargs = vars(args) 156 | parsed_kwargs.update(optional_kwargs) 157 | 158 | kwargs = {} 159 | kwargs.update(parsed_kwargs) 160 | kwargs.update(loaded_kwargs) 161 | 162 | args = Config(**kwargs) 163 | 164 | # Bind optimizer class. 165 | verbose = False 166 | args.optimizer = get_optimizer(args.optim, verbose=verbose) 167 | 168 | # Set seeds 169 | torch.manual_seed(args.seed) 170 | random.seed(args.seed) 171 | np.random.seed(args.seed) 172 | 173 | return args 174 | 175 | 176 | class Config(object): 177 | def __init__(self, **kwargs): 178 | """Configuration Class: set kwargs as class attributes with setattr""" 179 | for k, v in kwargs.items(): 180 | setattr(self, k, v) 181 | 182 | @property 183 | def config_str(self): 184 | return pprint.pformat(self.__dict__) 185 | 186 | def __repr__(self): 187 | """Pretty-print configurations in alphabetical order""" 188 | config_str = 'Configurations\n' 189 | config_str += self.config_str 190 | return config_str 191 | 192 | # def update(self, **kwargs): 193 | # for k, v in kwargs.items(): 194 | # setattr(self, k, v) 195 | 196 | # def save(self, path): 197 | # with open(path, 'w') as f: 198 | # yaml.dump(self.__dict__, f, default_flow_style=False) 199 | 200 | # @classmethod 201 | # def load(cls, path): 202 | # with open(path, 'r') as f: 203 | # kwargs = yaml.load(f) 204 | 205 | # return Config(**kwargs) 206 | 207 | 208 | if __name__ == '__main__': 209 | args = parse_args(True) 210 | -------------------------------------------------------------------------------- /retrieval/text_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | def repeat(text, n_max_gram=3, n_max_repeat=3): 4 | """repeat n-grams""" 5 | tokens = text.split() 6 | 7 | n_gram = random.randint(1, n_max_gram) 8 | 9 | repeat_token_idx = random.randint(0, len(tokens) - n_gram) 10 | 11 | repeated_tokens = tokens[repeat_token_idx:repeat_token_idx+n_gram] 12 | 13 | n_repeat = random.randint(1, n_max_repeat) 14 | for _ in range(n_repeat): 15 | insert_idx = random.randint(0, len(tokens)) 16 | tokens = tokens[:insert_idx] + \ 17 | repeated_tokens + tokens[insert_idx:] 18 | 19 | new_text = " ".join(tokens) 20 | return new_text 21 | 22 | def remove(text, n_max_gram=3): 23 | """remove n-grams""" 24 | tokens = text.split() 25 | 26 | n_gram = random.randint(1, n_max_gram) 27 | 28 | remove_token_idx = random.randint(0, len(tokens) - n_gram) 29 | 30 | tokens = tokens[:remove_token_idx] + tokens[remove_token_idx + n_gram:] 31 | 32 | new_text = " ".join(tokens) 33 | return new_text 34 | 35 | def insert(text, vocab, n_max_tokens=3): 36 | """Insert tokens""" 37 | tokens = text.split() 38 | 39 | n_insert_token = random.randint(1, n_max_tokens) 40 | 41 | for _ in range(n_insert_token): 42 | insert_token_idx = random.randint(0, len(tokens) - 1) 43 | insert_token = random.choice(vocab) 44 | tokens = tokens[:insert_token_idx] + [insert_token] + tokens[insert_token_idx:] 45 | 46 | new_text = " ".join(tokens) 47 | return new_text 48 | 49 | def swap(text, vocab, n_max_tokens=3): 50 | """Swap tokens""" 51 | tokens = text.split() 52 | 53 | n_swap_tokens = random.randint(1, n_max_tokens) 54 | 55 | for _ in range(n_swap_tokens): 56 | swap_token_idx = random.randint(0, len(tokens) - 1) 57 | 58 | swap_token = random.choice(vocab) 59 | while swap_token == tokens[swap_token_idx]: 60 | swap_token = random.choice(vocab) 61 | 62 | tokens[swap_token_idx] = swap_token 63 | 64 | new_text = " ".join(tokens) 65 | return new_text 66 | 67 | def shuffle(text): 68 | """shuffle tokens""" 69 | tokens = text.split() 70 | 71 | random.shuffle(tokens) 72 | 73 | new_text = " ".join(tokens) 74 | return new_text 75 | -------------------------------------------------------------------------------- /save/README.md: -------------------------------------------------------------------------------- 1 | Directory for checkpoints -------------------------------------------------------------------------------- /scripts/build_bpe_subword_nmt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 15 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 16 | first and last indices (in range 1..M) of labels for each image 17 | /label_length stores the length of the sequence for each of the M sequences 18 | 19 | The json file has a dict that contains: 20 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 21 | - an 'images' field that is a list holding auxiliary information for each image, 22 | such as in particular the 'split' it was assigned to. 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import os 30 | import json 31 | import argparse 32 | from random import shuffle, seed 33 | import string 34 | # non-standard dependencies: 35 | import h5py 36 | import numpy as np 37 | import torch 38 | import torchvision.models as models 39 | import skimage.io 40 | from PIL import Image 41 | 42 | import codecs 43 | import tempfile 44 | from subword_nmt import learn_bpe, apply_bpe 45 | 46 | # python scripts/build_bpe_subword_nmt.py --input_json data/dataset_coco.json --output_json data/cocotalkbpe.json --output_h5 data/cocotalkbpe 47 | 48 | def build_vocab(imgs, params): 49 | # count up the number of words 50 | captions = [] 51 | for img in imgs: 52 | for sent in img['sentences']: 53 | captions.append(' '.join(sent['tokens'])) 54 | captions='\n'.join(captions) 55 | all_captions = tempfile.NamedTemporaryFile(delete=False) 56 | all_captions.close() 57 | with open(all_captions.name, 'w') as txt_file: 58 | txt_file.write(captions) 59 | 60 | # 61 | codecs_output = tempfile.NamedTemporaryFile(delete=False) 62 | codecs_output.close() 63 | with codecs.open(codecs_output.name, 'w', encoding='UTF-8') as output: 64 | learn_bpe.learn_bpe(codecs.open(all_captions.name, encoding='UTF-8'), output, params['symbol_count']) 65 | 66 | with codecs.open(codecs_output.name, encoding='UTF-8') as codes: 67 | bpe = apply_bpe.BPE(codes) 68 | 69 | tmp = tempfile.NamedTemporaryFile(delete=False) 70 | tmp.close() 71 | 72 | tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8') 73 | 74 | for _, img in enumerate(imgs): 75 | img['final_captions'] = [] 76 | for sent in img['sentences']: 77 | txt = ' '.join(sent['tokens']) 78 | txt = bpe.segment(txt).strip() 79 | img['final_captions'].append(txt.split(' ')) 80 | tmpout.write(txt) 81 | tmpout.write('\n') 82 | if _ < 20: 83 | print(txt) 84 | 85 | tmpout.close() 86 | tmpin = codecs.open(tmp.name, encoding='UTF-8') 87 | 88 | vocab = learn_bpe.get_vocabulary(tmpin) 89 | vocab = sorted(vocab.keys(), key=lambda x: vocab[x], reverse=True) 90 | 91 | # Always insert UNK 92 | print('inserting the special UNK token') 93 | vocab.append('UNK') 94 | 95 | print('Vocab size:', len(vocab)) 96 | 97 | os.remove(all_captions.name) 98 | with open(codecs_output.name, 'r') as codes: 99 | bpe = codes.read() 100 | os.remove(codecs_output.name) 101 | os.remove(tmp.name) 102 | 103 | return vocab, bpe 104 | 105 | def encode_captions(imgs, params, wtoi): 106 | """ 107 | encode all captions into one large array, which will be 1-indexed. 108 | also produces label_start_ix and label_end_ix which store 1-indexed 109 | and inclusive (Lua-style) pointers to the first and last caption for 110 | each image in the dataset. 111 | """ 112 | 113 | max_length = params['max_length'] 114 | N = len(imgs) 115 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions 116 | 117 | label_arrays = [] 118 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed 119 | label_end_ix = np.zeros(N, dtype='uint32') 120 | label_length = np.zeros(M, dtype='uint32') 121 | caption_counter = 0 122 | counter = 1 123 | for i,img in enumerate(imgs): 124 | n = len(img['final_captions']) 125 | assert n > 0, 'error: some image has no captions' 126 | 127 | Li = np.zeros((n, max_length), dtype='uint32') 128 | for j,s in enumerate(img['final_captions']): 129 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence 130 | caption_counter += 1 131 | for k,w in enumerate(s): 132 | if k < max_length: 133 | Li[j,k] = wtoi[w] 134 | 135 | # note: word indices are 1-indexed, and captions are padded with zeros 136 | label_arrays.append(Li) 137 | label_start_ix[i] = counter 138 | label_end_ix[i] = counter + n - 1 139 | 140 | counter += n 141 | 142 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 143 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 144 | assert np.all(label_length > 0), 'error: some caption had no words?' 145 | 146 | print('encoded captions to array of size ', L.shape) 147 | return L, label_start_ix, label_end_ix, label_length 148 | 149 | def main(params): 150 | 151 | imgs = json.load(open(params['input_json'], 'r')) 152 | imgs = imgs['images'] 153 | 154 | seed(123) # make reproducible 155 | 156 | # create the vocab 157 | vocab, bpe = build_vocab(imgs, params) 158 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 159 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 160 | 161 | # encode captions in large arrays, ready to ship to hdf5 file 162 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) 163 | 164 | # create output h5 file 165 | N = len(imgs) 166 | f_lb = h5py.File(params['output_h5']+'_label.h5', "w") 167 | f_lb.create_dataset("labels", dtype='uint32', data=L) 168 | f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) 169 | f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) 170 | f_lb.create_dataset("label_length", dtype='uint32', data=label_length) 171 | f_lb.close() 172 | 173 | # create output json file 174 | out = {} 175 | out['ix_to_word'] = itow # encode the (1-indexed) vocab 176 | out['images'] = [] 177 | out['bpe'] = bpe 178 | for i,img in enumerate(imgs): 179 | 180 | jimg = {} 181 | jimg['split'] = img['split'] 182 | if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need 183 | if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) 184 | 185 | if params['images_root'] != '': 186 | with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: 187 | jimg['width'], jimg['height'] = _img.size 188 | 189 | out['images'].append(jimg) 190 | 191 | json.dump(out, open(params['output_json'], 'w')) 192 | print('wrote ', params['output_json']) 193 | 194 | if __name__ == "__main__": 195 | 196 | parser = argparse.ArgumentParser() 197 | 198 | # input json 199 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 200 | parser.add_argument('--output_json', default='data.json', help='output json file') 201 | parser.add_argument('--output_h5', default='data', help='output h5 file') 202 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 203 | 204 | # options 205 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 206 | parser.add_argument('--symbol_count', default=10000, type=int, help='only words that occur more than this number of times will be put in vocab') 207 | 208 | args = parser.parse_args() 209 | params = vars(args) # convert to ordinary dict 210 | print('parsed input parameters:') 211 | print(json.dumps(params, indent = 2)) 212 | main(params) 213 | 214 | 215 | -------------------------------------------------------------------------------- /scripts/clip_prepro_feats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into features files for use in data_loader.py 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: two folders of features 13 | """ 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import json 21 | import argparse 22 | from random import shuffle, seed 23 | import string 24 | # non-standard dependencies: 25 | import h5py 26 | from six.moves import cPickle 27 | import numpy as np 28 | import torch 29 | import torchvision.models as models 30 | import skimage.io 31 | 32 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 33 | from PIL import Image 34 | from torch import nn 35 | 36 | preprocess = Compose([ 37 | Resize((448, 448), interpolation=Image.BICUBIC), 38 | CenterCrop((448, 448)), 39 | ToTensor() 40 | ]) 41 | 42 | 43 | from clip.clip import load 44 | from timm.models.vision_transformer import resize_pos_embed 45 | import timm 46 | 47 | from captioning.utils.resnet_utils import myResnet 48 | import captioning.utils.resnet as resnet 49 | 50 | from tqdm import tqdm 51 | 52 | 53 | def main(params): 54 | if params["model_type"] != 'vit_base_patch32_224_in21k': 55 | model, transform = load(params["model_type"], jit=False) 56 | else: 57 | model = timm.create_model(params["model_type"], pretrained=True) 58 | model = model.cuda() 59 | 60 | if params["model_type"] != 'vit_base_patch32_224_in21k': 61 | save_model_type = params["model_type"].split("-")[0] 62 | mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to("cuda").reshape(3, 1, 1) 63 | std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to("cuda").reshape(3, 1, 1) 64 | 65 | if "RN" in params["model_type"]: 66 | num_patches = 196 #600 * 1000 // 32 // 32 67 | pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, model.visual.attnpool.positional_embedding.shape[-1], device='cuda'),) 68 | pos_embed.weight = resize_pos_embed(model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed) 69 | model.visual.attnpool.positional_embedding = pos_embed 70 | 71 | else: 72 | save_model_type = 'vit_base' 73 | mean = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1) 74 | std = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1) 75 | 76 | num_patches = 196 #600 * 1000 // 32 // 32 77 | pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, 768, device='cuda'),) 78 | pos_embed.weight = resize_pos_embed(model.pos_embed, pos_embed) 79 | model.pos_embed = pos_embed 80 | 81 | if params["model_type"] == "ViT-B/32": 82 | num_patches = 196 #600 * 1000 // 32 // 32 83 | pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768, device='cuda'),) 84 | pos_embed.weight = resize_pos_embed(model.visual.positional_embedding.unsqueeze(0), pos_embed.unsqueeze(0)) 85 | model.visual.positional_embedding = pos_embed 86 | imgs = json.load(open(params['input_json'], 'r')) 87 | 88 | imgs = imgs['images'] 89 | 90 | if args.n_jobs > 1: 91 | print('Total imgs:', len(imgs)) 92 | print('Using {} jobs'.format(args.n_jobs)) 93 | print('job id:', args.job_id) 94 | imgs = imgs[args.job_id::args.n_jobs] 95 | 96 | N = len(imgs) 97 | 98 | seed(123) # make reproducible 99 | 100 | dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc' 101 | dir_att = params['output_dir']+'_clip_'+save_model_type+'_att' 102 | if not os.path.isdir(dir_fc): 103 | os.mkdir(dir_fc) 104 | if not os.path.isdir(dir_att): 105 | os.mkdir(dir_att) 106 | 107 | for i,img in enumerate(tqdm(imgs)): 108 | # load the image 109 | with torch.no_grad(): 110 | 111 | image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB")) 112 | image = torch.tensor(np.stack([image])).cuda() 113 | image -= mean 114 | image /= std 115 | if "RN" in params["model_type"]: 116 | tmp_att, tmp_fc = model.encode_image(image) 117 | tmp_att = tmp_att[0].permute(1, 2, 0) 118 | tmp_fc = tmp_fc[0] 119 | elif params["model_type"] == 'vit_base_patch32_224_in21k': 120 | x = model(image) 121 | tmp_fc = x[0, 0, :] 122 | tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) 123 | else: 124 | x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid] 125 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 126 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 127 | x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 128 | x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :] 129 | x = model.visual.ln_pre(x) 130 | 131 | x = x.permute(1, 0, 2) # NLD -> LND 132 | 133 | for layer_idx, layer in enumerate(model.visual.transformer.resblocks): 134 | x = layer(x) 135 | 136 | x = x.permute(1, 0, 2) 137 | tmp_fc = x[0, 0, :] 138 | tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) 139 | 140 | np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) 141 | np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) 142 | 143 | 144 | # if i % 1000 == 0: 145 | # print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) 146 | print('wrote ', dir_fc, dir_att) 147 | 148 | if __name__ == "__main__": 149 | 150 | parser = argparse.ArgumentParser() 151 | 152 | # input json 153 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 154 | parser.add_argument('--output_dir', default='data', help='output h5 file') 155 | 156 | # options 157 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 158 | parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') 159 | parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k') 160 | 161 | parser.add_argument('--n_jobs', default=-1, type=int, help='number of jobs to run in parallel') 162 | parser.add_argument('--job_id', default=0, type=int, help='job id') 163 | parser.add_argument('--batch_size', default=1, type=int, help='batch size') 164 | 165 | 166 | args = parser.parse_args() 167 | params = vars(args) # convert to ordinary dict 168 | print('parsed input parameters:') 169 | print(json.dumps(params, indent = 2)) 170 | main(params) 171 | -------------------------------------------------------------------------------- /scripts/clipscore_prepro_feats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into features files for use in data_loader.py 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: two folders of features 13 | """ 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import json 21 | import argparse 22 | from random import shuffle, seed 23 | import string 24 | # non-standard dependencies: 25 | import h5py 26 | from six.moves import cPickle 27 | import numpy as np 28 | import torch 29 | import torchvision.models as models 30 | import skimage.io 31 | 32 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 33 | from PIL import Image 34 | from torch import nn 35 | 36 | # preprocess = Compose([ 37 | # Resize((448, 448), interpolation=Image.BICUBIC), 38 | # CenterCrop((448, 448)), 39 | # ToTensor() 40 | # ]) 41 | 42 | 43 | # from clip.clip import load 44 | # from timm.models.vision_transformer import resize_pos_embed 45 | # import timm 46 | 47 | # from captioning.utils.resnet_utils import myResnet 48 | # import captioning.utils.resnet as resnet 49 | 50 | from captioning.utils.clipscore import CLIPScore 51 | 52 | from tqdm import tqdm 53 | 54 | 55 | 56 | def main(params): 57 | 58 | clipscore_model = CLIPScore() 59 | clipscore_model.to('cuda') 60 | 61 | imgs = json.load(open(params['input_json'], 'r')) 62 | imgs = imgs['images'] 63 | 64 | if args.n_jobs > 1: 65 | print('Total imgs:', len(imgs)) 66 | print('Using {} jobs'.format(args.n_jobs)) 67 | print('job id:', args.job_id) 68 | imgs = imgs[args.job_id::args.n_jobs] 69 | 70 | N = len(imgs) 71 | 72 | seed(123) # make reproducible 73 | 74 | # dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc' 75 | # dir_att = params['output_dir']+'_clip_'+save_model_type+'_att' 76 | 77 | vis_dir_fc = params['output_dir']+'_clipscore_vis' 78 | if not os.path.isdir(vis_dir_fc): 79 | os.mkdir(vis_dir_fc) 80 | 81 | # text_dir_fc = params['output_dir']+'_clipscore_text' 82 | # if not os.path.isdir(text_dir_fc): 83 | # os.mkdir(text_dir_fc) 84 | 85 | # if not os.path.isdir(dir_att): 86 | # os.mkdir(dir_att) 87 | 88 | for i, img in enumerate(tqdm(imgs)): 89 | # load the image 90 | 91 | img_path = os.path.join(params['images_root'], img['filepath'], img['filename']) 92 | img_feat = clipscore_model.image_extract(img_path) 93 | img_feat = img_feat.view(512) 94 | 95 | # for d in img['sentences']: 96 | # text = d['raw'].strip() 97 | # text_feat = clipscore_model.text_extract(text) 98 | 99 | 100 | # with torch.no_grad(): 101 | 102 | # image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB")) 103 | # image = torch.tensor(np.stack([image])).cuda() 104 | # image -= mean 105 | # image /= std 106 | # if "RN" in params["model_type"]: 107 | # tmp_att, tmp_fc = model.encode_image(image) 108 | # tmp_att = tmp_att[0].permute(1, 2, 0) 109 | # tmp_fc = tmp_fc[0] 110 | # elif params["model_type"] == 'vit_base_patch32_224_in21k': 111 | # x = model(image) 112 | # tmp_fc = x[0, 0, :] 113 | # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) 114 | # else: 115 | # x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid] 116 | # x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 117 | # x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 118 | # x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 119 | # x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :] 120 | # x = model.visual.ln_pre(x) 121 | 122 | # x = x.permute(1, 0, 2) # NLD -> LND 123 | 124 | # for layer_idx, layer in enumerate(model.visual.transformer.resblocks): 125 | # x = layer(x) 126 | 127 | # x = x.permute(1, 0, 2) 128 | # tmp_fc = x[0, 0, :] 129 | # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) 130 | 131 | np.save(os.path.join(vis_dir_fc, str(img['cocoid'])), img_feat.data.cpu().float().numpy()) 132 | # np.save(os.path.join(text_dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) 133 | 134 | 135 | # np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) 136 | 137 | if i % 1000 == 0: 138 | print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) 139 | print('wrote ', vis_dir_fc) 140 | 141 | if __name__ == "__main__": 142 | 143 | parser = argparse.ArgumentParser() 144 | 145 | # input json 146 | # dataset_coco.json 147 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 148 | parser.add_argument('--output_dir', default='data', help='output h5 file') 149 | 150 | # options 151 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 152 | # parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') 153 | # parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k') 154 | 155 | parser.add_argument('--n_jobs', default=-1, type=int, help='number of jobs to run in parallel') 156 | parser.add_argument('--job_id', default=0, type=int, help='job id') 157 | 158 | args = parser.parse_args() 159 | params = vars(args) # convert to ordinary dict 160 | print('parsed input parameters:') 161 | print(json.dumps(params, indent = 2)) 162 | main(params) 163 | -------------------------------------------------------------------------------- /scripts/copy_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ ! -d log_$2 ]; then 4 | cp -r log_$1 log_$2 5 | cd log_$2 6 | mv infos_$1-best.pkl infos_$2-best.pkl 7 | mv infos_$1.pkl infos_$2.pkl 8 | cd ../ 9 | fi 10 | -------------------------------------------------------------------------------- /scripts/dump_to_h5df.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import h5py 3 | import os 4 | import numpy as np 5 | import json 6 | from tqdm import tqdm 7 | 8 | 9 | def main(params): 10 | 11 | imgs = json.load(open(params['input_json'], 'r')) 12 | imgs = imgs['images'] 13 | N = len(imgs) 14 | 15 | if params['fc_input_dir'] is not None: 16 | print('processing fc') 17 | with h5py.File(params['fc_output']) as file_fc: 18 | for i, img in enumerate(tqdm(imgs)): 19 | npy_fc_path = os.path.join( 20 | params['fc_input_dir'], 21 | str(img['cocoid']) + '.npy') 22 | 23 | d_set_fc = file_fc.create_dataset( 24 | str(img['cocoid']), data=np.load(npy_fc_path)) 25 | file_fc.close() 26 | 27 | if params['att_input_dir'] is not None: 28 | print('processing att') 29 | with h5py.File(params['att_output']) as file_att: 30 | for i, img in enumerate(tqdm(imgs)): 31 | npy_att_path = os.path.join( 32 | params['att_input_dir'], 33 | str(img['cocoid']) + '.npz') 34 | 35 | d_set_att = file_att.create_dataset( 36 | str(img['cocoid']), 37 | data=np.load(npy_att_path)['feat']) 38 | file_att.close() 39 | 40 | 41 | if __name__ == "__main__": 42 | 43 | parser = argparse.ArgumentParser() 44 | 45 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 46 | parser.add_argument('--fc_output', default='data', help='output h5 filename for fc') 47 | parser.add_argument('--att_output', default='data', help='output h5 file for att') 48 | parser.add_argument('--fc_input_dir', default=None, help='input directory for numpy fc files') 49 | parser.add_argument('--att_input_dir', default=None, help='input directory for numpy att files') 50 | 51 | args = parser.parse_args() 52 | params = vars(args) # convert to ordinary dict 53 | print('parsed input parameters:') 54 | print(json.dumps(params, indent=2)) 55 | 56 | main(params) -------------------------------------------------------------------------------- /scripts/dump_to_lmdb.py: -------------------------------------------------------------------------------- 1 | # copy from https://github.com/Lyken17/Efficient-PyTorch/tools 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import os.path as osp 9 | import os, sys 10 | import os.path as osp 11 | from PIL import Image 12 | import six 13 | import string 14 | 15 | from lmdbdict import lmdbdict 16 | from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC 17 | import pickle 18 | import tqdm 19 | import numpy as np 20 | import argparse 21 | import json 22 | 23 | import torch 24 | import torch.utils.data as data 25 | from torch.utils.data import DataLoader 26 | 27 | import csv 28 | csv.field_size_limit(sys.maxsize) 29 | FIELDNAMES = ['image_id', 'status'] 30 | 31 | class FolderLMDB(data.Dataset): 32 | def __init__(self, db_path, fn_list=None): 33 | self.db_path = db_path 34 | self.lmdb = lmdbdict(db_path, unsafe=True) 35 | self.lmdb._key_dumps = DUMPS_FUNC['ascii'] 36 | self.lmdb._value_loads = LOADS_FUNC['identity'] 37 | if fn_list is not None: 38 | self.length = len(fn_list) 39 | self.keys = fn_list 40 | else: 41 | raise Error 42 | 43 | def __getitem__(self, index): 44 | byteflow = self.lmdb[self.keys[index]] 45 | 46 | # load image 47 | imgbuf = byteflow 48 | buf = six.BytesIO() 49 | buf.write(imgbuf) 50 | buf.seek(0) 51 | try: 52 | if args.extension == '.npz': 53 | feat = np.load(buf)['feat'] 54 | else: 55 | feat = np.load(buf) 56 | except Exception as e: 57 | print(self.keys[index], e) 58 | return None 59 | 60 | return feat 61 | 62 | def __len__(self): 63 | return self.length 64 | 65 | def __repr__(self): 66 | return self.__class__.__name__ + ' (' + self.db_path + ')' 67 | 68 | 69 | def make_dataset(dir, extension): 70 | images = [] 71 | dir = os.path.expanduser(dir) 72 | for root, _, fnames in sorted(os.walk(dir)): 73 | for fname in sorted(fnames): 74 | if has_file_allowed_extension(fname, [extension]): 75 | path = os.path.join(root, fname) 76 | images.append(path) 77 | 78 | return images 79 | 80 | 81 | def raw_reader(path): 82 | with open(path, 'rb') as f: 83 | bin_data = f.read() 84 | return bin_data 85 | 86 | 87 | def raw_npz_reader(path): 88 | with open(path, 'rb') as f: 89 | bin_data = f.read() 90 | try: 91 | npz_data = np.load(six.BytesIO(bin_data))['feat'] 92 | except Exception as e: 93 | print(path) 94 | npz_data = None 95 | print(e) 96 | return bin_data, npz_data 97 | 98 | 99 | def raw_npy_reader(path): 100 | with open(path, 'rb') as f: 101 | bin_data = f.read() 102 | try: 103 | npy_data = np.load(six.BytesIO(bin_data)) 104 | except Exception as e: 105 | print(path) 106 | npy_data = None 107 | print(e) 108 | return bin_data, npy_data 109 | 110 | 111 | class Folder(data.Dataset): 112 | 113 | def __init__(self, root, loader, extension, fn_list=None): 114 | super(Folder, self).__init__() 115 | self.root = root 116 | if fn_list: 117 | samples = [os.path.join(root, str(_)+extension) for _ in fn_list] 118 | else: 119 | samples = make_dataset(self.root, extension) 120 | 121 | self.loader = loader 122 | self.extension = extension 123 | self.samples = samples 124 | 125 | def __getitem__(self, index): 126 | """ 127 | Args: 128 | index (int): Index 129 | Returns: 130 | tuple: (sample, target) where target is class_index of the target class. 131 | """ 132 | path = self.samples[index] 133 | sample = self.loader(path) 134 | 135 | return (path.split('/')[-1].split('.')[0],) + sample 136 | 137 | def __len__(self): 138 | return len(self.samples) 139 | 140 | 141 | def folder2lmdb(dpath, fn_list, write_frequency=5000): 142 | directory = osp.expanduser(osp.join(dpath)) 143 | print("Loading dataset from %s" % directory) 144 | if args.extension == '.npz': 145 | dataset = Folder(directory, loader=raw_npz_reader, extension='.npz', 146 | fn_list=fn_list) 147 | else: 148 | dataset = Folder(directory, loader=raw_npy_reader, extension='.npy', 149 | fn_list=fn_list) 150 | data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x) 151 | 152 | # lmdb_path = osp.join(dpath, "%s.lmdb" % (directory.split('/')[-1])) 153 | lmdb_path = osp.join("%s.lmdb" % (directory)) 154 | isdir = os.path.isdir(lmdb_path) 155 | 156 | print("Generate LMDB to %s" % lmdb_path) 157 | db = lmdbdict(lmdb_path, mode='w', key_method='ascii', value_method='identity') 158 | 159 | tsvfile = open(args.output_file, 'a') 160 | writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) 161 | names = [] 162 | all_keys = [] 163 | for idx, data in enumerate(tqdm.tqdm(data_loader)): 164 | # print(type(data), data) 165 | name, byte, npz = data[0] 166 | if npz is not None: 167 | db[name] = byte 168 | all_keys.append(name) 169 | names.append({'image_id': name, 'status': str(npz is not None)}) 170 | if idx % write_frequency == 0: 171 | print("[%d/%d]" % (idx, len(data_loader))) 172 | print('writing') 173 | db.flush() 174 | # write in tsv 175 | for name in names: 176 | writer.writerow(name) 177 | names = [] 178 | tsvfile.flush() 179 | print('writing finished') 180 | # write all keys 181 | # txn.put("keys".encode(), pickle.dumps(all_keys)) 182 | # # finish iterating through dataset 183 | # txn.commit() 184 | for name in names: 185 | writer.writerow(name) 186 | tsvfile.flush() 187 | tsvfile.close() 188 | 189 | print("Flushing database ...") 190 | db.flush() 191 | del db 192 | 193 | def parse_args(): 194 | """ 195 | Parse input arguments 196 | """ 197 | parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network') 198 | # parser.add_argument('--json) 199 | parser.add_argument('--input_json', default='./data/dataset_coco.json', type=str) 200 | parser.add_argument('--output_file', default='.dump_cache.tsv', type=str) 201 | parser.add_argument('--folder', default='./data/cocobu_att', type=str) 202 | parser.add_argument('--extension', default='.npz', type=str) 203 | 204 | args = parser.parse_args() 205 | return args 206 | 207 | if __name__ == "__main__": 208 | global args 209 | args = parse_args() 210 | 211 | args.output_file += args.folder.split('/')[-1] 212 | if args.folder.find('/') > 0: 213 | args.output_file = args.folder[:args.folder.rfind('/')+1]+args.output_file 214 | print(args.output_file) 215 | 216 | img_list = json.load(open(args.input_json, 'r'))['images'] 217 | fn_list = [str(_['cocoid']) for _ in img_list] 218 | found_ids = set() 219 | try: 220 | with open(args.output_file, 'r') as tsvfile: 221 | reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) 222 | for item in reader: 223 | if item['status'] == 'True': 224 | found_ids.add(item['image_id']) 225 | except: 226 | pass 227 | fn_list = [_ for _ in fn_list if _ not in found_ids] 228 | folder2lmdb(args.folder, fn_list) 229 | 230 | # Test existing. 231 | found_ids = set() 232 | with open(args.output_file, 'r') as tsvfile: 233 | reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) 234 | for item in reader: 235 | if item['status'] == 'True': 236 | found_ids.add(item['image_id']) 237 | 238 | folder_dataset = FolderLMDB(args.folder+'.lmdb', list(found_ids)) 239 | data_loader = DataLoader(folder_dataset, num_workers=16, collate_fn=lambda x: x) 240 | for data in tqdm.tqdm(data_loader): 241 | assert data[0] is not None -------------------------------------------------------------------------------- /scripts/make_bu_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import base64 7 | import numpy as np 8 | import csv 9 | import sys 10 | import zlib 11 | import time 12 | import mmap 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | # output_dir 18 | parser.add_argument('--downloaded_feats', default='data/bu_data', help='downloaded feature directory') 19 | parser.add_argument('--output_dir', default='data/cocobu', help='output feature files') 20 | 21 | args = parser.parse_args() 22 | 23 | csv.field_size_limit(sys.maxsize) 24 | 25 | 26 | FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features'] 27 | infiles = ['trainval/karpathy_test_resnet101_faster_rcnn_genome.tsv', 28 | 'trainval/karpathy_val_resnet101_faster_rcnn_genome.tsv',\ 29 | 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.0', \ 30 | 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.1'] 31 | 32 | os.makedirs(args.output_dir+'_att') 33 | os.makedirs(args.output_dir+'_fc') 34 | os.makedirs(args.output_dir+'_box') 35 | 36 | for infile in infiles: 37 | print('Reading ' + infile) 38 | with open(os.path.join(args.downloaded_feats, infile), "r") as tsv_in_file: 39 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 40 | for item in reader: 41 | item['image_id'] = int(item['image_id']) 42 | item['num_boxes'] = int(item['num_boxes']) 43 | for field in ['boxes', 'features']: 44 | item[field] = np.frombuffer(base64.decodestring(item[field].encode('ascii')), 45 | dtype=np.float32).reshape((item['num_boxes'],-1)) 46 | np.savez_compressed(os.path.join(args.output_dir+'_att', str(item['image_id'])), feat=item['features']) 47 | np.save(os.path.join(args.output_dir+'_fc', str(item['image_id'])), item['features'].mean(0)) 48 | np.save(os.path.join(args.output_dir+'_box', str(item['image_id'])), item['boxes']) 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /scripts/prepro_feats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into features files for use in data_loader.py 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: two folders of features 13 | """ 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import json 21 | import argparse 22 | from random import shuffle, seed 23 | import string 24 | # non-standard dependencies: 25 | import h5py 26 | from six.moves import cPickle 27 | import numpy as np 28 | import torch 29 | import torchvision.models as models 30 | import skimage.io 31 | 32 | from torchvision import transforms as trn 33 | preprocess = trn.Compose([ 34 | #trn.ToTensor(), 35 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 36 | ]) 37 | 38 | from captioning.utils.resnet_utils import myResnet 39 | import captioning.utils.resnet as resnet 40 | 41 | 42 | def main(params): 43 | net = getattr(resnet, params['model'])() 44 | net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth'))) 45 | my_resnet = myResnet(net) 46 | my_resnet.cuda() 47 | my_resnet.eval() 48 | 49 | imgs = json.load(open(params['input_json'], 'r')) 50 | imgs = imgs['images'] 51 | N = len(imgs) 52 | 53 | seed(123) # make reproducible 54 | 55 | dir_fc = params['output_dir']+'_fc' 56 | dir_att = params['output_dir']+'_att' 57 | if not os.path.isdir(dir_fc): 58 | os.mkdir(dir_fc) 59 | if not os.path.isdir(dir_att): 60 | os.mkdir(dir_att) 61 | 62 | for i,img in enumerate(imgs): 63 | # load the image 64 | I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename'])) 65 | # handle grayscale input images 66 | if len(I.shape) == 2: 67 | I = I[:,:,np.newaxis] 68 | I = np.concatenate((I,I,I), axis=2) 69 | 70 | I = I.astype('float32')/255.0 71 | I = torch.from_numpy(I.transpose([2,0,1])).cuda() 72 | I = preprocess(I) 73 | with torch.no_grad(): 74 | tmp_fc, tmp_att = my_resnet(I, params['att_size']) 75 | # write to pkl 76 | # print(dir_fc, str(img['cocoid']), tmp_fc.shape, tmp_att.shape, dir_att) 77 | # exit() 78 | np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) 79 | np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) 80 | 81 | if i % 1000 == 0: 82 | print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) 83 | print('wrote ', params['output_dir']) 84 | 85 | if __name__ == "__main__": 86 | 87 | parser = argparse.ArgumentParser() 88 | 89 | # input json 90 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 91 | parser.add_argument('--output_dir', default='data', help='output h5 file') 92 | 93 | # options 94 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 95 | parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') 96 | parser.add_argument('--model', default='resnet101', type=str, help='resnet101, resnet152') 97 | parser.add_argument('--model_root', default='./data/imagenet_weights', type=str, help='model root') 98 | 99 | args = parser.parse_args() 100 | params = vars(args) # convert to ordinary dict 101 | print('parsed input parameters:') 102 | print(json.dumps(params, indent = 2)) 103 | main(params) 104 | -------------------------------------------------------------------------------- /scripts/prepro_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.py 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 15 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 16 | first and last indices (in range 1..M) of labels for each image 17 | /label_length stores the length of the sequence for each of the M sequences 18 | 19 | The json file has a dict that contains: 20 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 21 | - an 'images' field that is a list holding auxiliary information for each image, 22 | such as in particular the 'split' it was assigned to. 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import os 30 | import json 31 | import argparse 32 | from random import shuffle, seed 33 | import string 34 | # non-standard dependencies: 35 | import h5py 36 | import numpy as np 37 | import torch 38 | import torchvision.models as models 39 | import skimage.io 40 | from PIL import Image 41 | 42 | 43 | def build_vocab(imgs, params): 44 | count_thr = params['word_count_threshold'] 45 | 46 | # count up the number of words 47 | counts = {} 48 | for img in imgs: 49 | for sent in img['sentences']: 50 | for w in sent['tokens']: 51 | counts[w] = counts.get(w, 0) + 1 52 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True) 53 | print('top words and their counts:') 54 | print('\n'.join(map(str,cw[:20]))) 55 | 56 | # print some stats 57 | total_words = sum(counts.values()) 58 | print('total words:', total_words) 59 | bad_words = [w for w,n in counts.items() if n <= count_thr] 60 | vocab = [w for w,n in counts.items() if n > count_thr] 61 | bad_count = sum(counts[w] for w in bad_words) 62 | print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) 63 | print('number of words in vocab would be %d' % (len(vocab), )) 64 | print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) 65 | 66 | # lets look at the distribution of lengths as well 67 | sent_lengths = {} 68 | for img in imgs: 69 | for sent in img['sentences']: 70 | txt = sent['tokens'] 71 | nw = len(txt) 72 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 73 | max_len = max(sent_lengths.keys()) 74 | print('max length sentence in raw data: ', max_len) 75 | print('sentence length distribution (count, number of words):') 76 | sum_len = sum(sent_lengths.values()) 77 | for i in range(max_len+1): 78 | print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len)) 79 | 80 | # lets now produce the final annotations 81 | if bad_count > 0: 82 | # additional special UNK token we will use below to map infrequent words to 83 | print('inserting the special UNK token') 84 | vocab.append('UNK') 85 | 86 | for img in imgs: 87 | img['final_captions'] = [] 88 | for sent in img['sentences']: 89 | txt = sent['tokens'] 90 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] 91 | img['final_captions'].append(caption) 92 | 93 | return vocab 94 | 95 | 96 | def encode_captions(imgs, params, wtoi): 97 | """ 98 | encode all captions into one large array, which will be 1-indexed. 99 | also produces label_start_ix and label_end_ix which store 1-indexed 100 | and inclusive (Lua-style) pointers to the first and last caption for 101 | each image in the dataset. 102 | """ 103 | 104 | max_length = params['max_length'] 105 | N = len(imgs) 106 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions 107 | 108 | label_arrays = [] 109 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed 110 | label_end_ix = np.zeros(N, dtype='uint32') 111 | label_length = np.zeros(M, dtype='uint32') 112 | caption_counter = 0 113 | counter = 1 114 | for i,img in enumerate(imgs): 115 | n = len(img['final_captions']) 116 | assert n > 0, 'error: some image has no captions' 117 | 118 | Li = np.zeros((n, max_length), dtype='uint32') 119 | for j,s in enumerate(img['final_captions']): 120 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence 121 | caption_counter += 1 122 | for k,w in enumerate(s): 123 | if k < max_length: 124 | Li[j,k] = wtoi[w] 125 | 126 | # note: word indices are 1-indexed, and captions are padded with zeros 127 | label_arrays.append(Li) 128 | label_start_ix[i] = counter 129 | label_end_ix[i] = counter + n - 1 130 | 131 | counter += n 132 | 133 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 134 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 135 | assert np.all(label_length > 0), 'error: some caption had no words?' 136 | 137 | print('encoded captions to array of size ', L.shape) 138 | return L, label_start_ix, label_end_ix, label_length 139 | 140 | 141 | def main(params): 142 | 143 | imgs = json.load(open(params['input_json'], 'r')) 144 | imgs = imgs['images'] 145 | 146 | seed(123) # make reproducible 147 | 148 | # create the vocab 149 | vocab = build_vocab(imgs, params) 150 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 151 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 152 | 153 | # encode captions in large arrays, ready to ship to hdf5 file 154 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) 155 | 156 | # create output h5 file 157 | N = len(imgs) 158 | f_lb = h5py.File(params['output_h5']+'_label.h5', "w") 159 | f_lb.create_dataset("labels", dtype='uint32', data=L) 160 | f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) 161 | f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) 162 | f_lb.create_dataset("label_length", dtype='uint32', data=label_length) 163 | f_lb.close() 164 | 165 | # create output json file 166 | out = {} 167 | out['ix_to_word'] = itow # encode the (1-indexed) vocab 168 | out['images'] = [] 169 | for i,img in enumerate(imgs): 170 | 171 | jimg = {} 172 | jimg['split'] = img['split'] 173 | if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need 174 | if 'cocoid' in img: 175 | jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) 176 | elif 'imgid' in img: 177 | jimg['id'] = img['imgid'] 178 | 179 | if params['images_root'] != '': 180 | with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: 181 | jimg['width'], jimg['height'] = _img.size 182 | 183 | out['images'].append(jimg) 184 | 185 | json.dump(out, open(params['output_json'], 'w')) 186 | print('wrote ', params['output_json']) 187 | 188 | if __name__ == "__main__": 189 | 190 | parser = argparse.ArgumentParser() 191 | 192 | # input json 193 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 194 | parser.add_argument('--output_json', default='data.json', help='output json file') 195 | parser.add_argument('--output_h5', default='data', help='output h5 file') 196 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 197 | 198 | # options 199 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 200 | parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab') 201 | 202 | args = parser.parse_args() 203 | params = vars(args) # convert to ordinary dict 204 | print('parsed input parameters:') 205 | print(json.dumps(params, indent = 2)) 206 | main(params) 207 | -------------------------------------------------------------------------------- /scripts/prepro_ngrams.py: -------------------------------------------------------------------------------- 1 | """ 2 | Precompute ngram counts of captions, to accelerate cider computation during training time. 3 | """ 4 | 5 | import os 6 | import json 7 | import argparse 8 | from six.moves import cPickle 9 | import captioning.utils.misc as utils 10 | from collections import defaultdict 11 | 12 | import sys 13 | sys.path.append("cider") 14 | from pyciderevalcap.ciderD.ciderD_scorer import CiderScorer 15 | 16 | 17 | def get_doc_freq(refs, params): 18 | tmp = CiderScorer(df_mode="corpus") 19 | for ref in refs: 20 | tmp.cook_append(None, ref) 21 | tmp.compute_doc_freq() 22 | return tmp.document_frequency, len(tmp.crefs) 23 | 24 | 25 | def build_dict(imgs, wtoi, params): 26 | wtoi[''] = 0 27 | 28 | count_imgs = 0 29 | 30 | refs_words = [] 31 | refs_idxs = [] 32 | for img in imgs: 33 | if (params['split'] == img['split']) or \ 34 | (params['split'] == 'train' and img['split'] == 'restval') or \ 35 | (params['split'] == 'all'): 36 | #(params['split'] == 'val' and img['split'] == 'restval') or \ 37 | ref_words = [] 38 | ref_idxs = [] 39 | for sent in img['sentences']: 40 | if hasattr(params, 'bpe'): 41 | sent['tokens'] = params.bpe.segment(' '.join(sent['tokens'])).strip().split(' ') 42 | tmp_tokens = sent['tokens'] + [''] 43 | tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens] 44 | ref_words.append(' '.join(tmp_tokens)) 45 | ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens])) 46 | refs_words.append(ref_words) 47 | refs_idxs.append(ref_idxs) 48 | count_imgs += 1 49 | print('total imgs:', count_imgs) 50 | 51 | ngram_words, count_refs = get_doc_freq(refs_words, params) 52 | ngram_idxs, count_refs = get_doc_freq(refs_idxs, params) 53 | print('count_refs:', count_refs) 54 | return ngram_words, ngram_idxs, count_refs 55 | 56 | def main(params): 57 | 58 | imgs = json.load(open(params['input_json'], 'r')) 59 | dict_json = json.load(open(params['dict_json'], 'r')) 60 | itow = dict_json['ix_to_word'] 61 | wtoi = {w:i for i,w in itow.items()} 62 | 63 | # Load bpe 64 | if 'bpe' in dict_json: 65 | import tempfile 66 | import codecs 67 | codes_f = tempfile.NamedTemporaryFile(delete=False) 68 | codes_f.close() 69 | with open(codes_f.name, 'w') as f: 70 | f.write(dict_json['bpe']) 71 | with codecs.open(codes_f.name, encoding='UTF-8') as codes: 72 | bpe = apply_bpe.BPE(codes) 73 | params.bpe = bpe 74 | 75 | imgs = imgs['images'] 76 | 77 | ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params) 78 | 79 | utils.pickle_dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','wb')) 80 | utils.pickle_dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','wb')) 81 | 82 | if __name__ == "__main__": 83 | 84 | parser = argparse.ArgumentParser() 85 | 86 | # input json 87 | parser.add_argument('--input_json', default='data/dataset_coco.json', help='input json file to process into hdf5') 88 | parser.add_argument('--dict_json', default='data/cocotalk.json', help='output json file') 89 | parser.add_argument('--output_pkl', default='data/coco-all', help='output pickle file') 90 | parser.add_argument('--split', default='all', help='test, val, train, all') 91 | args = parser.parse_args() 92 | params = vars(args) # convert to ordinary dict 93 | 94 | main(params) 95 | -------------------------------------------------------------------------------- /scripts/prepro_reference_json.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Create a reference json file used for evaluation with `coco-caption` repo. 4 | Used when reference json is not provided, (e.g., flickr30k, or you have your own split of train/val/test) 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import json 13 | import argparse 14 | import sys 15 | import hashlib 16 | from random import shuffle, seed 17 | 18 | 19 | def main(params): 20 | 21 | imgs = json.load(open(params['input_json'][0], 'r'))['images'] 22 | # tmp = [] 23 | # for k in imgs.keys(): 24 | # for img in imgs[k]: 25 | # img['filename'] = img['image_id'] # k+'/'+img['image_id'] 26 | # img['image_id'] = int( 27 | # int(hashlib.sha256(img['image_id']).hexdigest(), 16) % sys.maxint) 28 | # tmp.append(img) 29 | # imgs = tmp 30 | 31 | # create output json file 32 | out = {'info': {'description': 'This is stable 1.0 version of the 2014 MS COCO dataset.', 'url': 'http://mscoco.org', 'version': '1.0', 'year': 2014, 'contributor': 'Microsoft COCO group', 'date_created': '2015-01-27 09:11:52.357475'}, 'licenses': [{'url': 'http://creativecommons.org/licenses/by-nc-sa/2.0/', 'id': 1, 'name': 'Attribution-NonCommercial-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nc/2.0/', 'id': 2, 'name': 'Attribution-NonCommercial License'}, {'url': 'http://creativecommons.org/licenses/by-nc-nd/2.0/', 'id': 3, 'name': 'Attribution-NonCommercial-NoDerivs License'}, {'url': 'http://creativecommons.org/licenses/by/2.0/', 'id': 4, 'name': 'Attribution License'}, {'url': 'http://creativecommons.org/licenses/by-sa/2.0/', 'id': 5, 'name': 'Attribution-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nd/2.0/', 'id': 6, 'name': 'Attribution-NoDerivs License'}, {'url': 'http://flickr.com/commons/usage/', 'id': 7, 'name': 'No known copyright restrictions'}, {'url': 'http://www.usa.gov/copyright.shtml', 'id': 8, 'name': 'United States Government Work'}], 'type': 'captions'} 33 | out.update({'images': [], 'annotations': []}) 34 | 35 | cnt = 0 36 | empty_cnt = 0 37 | for i, img in enumerate(imgs): 38 | if img['split'] == 'train': 39 | continue 40 | out['images'].append( 41 | {'id': img.get('cocoid', img['imgid'])}) 42 | for j, s in enumerate(img['sentences']): 43 | if len(s) == 0: 44 | continue 45 | s = ' '.join(s['tokens']) 46 | out['annotations'].append( 47 | {'image_id': out['images'][-1]['id'], 'caption': s, 'id': cnt}) 48 | cnt += 1 49 | 50 | json.dump(out, open(params['output_json'], 'w')) 51 | print('wrote ', params['output_json']) 52 | 53 | 54 | if __name__ == "__main__": 55 | 56 | parser = argparse.ArgumentParser() 57 | 58 | # input json 59 | parser.add_argument('--input_json', nargs='+', required=True, 60 | help='input json file to process into hdf5') 61 | parser.add_argument('--output_json', default='data.json', 62 | help='output json file') 63 | 64 | args = parser.parse_args() 65 | params = vars(args) # convert to ordinary dict 66 | print('parsed input parameters:') 67 | print(json.dumps(params, indent=2)) 68 | main(params) 69 | 70 | -------------------------------------------------------------------------------- /scripts_FineCapEval/clip_prepro_feats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into features files for use in data_loader.py 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: two folders of features 13 | """ 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import json 21 | import argparse 22 | from random import shuffle, seed 23 | import string 24 | # non-standard dependencies: 25 | import h5py 26 | from six.moves import cPickle 27 | import numpy as np 28 | import torch 29 | import torchvision.models as models 30 | import skimage.io 31 | 32 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 33 | from PIL import Image 34 | from torch import nn 35 | 36 | preprocess = Compose([ 37 | Resize((448, 448), interpolation=Image.BICUBIC), 38 | CenterCrop((448, 448)), 39 | ToTensor() 40 | ]) 41 | 42 | 43 | from clip.clip import load 44 | from timm.models.vision_transformer import resize_pos_embed 45 | import timm 46 | 47 | from captioning.utils.resnet_utils import myResnet 48 | import captioning.utils.resnet as resnet 49 | 50 | from tqdm import tqdm 51 | 52 | 53 | def main(params): 54 | if params["model_type"] != 'vit_base_patch32_224_in21k': 55 | model, transform = load(params["model_type"], jit=False) 56 | else: 57 | model = timm.create_model(params["model_type"], pretrained=True) 58 | model = model.cuda() 59 | 60 | if params["model_type"] != 'vit_base_patch32_224_in21k': 61 | save_model_type = params["model_type"].split("-")[0] 62 | mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to("cuda").reshape(3, 1, 1) 63 | std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to("cuda").reshape(3, 1, 1) 64 | 65 | if "RN" in params["model_type"]: 66 | num_patches = 196 #600 * 1000 // 32 // 32 67 | pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, model.visual.attnpool.positional_embedding.shape[-1], device='cuda'),) 68 | pos_embed.weight = resize_pos_embed(model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed) 69 | model.visual.attnpool.positional_embedding = pos_embed 70 | 71 | else: 72 | save_model_type = 'vit_base' 73 | mean = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1) 74 | std = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1) 75 | 76 | num_patches = 196 #600 * 1000 // 32 // 32 77 | pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, 768, device='cuda'),) 78 | pos_embed.weight = resize_pos_embed(model.pos_embed, pos_embed) 79 | model.pos_embed = pos_embed 80 | 81 | if params["model_type"] == "ViT-B/32": 82 | num_patches = 196 #600 * 1000 // 32 // 32 83 | pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768, device='cuda'),) 84 | pos_embed.weight = resize_pos_embed(model.visual.positional_embedding.unsqueeze(0), pos_embed.unsqueeze(0)) 85 | model.visual.positional_embedding = pos_embed 86 | imgs = json.load(open(params['input_json'], 'r')) 87 | imgs = imgs['images'] 88 | N = len(imgs) 89 | 90 | seed(123) # make reproducible 91 | 92 | dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc' 93 | dir_att = params['output_dir']+'_clip_'+save_model_type+'_att' 94 | if not os.path.isdir(dir_fc): 95 | os.mkdir(dir_fc) 96 | if not os.path.isdir(dir_att): 97 | os.mkdir(dir_att) 98 | 99 | for i, img in enumerate(tqdm(imgs)): 100 | with torch.no_grad(): 101 | 102 | # img_path = os.path.join(params['images_root'], img['filepath'], img['filename']) 103 | # img_path = os.path.join(params['images_root'], img['file_name']) 104 | 105 | img_path = os.path.join(params['images_root'], img['file_path']) 106 | 107 | image = preprocess(Image.open( img_path ).convert("RGB")) 108 | image = torch.tensor(np.stack([image])).cuda() 109 | image -= mean 110 | image /= std 111 | if "RN" in params["model_type"]: 112 | tmp_att, tmp_fc = model.encode_image(image) 113 | tmp_att = tmp_att[0].permute(1, 2, 0) 114 | tmp_fc = tmp_fc[0] 115 | elif params["model_type"] == 'vit_base_patch32_224_in21k': 116 | x = model(image) 117 | tmp_fc = x[0, 0, :] 118 | tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) 119 | else: 120 | x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid] 121 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 122 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 123 | x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 124 | x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :] 125 | x = model.visual.ln_pre(x) 126 | 127 | x = x.permute(1, 0, 2) # NLD -> LND 128 | 129 | for layer_idx, layer in enumerate(model.visual.transformer.resblocks): 130 | x = layer(x) 131 | 132 | x = x.permute(1, 0, 2) 133 | tmp_fc = x[0, 0, :] 134 | tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) 135 | 136 | # np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) 137 | # np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) 138 | np.save(os.path.join(dir_fc, str(img['id'])), tmp_fc.data.cpu().float().numpy()) 139 | np.savez_compressed(os.path.join(dir_att, str(img['id'])), feat=tmp_att.data.cpu().float().numpy()) 140 | 141 | 142 | # if i % 1000 == 0: 143 | # print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) 144 | print('wrote ', dir_fc, dir_att) 145 | 146 | if __name__ == "__main__": 147 | 148 | parser = argparse.ArgumentParser() 149 | 150 | # input json 151 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 152 | parser.add_argument('--output_dir', default='data', help='output h5 file') 153 | 154 | # options 155 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 156 | parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') 157 | parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k') 158 | 159 | args = parser.parse_args() 160 | params = vars(args) # convert to ordinary dict 161 | print('parsed input parameters:') 162 | print(json.dumps(params, indent = 2)) 163 | main(params) 164 | -------------------------------------------------------------------------------- /scripts_FineCapEval/clipscore_prepro_feats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into features files for use in data_loader.py 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: two folders of features 13 | """ 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import json 21 | import argparse 22 | from random import shuffle, seed 23 | import string 24 | # non-standard dependencies: 25 | import h5py 26 | from six.moves import cPickle 27 | import numpy as np 28 | import torch 29 | import torchvision.models as models 30 | import skimage.io 31 | 32 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 33 | from PIL import Image 34 | from torch import nn 35 | 36 | # preprocess = Compose([ 37 | # Resize((448, 448), interpolation=Image.BICUBIC), 38 | # CenterCrop((448, 448)), 39 | # ToTensor() 40 | # ]) 41 | 42 | 43 | # from clip.clip import load 44 | # from timm.models.vision_transformer import resize_pos_embed 45 | # import timm 46 | 47 | # from captioning.utils.resnet_utils import myResnet 48 | # import captioning.utils.resnet as resnet 49 | 50 | from captioning.utils.clipscore import CLIPScore 51 | 52 | from tqdm import tqdm 53 | 54 | 55 | def main(params): 56 | 57 | clipscore_model = CLIPScore() 58 | clipscore_model.to('cuda') 59 | 60 | imgs = json.load(open(params['input_json'], 'r')) 61 | imgs = imgs['images'] 62 | N = len(imgs) 63 | 64 | seed(123) # make reproducible 65 | 66 | # dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc' 67 | # dir_att = params['output_dir']+'_clip_'+save_model_type+'_att' 68 | 69 | vis_dir_fc = params['output_dir']+'_clipscore_vis' 70 | if not os.path.isdir(vis_dir_fc): 71 | os.mkdir(vis_dir_fc) 72 | 73 | # text_dir_fc = params['output_dir']+'_clipscore_text' 74 | # if not os.path.isdir(text_dir_fc): 75 | # os.mkdir(text_dir_fc) 76 | 77 | # if not os.path.isdir(dir_att): 78 | # os.mkdir(dir_att) 79 | 80 | for i,img in enumerate(tqdm(imgs)): 81 | # load the image 82 | 83 | # img_path = os.path.join(params['images_root'], img['filepath'], img['filename']) 84 | # img_path = os.path.join(params['images_root'], img['file_name']) 85 | img_path = os.path.join(params['images_root'], img['file_path']) 86 | 87 | img_feat = clipscore_model.image_extract(img_path) 88 | img_feat = img_feat.view(512) 89 | 90 | # for d in img['sentences']: 91 | # text = d['raw'].strip() 92 | # text_feat = clipscore_model.text_extract(text) 93 | 94 | 95 | # with torch.no_grad(): 96 | 97 | # image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB")) 98 | # image = torch.tensor(np.stack([image])).cuda() 99 | # image -= mean 100 | # image /= std 101 | # if "RN" in params["model_type"]: 102 | # tmp_att, tmp_fc = model.encode_image(image) 103 | # tmp_att = tmp_att[0].permute(1, 2, 0) 104 | # tmp_fc = tmp_fc[0] 105 | # elif params["model_type"] == 'vit_base_patch32_224_in21k': 106 | # x = model(image) 107 | # tmp_fc = x[0, 0, :] 108 | # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) 109 | # else: 110 | # x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid] 111 | # x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 112 | # x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 113 | # x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 114 | # x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :] 115 | # x = model.visual.ln_pre(x) 116 | 117 | # x = x.permute(1, 0, 2) # NLD -> LND 118 | 119 | # for layer_idx, layer in enumerate(model.visual.transformer.resblocks): 120 | # x = layer(x) 121 | 122 | # x = x.permute(1, 0, 2) 123 | # tmp_fc = x[0, 0, :] 124 | # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) 125 | 126 | np.save(os.path.join(vis_dir_fc, str(img['id'])), img_feat.data.cpu().float().numpy()) 127 | # np.save(os.path.join(text_dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) 128 | 129 | 130 | # np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) 131 | 132 | # if i % 1000 == 0: 133 | # print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) 134 | print('wrote ', vis_dir_fc) 135 | 136 | if __name__ == "__main__": 137 | 138 | parser = argparse.ArgumentParser() 139 | 140 | # input json 141 | # dataset_coco.json 142 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 143 | parser.add_argument('--output_dir', default='data', help='output h5 file') 144 | 145 | # options 146 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 147 | # parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') 148 | # parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k') 149 | 150 | args = parser.parse_args() 151 | params = vars(args) # convert to ordinary dict 152 | print('parsed input parameters:') 153 | print(json.dumps(params, indent = 2)) 154 | main(params) 155 | -------------------------------------------------------------------------------- /scripts_FineCapEval/prepro_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.py 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 15 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 16 | first and last indices (in range 1..M) of labels for each image 17 | /label_length stores the length of the sequence for each of the M sequences 18 | 19 | The json file has a dict that contains: 20 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 21 | - an 'images' field that is a list holding auxiliary information for each image, 22 | such as in particular the 'split' it was assigned to. 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import os 30 | import json 31 | import argparse 32 | from random import shuffle, seed 33 | import string 34 | # non-standard dependencies: 35 | import h5py 36 | import numpy as np 37 | import torch 38 | import torchvision.models as models 39 | import skimage.io 40 | from PIL import Image 41 | 42 | 43 | def build_vocab(imgs, params): 44 | count_thr = params['word_count_threshold'] 45 | 46 | # count up the number of words 47 | counts = {} 48 | for img in imgs: 49 | for sent in img['sentences']: 50 | for w in sent['tokens']: 51 | counts[w] = counts.get(w, 0) + 1 52 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True) 53 | print('top words and their counts:') 54 | print('\n'.join(map(str,cw[:20]))) 55 | 56 | # print some stats 57 | total_words = sum(counts.values()) 58 | print('total words:', total_words) 59 | bad_words = [w for w,n in counts.items() if n <= count_thr] 60 | vocab = [w for w,n in counts.items() if n > count_thr] 61 | bad_count = sum(counts[w] for w in bad_words) 62 | print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) 63 | print('number of words in vocab would be %d' % (len(vocab), )) 64 | print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) 65 | 66 | # lets look at the distribution of lengths as well 67 | sent_lengths = {} 68 | for img in imgs: 69 | for sent in img['sentences']: 70 | txt = sent['tokens'] 71 | nw = len(txt) 72 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 73 | max_len = max(sent_lengths.keys()) 74 | print('max length sentence in raw data: ', max_len) 75 | print('sentence length distribution (count, number of words):') 76 | sum_len = sum(sent_lengths.values()) 77 | for i in range(max_len+1): 78 | print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len)) 79 | 80 | # lets now produce the final annotations 81 | if bad_count > 0: 82 | # additional special UNK token we will use below to map infrequent words to 83 | print('inserting the special UNK token') 84 | vocab.append('UNK') 85 | 86 | for img in imgs: 87 | img['final_captions'] = [] 88 | for sent in img['sentences']: 89 | txt = sent['tokens'] 90 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] 91 | img['final_captions'].append(caption) 92 | 93 | return vocab 94 | 95 | 96 | def encode_captions(imgs, params, wtoi): 97 | """ 98 | encode all captions into one large array, which will be 1-indexed. 99 | also produces label_start_ix and label_end_ix which store 1-indexed 100 | and inclusive (Lua-style) pointers to the first and last caption for 101 | each image in the dataset. 102 | """ 103 | 104 | max_length = params['max_length'] 105 | N = len(imgs) 106 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions 107 | 108 | label_arrays = [] 109 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed 110 | label_end_ix = np.zeros(N, dtype='uint32') 111 | label_length = np.zeros(M, dtype='uint32') 112 | caption_counter = 0 113 | counter = 1 114 | for i,img in enumerate(imgs): 115 | n = len(img['final_captions']) 116 | assert n > 0, 'error: some image has no captions' 117 | 118 | Li = np.zeros((n, max_length), dtype='uint32') 119 | for j,s in enumerate(img['final_captions']): 120 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence 121 | caption_counter += 1 122 | for k,w in enumerate(s): 123 | if k < max_length: 124 | Li[j,k] = wtoi[w] 125 | 126 | # note: word indices are 1-indexed, and captions are padded with zeros 127 | label_arrays.append(Li) 128 | label_start_ix[i] = counter 129 | label_end_ix[i] = counter + n - 1 130 | 131 | counter += n 132 | 133 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 134 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 135 | assert np.all(label_length > 0), 'error: some caption had no words?' 136 | 137 | print('encoded captions to array of size ', L.shape) 138 | return L, label_start_ix, label_end_ix, label_length 139 | 140 | 141 | def main(params): 142 | 143 | imgs = json.load(open(params['input_json'], 'r')) 144 | imgs = imgs['images'] 145 | 146 | seed(123) # make reproducible 147 | 148 | # # create the vocab 149 | # vocab = build_vocab(imgs, params) 150 | # itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 151 | # wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 152 | 153 | itow = imgs['ix_to_word'] 154 | wtoi = {w:i for i, w in itow.items()} 155 | 156 | # encode captions in large arrays, ready to ship to hdf5 file 157 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) 158 | 159 | # create output h5 file 160 | N = len(imgs) 161 | f_lb = h5py.File(params['output_h5']+'_label.h5', "w") 162 | f_lb.create_dataset("labels", dtype='uint32', data=L) 163 | f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) 164 | f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) 165 | f_lb.create_dataset("label_length", dtype='uint32', data=label_length) 166 | f_lb.close() 167 | 168 | # create output json file 169 | out = {} 170 | out['ix_to_word'] = itow # encode the (1-indexed) vocab 171 | out['images'] = [] 172 | for i,img in enumerate(imgs): 173 | 174 | jimg = {} 175 | jimg['split'] = img['split'] 176 | if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need 177 | if 'cocoid' in img: 178 | jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) 179 | elif 'imgid' in img: 180 | jimg['id'] = img['imgid'] 181 | 182 | if params['images_root'] != '': 183 | with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: 184 | jimg['width'], jimg['height'] = _img.size 185 | 186 | out['images'].append(jimg) 187 | 188 | json.dump(out, open(params['output_json'], 'w')) 189 | print('wrote ', params['output_json']) 190 | 191 | if __name__ == "__main__": 192 | 193 | parser = argparse.ArgumentParser() 194 | 195 | # input json 196 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 197 | parser.add_argument('--output_json', default='data.json', help='output json file') 198 | parser.add_argument('--output_h5', default='data', help='output h5 file') 199 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 200 | 201 | # options 202 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 203 | parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab') 204 | 205 | args = parser.parse_args() 206 | params = vars(args) # convert to ordinary dict 207 | print('parsed input parameters:') 208 | print(json.dumps(params, indent = 2)) 209 | main(params) 210 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="captioning", 5 | version="0.0.1", 6 | packages=setuptools.find_packages(), 7 | ) -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import numpy as np 7 | 8 | import time 9 | import os 10 | from six.moves import cPickle 11 | 12 | import captioning.utils.opts as opts 13 | import captioning.models as models 14 | from captioning.data.dataloader import * 15 | # from captioning.data.dataloaderraw import * 16 | import captioning.utils.eval_utils as eval_utils 17 | import argparse 18 | import captioning.utils.misc as utils 19 | import captioning.modules.losses as losses 20 | import torch 21 | 22 | # Input arguments and options 23 | parser = argparse.ArgumentParser() 24 | # Input paths 25 | parser.add_argument('--model', type=str, default='', 26 | help='path to model to evaluate') 27 | parser.add_argument('--cnn_model', type=str, default='resnet101', 28 | help='resnet101, resnet152') 29 | parser.add_argument('--infos_path', type=str, default='', 30 | help='path to infos to evaluate') 31 | parser.add_argument('--only_lang_eval', type=int, default=0, 32 | help='lang eval on saved results') 33 | parser.add_argument('--force', type=int, default=0, 34 | help='force to evaluate no matter if there are results available') 35 | parser.add_argument('--device', type=str, default='cuda', 36 | help='cpu or cuda') 37 | opts.add_eval_options(parser) 38 | opts.add_diversity_opts(parser) 39 | opt = parser.parse_args() 40 | 41 | # Load infos 42 | with open(opt.infos_path, 'rb') as f: 43 | infos = utils.pickle_load(f) 44 | 45 | # override and collect parameters 46 | replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id'] 47 | ignore = ['start_from'] 48 | 49 | for k in vars(infos['opt']).keys(): 50 | if k in replace: 51 | setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, '')) 52 | elif k not in ignore: 53 | if not k in vars(opt): 54 | vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model 55 | 56 | vocab = infos['vocab'] # ix -> word mapping 57 | 58 | pred_fn = os.path.join('eval_results/', '.saved_pred_'+ opt.id + '_' + opt.split + '.pth') 59 | result_fn = os.path.join('eval_results/', opt.id + '_' + opt.split + '.json') 60 | 61 | if opt.only_lang_eval == 1 or (not opt.force and os.path.isfile(pred_fn)): 62 | # if results existed, then skip, unless force is on 63 | if not opt.force: 64 | try: 65 | if os.path.isfile(result_fn): 66 | print(result_fn) 67 | json.load(open(result_fn, 'r')) 68 | print('already evaluated') 69 | os._exit(0) 70 | except: 71 | pass 72 | 73 | predictions, n_predictions = torch.load(pred_fn) 74 | lang_stats = eval_utils.language_eval(opt.input_json, predictions, n_predictions, vars(opt), opt.split) 75 | print(lang_stats) 76 | os._exit(0) 77 | 78 | # At this point only_lang_eval if 0 79 | if not opt.force: 80 | # Check out if 81 | try: 82 | # if no pred exists, then continue 83 | tmp = torch.load(pred_fn) 84 | # if language_eval == 1, and no pred exists, then continue 85 | if opt.language_eval == 1: 86 | json.load(open(result_fn, 'r')) 87 | print('Result is already there') 88 | os._exit(0) 89 | except: 90 | pass 91 | 92 | # Setup the model 93 | opt.vocab = vocab 94 | model = models.setup(opt) 95 | del opt.vocab 96 | model.load_state_dict(torch.load(opt.model, map_location='cpu')) 97 | model.to(opt.device) 98 | model.eval() 99 | crit = losses.LanguageModelCriterion() 100 | 101 | # Create the Data Loader instance 102 | if len(opt.image_folder) == 0: 103 | loader = DataLoader(opt) 104 | else: 105 | loader = DataLoaderRaw({'folder_path': opt.image_folder, 106 | 'coco_json': opt.coco_json, 107 | 'batch_size': opt.batch_size, 108 | 'cnn_model': opt.cnn_model}) 109 | # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json 110 | # So make sure to use the vocab in infos file. 111 | loader.dataset.ix_to_word = infos['vocab'] 112 | 113 | 114 | # Set sample options 115 | opt.dataset = opt.input_json 116 | loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, 117 | vars(opt)) 118 | 119 | print('loss: ', loss) 120 | if lang_stats: 121 | print(lang_stats) 122 | 123 | if opt.dump_json == 1: 124 | # dump the json 125 | json.dump(split_predictions, open('vis/vis.json', 'w')) 126 | -------------------------------------------------------------------------------- /tools/eval_clip_retrieval.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image 3 | # import requests 4 | 5 | from transformers import CLIPProcessor, CLIPModel 6 | 7 | import torch 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | from pathlib import Path 11 | from tqdm import tqdm 12 | import json 13 | import argparse 14 | import numpy as np 15 | 16 | class COCODataset(Dataset): 17 | def __init__(self, 18 | coco_root="/nas-ssd/jmincho/datasets/COCO/", 19 | gen_caption_path=None, 20 | is_gt=True): 21 | super().__init__() 22 | 23 | self.coco_root = Path(coco_root) 24 | 25 | self.image_dir = self.coco_root.joinpath('images/val2014') 26 | 27 | if is_gt: 28 | print("Loading karpathy splits") 29 | data_info_path = self.coco_root.joinpath('dataset_coco.json') 30 | with open(data_info_path) as f: 31 | karpathy_data = json.load(f) 32 | 33 | data = [] 34 | for datum in karpathy_data['images']: 35 | # karpathy test split 36 | if datum['split'] == 'test': 37 | img_id = datum['filename'].split('.')[0] 38 | new_datum = { 39 | 'img_id': img_id, 40 | 'captions': [d['raw'].strip() for d in datum['sentences']], 41 | } 42 | data.append(new_datum) 43 | else: 44 | print("Loading generated captions") 45 | gen_caption_path = Path(gen_caption_path) 46 | with open(gen_caption_path) as f: 47 | # karpathy_data = json.load(f) 48 | imgTogen_results = json.load(f)['imgToEval'] 49 | data = [] 50 | for img_id, img_data in imgTogen_results.items(): 51 | new_datum = { 52 | 'img_id': img_id, 53 | 'captions': [img_data['caption']], 54 | } 55 | data.append(new_datum) 56 | 57 | self.data = data 58 | print('# images:', len(self.data)) 59 | 60 | self.img_transform = processor.feature_extractor 61 | self.tokenizer = processor.tokenizer 62 | 63 | def __len__(self): 64 | return len(self.data) 65 | 66 | def __getitem__(self, idx): 67 | datum = self.data[idx] 68 | img_id = datum['img_id'] 69 | if 'COCO' not in img_id: 70 | img_id = f'COCO_val2014_{str(img_id).zfill(12)}' 71 | img_fname = f"{img_id}.jpg" 72 | # COCO_val2014_000000522418.jpg 73 | img_path = self.image_dir.joinpath(img_fname) 74 | img = Image.open(img_path).convert("RGB") 75 | 76 | # take first caption 77 | caption = datum['captions'][0] 78 | 79 | return { 80 | "img": img, 81 | "caption": caption, 82 | } 83 | 84 | def collate_fn(self, datum_list): 85 | B = len(datum_list) 86 | imgs = [datum['img'] for datum in datum_list] 87 | images = self.img_transform(imgs, return_tensors="pt") 88 | 89 | captions = [datum['caption'] for datum in datum_list] 90 | 91 | text_tokens = self.tokenizer(captions, return_tensors="pt", padding=True) 92 | batch = { 93 | 'images': images, 94 | 'captions': text_tokens, 95 | } 96 | return batch 97 | 98 | 99 | def compute_similarity(image_features, text_features, bs = 1000): 100 | # compute similarity 101 | max_pairs = image_features.shape[0] 102 | similarity_scores = torch.zeros(max_pairs, max_pairs) 103 | for v in range(0, max_pairs, bs): 104 | for t in range(0, max_pairs, bs): 105 | # print('Processing Visual '+str(v)+' Text '+str(t), end='\r') 106 | batch_visual_emb = image_features[v:v+bs] 107 | batch_caption_emb = text_features[t:t+bs] 108 | 109 | logits = batch_visual_emb @ batch_caption_emb.t() 110 | similarity_scores[v:v+bs,t:t+bs] = logits 111 | 112 | print('Done similarity') 113 | return similarity_scores 114 | 115 | def compute_retrieval(a2b_sims, return_ranks=True): 116 | """ 117 | Args: 118 | a2b_sims: Result of computing similarity between two sets of embeddings (emb1 @ emb2.T) 119 | with shape (num_datapoints, num_datapoints). 120 | 121 | Returns: 122 | Retrieval metrics for that similarity. 123 | """ 124 | npts = a2b_sims.shape[0] 125 | ranks = np.zeros(npts) 126 | top1 = np.zeros(npts) 127 | # loop source embedding indices 128 | for index in range(npts): 129 | # get order of similarities to target embeddings 130 | inds = np.argsort(a2b_sims[index])[::-1] 131 | # find where the correct embedding is ranked 132 | where = np.where(inds == index) 133 | rank = where[0][0] 134 | ranks[index] = rank 135 | # save the top1 result as well 136 | top1[index] = inds[0] 137 | 138 | # Compute metrics 139 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 140 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 141 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 142 | r50 = 100.0 * len(np.where(ranks < 50)[0]) / len(ranks) 143 | medr = np.floor(np.median(ranks)) + 1 144 | meanr = ranks.mean() + 1 145 | 146 | report_dict = {"r1": r1, "r5": r5, "r10": r10, "r50": r50, "medr": medr, "meanr": meanr, "sum": r1 + r5 + r10} 147 | 148 | if return_ranks: 149 | return report_dict, (ranks, top1) 150 | else: 151 | return report_dict 152 | 153 | 154 | if __name__ == '__main__': 155 | 156 | parser = argparse.ArgumentParser() 157 | parser.add_argument('--coco_root', type=str, default="/nas-ssd/jmincho/datasets/COCO/") 158 | parser.add_argument('--gt', action='store_true') 159 | parser.add_argument('--gen_caption_path', type=str, default="./eval_results/clipRN50_cider_test.json") 160 | args = parser.parse_args() 161 | 162 | model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 163 | processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 164 | 165 | device = "cuda" 166 | model = model.to(device) 167 | model.eval() 168 | print(f"Loaded CLIP at {device}") 169 | 170 | batch_size = 1000 171 | 172 | dataset = COCODataset( 173 | coco_root="/nas-ssd/jmincho/datasets/COCO/", 174 | gen_caption_path=args.gen_caption_path, 175 | is_gt=args.gt 176 | ) 177 | data_loader = DataLoader( 178 | dataset, 179 | batch_size=batch_size, 180 | collate_fn=dataset.collate_fn, 181 | shuffle=False, 182 | num_workers=8) 183 | 184 | # fwd all samples 185 | image_features = [] 186 | text_features = [] 187 | for batch_idx, batch in enumerate(tqdm(data_loader)): 188 | # print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end="\r") 189 | # images, texts = batch 190 | 191 | with torch.no_grad(): 192 | images = batch["images"].to(device) 193 | texts = batch["captions"].to(device) 194 | 195 | vision_outputs = model.vision_model(**batch['images']) 196 | text_outputs = model.text_model(**batch['captions']) 197 | 198 | image_embeds = vision_outputs[1] 199 | image_embeds = model.visual_projection(image_embeds) 200 | 201 | text_embeds = text_outputs[1] 202 | text_embeds = model.text_projection(text_embeds) 203 | 204 | # normalized features 205 | image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) 206 | text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) 207 | 208 | text_features.append(text_embeds.detach().cpu()) 209 | image_features.append(image_embeds.detach().cpu()) 210 | 211 | image_features = torch.cat(image_features, 0) 212 | text_features = torch.cat(text_features, 0) 213 | print('Done forward') 214 | 215 | # normalized features 216 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 217 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 218 | 219 | # if not single_caption: 220 | # for cap_idx in range(text_features.shape[1]): 221 | # similarity_scores = compute_similarity(image_features, text_features[:,cap_idx,:]) 222 | # i2t_dict = compute_retrieval(similarity_scores.numpy()) 223 | # t2i_dict = compute_retrieval(similarity_scores.t().numpy()) 224 | # print(cap_idx, 'i2t', i2t_dict) 225 | # print(cap_idx, 't2i', t2i_dict) 226 | # else: 227 | similarity_scores = compute_similarity(image_features, text_features) 228 | i2t_dict = compute_retrieval(similarity_scores.numpy()) 229 | t2i_dict = compute_retrieval(similarity_scores.t().numpy()) 230 | print('i2t', i2t_dict) 231 | print('t2i', t2i_dict) 232 | -------------------------------------------------------------------------------- /tools/eval_finecapeval.py: -------------------------------------------------------------------------------- 1 | 2 | from tqdm import tqdm 3 | from pprint import pprint 4 | import pandas as pd 5 | import argparse 6 | import re 7 | import json 8 | import nltk 9 | from nltk.tokenize import word_tokenize 10 | from nltk.stem.porter import PorterStemmer 11 | p_stemmer = PorterStemmer() 12 | 13 | # nltk.download('punkt') 14 | # nltk.download('wordnet') 15 | # nltk.download('stopwords') 16 | 17 | import language_evaluation 18 | evaluator = language_evaluation.CocoEvaluator() 19 | 20 | 21 | def nltk_process(text): 22 | # Tokenization 23 | nltk_tokenList = word_tokenize(text) 24 | 25 | # Stemming 26 | nltk_stemedList = [] 27 | for word in nltk_tokenList: 28 | nltk_stemedList.append(p_stemmer.stem(word)) 29 | 30 | filtered_sentence = nltk_stemedList 31 | 32 | # Removing Punctuation 33 | 34 | tokens = [re.sub(r'[^a-zA-Z0-9]', '', tok) for tok in filtered_sentence] 35 | 36 | text = " ".join(tokens) 37 | 38 | return text 39 | 40 | 41 | def calculate_finegrained_scores(pred_id2sent, id2caption, use_coco_eval=False): 42 | if use_coco_eval: 43 | n_total = 0 44 | refs = [] 45 | hyps = [] 46 | for id, gt_captions in id2caption.items(): 47 | pred_sent = pred_id2sent[id] 48 | 49 | refs.append(gt_captions) 50 | hyps.append(pred_sent) 51 | 52 | n_total += 1 53 | 54 | print('caption') 55 | results = evaluator.run_evaluation(hyps, refs) 56 | pprint(results) 57 | 58 | n_total = 0 59 | total_score = 0 60 | for id, gt_phrases in id2background.items(): 61 | pred_sent = pred_id2sent[id] 62 | 63 | score = 0 64 | n_phrases = len(gt_phrases) 65 | 66 | for gt_phrase in gt_phrases: 67 | word_score = 0 68 | for gt_word in gt_phrase.split(): 69 | if gt_word in pred_sent: 70 | word_score += 1 71 | if len(gt_phrase.split()) > 0: 72 | score += word_score / len(gt_phrase.split()) 73 | 74 | if n_phrases > 0: 75 | score /= n_phrases 76 | 77 | total_score += score 78 | n_total += 1 79 | print('background') 80 | # print('# retrieved words:', n_retrieved) 81 | print(f'Acc: {total_score / n_total * 100:.2f}') 82 | 83 | n_total = 0 84 | total_score = 0 85 | for id, gt_phrases in id2object.items(): 86 | pred_sent = pred_id2sent[id] 87 | 88 | score = 0 89 | n_phrases = len(gt_phrases) 90 | 91 | for gt_phrase in gt_phrases: 92 | word_score = 0 93 | for gt_word in gt_phrase.split(): 94 | if gt_word in pred_sent: 95 | word_score += 1 96 | if len(gt_phrase.split()) > 0: 97 | score += word_score / len(gt_phrase.split()) 98 | 99 | if n_phrases > 0: 100 | score /= n_phrases 101 | 102 | total_score += score 103 | n_total += 1 104 | print('object') 105 | # print('# retrieved words:', n_retrieved) 106 | print(f'Acc: {total_score / n_total * 100:.2f}') 107 | 108 | n_total = 0 109 | total_score = 0 110 | for id, gt_phrases in id2relation.items(): 111 | pred_sent = pred_id2sent[id] 112 | 113 | score = 0 114 | n_phrases = len(gt_phrases) 115 | 116 | for gt_phrase in gt_phrases: 117 | word_score = 0 118 | for gt_word in gt_phrase.split(): 119 | if gt_word in pred_sent: 120 | word_score += 1 121 | if len(gt_phrase.split()) > 0: 122 | score += word_score / len(gt_phrase.split()) 123 | 124 | if n_phrases > 0: 125 | score /= n_phrases 126 | 127 | total_score += score 128 | n_total += 1 129 | print('relation') 130 | # print('# retrieved words:', n_retrieved) 131 | print(f'Acc: {total_score / n_total * 100:.2f}') 132 | 133 | 134 | if __name__ == '__main__': 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument('--finecapeval_path', type=str, default="data/FineCapEval.csv") 137 | parser.add_argument('--generated_id2caption', type=str, default="FineCapEval_results/mle.json") 138 | args = parser.parse_args() 139 | 140 | df = pd.read_csv(args.finecapeval_path) 141 | assert df.shape == (5000, 5) 142 | 143 | generated_id2caption = json.load(open(args.generated_id2caption, 'r')) 144 | 145 | print("Preprocessing GT FineCapEval data...") 146 | id2caption = {} 147 | id2background = {} 148 | id2object = {} 149 | id2relation = {} 150 | 151 | for row in tqdm(df.itertuples(), total=len(df)): 152 | 153 | id = row.image.split('.')[0] 154 | caption = row.caption 155 | background = row.background 156 | object = row.object 157 | relation = row.relation 158 | 159 | if not isinstance(caption, str): 160 | continue 161 | if not isinstance(background, str): 162 | continue 163 | if not isinstance(object, str): 164 | continue 165 | if not isinstance(relation, str): 166 | continue 167 | 168 | if id not in id2caption: 169 | id2caption[id] = [] 170 | id2background[id] = [] 171 | id2object[id] = [] 172 | id2relation[id] = [] 173 | 174 | id2caption[id].append(caption) 175 | 176 | phrases = [] 177 | for phrase in background.lower().split('\;'): 178 | if len(phrase) > 1: 179 | phrase = nltk_process(phrase) 180 | phrases.append(phrase) 181 | id2background[id].extend(phrases) 182 | 183 | phrases = [] 184 | for phrase in object.lower().split('\;'): 185 | if len(phrase) > 1: 186 | phrase = nltk_process(phrase) 187 | phrases.append(phrase) 188 | id2object[id].extend(phrases) 189 | 190 | phrases = [] 191 | for phrase in relation.lower().split('\;'): 192 | if len(phrase) > 1: 193 | phrase = nltk_process(phrase) 194 | phrases.append(phrase) 195 | id2relation[id].extend(phrases) 196 | 197 | print("Calculating scores...") 198 | calculate_finegrained_scores( 199 | generated_id2caption, 200 | id2caption, 201 | use_coco_eval=True) 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /tools/finecapeval_inference.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | import numpy as np 8 | 9 | import time 10 | import os 11 | from collections import defaultdict 12 | import json 13 | 14 | import captioning.utils.opts as opts 15 | import captioning.models as models 16 | from captioning.data.pth_loader import CaptionDataset 17 | import captioning.utils.eval_utils as eval_utils 18 | # import captioning.utils.vizwiz_eval_utils as vizwiz_eval_utils 19 | import captioning.utils.misc as utils 20 | from captioning.utils.rewards import init_scorer, get_self_critical_reward 21 | from captioning.modules.loss_wrapper import LossWrapper 22 | 23 | import pytorch_lightning as pl 24 | 25 | 26 | class ModelCheckpoint(pl.callbacks.ModelCheckpoint): 27 | 28 | def on_keyboard_interrupt(self, trainer, pl_module): 29 | # Save model when keyboard interrupt 30 | filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') 31 | self._save_model(filepath) 32 | 33 | 34 | if __name__ == '__main__': 35 | 36 | device = 'cuda' 37 | 38 | import argparse 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('--reward', type=str, default='mle') 41 | args = parser.parse_args() 42 | 43 | if args.reward == 'mle': 44 | cfg = f'configs/phase1/fg_clipRN50_{args.reward}.yml' 45 | else: 46 | cfg = f'configs/phase2/fg_clipRN50_{args.reward}.yml' 47 | 48 | print("Loading cfg from", cfg) 49 | 50 | opt = opts.parse_opt(parse=False, cfg=cfg) 51 | 52 | dataset = CaptionDataset(opt) 53 | 54 | opt.vocab_size = dataset.vocab_size 55 | opt.seq_length = dataset.seq_length 56 | 57 | opt.batch_size = 40 58 | 59 | opt.vocab = dataset.get_vocab() 60 | 61 | model = models.setup(opt) 62 | del opt.vocab 63 | 64 | ckpt_path = opt.checkpoint_path + '-last.ckpt' 65 | 66 | print("Loading checkpoint from", ckpt_path) 67 | raw_state_dict = torch.load( 68 | ckpt_path, 69 | map_location=device) 70 | 71 | strict = True 72 | 73 | state_dict = raw_state_dict['state_dict'] 74 | 75 | if '_vocab' in state_dict: 76 | model.vocab = utils.deserialize(state_dict['_vocab']) 77 | del state_dict['_vocab'] 78 | elif strict: 79 | raise KeyError 80 | if '_opt' in state_dict: 81 | saved_model_opt = utils.deserialize(state_dict['_opt']) 82 | del state_dict['_opt'] 83 | # Make sure the saved opt is compatible with the curren topt 84 | need_be_same = ["caption_model", 85 | "rnn_type", "rnn_size", "num_layers"] 86 | for checkme in need_be_same: 87 | if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ 88 | getattr(opt, checkme) in ['updown', 'topdown']: 89 | continue 90 | assert getattr(saved_model_opt, checkme) == getattr( 91 | opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme 92 | elif strict: 93 | raise KeyError 94 | res = model.load_state_dict(state_dict, strict) 95 | print(res) 96 | 97 | opt.use_grammar = False 98 | 99 | lw_model = LossWrapper(model, opt) 100 | 101 | split = 'test' 102 | 103 | print("Building dataloader...") 104 | 105 | test_dataset = torch.utils.data.Subset( 106 | dataset, 107 | dataset.split_ix[split] 108 | ) 109 | test_loader = torch.utils.data.DataLoader( 110 | test_dataset, 111 | batch_size=opt.batch_size, 112 | shuffle=False, 113 | num_workers=4, 114 | drop_last=False, 115 | collate_fn=dataset.collate_func 116 | ) 117 | 118 | eval_kwargs = {'dataset': opt.input_json} 119 | eval_kwargs.update(vars(opt)) 120 | 121 | verbose = eval_kwargs.get('verbose', True) 122 | verbose_beam = eval_kwargs.get('verbose_beam', 0) 123 | verbose_loss = eval_kwargs.get('verbose_loss', 1) 124 | # num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) 125 | # lang_eval = eval_kwargs.get('language_eval', 0) 126 | dataset = eval_kwargs.get('dataset', 'coco') 127 | beam_size = eval_kwargs.get('beam_size', 1) 128 | sample_n = eval_kwargs.get('sample_n', 1) 129 | remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) 130 | 131 | crit = lw_model.crit 132 | 133 | model = model.to(device) 134 | 135 | from tqdm import tqdm 136 | 137 | test_id2sent = {} 138 | 139 | model.eval() 140 | 141 | print("running inference...") 142 | 143 | for data in tqdm(test_loader): 144 | with torch.no_grad(): 145 | # forward the model to get loss 146 | tmp = [data['fc_feats'], data['att_feats'], 147 | data['labels'], data['masks'], data['att_masks']] 148 | tmp = [d.to(device) if isinstance(d, torch.Tensor) else d for d in tmp] 149 | 150 | fc_feats, att_feats, labels, masks, att_masks = tmp 151 | 152 | loss = crit(model(fc_feats, att_feats, 153 | labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) 154 | 155 | # forward the model to also get generated samples for each image 156 | # Only leave one feature for each image, in case duplicate sample 157 | tmp_eval_kwargs = eval_kwargs.copy() 158 | tmp_eval_kwargs.update({'sample_n': 1}) 159 | seq, seq_logprobs = model( 160 | fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') 161 | seq = seq.data 162 | entropy = - (F.softmax(seq_logprobs, dim=2) * 163 | seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) 164 | perplexity = - \ 165 | seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze( 166 | 2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) 167 | 168 | # Print beam search 169 | if beam_size > 1 and verbose_beam: 170 | for i in range(fc_feats.shape[0]): 171 | print('\n'.join([utils.decode_sequence(model.vocab, _[ 172 | 'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) 173 | print('--' * 10) 174 | sents = utils.decode_sequence(model.vocab, seq) 175 | 176 | for d, sent in zip(data['infos'], sents): 177 | test_id2sent[d['id']] = sent 178 | 179 | res_path = f'FineCapEval_results/clipRN50_{args.reward}.json' 180 | 181 | print("Results save at {}".format(res_path)) 182 | 183 | with open(res_path, 'w') as f: 184 | json.dump(test_id2sent, f) 185 | 186 | 187 | --------------------------------------------------------------------------------