├── visdial ├── __init__.py ├── data │ ├── __init__.py │ ├── vocabulary.py │ ├── readers.py │ └── dataset.py ├── utils │ ├── __init__.py │ ├── utils.py │ └── checkpointing.py ├── optim │ ├── __init__.py │ ├── lr_scheduler.py │ └── adam.py ├── decoders │ ├── __init__.py │ ├── decoder.py │ ├── disc_decoder.py │ └── gen_decoder.py ├── common │ ├── __init__.py │ ├── utils.py │ ├── self_attention.py │ ├── embeddings.py │ └── dynamic_rnn.py ├── encoders │ ├── __init__.py │ ├── encoder.py │ ├── img_encoder.py │ ├── attn_encoder.py │ └── text_encoder.py ├── loss.py ├── model.py └── metrics.py ├── datasets ├── annotations │ └── .gitignore ├── bottom-up-attention │ └── .gitignore └── genome │ └── 1600-400-20 │ ├── attributes_vocab.txt │ └── objects_vocab.txt ├── .gitignore ├── evaluate.py ├── requirements.txt ├── options.py ├── finetune.py ├── train.py ├── others └── generate_visdial.py └── README.md /visdial/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/annotations/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/bottom-up-attention/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /visdial/data/__init__.py: -------------------------------------------------------------------------------- 1 | from visdial.data.dataset import VisDialDataset 2 | -------------------------------------------------------------------------------- /visdial/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import check_flag, clones, move_to_cuda, get_num_params 2 | -------------------------------------------------------------------------------- /visdial/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adam import Adam, get_weight_decay_params 2 | from .lr_scheduler import LRScheduler 3 | -------------------------------------------------------------------------------- /visdial/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import Decoder 2 | from .disc_decoder import DiscriminativeDecoder 3 | from .gen_decoder import GenerativeDecoder 4 | -------------------------------------------------------------------------------- /visdial/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import clones, check_flag 2 | from .dynamic_rnn import DynamicRNN 3 | from .embeddings import PositionalEmbedding 4 | from .self_attention import SelfAttention 5 | -------------------------------------------------------------------------------- /visdial/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from visdial.encoders.text_encoder import HistEncoder 2 | from visdial.encoders.text_encoder import TextEncoder 3 | from visdial.encoders.text_encoder import QuesEncoder 4 | from visdial.encoders.encoder import Encoder 5 | from visdial.encoders.img_encoder import ImageEncoder 6 | from visdial.encoders.attn_encoder import AttentionStackEncoder 7 | -------------------------------------------------------------------------------- /visdial/common/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch.nn as nn 3 | 4 | 5 | def clones(module, N): 6 | "Produce N identical modules" 7 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 8 | 9 | 10 | def check_flag(d, key): 11 | "Check whether the dictionary `d` has `key` and `d[key]` is True" 12 | return d.get(key) is not None and d.get(key) 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Large file 7 | *.pkl 8 | *.h5 9 | 10 | 11 | # C extensions 12 | *.so 13 | 14 | 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | 35 | # Datasets, pretrained models, checkpoints and preprocessed files 36 | !visdial/data/ 37 | checkpoints/ 38 | logs/ 39 | .idea 40 | .DS_Store 41 | 42 | 43 | # IPython Notebook 44 | .ipynb_checkpoints 45 | -------------------------------------------------------------------------------- /visdial/utils/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch.nn as nn 3 | 4 | 5 | def clones(module, N): 6 | "Produce N identical layers." 7 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 8 | 9 | 10 | def check_flag(d, key): 11 | return d.get(key) is not None and d.get(key) 12 | 13 | 14 | def get_num_params(module): 15 | """Compute the number of parameters of the module""" 16 | pp = 0 17 | for p in list(module.parameters()): 18 | nn = 1 19 | for s in list(p.size()): 20 | nn = nn * s 21 | pp += nn 22 | return pp 23 | 24 | 25 | def move_to_cuda(batch, device): 26 | for key in batch: 27 | batch[key] = batch[key].to(device) 28 | return batch 29 | -------------------------------------------------------------------------------- /visdial/common/self_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SelfAttention(nn.Module): 6 | """This module perform self-attention on an utility 7 | to summarize it into a single vector.""" 8 | 9 | def __init__(self, hidden_size): 10 | super(SelfAttention, self).__init__() 11 | self.attn_linear = nn.Sequential( 12 | nn.Linear(hidden_size, hidden_size), 13 | nn.ReLU(inplace=True), 14 | nn.Linear(hidden_size, 1) 15 | ) 16 | self.attn_weights = None 17 | 18 | def forward(self, x, mask_x): 19 | """ 20 | Arguments 21 | --------- 22 | x: torch.FloatTensor 23 | The input tensor which is a sequence of tokens 24 | Shape [batch_size, M, hidden_size] 25 | mask_x: torch.LongTensor 26 | The mask of the input x where 0 represents the token 27 | Shape [batch_size, M] 28 | Returns 29 | ------- 30 | summarized_vector: torch.FloatTensor 31 | The summarized vector of the utility (the context vector for this utility) 32 | Shape [batch_size, hidden_size] 33 | """ 34 | 35 | # shape [bs, M, 1] 36 | attn_weights = self.attn_linear(x) 37 | attn_weights = attn_weights.masked_fill(mask_x.unsqueeze(-1) == 0, value=-9e10) 38 | attn_weights = torch.softmax(attn_weights, dim=-2) 39 | self.attn_weights = attn_weights 40 | 41 | # shape [bs, 1, hidden_size] 42 | summarized_vector = torch.matmul(attn_weights.transpose(-2, -1), x) 43 | summarized_vector = summarized_vector.squeeze(dim=-2) 44 | return summarized_vector 45 | -------------------------------------------------------------------------------- /visdial/common/embeddings.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class PositionalEmbedding(nn.Module): 7 | """ 8 | Compute the Positional Embedding based on 9 | BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 10 | https://arxiv.org/pdf/1810.04805 11 | """ 12 | 13 | def __init__(self, embedding_size, max_len=512): 14 | super().__init__() 15 | 16 | # Compute the positional encodings once in log space. 17 | pe = torch.zeros(max_len, embedding_size).float() 18 | pe.require_grad = False 19 | 20 | position = torch.arange(0, max_len).float().unsqueeze(1) 21 | div_term = (torch.arange(0, embedding_size, 2).float() * -(math.log(10000.0) / embedding_size)).exp() 22 | 23 | pe[:, 0::2] = torch.sin(position * div_term) 24 | pe[:, 1::2] = torch.cos(position * div_term) 25 | 26 | pe = pe.unsqueeze(0) 27 | self.register_buffer('pe', pe) 28 | 29 | def forward(self, x): 30 | """ 31 | Arguments 32 | --------- 33 | x: torch.FloatTensor 34 | The input tensor which is a sequence of tokens 35 | Shape [batch_size, seq_len, ...] is expected! 36 | Returns 37 | ------- 38 | pos_embedding: torch.FloatTensor 39 | The positional embeddings for all the tokens in the sequence! 40 | Shape [batch_size, seq_len, hidden_size] 41 | """ 42 | 43 | # shape [BS, seq_len] 44 | bs, seq_len = x.size(0), x.size(1) 45 | 46 | # shape [1, seq_len, embedding_size] 47 | pos_embedding = self.pe[:, :seq_len] 48 | 49 | # shape [BS, seq_len, embedding_size] 50 | pos_embedding = pos_embedding.repeat(bs, 1, 1) 51 | 52 | return pos_embedding 53 | -------------------------------------------------------------------------------- /visdial/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Encoder(nn.Module): 6 | 7 | def __init__(self, config, text_encoder, img_encoder, attn_encoder): 8 | super(Encoder, self).__init__() 9 | self.text_encoder = text_encoder 10 | self.img_encoder = img_encoder 11 | self.attn_encoder = attn_encoder 12 | self.config = config 13 | 14 | def forward(self, batch, test_mode=False): 15 | """ 16 | Arguments 17 | --------- 18 | batch: Dictionary 19 | This provides a dictionary of inputs. 20 | test_mode: Boolean 21 | Whether the forward is performed on test data 22 | Returns 23 | ------- 24 | batch_output: a tuple of the following 25 | im: torch.FloatTensor 26 | The representation of image utility 27 | Shape [batch_size x NH, K, hidden_size] 28 | qe: torch.FloatTensor 29 | The representation of question utility 30 | Shape [batch_size x NH, N, hidden_size] 31 | hi: torch.FloatTensor 32 | The representation of history utility 33 | Shape [batch_size x NH, T, hidden_size] 34 | mask_im: torch.LongTensor 35 | Shape [batch_size x NH, K] 36 | mask_qe: torch.LongTensor 37 | Shape [batch_size x NH, N] 38 | mask_hi: torch.LongTensor 39 | Shape [batch_size x NH, T] 40 | 41 | It is noted 42 | K is num_proposals, 43 | T is the number of rounds 44 | N is the max sequence length in the question. 45 | """ 46 | 47 | # [BS x NH, T, HS] hist 48 | # [BS x NH, N, HS] ques 49 | # [BS x NH, T] hist_mask 50 | # [BS x NH, N] ques_mask 51 | hist, ques, hist_mask, ques_mask = self.text_encoder(batch, test_mode=test_mode) 52 | 53 | # [BS x NH, K, HS] img 54 | # [BS x NH, K] img_mask 55 | img, img_mask = self.img_encoder(batch, test_mode=test_mode) 56 | 57 | batch_input = img, ques, hist, img_mask, ques_mask, hist_mask 58 | batch_output = self.attn_encoder(batch_input) 59 | return batch_output 60 | -------------------------------------------------------------------------------- /visdial/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FinetuneLoss(nn.Module): 7 | """ 8 | Compute the loss during the fine-tuning 9 | """ 10 | 11 | def __init__(self): 12 | super(FinetuneLoss, self).__init__() 13 | 14 | def forward(self, scores, batch): 15 | """ 16 | Arguments 17 | --------- 18 | scores: torch.FloatTensor 19 | The prediction scores from the model 20 | Shape [N, num_classes] 21 | batch: Dictionary 22 | The input batch provides the relevance scores 23 | Returns 24 | ------- 25 | Loss: torch.FloatTensor 26 | The computed loss (the mean) 27 | Shape [] 28 | """ 29 | # scores [BS, NH, NO] 30 | BS, NH, NO = scores.size() 31 | relev_round_indices = batch['round_id'] - 1 # Must be -1 32 | # [BS, 1, NO] 33 | relev_round_indices = relev_round_indices[:, None, None].repeat(1, 1, NO) 34 | # [BS, 1, NO] 35 | scores = torch.gather(scores, 1, relev_round_indices) 36 | # [BS, NO] 37 | scores = scores.squeeze(dim=1) 38 | scores = nn.functional.log_softmax(scores, dim=-1) 39 | 40 | loss = torch.mean((batch['gt_relevance'] * scores)) * (-1) 41 | return loss 42 | 43 | 44 | def convert_to_one_hot(target, num_classes): 45 | """ 46 | Arguments 47 | --------- 48 | target: torch.LongTensor 49 | The input tensor of ground truth 50 | Shape [N, ] 51 | Returns 52 | ------- 53 | one_hot: torch.LongTensor 54 | The summarized vector of the utility (the context vector for this utility) 55 | Shape [N, num_classes] 56 | """ 57 | one_hot = torch.zeros(*target.size(), num_classes, device=target.device) 58 | return one_hot.scatter_(-1, target.unsqueeze(-1), 1.0) 59 | 60 | 61 | class DiscLoss(nn.Module): 62 | 63 | def __init__(self, return_mean=True): 64 | """ 65 | Arguments 66 | --------- 67 | return_mean: torch.FloatTensor 68 | Whether to return the mean 69 | If not, return the summation of all element loss 70 | """ 71 | super(DiscLoss, self).__init__() 72 | self.return_mean = return_mean 73 | 74 | def forward(self, outputs, target): 75 | """ 76 | Arguments 77 | --------- 78 | outputs: torch.FloatTensor 79 | The prediction scores from the model 80 | Shape [N, num_classes] 81 | target: torch.LongTensor 82 | The input tensor of ground truth 83 | Shape [N, ] 84 | Returns 85 | ------- 86 | Loss: torch.FloatTensor 87 | The computed loss (the summation or the mean) 88 | Shape [] 89 | """ 90 | num_classes = outputs.size(-1) 91 | batch_size = torch.prod(torch.tensor(outputs.size()[:-1])) 92 | 93 | one_hot_target = convert_to_one_hot(target, num_classes=num_classes) 94 | 95 | log_prob = F.log_softmax(outputs, dim=-1) 96 | loss = -1 * torch.sum(one_hot_target * log_prob) 97 | 98 | if self.return_mean: 99 | return loss / batch_size 100 | return loss 101 | -------------------------------------------------------------------------------- /visdial/common/dynamic_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 4 | 5 | 6 | class DynamicRNN(nn.Module): 7 | """ 8 | The wrapper version of recurrent modules including RNN, LSTM 9 | that support packed sequence batch. 10 | """ 11 | 12 | def __init__(self, rnn_module): 13 | super().__init__() 14 | 15 | if isinstance(rnn_module, nn.LSTM): 16 | self.bidirectional = rnn_module.bidirectional 17 | 18 | self.rnn_module = rnn_module 19 | 20 | def forward(self, x, len_x, initial_state=None): 21 | """ 22 | Arguments 23 | --------- 24 | x: torch.FloatTensor 25 | padded input sequence tensor for RNN model 26 | Shape [batch_size, max_seq_len, embed_size] 27 | len_x: torch.LongTensor 28 | Length of sequences (b, ) 29 | initial_state: torch.FloatTensor 30 | Initial (hidden, cell) states of RNN model. 31 | Returns 32 | ------- 33 | A tuple of (padded_output, h_n) or (padded_output, (h_n, c_n)) 34 | padded_output: torch.FloatTensor 35 | The output of all hidden for each elements. 36 | Shape [batch_size, max_seq_len, hidden_size] 37 | h_n: torch.FloatTensor 38 | The hidden state of the last step for each packed sequence (not including padding elements) 39 | Shape [batch_size, hidden_size] 40 | c_n: torch.FloatTensor 41 | If rnn_model is RNN, c_n = None 42 | The cell state of the last step for each packed sequence (not including padding elements) 43 | Shape [batch_size, hidden_size] 44 | """ 45 | 46 | # First sort the sequences in batch in the descending order of length 47 | sorted_len, idx = len_x.sort(dim=0, descending=True) 48 | sorted_x = x[idx] 49 | 50 | # Convert to packed sequence batch 51 | packed_x = pack_padded_sequence(sorted_x, lengths=sorted_len, batch_first=True) 52 | 53 | # Check init_state 54 | if initial_state is not None: 55 | if isinstance(initial_state, tuple): # (h_0, c_0) in LSTM 56 | hx = [state[:, idx] for state in initial_state] 57 | else: 58 | hx = initial_state[:, idx] # h_0 in RNN 59 | else: 60 | hx = None 61 | 62 | # Do forward pass 63 | self.rnn_module.flatten_parameters() 64 | packed_output, last_s = self.rnn_module(packed_x, hx) 65 | 66 | # Pad the packed_output 67 | max_seq_len = x.size(1) 68 | padded_output, _ = pad_packed_sequence(packed_output, batch_first=True, total_length=max_seq_len) 69 | 70 | # Reverse to the original order 71 | _, reverse_idx = idx.sort(dim=0, descending=False) 72 | 73 | # shape: [BS, PaddedSEQ, HS] 74 | padded_output = padded_output[reverse_idx] 75 | 76 | if isinstance(self.rnn_module, nn.RNN): 77 | h_n, c_n = last_s[:, reverse_idx], None 78 | else: 79 | # shape: [num_layers x 2, BS, HS] if bidirectional 80 | # shape: [num_layers, BS, HS] if None 81 | h_n, c_n = [s[:, reverse_idx] for s in last_s] 82 | 83 | # The hidden cells of last layer is (h_n, h_n_inverse) is h_n[-2:, :, ] 84 | return padded_output, (h_n, c_n) 85 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | import os 5 | import torch 6 | from tqdm import tqdm 7 | from torch.utils.data import DataLoader 8 | from visdial.data.dataset import VisDialDataset 9 | from visdial.metrics import SparseGTMetrics, NDCG, scores_to_ranks 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--model_path", default='checkpoints/model_v1.pth') 13 | parser.add_argument("--split", default="test") 14 | parser.add_argument("--decoder_type", default='disc') 15 | parser.add_argument("--device", default="cuda:0") 16 | parser.add_argument("--output_path", default="checkpoints/val.json") 17 | 18 | args = parser.parse_args() 19 | device = args.device 20 | split = args.split 21 | decoder_type = args.decoder_type 22 | model = torch.load(args.model_path) 23 | config = model.encoder.config 24 | 25 | test_mode = False 26 | if args.split == 'test': 27 | test_mode = True 28 | config['dataset']['test_feat_img_path'] = config['dataset']['train_feat_img_path'].replace( 29 | "trainval_resnet101_faster_rcnn_genome__num_boxes", 30 | "test2018_resnet101_faster_rcnn_genome__num_boxes" 31 | ) 32 | config['dataset']['test_json_dialog_path'] = config['dataset']['train_json_dialog_path'].replace( 33 | 'visdial_1.0_train.json', 34 | 'visdial_1.0_test.json' 35 | ) 36 | 37 | model = model.to(device) 38 | 39 | sparse_metrics = SparseGTMetrics() 40 | ndcg = NDCG() 41 | 42 | dataset = VisDialDataset(config, split=args.split) 43 | dataloader = DataLoader(dataset, batch_size=1) 44 | 45 | model = model.eval() 46 | ranks_json = [] 47 | 48 | for idx, batch in enumerate(tqdm(dataloader)): 49 | torch.cuda.empty_cache() 50 | for key in batch: 51 | batch[key] = batch[key].to(device) 52 | 53 | with torch.no_grad(): 54 | output = model(batch, test_mode=test_mode) 55 | 56 | if decoder_type == 'misc': 57 | output = (output['opts_out_scores'] + output['opt_scores']) / 2.0 58 | elif decoder_type == 'disc': 59 | output = output['opt_scores'] 60 | elif decoder_type == 'gen': 61 | output = output['opts_out_scores'] 62 | ranks = scores_to_ranks(output) 63 | 64 | for i in range(len(batch["img_ids"])): 65 | if split == split: 66 | ranks_json.append( 67 | { 68 | "image_id": batch["img_ids"][i].item(), 69 | "round_id": int(batch["num_rounds"][i].item()), 70 | "ranks": [ 71 | rank.item() 72 | for rank in ranks[i][0] 73 | ], 74 | } 75 | ) 76 | else: 77 | for j in range(batch["num_rounds"][i]): 78 | ranks_json.append( 79 | { 80 | "image_id": batch["img_ids"][i].item(), 81 | "round_id": int(j + 1), 82 | "ranks": [rank.item() for rank in ranks[i][j]], 83 | } 84 | ) 85 | 86 | if split == 'val' and not config['dataset']['v0.9']: 87 | sparse_metrics.observe(output, batch['ans_ind']) 88 | output = output[torch.arange(output.size(0)), batch['round_id'] - 1, :] 89 | ndcg.observe(output, batch["gt_relevance"]) 90 | 91 | jpath = args.output_path 92 | 93 | print("Writing ranks to {}".format(jpath)) 94 | os.makedirs(os.path.dirname(jpath), exist_ok=True) 95 | json.dump(ranks_json, open(jpath, "w")) 96 | 97 | all_metrics = {} 98 | all_metrics.update(sparse_metrics.retrieve(reset=True)) 99 | all_metrics.update(ndcg.retrieve(reset=True)) 100 | for metric_name, metric_value in all_metrics.items(): 101 | print(f"{metric_name}: {metric_value}") 102 | -------------------------------------------------------------------------------- /visdial/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from visdial.decoders import DiscriminativeDecoder, GenerativeDecoder, Decoder 5 | from visdial.encoders import ImageEncoder, TextEncoder, HistEncoder, QuesEncoder, AttentionStackEncoder, Encoder 6 | 7 | 8 | class VisdialModel(nn.Module): 9 | def __init__(self, encoder, decoder, init_type='kaiming_uniform'): 10 | super(VisdialModel, self).__init__() 11 | self.encoder = encoder 12 | self.decoder = decoder 13 | self.init_type = init_type 14 | self.apply(self.weight_init) 15 | 16 | def load_my_state_dict(self, state_dict): 17 | own_state = self.state_dict() 18 | for name, param in state_dict.items(): 19 | if name not in own_state: 20 | continue 21 | if isinstance(param, nn.Parameter): 22 | param = param.data 23 | own_state[name].copy_(param) 24 | 25 | def forward(self, batch, test_mode=False): 26 | return self.decoder(batch, self.encoder(batch, test_mode=test_mode), test_mode=test_mode) 27 | 28 | def weight_init(self, m): 29 | init_dict = { 30 | 'kaiming_uniform': nn.init.kaiming_uniform_, 31 | 'kaiming_normal': nn.init.kaiming_normal_, 32 | } 33 | 34 | if isinstance(m, nn.Linear): 35 | init_dict[self.init_type](m.weight) 36 | 37 | if m.bias is not None: 38 | nn.init.constant_(m.bias, 0) 39 | 40 | 41 | def get_attn_encoder(config): 42 | encoder = Encoder( 43 | config=config, 44 | text_encoder=TextEncoder(config, HistEncoder(config), QuesEncoder(config)), 45 | img_encoder=ImageEncoder(config), 46 | attn_encoder=AttentionStackEncoder(config), 47 | ) 48 | return encoder 49 | 50 | 51 | def get_disc_model(config): 52 | encoder = get_attn_encoder(config) 53 | encoder.img_encoder.text_embedding = encoder.text_encoder.text_embedding 54 | 55 | disc_decoder = DiscriminativeDecoder(config) 56 | disc_decoder.text_embedding = encoder.text_encoder.text_embedding 57 | 58 | gen_decoder = None 59 | decoder = Decoder(config, disc_decoder, gen_decoder) 60 | 61 | model = VisdialModel(encoder, decoder) 62 | return model 63 | 64 | 65 | def get_gen_model(config): 66 | encoder = get_attn_encoder(config) 67 | encoder.img_encoder.text_embedding = encoder.text_encoder.text_embedding 68 | 69 | disc_decoder = None 70 | gen_decoder = GenerativeDecoder(config) 71 | gen_decoder.text_embedding = encoder.text_encoder.text_embedding 72 | decoder = Decoder(config, disc_decoder, gen_decoder) 73 | return VisdialModel(encoder, decoder) 74 | 75 | 76 | def get_misc_model(config): 77 | encoder = get_attn_encoder(config) 78 | encoder.img_encoder.text_embedding = encoder.text_encoder.text_embedding 79 | 80 | disc_decoder = DiscriminativeDecoder(config) 81 | disc_decoder.text_embedding = encoder.text_encoder.text_embedding 82 | 83 | gen_decoder = GenerativeDecoder(config) 84 | gen_decoder.text_embedding = encoder.text_encoder.text_embedding 85 | 86 | decoder = Decoder(config, disc_decoder, gen_decoder) 87 | return VisdialModel(encoder, decoder) 88 | 89 | 90 | def get_model(config): 91 | get_model_dict = { 92 | 'gen': get_gen_model, 93 | 'disc': get_disc_model, 94 | 'misc': get_misc_model 95 | } 96 | 97 | model = get_model_dict[config['model']['decoder_type']](config) 98 | glove_path = os.path.expanduser(config['dataset']['glove_path']) 99 | glove_weights = torch.load(glove_path) 100 | model.encoder.text_encoder.text_embedding.load_state_dict(glove_weights) 101 | return model 102 | -------------------------------------------------------------------------------- /visdial/data/vocabulary.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Vocabulary maintains a mapping between words and corresponding unique 3 | integers, holds special integers (tokens) for indicating start and end of 4 | sequence, and offers functionality to map out-of-vocabulary words to the 5 | corresponding token. 6 | Credit: 7 | https://github.com/batra-mlp-lab/visdial-challenge-starter-pytorch/blob/master/visdialch/data/vocabulary.py 8 | """ 9 | import json 10 | import os 11 | from typing import List 12 | 13 | 14 | class Vocabulary(object): 15 | """ 16 | A simple Vocabulary class which maintains a mapping between words and 17 | integer tokens. Can be initialized either by word counts from the VisDial 18 | v1.0 train dataset, or a pre-saved vocabulary mapping. 19 | Parameters 20 | ---------- 21 | word_counts_path: str 22 | Path to a json file containing counts of each word across captions, 23 | questions and answers of the VisDial v1.0 train dataset. 24 | min_count : int, optional (default=0) 25 | When initializing the vocabulary from word counts, you can specify a 26 | minimum count, and every token with a count less than this will be 27 | excluded from vocabulary. 28 | """ 29 | 30 | PAD_TOKEN = "" 31 | SOS_TOKEN = "" 32 | EOS_TOKEN = "" 33 | UNK_TOKEN = "" 34 | 35 | PAD_INDEX = 0 36 | SOS_INDEX = 1 37 | EOS_INDEX = 2 38 | UNK_INDEX = 3 39 | 40 | def __init__(self, word_counts_path: str, min_count: int = 5): 41 | if not os.path.exists(word_counts_path): 42 | raise FileNotFoundError( 43 | f"Word counts do not exist at {word_counts_path}" 44 | ) 45 | 46 | with open(word_counts_path, "r") as word_counts_file: 47 | word_counts = json.load(word_counts_file) 48 | 49 | # form a list of (word, count) tuples and apply min_count threshold 50 | word_counts = [ 51 | (word, count) 52 | for word, count in word_counts.items() 53 | if count >= min_count 54 | ] 55 | # sort in descending order of word counts 56 | word_counts = sorted(word_counts, key=lambda wc: -wc[1]) 57 | words = [w[0] for w in word_counts] 58 | 59 | self.word2index = {} 60 | self.word2index[self.PAD_TOKEN] = self.PAD_INDEX 61 | self.word2index[self.SOS_TOKEN] = self.SOS_INDEX 62 | self.word2index[self.EOS_TOKEN] = self.EOS_INDEX 63 | self.word2index[self.UNK_TOKEN] = self.UNK_INDEX 64 | for index, word in enumerate(words): 65 | self.word2index[word] = index + 4 66 | 67 | self.index2word = { 68 | index: word for word, index in self.word2index.items() 69 | } 70 | 71 | @classmethod 72 | def from_saved(cls, saved_vocabulary_path: str) -> "Vocabulary": 73 | """Build the vocabulary from a json file saved by ``save`` method. 74 | Parameters 75 | ---------- 76 | saved_vocabulary_path : str 77 | Path to a json file containing word to integer mappings 78 | (saved vocabulary). 79 | """ 80 | with open(saved_vocabulary_path, "r") as saved_vocabulary_file: 81 | cls.word2index = json.load(saved_vocabulary_file) 82 | cls.index2word = { 83 | index: word for word, index in cls.word2index.items() 84 | } 85 | 86 | def convert_tokens_to_ids(self, words: List[str]) -> List[int]: 87 | return [self.word2index.get(word, self.UNK_INDEX) for word in words] 88 | 89 | def convert_ids_to_tokens(self, indices: List[int]) -> List[str]: 90 | return [ 91 | self.index2word.get(index, self.UNK_TOKEN) for index in indices 92 | ] 93 | 94 | def save(self, save_vocabulary_path: str) -> None: 95 | with open(save_vocabulary_path, "w") as save_vocabulary_file: 96 | json.dump(self.word2index, save_vocabulary_file) 97 | 98 | def __len__(self): 99 | return len(self.index2word) 100 | -------------------------------------------------------------------------------- /visdial/decoders/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from visdial.common import SelfAttention 4 | 5 | 6 | class Decoder(nn.Module): 7 | """The wrapper Decoder which includes Discriminative or Generative decoders 8 | """ 9 | 10 | def __init__(self, config, disc_decoder=None, gen_decoder=None): 11 | super(Decoder, self).__init__() 12 | self.disc_decoder = disc_decoder 13 | self.gen_decoder = gen_decoder 14 | self.config = config 15 | hidden_size = self.config['model']['hidden_size'] 16 | 17 | if self.config['model'].get('encoder_out') is not None: 18 | num_feats = len(self.config['model']['encoder_out']) 19 | 20 | if 'img' in self.config['model']['encoder_out']: 21 | self.img_summary = SelfAttention(hidden_size) 22 | 23 | if 'hist' in self.config['model']['encoder_out']: 24 | self.hist_summary = SelfAttention(hidden_size) 25 | 26 | if 'ques' in self.config['model']['encoder_out']: 27 | self.ques_summary = SelfAttention(hidden_size) 28 | 29 | self.context_linear = nn.Sequential( 30 | nn.Linear(hidden_size * num_feats, hidden_size), 31 | nn.ReLU(inplace=True), 32 | nn.Linear(hidden_size, hidden_size), 33 | nn.LayerNorm(hidden_size)) 34 | 35 | def forward(self, batch, encoder_output, test_mode=False): 36 | """ 37 | Arguments 38 | --------- 39 | batch: Dictionary 40 | This provides a dictionary of inputs. 41 | encoder_output: A tuple of encoder output: 42 | img: torch.FloatTensor 43 | Shape [batch_size x NH, K, hidden_size] 44 | ques: torch.FloatTensor 45 | Shape [batch_size x NH, N, hidden_size] 46 | hist: torch.FloatTensor 47 | Shape [batch_size x NH, T, hidden_size] 48 | img_mask: torch.LongTensor 49 | Shape [batch_size x NH, K] 50 | ques_mask: torch.LongTensor 51 | Shape [batch_size x NH, N] 52 | hist_mask: torch.LongTensor 53 | Shape [batch_size x NH, T] 54 | 55 | test_mode: Boolean 56 | Whether the forward is performed on test data 57 | Returns 58 | ------- 59 | output : Dictionary 60 | output['opt_scores']: torch.FloatTensor 61 | The output from Discriminative Decoder 62 | Shape: [batch_size, NH, num_options] 63 | 64 | output['opts_out_scores']: torch.FloatTensor 65 | The output from Generative Decoder (test mode or validation mode) 66 | Shape: [batch_size, NH, num_options] 67 | 68 | output['ans_out_scores']: torch.FloatTensor 69 | The output from Generative Decoder (training mode) 70 | Shape: Shape [batch_size, N, vocab_size] 71 | """ 72 | img, ques, hist, img_mask, ques_mask, hist_mask = encoder_output 73 | 74 | BS, NH = batch['ques_len'].shape 75 | if self.config['model']['test_mode'] or test_mode: 76 | NH = 1 77 | 78 | # Perform self-attention on each utility 79 | encoder_output = [] 80 | if self.config['model'].get('encoder_out') is not None: 81 | if 'img' in self.config['model']['encoder_out']: 82 | encoder_output.append(self.img_summary(img, img_mask)) 83 | 84 | if 'hist' in self.config['model']['encoder_out']: 85 | encoder_output.append(self.hist_summary(hist, hist_mask)) 86 | 87 | if 'ques' in self.config['model']['encoder_out']: 88 | encoder_output.append(self.ques_summary(ques, ques_mask)) 89 | encoder_output = torch.cat(encoder_output, dim=-1) 90 | 91 | # shape [BS x NH, HS] 92 | context_vec = self.context_linear(encoder_output) 93 | # shape [BS, NH, HS] 94 | context_vec = context_vec.view(BS, NH, -1) 95 | 96 | output = {} 97 | if self.disc_decoder is not None: 98 | output['opt_scores'] = self.disc_decoder(batch, 99 | context_vec, 100 | test_mode=test_mode)['opt_scores'] 101 | if self.gen_decoder is not None: 102 | if self.training: 103 | output['ans_out_scores'] = self.gen_decoder(batch, 104 | context_vec, 105 | test_mode=test_mode)['ans_out_scores'] 106 | else: 107 | output['opts_out_scores'] = self.gen_decoder(batch, 108 | context_vec, 109 | test_mode=test_mode)['opts_out_scores'] 110 | return output 111 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.0 2 | alabaster==0.7.12 3 | anaconda-client==1.7.2 4 | anaconda-navigator==1.9.7 5 | anaconda-project==0.8.3 6 | asn1crypto==0.24.0 7 | astroid==2.2.5 8 | astropy==3.2.1 9 | atomicwrites==1.3.0 10 | attrs==19.1.0 11 | Babel==2.7.0 12 | backcall==0.1.0 13 | backports.functools-lru-cache==1.5 14 | backports.os==0.1.1 15 | backports.shutil-get-terminal-size==1.0.0 16 | backports.tempfile==1.0 17 | backports.weakref==1.0.post1 18 | beautifulsoup4==4.7.1 19 | bitarray==0.9.3 20 | bkcharts==0.2 21 | bleach==3.1.0 22 | bokeh==1.2.0 23 | boto==2.49.0 24 | Bottleneck==1.2.1 25 | certifi==2019.6.16 26 | cffi==1.12.3 27 | chardet==3.0.4 28 | Click==7.0 29 | cloudpickle==1.2.1 30 | clyent==1.2.2 31 | colorama==0.4.1 32 | conda==4.7.10 33 | conda-build==3.18.8 34 | conda-package-handling==1.3.11 35 | conda-verify==3.4.2 36 | configobj==5.0.6 37 | contextlib2==0.5.5 38 | cryptography==2.7 39 | cycler==0.10.0 40 | Cython==0.29.12 41 | cytoolz==0.10.0 42 | dask==2.1.0 43 | decorator==4.4.0 44 | defusedxml==0.6.0 45 | distributed==2.1.0 46 | docutils==0.14 47 | entrypoints==0.3 48 | et-xmlfile==1.0.1 49 | everett==1.0.2 50 | fastcache==1.1.0 51 | filelock==3.0.12 52 | Flask==1.1.1 53 | future==0.17.1 54 | gevent==1.4.0 55 | glob2==0.7 56 | gmpy2==2.0.8 57 | greenlet==0.4.15 58 | grpcio==1.24.0 59 | h5py==2.9.0 60 | heapdict==1.0.0 61 | html5lib==1.0.1 62 | idna==2.8 63 | imageio==2.5.0 64 | imagesize==1.1.0 65 | importlib-metadata==0.17 66 | ipykernel==5.1.1 67 | ipython==7.6.1 68 | ipython-genutils==0.2.0 69 | ipywidgets==7.5.0 70 | isort==4.3.21 71 | itsdangerous==1.1.0 72 | jdcal==1.4.1 73 | jedi==0.13.3 74 | jeepney==0.4 75 | Jinja2==2.10.1 76 | joblib==0.13.2 77 | json5==0.8.4 78 | jsonschema==3.0.1 79 | jupyter==1.0.0 80 | jupyter-client==5.3.1 81 | jupyter-console==6.0.0 82 | jupyter-core==4.5.0 83 | jupyterlab==1.0.2 84 | jupyterlab-server==1.0.0 85 | keyring==18.0.0 86 | kiwisolver==1.1.0 87 | lazy-object-proxy==1.4.1 88 | libarchive-c==2.8 89 | lief==0.9.0 90 | llvmlite==0.29.0 91 | locket==0.2.0 92 | lxml==4.3.4 93 | Markdown==3.1.1 94 | MarkupSafe==1.1.1 95 | matplotlib==3.1.0 96 | mccabe==0.6.1 97 | mistune==0.8.4 98 | mkl-fft==1.0.12 99 | mkl-random==1.0.2 100 | mkl-service==2.0.2 101 | mock==3.0.5 102 | more-itertools==7.0.0 103 | mpmath==1.1.0 104 | msgpack==0.6.1 105 | multipledispatch==0.6.0 106 | navigator-updater==0.2.1 107 | nbconvert==5.5.0 108 | nbformat==4.4.0 109 | netifaces==0.10.9 110 | networkx==2.3 111 | nltk==3.4.4 112 | nose==1.3.7 113 | notebook==6.0.0 114 | numba==0.44.1 115 | numexpr==2.6.9 116 | numpy==1.16.4 117 | numpydoc==0.9.1 118 | nvidia-ml-py3==7.352.0 119 | olefile==0.46 120 | opencv-python==4.1.1.26 121 | openpyxl==2.6.2 122 | packaging==19.0 123 | pandas==0.24.2 124 | pandocfilters==1.4.2 125 | parso==0.5.0 126 | partd==1.0.0 127 | path.py==12.0.1 128 | pathlib2==2.3.4 129 | patsy==0.5.1 130 | pep8==1.7.1 131 | pexpect==4.7.0 132 | pickleshare==0.7.5 133 | Pillow==6.1.0 134 | pkginfo==1.5.0.1 135 | pluggy==0.12.0 136 | ply==3.11 137 | prometheus-client==0.7.1 138 | prompt-toolkit==2.0.9 139 | protobuf==3.9.2 140 | psutil==5.6.3 141 | ptyprocess==0.6.0 142 | py==1.8.0 143 | pycodestyle==2.5.0 144 | pycosat==0.6.3 145 | pycparser==2.19 146 | pycrypto==2.6.1 147 | pycurl==7.43.0.3 148 | pyflakes==2.1.1 149 | Pygments==2.4.2 150 | pylint==2.3.1 151 | pyodbc==4.0.26 152 | pyOpenSSL==19.0.0 153 | pyparsing==2.4.0 154 | pyrsistent==0.14.11 155 | PySocks==1.7.0 156 | pytest==5.0.1 157 | pytest-arraydiff==0.3 158 | pytest-astropy==0.5.0 159 | pytest-doctestplus==0.3.0 160 | pytest-openfiles==0.3.2 161 | pytest-remotedata==0.3.1 162 | python-dateutil==2.8.0 163 | pytz==2019.1 164 | PyWavelets==1.0.3 165 | PyYAML==5.1.1 166 | pyzmq==18.0.0 167 | QtAwesome==0.5.7 168 | qtconsole==4.5.1 169 | QtPy==1.8.0 170 | requests==2.22.0 171 | rope==0.14.0 172 | ruamel-yaml==0.15.46 173 | scikit-image==0.15.0 174 | scikit-learn==0.21.2 175 | scipy==1.3.0 176 | seaborn==0.9.0 177 | SecretStorage==3.1.1 178 | Send2Trash==1.5.0 179 | simplegeneric==0.8.1 180 | singledispatch==3.4.0.3 181 | six==1.12.0 182 | snowballstemmer==1.9.0 183 | sortedcollections==1.1.2 184 | sortedcontainers==2.1.0 185 | soupsieve==1.8 186 | Sphinx==2.1.2 187 | sphinxcontrib-applehelp==1.0.1 188 | sphinxcontrib-devhelp==1.0.1 189 | sphinxcontrib-htmlhelp==1.0.2 190 | sphinxcontrib-jsmath==1.0.1 191 | sphinxcontrib-qthelp==1.0.2 192 | sphinxcontrib-serializinghtml==1.1.3 193 | sphinxcontrib-websupport==1.1.2 194 | spyder==3.3.6 195 | spyder-kernels==0.5.1 196 | SQLAlchemy==1.3.5 197 | statsmodels==0.10.0 198 | sympy==1.4 199 | tables==3.5.2 200 | tb-nightly==2.1.0a20190928 201 | tblib==1.4.0 202 | terminado==0.8.2 203 | testpath==0.4.2 204 | toolz==0.10.0 205 | torch==1.2.0 206 | torchstat==0.0.7 207 | torchtext==0.4.0 208 | torchvision==0.4.0 209 | tornado==6.0.3 210 | tqdm==4.32.1 211 | traitlets==4.3.2 212 | unicodecsv==0.14.1 213 | urllib3==1.24.2 214 | wcwidth==0.1.7 215 | webencodings==0.5.1 216 | websocket-client==0.56.0 217 | Werkzeug==0.15.4 218 | widgetsnbextension==3.5.0 219 | wrapt==1.11.2 220 | wurlitzer==1.0.2 221 | xlrd==1.2.0 222 | XlsxWriter==1.1.8 223 | xlwt==1.3.0 224 | zict==1.0.0 225 | zipp==0.5.1 226 | -------------------------------------------------------------------------------- /visdial/decoders/disc_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from visdial.common import DynamicRNN 4 | 5 | 6 | class DiscriminativeDecoder(nn.Module): 7 | """The Discriminative Decoder computes the rankings (prediction scores) for each option (candidate answers) 8 | Given `context_vec` and sequences candidate options. 9 | """ 10 | 11 | def __init__(self, config): 12 | super().__init__() 13 | self.config = config 14 | self.test_mode = self.config['model']['test_mode'] 15 | 16 | self.text_embedding = nn.Embedding(config['model']['txt_vocab_size'], 17 | config['model']['txt_embedding_size'], 18 | padding_idx=0) 19 | 20 | self.opt_lstm = nn.LSTM(config['model']['txt_embedding_size'], 21 | config['model']['hidden_size'], 22 | num_layers=2, 23 | batch_first=True, 24 | dropout=config['model']['dropout'], 25 | bidirectional=config['model']['txt_bidirectional']) 26 | 27 | if config['model']['txt_has_decoder_layer_norm']: 28 | self.layer_norm = nn.LayerNorm(config['model']['hidden_size']) 29 | 30 | self.option_linear = nn.Linear(config['model']['hidden_size'] * 2, 31 | config['model']['hidden_size']) 32 | 33 | # Options are variable length padded sequences, use DynamicRNN. 34 | self.opt_lstm = DynamicRNN(self.opt_lstm) 35 | 36 | def forward(self, batch, context_vec, test_mode=False): 37 | """ 38 | Arguments 39 | --------- 40 | batch: Dictionary 41 | This provides a dictionary of inputs. 42 | context_vec: torch.FloatTensor 43 | The context vector summarized from all utilities 44 | Shape [batch_size, NH, hidden_size] 45 | 46 | test_mode: Boolean 47 | Whether the forward is performed on test data 48 | Returns 49 | ------- 50 | output : Dictionary 51 | output['opt_scores']: torch.FloatTensor 52 | The output from Discriminative Decoder 53 | Shape: [batch_size, NH, num_options] 54 | """ 55 | # shape: [BS, NH, NO, SEQ] 56 | options = batch["opts"] 57 | 58 | # batch_size, num_rounds, num_opts, seq_len 59 | BS, NH, NO, SEQ = options.size() 60 | HS = self.config['model']['hidden_size'] 61 | 62 | # shape: [BS x NH x NO, SEQ] 63 | options = options.view(BS * NH * NO, SEQ) 64 | 65 | # shape: [BS, NH, NO] 66 | options_length = batch["opts_len"] 67 | 68 | # shape: [BS x NH x NO] 69 | options_length = options_length.view(BS * NH * NO) 70 | 71 | # Pick options with non-zero length (relevant for test split). 72 | # shape: [BS x (nR x NO)] <- nR ~= 1 or 10 for test: nR = 1, for train, val nR = 10 73 | nonzero_options_length_indices = options_length.nonzero().squeeze() 74 | 75 | # shape: [BS x (nR x NO)] 76 | nonzero_options_length = options_length[nonzero_options_length_indices] 77 | 78 | # shape: [BS x (nR x NO)] 79 | nonzero_options = options[nonzero_options_length_indices] 80 | 81 | # shape: [BS x NH x NO, SEQ, WE] 82 | # shape: [BS x 1 x NO, SEQ, WE] <- FOR TEST SPLIT 83 | nonzero_options_embed = self.text_embedding(nonzero_options) 84 | 85 | # shape: [lstm_layers x bi, BS x NH x NO, HS] 86 | # shape: [lstm_layers x bi, BS x 1 x NO, HS] FOR TEST SPLIT, 87 | _, (nonzero_options_embed, _) = self.opt_lstm(nonzero_options_embed, nonzero_options_length) 88 | 89 | # shape: [2, BS x NH x NO, HS] 90 | nonzero_options_embed = nonzero_options_embed[-2:] 91 | 92 | # shape: [BS x NH x NO, HS x 2] 93 | nonzero_options_embed = torch.cat([nonzero_options_embed[0], nonzero_options_embed[1]], dim=-1) 94 | 95 | # shape: [BS x NH x NO, HS] 96 | nonzero_options_embed = self.option_linear(nonzero_options_embed) 97 | 98 | # shape: [BS x NH x NO, HS] <- move back to standard for TEST split 99 | options_embed = torch.zeros(BS * NH * NO, HS, device=options.device) 100 | 101 | if self.config['model']['txt_has_decoder_layer_norm']: 102 | options_embed = self.layer_norm(options_embed) 103 | 104 | # shape: [BS x NH x NO, HS] 105 | options_embed[nonzero_options_length_indices] = nonzero_options_embed 106 | 107 | # shape: [BS, NH, HS] -> [BS, NH, HS, 1] 108 | context_vec = context_vec.unsqueeze(-1) 109 | # shape: [BS, NH, NO, HS] 110 | options_embed = options_embed.view(BS, NH, NO, -1) 111 | 112 | # shape: [BS, NH, NO, 1] 113 | scores = torch.matmul(options_embed, context_vec) 114 | 115 | # shape: [BS, NH, NO] 116 | scores = scores.squeeze(-1) 117 | 118 | if self.test_mode: 119 | scores = scores[:, batch['num_rounds'] - 1] 120 | 121 | return {'opt_scores': scores} 122 | -------------------------------------------------------------------------------- /visdial/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from bisect import bisect 3 | 4 | 5 | class LRScheduler(object): 6 | 7 | def __init__(self, optimizer, 8 | batch_size, num_samples, num_epochs, 9 | init_lr=0.001, 10 | min_lr=1e-4, 11 | warmup_factor=0.1, 12 | warmup_epochs=1, 13 | scheduler_type='CosineLR', 14 | milestone_steps=[3, 5, 7, 9, 11, 13], 15 | linear_gama=0.5, 16 | **kwargs 17 | ): 18 | self.optimizer = optimizer 19 | self.scheduler_type = scheduler_type 20 | 21 | self._SCHEDULER = { 22 | 'CosineLR': self.cosine_step, 23 | 'LinearLR': self.linear_step, 24 | 'CosineStepLR': self.cosine_multi_step 25 | } 26 | self.scheduler_step = self._SCHEDULER[scheduler_type] 27 | self.scheduler_type = scheduler_type 28 | 29 | self.batch_size = batch_size 30 | self.num_samples = num_samples 31 | 32 | self.init_lr = init_lr 33 | self.min_lr = min_lr 34 | self.num_epochs = num_epochs 35 | 36 | if num_samples % batch_size == 0: 37 | self.total_iters_per_epoch = num_samples // batch_size 38 | else: 39 | self.total_iters_per_epoch = num_samples // batch_size + 1 40 | 41 | self.warmup_factor = warmup_factor 42 | self.warmup_epochs = warmup_epochs 43 | 44 | self.linear_gama = linear_gama 45 | self.milestone_steps = milestone_steps 46 | 47 | def step(self, cur_iter): 48 | current_epoch = cur_iter / self.total_iters_per_epoch 49 | if current_epoch < 1: 50 | lr = self.warmup_step(cur_iter) 51 | else: 52 | lr = self.scheduler_step(cur_iter) 53 | 54 | lr = self.min_lr if lr < self.min_lr else lr 55 | self.update(lr) 56 | return lr 57 | 58 | def update(self, lr): 59 | for param_group in self.optimizer.param_groups: 60 | param_group['lr'] = lr 61 | 62 | def warmup_step(self, cur_iter): 63 | """ 64 | current_iteration: 1 iter = 1 batch step 65 | (should be understood as the global_steps accumulated from the beginning.) 66 | return the factor. 67 | """ 68 | current_epoch = float(cur_iter) / float(self.total_iters_per_epoch) 69 | alpha = current_epoch / self.warmup_epochs 70 | return self.init_lr * (self.warmup_factor * (1.0 - alpha) + alpha) 71 | 72 | def linear_step(self, cur_iter): 73 | current_epoch = cur_iter // self.total_iters_per_epoch 74 | idx = bisect(self.milestone_steps, current_epoch) 75 | return self.init_lr * pow(self.linear_gama, idx) 76 | 77 | def cosine_step(self, cur_iter): 78 | # current_epoch = cur_iter // self.total_iters_per_epoch 79 | # return self.init_lr * (1 + math.cos(math.pi * current_epoch / self.num_epochs)) / 2 80 | total_iters = self.num_epochs * self.total_iters_per_epoch 81 | return self.init_lr * (1 + math.cos(math.pi * cur_iter / total_iters)) / 2 82 | 83 | def cosine_multi_step(self, cur_iter): 84 | milestones = [1, 8, 16, 24, 32] 85 | 86 | def find_range(cur_epoch): 87 | for i, milestone in enumerate(milestones): 88 | if cur_epoch >= milestone: 89 | continue 90 | else: 91 | return i - 1, milestones[i - 1], milestones[i] 92 | 93 | cur_epoch = int(cur_iter / self.total_iters_per_epoch) 94 | idx, low_epoch, high_epoch = find_range(cur_epoch) 95 | rel_iter = cur_iter - low_epoch * self.total_iters_per_epoch 96 | if idx < 3: 97 | lr = self.init_lr * pow(self.linear_gama, idx) 98 | else: 99 | lr = 1.25e-4 100 | total_iters = (high_epoch - low_epoch) * self.total_iters_per_epoch 101 | return lr * (1 + math.cos(math.pi * rel_iter / total_iters)) / 2 102 | 103 | 104 | def test_lr_scheduler(scheduler_type='LinearLR'): 105 | import torch 106 | import numpy as np 107 | import matplotlib.pyplot as plt 108 | 109 | adam = torch.optim.Adam(torch.nn.Linear(2, 3).parameters(), lr=0.01) 110 | num_epochs = 30 111 | num_iter_per_epoch = 10 112 | batch_size = 8 113 | num_samples = 80 114 | init_lr = 0.01 115 | min_lr = 1e-5 116 | scheduler = LRScheduler(adam, batch_size, 117 | num_samples, num_epochs, 118 | scheduler_type=scheduler_type, 119 | init_lr=init_lr, min_lr=min_lr) 120 | 121 | lr = [] 122 | global_steps = 0 123 | for epoch in range(num_epochs): 124 | for i in range(num_iter_per_epoch): 125 | lr.append(scheduler.step(global_steps)) 126 | # print(adam.state_dict) 127 | global_steps += 1 128 | 129 | for i, l in enumerate(lr): 130 | print("%.5f" % l, end=' ') 131 | if (i + 1) % num_iter_per_epoch == 0: 132 | print("") 133 | 134 | plt.plot(np.arange(global_steps), lr) 135 | plt.ylim(0, init_lr) 136 | plt.show() 137 | -------------------------------------------------------------------------------- /datasets/genome/1600-400-20/attributes_vocab.txt: -------------------------------------------------------------------------------- 1 | gray,grey 2 | multi colored,multi-colored,multicolored 3 | double decker,double-decker 4 | unmade 5 | red 6 | camouflage 7 | blue 8 | white 9 | green 10 | pink 11 | yellow 12 | black 13 | ivory 14 | throwing 15 | orange 16 | spiky 17 | plaid 18 | purple 19 | soccer 20 | brake 21 | blonde 22 | american 23 | flat screen 24 | brown 25 | wooden 26 | performing 27 | pulled back 28 | windshield 29 | bald 30 | chocolate 31 | khaki 32 | apple 33 | blowing 34 | parked 35 | sticking out 36 | fluorescent 37 | glazed 38 | cooking 39 | brick 40 | home 41 | palm 42 | curly 43 | cheese 44 | crashing 45 | calm 46 | christmas 47 | gravel 48 | chain link,chainlink 49 | clear 50 | cloudy 51 | curled 52 | striped 53 | flying 54 | pine 55 | arched 56 | hardwood 57 | silver 58 | framed 59 | one way,oneway 60 | tall 61 | muscular 62 | skiing 63 | tiled 64 | bare 65 | surfing 66 | stuffed 67 | wii 68 | taking off 69 | sleeping 70 | jumping 71 | metal 72 | fire 73 | neon green 74 | soap 75 | park 76 | chalk 77 | license 78 | powdered 79 | up 80 | woven 81 | baby 82 | polar 83 | floppy 84 | toasted 85 | coffee 86 | potted 87 | wet 88 | tennis 89 | dry 90 | balding 91 | carpeted 92 | deep blue 93 | cardboard 94 | pointed 95 | sandy 96 | snow-covered,snow covered 97 | sheer 98 | wood 99 | swimming 100 | traffic 101 | crouching 102 | short 103 | melted 104 | marble 105 | rock 106 | open 107 | paper 108 | stacked 109 | stainless 110 | cluttered 111 | dirt 112 | waving 113 | ripe 114 | salt 115 | rolling 116 | long 117 | clock 118 | maroon 119 | little 120 | triangle 121 | large 122 | sand 123 | fallen 124 | foamy 125 | stack 126 | sliced 127 | blond 128 | plain 129 | straw 130 | busy 131 | checkered 132 | extended 133 | stainless steel,stainless-steel 134 | stone 135 | rocky 136 | laying down 137 | grazing 138 | porcelain 139 | snowboarding 140 | stop 141 | leather 142 | gold 143 | cargo 144 | playing tennis 145 | winter 146 | walking 147 | roman 148 | peeled 149 | plastic 150 | colorful 151 | shining 152 | burnt 153 | messy 154 | tile 155 | cloudless 156 | glass 157 | smiling 158 | fruit 159 | overcast 160 | adult 161 | water 162 | round 163 | birthday 164 | dark 165 | snowy 166 | leafless 167 | young 168 | wicker 169 | skateboarding 170 | cooked 171 | huge 172 | dress 173 | wire 174 | cracked 175 | concrete 176 | laying 177 | grassy 178 | foggy 179 | fried 180 | slice 181 | batting 182 | mountain 183 | halved 184 | ski 185 | statue 186 | still 187 | octagonal 188 | side view 189 | sitting 190 | wavy 191 | floral 192 | running 193 | moving 194 | small 195 | door 196 | wine 197 | closed 198 | cement 199 | splashing 200 | empty 201 | eating 202 | skating 203 | playing 204 | old 205 | tan 206 | leafy 207 | down 208 | electrical 209 | manicured 210 | standing 211 | blurry 212 | choppy 213 | driving 214 | watching 215 | parking 216 | pointy 217 | covering 218 | for sale 219 | reflecting 220 | railroad 221 | golden brown 222 | steep 223 | granite 224 | roll 225 | train 226 | spotted 227 | fluffy 228 | bending 229 | tarmacked 230 | furry 231 | dirty 232 | hanging 233 | above 234 | half full 235 | bright 236 | chrome 237 | toilet paper 238 | squatting 239 | chopped 240 | flowing 241 | neon 242 | skate 243 | rusty 244 | male 245 | covered 246 | outstretched 247 | lit 248 | riding 249 | shirtless 250 | reaching 251 | baseball 252 | iron 253 | night 254 | speckled 255 | bright blue 256 | horizontal 257 | denim 258 | cake 259 | hazy 260 | chipped 261 | police 262 | off 263 | dead 264 | nike 265 | steamed 266 | beige 267 | brunette 268 | short sleeved 269 | laptop 270 | decorated 271 | sharp 272 | perched 273 | clay 274 | made 275 | mesh 276 | street 277 | burgundy 278 | bent 279 | rusted 280 | paved 281 | patterned 282 | painted 283 | flat 284 | landing 285 | light blue 286 | puffy 287 | shaggy 288 | resting 289 | overgrown 290 | bending over 291 | circular 292 | curved 293 | cast 294 | rainbow colored,rainbow 295 | lime green 296 | ceramic 297 | dried 298 | styrofoam 299 | long sleeved,long sleeve 300 | wispy 301 | ocean 302 | big 303 | teal 304 | oval 305 | greenish 306 | murky 307 | tomato 308 | letter 309 | bricked 310 | in air 311 | distant 312 | full 313 | opened 314 | looking 315 | power 316 | holding 317 | browned 318 | growing 319 | backwards 320 | clean 321 | racing 322 | grilled 323 | seasoned 324 | barefoot 325 | kneeling 326 | digital 327 | herd 328 | sliding 329 | recessed 330 | lying 331 | serving 332 | polka dot 333 | cut 334 | ornate 335 | piled 336 | steel 337 | muddy 338 | hilly 339 | raised 340 | hitting 341 | evergreen 342 | sunny 343 | wrist 344 | half 345 | blank 346 | numbered 347 | electric 348 | computer 349 | rolled 350 | whole 351 | lush 352 | daytime 353 | toilet 354 | pointing 355 | asphalt 356 | public 357 | alone 358 | posing 359 | bunch 360 | square 361 | safety 362 | wearing 363 | stripes 364 | bathroom 365 | reflective 366 | assorted 367 | swinging 368 | airborne 369 | dark blue 370 | grass 371 | burned 372 | telephone 373 | docked 374 | pile 375 | laughing 376 | brass 377 | rubber 378 | frosted 379 | hairy 380 | overhead 381 | glowing 382 | soda 383 | number 384 | talking 385 | barren 386 | shaved 387 | shiny 388 | rough 389 | written 390 | older 391 | thin 392 | decorative 393 | wrinkled 394 | peeling 395 | golden 396 | metallic 397 | back 398 | thick 399 | black and white 400 | leaning -------------------------------------------------------------------------------- /visdial/utils/checkpointing.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import warnings 3 | import os 4 | 5 | import torch 6 | from torch import nn, optim 7 | import yaml 8 | from visdial.utils import check_flag 9 | 10 | 11 | class CheckpointManager(object): 12 | """A checkpoint manager saves state dicts of model and optimizer 13 | as .pth files in a specified directory. This class closely follows 14 | the API of PyTorch optimizers and learning rate schedulers. 15 | 16 | Note:: 17 | For ``DataParallel`` modules, ``model.module.state_dict()`` is 18 | saved, instead of ``model.state_dict()``. 19 | 20 | Arguments 21 | ---------- 22 | model: nn.Module 23 | Wrapped model, which needs to be checkpointed. 24 | optimizer: optim.Optimizer 25 | Wrapped optimizer which needs to be checkpointed. 26 | checkpoint_dirpath: str 27 | Path to an empty or non-existent directory to save checkpoints. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | model, 33 | optimizer, 34 | checkpoint_dirpath, 35 | config={}): 36 | 37 | if not isinstance(model, nn.Module): 38 | raise TypeError("{} is not a Module".format(type(model).__name__)) 39 | 40 | if not isinstance(optimizer, optim.Optimizer): 41 | raise TypeError( 42 | "{} is not an Optimizer".format(type(optimizer).__name__) 43 | ) 44 | 45 | self.model = model 46 | self.optimizer = optimizer 47 | self.ckpt_dirpath = Path(checkpoint_dirpath) 48 | self.best_ndcg = 0.0 49 | self.best_mean = 100. 50 | self.init_directory(config) 51 | self.best_ndcg_epoch = 0 52 | self.best_mean_epoch = 0 53 | 54 | def init_directory(self, config={}): 55 | """init""" 56 | self.ckpt_dirpath.mkdir(parents=True, exist_ok=True) 57 | 58 | import json 59 | with open(self.ckpt_dirpath / 'config.json', 'w') as f: 60 | json.dump(config, f) 61 | 62 | def step(self, epoch=None, only_best=False, metrics=None, key=''): 63 | """Save checkpoint if step size conditions meet. """ 64 | if check_flag(self.model.encoder.config['dataset'], 'v0.9'): 65 | self._save_state_dict(str(epoch), epoch, metrics) 66 | return 67 | 68 | if not only_best: 69 | self._save_state_dict(str(epoch), epoch, metrics) 70 | 71 | if metrics[key + 'ndcg'] >= self.best_ndcg: 72 | self.best_ndcg = metrics[key + 'ndcg'] 73 | self.best_ndcg_epoch = epoch 74 | 75 | if metrics[key + 'mean'] >= self.best_ndcg: 76 | self.best_mean = metrics[key + 'mean'] 77 | self.best_mean_epoch = epoch 78 | 79 | else: 80 | if metrics[key + 'ndcg'] >= self.best_ndcg: 81 | self.best_ndcg = metrics[key + 'ndcg'] 82 | self.best_ndcg_epoch = epoch 83 | print('Save best ndcg {} at epoch {}'.format(self.best_ndcg, epoch)) 84 | self._save_state_dict('best_ndcg', epoch, metrics) 85 | 86 | if metrics[key + 'mean'] <= self.best_mean: 87 | self.best_mean = metrics[key + 'mean'] 88 | self.best_ndcg_epoch = epoch 89 | print('Save best mean {} at epoch {}'.format(self.best_mean, epoch)) 90 | self._save_state_dict('best_mean', epoch, metrics) 91 | 92 | self._save_state_dict('last', epoch, metrics) 93 | 94 | def _save_state_dict(self, name, epoch, metrics): 95 | """save state_dict""" 96 | state_dict = {'model': self._get_model(), 97 | 'optimizer': self.optimizer, 98 | 'epoch': epoch, 99 | 'metrics': metrics} 100 | ckpt_path = self.ckpt_dirpath / f"checkpoint_{name}.pth" 101 | torch.save(state_dict, ckpt_path) 102 | 103 | def _get_model(self): 104 | """Returns state dict of model, taking care of DataParallel case.""" 105 | if isinstance(self.model, nn.DataParallel): 106 | return self.model.module 107 | else: 108 | return self.model 109 | 110 | 111 | def load_checkpoint(model, optimizer, checkpoint_pthpath=None, device='cuda', resume=False): 112 | """Load checkpoint including: 113 | the model, optimizer state_dicts 114 | """ 115 | # load 116 | 117 | if checkpoint_pthpath is not None: 118 | components = torch.load(checkpoint_pthpath, map_location=device) 119 | print("Loaded model from {}".format(checkpoint_pthpath)) 120 | print('At epoch:', components.get('epoch')) 121 | print('Metrics score:', components.get('metrics')) 122 | else: 123 | print("Can't load weight from {}".format(checkpoint_pthpath)) 124 | return 0, model, optimizer 125 | 126 | if resume: 127 | # "path/to/checkpoint_xx.pth" -> xx 128 | print('Resume training....') 129 | start_epoch = components['epoch'] 130 | model = components["model"] 131 | optimizer = components["optimizer"] 132 | return start_epoch, model, optimizer 133 | 134 | else: 135 | model = components["model"] 136 | return 0, model, optimizer 137 | 138 | 139 | def load_checkpoint_from_config(model, optimizer, config): 140 | return load_checkpoint(model, optimizer, 141 | checkpoint_pthpath=config['callbacks']['path_pretrained_ckpt'], 142 | resume=config['callbacks']['resume'], 143 | device='cuda') -------------------------------------------------------------------------------- /visdial/decoders/gen_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from visdial.common import DynamicRNN 4 | 5 | 6 | class GenerativeDecoder(nn.Module): 7 | """This Generative Decoder learn to predict ground-truth answer word-by-word during training 8 | and assign log-likelihood scores to all answer options (candidate answers) during evaluation. 9 | Given `context_vec` and and the sequences of candidate options. 10 | """ 11 | 12 | def __init__(self, config): 13 | super(GenerativeDecoder, self).__init__() 14 | self.config = config 15 | self.test_mode = self.config['model']['test_mode'] 16 | self.text_embedding = nn.Embedding(config['model']['txt_vocab_size'], 17 | config['model']['txt_embedding_size'], 18 | padding_idx=0) 19 | 20 | self.answer_lstm = nn.LSTM( 21 | config['model']['txt_embedding_size'], 22 | config['model']['hidden_size'], 23 | num_layers=2, 24 | batch_first=True, 25 | dropout=config['model']['dropout'] 26 | ) 27 | 28 | if config['model']['txt_has_decoder_layer_norm']: 29 | self.layer_norm = nn.LayerNorm(config['model']['hidden_size']) 30 | 31 | self.lstm_to_words = nn.Linear( 32 | config['model']['hidden_size'], config['model']['txt_vocab_size'] 33 | ) 34 | 35 | # self.dropout = nn.Dropout(p=config['model']['dropout']) 36 | self.logsoftmax = nn.LogSoftmax(dim=-1) 37 | 38 | def forward(self, batch, context_vec, test_mode=False): 39 | """ 40 | Arguments 41 | --------- 42 | batch: Dictionary 43 | This provides a dictionary of inputs. 44 | context_vec: torch.FloatTensor 45 | The context vector summarized from all utilities 46 | Shape [batch_size, NH, hidden_size] 47 | 48 | test_mode: Boolean 49 | Whether the forward is performed on test data 50 | Returns 51 | ------- 52 | output : Dictionary 53 | output['opts_out_scores']: torch.FloatTensor 54 | The output from Generative Decoder (test mode or validation mode) 55 | Shape: [batch_size, NH, num_options] 56 | 57 | output['ans_out_scores']: torch.FloatTensor 58 | The output from Generative Decoder (training mode) 59 | Shape: Shape [batch_size, N, vocab_size] 60 | """ 61 | self.answer_lstm.flatten_parameters() 62 | 63 | if self.training: 64 | # shape: [BS, NH, SEQ] 65 | ans_in = batch["ans_in"] 66 | (BS, NH, SEQ), HS = ans_in.size(), self.config['model']['hidden_size'] 67 | 68 | # shape: [BS x NH, SEQ] 69 | ans_in = ans_in.view(BS * NH, SEQ) 70 | 71 | # shape: [BS x NH, SEQ, WE] 72 | ans_in_embed = self.text_embedding(ans_in) 73 | 74 | # reshape encoder output to be set as initial hidden state of LSTM. 75 | # shape: [lstm_layers, BS x NH, HS] 76 | num_lstm_layers = 2 77 | init_hidden = context_vec.view(1, BS * NH, -1).repeat(num_lstm_layers, 1, 1) 78 | 79 | init_cell = torch.zeros_like(init_hidden) 80 | 81 | # shape: [BS x NH, SEQ, HS] 82 | ans_out, (_, _) = self.answer_lstm(ans_in_embed, (init_hidden, init_cell)) 83 | # ans_out = self.dropout(ans_out) 84 | 85 | # shape: [BS, NH, SEQ, VC] 86 | return {'ans_out_scores': self.lstm_to_words(ans_out).view(BS, NH, SEQ, -1)} 87 | 88 | else: 89 | opts_in = batch["opts_in"] 90 | target_opts_out = batch["opts_out"] 91 | 92 | if self.test_mode or test_mode: 93 | # shape: [BS, NH, NO, SEQ] 94 | opts_in = opts_in[:, batch['num_rounds'] - 1] 95 | # shape: [BS x NH x NO, SEQ] 96 | target_opts_out = batch["opts_out"][:, batch['num_rounds'] - 1] 97 | 98 | BS, NH, NO, SEQ = opts_in.size() 99 | 100 | target_opts_out = target_opts_out.view(BS * NH * NO, -1) 101 | 102 | # shape: [BS x NH x NO, SEQ] 103 | opts_in = opts_in.view(BS * NH * NO, SEQ) 104 | 105 | # shape: [BS x NH x NO, WE] 106 | opts_in_embed = self.text_embedding(opts_in) 107 | 108 | # reshape encoder output to be set as initial hidden state of LSTM. 109 | # shape: [BS, NH, 1, HS] 110 | init_hidden = context_vec.view(BS, NH, 1, -1) 111 | 112 | # shape: [BS, NH, NO, HS] 113 | init_hidden = init_hidden.repeat(1, 1, NO, 1) 114 | 115 | # shape: [1, BS x NH x NO, HS] 116 | init_hidden = init_hidden.view(1, BS * NH * NO, -1) 117 | 118 | num_lstm_layers = 2 119 | # shape: [lstm_layers, BS x NH x NO, HS] 120 | init_hidden = init_hidden.repeat(num_lstm_layers, 1, 1) 121 | 122 | init_cell = torch.zeros_like(init_hidden) 123 | 124 | # shape: [BS x NH x NO, SEQ, HS] 125 | opts_out, (_, _) = self.answer_lstm(opts_in_embed, (init_hidden, init_cell)) 126 | 127 | if self.config['model']['txt_has_decoder_layer_norm']: 128 | opts_out = self.layer_norm(opts_out) 129 | 130 | # shape: [BS x NH x NO, SEQ, VC] 131 | opts_word_scores = self.logsoftmax(self.lstm_to_words(opts_out)) 132 | 133 | # shape: [BS x NH x NO, SEQ] 134 | opts_out_scores = torch.gather(opts_word_scores, -1, target_opts_out.unsqueeze(-1)).squeeze() 135 | # ^ select the scores for target word in [vocab vector] of each word 136 | 137 | # shape: [BS x NH x NO, SEQ] <- remove the word 138 | opts_out_scores = (opts_out_scores * (target_opts_out > 0).float()) 139 | 140 | # sum all the scores for each word in the predicted answer -> final score 141 | # shape: [BS x NH x NO] 142 | opts_out_scores = torch.sum(opts_out_scores, dim=-1) 143 | 144 | # shape: [BS, NH, NO] 145 | opts_out_scores = opts_out_scores.view(BS, NH, NO) 146 | 147 | return {'opts_out_scores': opts_out_scores} 148 | -------------------------------------------------------------------------------- /visdial/optim/adam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | # 8 | # ------------------------------------------------------------------------- 9 | # 10 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License. 22 | 23 | import math 24 | import torch 25 | import torch.optim 26 | 27 | 28 | def get_weight_decay_params(model, weight_decay=1e-5, skip_list=[]): 29 | decay_params = [] 30 | no_decay_params = [] 31 | for name, param in model.named_parameters(): 32 | if not param.requires_grad: 33 | continue 34 | 35 | if len(param.shape) == 1 or name.endswith('.bias') or name in skip_list: 36 | no_decay_params.append(param) 37 | else: 38 | decay_params.append(param) 39 | 40 | return [ 41 | {'params': no_decay_params, 42 | 'weight_decay': 0. 43 | }, 44 | {'params': decay_params, 45 | 'weight_decay': weight_decay 46 | } 47 | ] 48 | 49 | 50 | class Adam(torch.optim.Optimizer): 51 | """Implements Adam algorithm. 52 | This implementation is modified from torch.optim.Adam based on: 53 | `Fixed Weight Decay Regularization in Adam` 54 | (see https://arxiv.org/abs/1711.05101) 55 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 56 | Arguments: 57 | params (iterable): iterable of parameters to optimize or dicts defining 58 | parameter groups 59 | lr (float, optional): learning rate (default: 1e-3) 60 | betas (Tuple[float, float], optional): coefficients used for computing 61 | running averages of gradient and its square (default: (0.9, 0.999)) 62 | eps (float, optional): term added to the denominator to improve 63 | numerical stability (default: 1e-8) 64 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 65 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 66 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 67 | .. _Adam\: A Method for Stochastic Optimization: 68 | https://arxiv.org/abs/1412.6980 69 | .. _On the Convergence of Adam and Beyond: 70 | https://openreview.net/forum?id=ryQu7f-RZ 71 | """ 72 | 73 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.997), eps=1e-9, 74 | weight_decay=1e-6, amsgrad=False): 75 | """ 76 | :param params: 77 | :param lr: 78 | :param betas: 79 | :param eps: 80 | :param weight_decay: 1e-5 -> 1e-6 is the best reported in the paper. 81 | :param amsgrad: 82 | """ 83 | defaults = dict(lr=lr, betas=betas, eps=eps, 84 | weight_decay=weight_decay, amsgrad=amsgrad) 85 | super(Adam, self).__init__(params, defaults) 86 | 87 | def step(self, closure=None): 88 | """Performs a single optimization step. 89 | Arguments: 90 | closure (callable, optional): A closure that reevaluates the model 91 | and returns the loss. 92 | """ 93 | loss = None 94 | if closure is not None: 95 | loss = closure() 96 | 97 | for group in self.param_groups: 98 | for p in group['params']: 99 | if p.grad is None: 100 | continue 101 | grad = p.grad.data 102 | if grad.is_sparse: 103 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 104 | amsgrad = group['amsgrad'] 105 | 106 | state = self.state[p] 107 | 108 | # State initialization 109 | if len(state) == 0: 110 | state['step'] = 0 111 | # Exponential moving average of gradient values 112 | state['exp_avg'] = torch.zeros_like(p.data) 113 | # Exponential moving average of squared gradient values 114 | state['exp_avg_sq'] = torch.zeros_like(p.data) 115 | if amsgrad: 116 | # Maintains max of all exp. moving avg. of sq. grad. values 117 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 118 | 119 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 120 | if amsgrad: 121 | max_exp_avg_sq = state['max_exp_avg_sq'] 122 | beta1, beta2 = group['betas'] 123 | 124 | state['step'] += 1 125 | 126 | # Decay the first and second moment running average coefficient 127 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 128 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 129 | if amsgrad: 130 | # Maintains the maximum of all 2nd moment running avg. till now 131 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 132 | # Use the max. for normalizing running avg. of gradient 133 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 134 | else: 135 | denom = exp_avg_sq.sqrt().add_(group['eps']) 136 | 137 | bias_correction1 = 1 - beta1 ** state['step'] 138 | bias_correction2 = 1 - beta2 ** state['step'] 139 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 140 | 141 | if group['weight_decay'] != 0: 142 | p.data.add_(-group['weight_decay'] * group['lr'], p.data) 143 | 144 | p.data.addcdiv_(-step_size, exp_avg, denom) 145 | return loss 146 | -------------------------------------------------------------------------------- /visdial/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Metric observes output of certain model, for example, in form of logits or 3 | scores, and accumulates a particular metric with reference to some provided 4 | targets. In context of VisDial, we use Recall (@ 1, 5, 10), Mean Rank, Mean 5 | Reciprocal Rank (MRR) and Normalized Discounted Cumulative Gain (NDCG). 6 | 7 | Each ``Metric`` must atleast implement three methods: 8 | - ``observe``, update accumulated metric with currently observed outputs 9 | and targets. 10 | - ``retrieve`` to return the accumulated metric., an optionally reset 11 | internally accumulated metric (this is commonly done between two epochs 12 | after validation). 13 | - ``reset`` to explicitly reset the internally accumulated metric. 14 | 15 | Caveat, if you wish to implement your own class of Metric, make sure you call 16 | ``detach`` on output tensors (like logits), else it will cause memory leaks. 17 | 18 | Credit: 19 | https://github.com/batra-mlp-lab/visdial-challenge-starter-pytorch/blob/master/visdialch/metrics.py 20 | """ 21 | import torch 22 | import pickle 23 | from nltk.tokenize.treebank import TreebankWordDetokenizer 24 | 25 | 26 | def scores_to_ranks(scores: torch.Tensor): 27 | """Convert model output scores into ranks.""" 28 | batch_size, num_rounds, num_options = scores.size() 29 | scores = scores.view(-1, num_options) 30 | 31 | # sort in descending order - largest score gets highest rank 32 | sorted_ranks, ranked_idx = scores.sort(1, descending=True) 33 | 34 | # i-th position in ranked_idx specifies which score shall take this 35 | # position but we want i-th position to have rank of score at that 36 | # position, do this conversion 37 | ranks = ranked_idx.clone().fill_(0) 38 | for i in range(ranked_idx.size(0)): 39 | for j in range(num_options): 40 | ranks[i][ranked_idx[i][j]] = j 41 | # convert from 0-99 ranks to 1-100 ranks 42 | ranks += 1 43 | ranks = ranks.view(batch_size, num_rounds, num_options) 44 | return ranks 45 | 46 | 47 | class SparseGTMetrics(object): 48 | """ 49 | A class to accumulate all metrics with sparse ground truth annotations. 50 | These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank. 51 | """ 52 | 53 | def __init__(self): 54 | self._rank_list = [] 55 | 56 | def observe( 57 | self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor 58 | ): 59 | predicted_scores = predicted_scores.detach() 60 | 61 | # shape: (batch_size, num_rounds, num_options) 62 | predicted_ranks = scores_to_ranks(predicted_scores) 63 | batch_size, num_rounds, num_options = predicted_ranks.size() 64 | 65 | # collapse batch dimension 66 | predicted_ranks = predicted_ranks.view( 67 | batch_size * num_rounds, num_options 68 | ) 69 | 70 | # shape: (batch_size * num_rounds, ) 71 | target_ranks = target_ranks.view(batch_size * num_rounds).long() 72 | 73 | # shape: (batch_size * num_rounds, ) 74 | predicted_gt_ranks = predicted_ranks[ 75 | torch.arange(batch_size * num_rounds), target_ranks 76 | ] 77 | self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy())) 78 | 79 | def retrieve(self, reset: bool = True, key=""): 80 | num_examples = len(self._rank_list) 81 | if num_examples > 0: 82 | # convert to numpy array for easy calculation. 83 | __rank_list = torch.tensor(self._rank_list).float() 84 | metrics = { 85 | key + "r@1": torch.mean((__rank_list <= 1).float()).item(), 86 | key + "r@5": torch.mean((__rank_list <= 5).float()).item(), 87 | key + "r@10": torch.mean((__rank_list <= 10).float()).item(), 88 | key + "mean": torch.mean(__rank_list).item(), 89 | key + "mrr": torch.mean(__rank_list.reciprocal()).item(), 90 | } 91 | else: 92 | metrics = {} 93 | 94 | if reset: 95 | self.reset() 96 | return metrics 97 | 98 | def reset(self): 99 | self._rank_list = [] 100 | 101 | 102 | class NDCG(object): 103 | def __init__(self): 104 | self._ndcg_numerator = 0.0 105 | self._ndcg_denominator = 0.0 106 | 107 | def observe( 108 | self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor 109 | ): 110 | """ 111 | Observe model output scores and target ground truth relevance and 112 | accumulate NDCG metric. 113 | 114 | Parameters 115 | ---------- 116 | predicted_scores: torch.Tensor 117 | A tensor of shape (batch_size, num_options), because dense 118 | annotations are available for 1 randomly picked round out of 10. 119 | target_relevance: torch.Tensor 120 | A tensor of shape same as predicted scores, indicating ground truth 121 | relevance of each answer option for a particular round. 122 | """ 123 | predicted_scores = predicted_scores.detach() 124 | 125 | # shape: (batch_size, 1, num_options) 126 | predicted_scores = predicted_scores.unsqueeze(1) 127 | predicted_ranks = scores_to_ranks(predicted_scores) 128 | 129 | # shape: (batch_size, num_options) 130 | predicted_ranks = predicted_ranks.squeeze(1) 131 | batch_size, num_options = predicted_ranks.size() 132 | 133 | k = torch.sum(target_relevance != 0, dim=-1) 134 | 135 | # shape: (batch_size, num_options) 136 | _, rankings = torch.sort(predicted_ranks, dim=-1) 137 | # Sort relevance in descending order so highest relevance gets top rnk. 138 | _, best_rankings = torch.sort( 139 | target_relevance, dim=-1, descending=True 140 | ) 141 | 142 | # shape: (batch_size, ) 143 | batch_ndcg = [] 144 | for batch_index in range(batch_size): 145 | num_relevant = k[batch_index] 146 | dcg = self._dcg( 147 | rankings[batch_index][:num_relevant], 148 | target_relevance[batch_index], 149 | ) 150 | best_dcg = self._dcg( 151 | best_rankings[batch_index][:num_relevant], 152 | target_relevance[batch_index], 153 | ) 154 | batch_ndcg.append(dcg / best_dcg) 155 | 156 | self._ndcg_denominator += batch_size 157 | self._ndcg_numerator += sum(batch_ndcg) 158 | 159 | def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor): 160 | sorted_relevance = relevance[rankings].cpu().float() 161 | discounts = torch.log2(torch.arange(len(rankings)).float() + 2) 162 | return torch.sum(sorted_relevance / discounts, dim=-1) 163 | 164 | def retrieve(self, reset: bool = True, key=""): 165 | if self._ndcg_denominator > 0: 166 | metrics = { 167 | key + "ndcg": float(self._ndcg_numerator / self._ndcg_denominator) 168 | } 169 | else: 170 | metrics = {} 171 | 172 | if reset: 173 | self.reset() 174 | return metrics 175 | 176 | def reset(self): 177 | self._ndcg_numerator = 0.0 178 | self._ndcg_denominator = 0.0 179 | -------------------------------------------------------------------------------- /visdial/encoders/img_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ImageEncoder(nn.Module): 6 | 7 | def __init__(self, config): 8 | super(ImageEncoder, self).__init__() 9 | 10 | self.config = config 11 | 12 | self.img_linear = nn.Sequential( 13 | nn.Linear(config['model']['img_feat_size'], 14 | config['model']['hidden_size']), 15 | nn.ReLU(inplace=True), 16 | nn.Dropout(p=config['model']['dropout']), 17 | nn.LayerNorm(config['model']['hidden_size']), 18 | ) 19 | 20 | if self.config['model']['img_has_classes'] or \ 21 | self.config['model']['img_has_attributes'] or \ 22 | self.config['model']['img_has_bboxes']: 23 | self.img_norm = nn.LayerNorm(config['model']['hidden_size']) 24 | 25 | self.text_embedding = nn.Embedding(config['model']['txt_vocab_size'], 26 | config['model']['txt_embedding_size']) 27 | 28 | if self.config['model']['img_has_classes']: 29 | self.cls_linear = nn.Sequential( 30 | nn.Linear(config['model']['txt_embedding_size'], 31 | config['model']['hidden_size']), 32 | nn.ReLU(inplace=True), 33 | nn.Dropout(p=config['model']['dropout']), 34 | nn.LayerNorm(config['model']['hidden_size']) 35 | ) 36 | 37 | if self.config['model']['img_has_attributes']: 38 | self.attr_linear = nn.Sequential( 39 | nn.Linear(config['model']['txt_embedding_size'], 40 | config['model']['hidden_size']), 41 | nn.ReLU(inplace=True), 42 | nn.Dropout(p=config['model']['dropout']), 43 | nn.LayerNorm(config['model']['hidden_size']) 44 | ) 45 | 46 | if self.config['model']['img_has_bboxes']: 47 | self.x1_embedding = nn.Embedding(600, config['model']['hidden_size']) 48 | self.x2_embedding = nn.Embedding(600, config['model']['hidden_size']) 49 | 50 | self.x3_embedding = nn.Embedding(600, config['model']['hidden_size']) 51 | self.x4_embedding = nn.Embedding(600, config['model']['hidden_size']) 52 | 53 | self.bbox_linear = nn.Sequential( 54 | nn.Linear(config['model']['hidden_size'], 55 | config['model']['hidden_size']), 56 | nn.ReLU(inplace=True), 57 | nn.Dropout(p=config['model']['dropout']), 58 | nn.LayerNorm(config['model']['hidden_size']) 59 | ) 60 | 61 | def forward(self, batch, test_mode=False): 62 | """ 63 | Arguments 64 | --------- 65 | batch: Dictionary 66 | This provides a dictionary of inputs. 67 | test_mode: Boolean 68 | Whether the forward is performed on test data 69 | Returns 70 | ------- 71 | output : A tuples of the following: 72 | img: torch.FloatTensor 73 | The representation of image utility 74 | Shape [batch_size * NH, K, hidden_size] 75 | img_mask: torch.LongTensor 76 | The mask of image utility 77 | Shape [batch_size * NH, K] 78 | """ 79 | bs, num_hist, _ = batch['ques_tokens'].size() 80 | hidden_size = self.config['model']['hidden_size'] 81 | 82 | if self.config['model']['test_mode'] or test_mode: 83 | num_hist = 1 84 | 85 | # shape: [batch_size, num_proposals, img_feat_size] 86 | img_feat = batch['img_feat'] 87 | 88 | # img_feat: shape [bs, num_proposals, hidden_size] 89 | img_feat = self.img_linear(img_feat) 90 | 91 | # shape [bs * num_hist, num_proposals, hidden_size] 92 | img_feat = img_feat.unsqueeze(1).repeat(1, num_hist, 1, 1) 93 | img_feat = img_feat.view(bs * num_hist, -1, img_feat.size(-1)) 94 | 95 | if batch.get('num_boxes', None) is None: 96 | # shape [bs * num_hist, num_proposals] 97 | img_mask = img_feat.new_ones(img_feat.shape[:-1], dtype=torch.long) 98 | else: 99 | num_boxes = batch['num_boxes'] 100 | 101 | # [bs * num_hist, num_proposals] 102 | img_mask = torch.arange(img_feat.shape[-2], device=img_feat.device) 103 | img_mask = img_mask.repeat(bs * num_hist, 1) 104 | 105 | # [bs * num_hist, 1] 106 | num_boxes = num_boxes[:, None, None].repeat(1, num_hist, 1) 107 | num_boxes = num_boxes.view(bs * num_hist, 1) 108 | 109 | # [bs * num_hist, num_proposals] 110 | img_mask = (img_mask < num_boxes).long() 111 | 112 | if self.config['model']['img_has_classes']: 113 | # [bs, num_proposals] 114 | classes = batch['classes'] 115 | 116 | # [bs * num_hist, num_proposals] 117 | classes = classes.unsqueeze(1).repeat(1, num_hist, 1) 118 | classes = classes.view(bs * num_hist, -1) 119 | 120 | classes = self.text_embedding(classes) 121 | classes = self.cls_linear(classes) 122 | img_feat += classes 123 | 124 | if self.config['model']['img_has_attributes']: 125 | # [bs, num_proposals, num_attrs] 126 | attrs = batch['attrs'] 127 | 128 | # [bs, num_proposals, num_attrs, hidden_size] 129 | attrs = self.text_embedding(attrs) 130 | attrs = self.attr_linear(attrs) 131 | 132 | # [bs, num_proposals, num_attrs] 133 | attr_scores = batch['attr_scores'] 134 | 135 | # [bs, num_proposals, hidden_size] 136 | attrs = torch.matmul(attr_scores.unsqueeze(-2), attrs).squeeze(-2) 137 | 138 | # [bs * num_hist, num_proposals, hidden_size] 139 | attrs = attrs.unsqueeze(1).repeat(1, num_hist, 1, 1) 140 | attrs = attrs.view(bs * num_hist, -1, hidden_size) 141 | 142 | img_feat += attrs 143 | 144 | if self.config['model']['img_has_bboxes']: 145 | # [bs, num_proposals, hidden_size] 146 | w = batch['img_w'].unsqueeze(-1).float() 147 | h = batch['img_h'].unsqueeze(-1).float() 148 | 149 | x1 = (self.bbox_linear(self.x1_embedding((batch['boxes'][:, :, 0] * 600 / w).long()))) 150 | x2 = (self.bbox_linear(self.x2_embedding((batch['boxes'][:, :, 1] * 600 / h).long()))) 151 | x3 = (self.bbox_linear(self.x3_embedding((batch['boxes'][:, :, 2] * 600 / w).long()))) 152 | x4 = (self.bbox_linear(self.x4_embedding((batch['boxes'][:, :, 3] * 600 / h).long()))) 153 | 154 | # [bs, num_proposals, hidden_size] 155 | bboxes = (x1 + x2 + x3 + x4) / 4.0 156 | 157 | # [bs, num_hist, num_proposals, hidden_size] 158 | bboxes = bboxes.unsqueeze(1).repeat(1, num_hist, 1, 1) 159 | 160 | # [bs, num_hist, num_proposals, hidden_size] 161 | bboxes = bboxes.view(bs * num_hist, -1, hidden_size) 162 | 163 | img_feat += bboxes 164 | 165 | if self.config['model']['img_has_classes'] or \ 166 | self.config['model']['img_has_attributes'] or \ 167 | self.config['model']['img_has_bboxes']: 168 | img = self.img_norm(img_feat) 169 | 170 | return img, img_mask 171 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | def get_training_config_and_args(): 6 | parser = get_training_parser() 7 | args = parser.parse_args() 8 | 9 | config = {} 10 | for group in parser._action_groups: 11 | if group.title in ['optional arguments', 'positional arguments']: 12 | group_dict = {arg.dest: getattr(args, arg.dest, None) for arg in group._group_actions} 13 | for key in group_dict: 14 | config[key] = group_dict[key] 15 | print(key) 16 | else: 17 | group_dict = {arg.dest: getattr(args, arg.dest, None) for arg in group._group_actions} 18 | config[group.title] = group_dict 19 | 20 | for dir in [config['callbacks']['log_dir'], config['callbacks']['save_dir']]: 21 | if not os.path.exists(dir): 22 | os.system(f"mkdir -p {dir}") 23 | 24 | return config, args 25 | 26 | 27 | def get_training_parser(): 28 | parser = argparse.ArgumentParser(description='Visual Dialog Toolkit', 29 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 30 | parser.add_argument('--seed', metavar='N', type=int, 31 | default=0) 32 | add_dataset_args(parser) 33 | add_model_args(parser) 34 | add_solver_args(parser) 35 | add_callback_args(parser) 36 | parser.add_argument('--config_name', metavar='S', default='c1.0.0') 37 | return parser 38 | 39 | 40 | def add_dataset_args(parser): 41 | group = parser.add_argument_group("dataset") 42 | group.add_argument('--v0.9', action='store_true', 43 | default=False) 44 | group.add_argument('--overfit', action='store_true', 45 | default=False, 46 | help='overfit on small dataset') 47 | group.add_argument('--concat_hist', action='store_true', 48 | default=False, 49 | help='concat history rounds into a single vector') 50 | group.add_argument('--max_seq_len', type=int, metavar='N', 51 | default=20, 52 | help='max number of tokens in a sequence') 53 | group.add_argument('--vocab_min_count', type=int, metavar='N', 54 | default=5, 55 | help='The word with frequency of 5 times will be listed in Vocabulary') 56 | group.add_argument('--finetune', default=False, action='store_true') 57 | group.add_argument('--is_add_boundaries', default=True, action='store_true') 58 | group.add_argument('--is_return_options', default=True, action='store_true') 59 | group.add_argument('--num_boxes', choices=['fixed', 'adaptive'], 60 | default='fixed', metavar='S', 61 | help='The number of boxes per image from Faster R-CNN') 62 | group.add_argument('--glove_path', metavar='PATH', 63 | default='datasets/glove/embedding_Glove_840_300d.pkl') 64 | group.add_argument('--train_feat_img_path', metavar='PATH', 65 | default='datasets/bottom-up-attention/trainval_resnet101_faster_rcnn_genome_100.h5') 66 | group.add_argument('--val_feat_img_path', metavar='PATH', 67 | default='datasets/bottom-up-attention/val2018_resnet101_faster_rcnn_genome_100.h5') 68 | group.add_argument('--test_feat_img_path', metavar='PATH', 69 | default='datasets/bottom-up-attention/test2018_resnet101_faster_rcnn_genome_100.h5') 70 | group.add_argument('--train_json_dialog_path', metavar='PATH', 71 | default='datasets/annotations/visdial_1.0_train.json') 72 | group.add_argument('--val_json_dialog_path', metavar='PATH', 73 | default='datasets/annotations/visdial_1.0_val.json') 74 | group.add_argument('--test_json_dialog_path', metavar='PATH', 75 | default='datasets/annotations/visdial_1.0_test.json') 76 | group.add_argument('--val_json_dense_dialog_path', metavar='PATH', 77 | default='datasets/annotations/visdial_1.0_val_dense_annotations.json') 78 | group.add_argument('--train_json_word_count_path', metavar='PATH', 79 | default='datasets/annotations/visdial_1.0_word_counts_train.json') 80 | return group 81 | 82 | 83 | def add_solver_args(parser): 84 | group = parser.add_argument_group('solver') 85 | """Adam Optimizer""" 86 | group.add_argument('--optimizer', default='adam', 87 | choices=['sgd', 'adam', 'adamax']) 88 | group.add_argument('--adam_betas', nargs='+', type=float, default=[0.9, 0.997]) 89 | group.add_argument('--adam_eps', type=float, default=1e-9) 90 | group.add_argument('--weight_decay', '--wd', default=1e-5, type=float, metavar='WD', 91 | help='weight decay') 92 | group.add_argument('--clip_norm', default=None, type=float, 93 | metavar='N', 94 | help='clip threshold of gradients') 95 | 96 | """Dataloader""" 97 | group.add_argument('--num_epochs', default=30, type=int, metavar='N', 98 | help='Total number of epochs') 99 | group.add_argument('--batch_size', default=8, type=int, 100 | metavar='N', 101 | help="Batch_size for training") 102 | group.add_argument('--cpu_workers', default=8, type=int) 103 | group.add_argument('--batch_size_multiplier', default=1, type=int, 104 | metavar='N', 105 | help='Cumsum of loss in N batches and update optimizer once') 106 | 107 | """Learning Rate Scheduler""" 108 | group.add_argument('--scheduler_type', default='LinearLR', 109 | help='learning rate scheduler type', 110 | choices=['CosineLR', 'LinearLR', "CosineStepLR"]) 111 | group.add_argument('--init_lr', default=5e-3, type=float, 112 | help='initial learning rate') 113 | group.add_argument('--min_lr', default=1e-5, type=float, metavar='LR', 114 | help='minimum learning rate') 115 | group.add_argument('--num_samples', default=123287, type=int, 116 | help='The number of training samples') 117 | 118 | """Warmup Scheduler""" 119 | group.add_argument('--warmup_factor', default=0.2, type=float, 120 | metavar='N', 121 | help='lr will increase from 0 -> init_lr with warm_factor:' 122 | 'after every batch, lr = lr * warmup_factor') 123 | group.add_argument('--warmup_epochs', default=1, type=int, metavar='N') 124 | 125 | """Linear Scheduler""" 126 | group.add_argument('--linear_gama', default=0.5, type=float, metavar='LG', 127 | help='learning rate shrink factor for step reduce, lr_new = (lr * lr_gama) at milestone step') 128 | group.add_argument('--milestone_steps', nargs='+', type=int, metavar='LS', default=[3, 6, 8, 10, 11], 129 | help='If we use step_lr_scheduler rather than cosine') 130 | group.add_argument('--fp16', default=False, action='store_true') 131 | return group 132 | 133 | 134 | def add_callback_args(parser): 135 | group = parser.add_argument_group('callbacks') 136 | group.add_argument('--resume', default=False, action='store_true') 137 | group.add_argument('--validate', default=True, action='store_true') 138 | group.add_argument('--path_pretrained_ckpt', metavar='DIR', default=None, 139 | help='filename in save-dir from which to load checkpoint, checkpoint_last.pt') 140 | group.add_argument('--save_dir', default='checkpoints/') 141 | group.add_argument('--log_dir', default='checkpoints/tensorboard/') 142 | return group 143 | 144 | 145 | def add_model_args(parser): 146 | group = parser.add_argument_group('model') 147 | group.add_argument('--decoder_type', choices=['misc', 'disc', 'gen'], default='misc', help='Type of decoder') 148 | group.add_argument('--encoder_out', type=str, nargs='+', default=['img', 'ques'], ) 149 | group.add_argument('--hidden_size', type=int, metavar='N', default=512) 150 | group.add_argument('--dropout', type=float, metavar='N', default=0.1) 151 | group.add_argument('--test_mode', action='store_true', default=False) 152 | 153 | """Image Feature""" 154 | group.add_argument('--img_feat_size', type=int, metavar='N', default=2048) 155 | group.add_argument('--img_num_attns', type=int, metavar='N', default=None) 156 | group.add_argument('--img_has_bboxes', action='store_true', default=False) 157 | group.add_argument('--img_has_attributes', action='store_true', default=False) 158 | group.add_argument('--img_has_classes', action='store_true', default=False) 159 | 160 | """Text Feature""" 161 | group.add_argument('--txt_vocab_size', type=int, metavar='N', default=11322) 162 | group.add_argument('--txt_tokenizer', choices=['nlp', 'bert'], default='nlp') 163 | group.add_argument('--txt_bidirectional', action='store_true', default=True) 164 | group.add_argument('--txt_embedding_size', type=int, default=300) 165 | group.add_argument('--txt_has_pos_embedding', action='store_true', default=False) 166 | group.add_argument('--txt_has_layer_norm', action='store_true', default=False) 167 | group.add_argument('--txt_has_decoder_layer_norm', action='store_true', default=False) 168 | 169 | """Cross-Attention""" 170 | group.add_argument('--ca_has_shared_attns', action='store_true', default=False) 171 | group.add_argument('--ca_has_proj_linear', action='store_true', default=False) 172 | group.add_argument('--ca_has_layer_norm', action='store_true', default=False) 173 | group.add_argument('--ca_has_residual', action='store_true', default=False) 174 | group.add_argument('--ca_num_attn_stacks', type=int, metavar='N', default=1) 175 | group.add_argument('--ca_num_attn_heads', type=int, metavar='N', default=4) 176 | group.add_argument('--ca_pad_size', type=int, default=2) 177 | # computing the avg attention maps for further visualization 178 | group.add_argument('--ca_has_avg_attns', action='store_true', default=False) 179 | group.add_argument('--ca_has_self_attns', action='store_true', default=False) 180 | 181 | return group 182 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import csv 4 | import torch 5 | import random 6 | import logging 7 | import numpy as np 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from visdial.model import get_model 11 | from visdial.data.dataset import VisDialDataset 12 | from visdial.metrics import SparseGTMetrics, NDCG 13 | from visdial.utils.checkpointing import CheckpointManager 14 | from visdial.utils import move_to_cuda 15 | from torch.utils.tensorboard import SummaryWriter 16 | from visdial.optim import Adam, LRScheduler 17 | from visdial.loss import FinetuneLoss 18 | import argparse 19 | from tqdm import tqdm 20 | import itertools 21 | 22 | # Load config 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--model_path', default='checkpoints/model_v1.pth') 25 | parser.add_argument('--save_path', default='checkpoints/finetune') 26 | parser.add_argument('--num_epochs', type=int, default=2) 27 | parser.add_argument('--init_lr', type=float, default=1e-4) 28 | parser.add_argument('--scheduler_type', type=str, default='CosineLR') 29 | parser.add_argument('--batch_size', type=int, default=8) 30 | parser.add_argument('--overfit', action="store_true", default=False) 31 | 32 | args = parser.parse_args() 33 | config_path = os.path.expanduser(args.cpath) 34 | model = torch.load(args.model_path) 35 | config = model.encoder.config 36 | 37 | config['dataset']['train_json_dense_dialog_path'] = 'datasets/annotations/visdial_1.0_train_dense_sample.json' 38 | config['dataset']['overfit'] = args.overfit 39 | config['dataset']['finetune'] = True 40 | config['dataset']['evaluate'] = False 41 | config['solver']['num_epochs'] = args.num_epochs 42 | 43 | # For reproducibility 44 | seed = config['seed'] 45 | random.seed(seed) 46 | np.random.seed(seed) 47 | torch.manual_seed(seed) 48 | torch.cuda.manual_seed(seed) 49 | torch.cuda.manual_seed_all(seed) 50 | torch.backends.cudnn.benchmark = False 51 | torch.backends.cudnn.deterministic = True 52 | os.environ['PYTHONHASHSEED'] = str(seed) 53 | 54 | # datasets 55 | print(f"CUDA number: {torch.cuda.device_count()}") 56 | 57 | """DATASET INIT""" 58 | print("Loading dataset...") 59 | val_dataset = VisDialDataset(config, split='val') 60 | 61 | val_dataloader = DataLoader(val_dataset, 62 | batch_size=args.batch_size, 63 | num_workers=config['solver']['cpu_workers'], 64 | shuffle=True) 65 | 66 | if config['dataset']['overfit']: 67 | train_dataset = val_dataset 68 | train_dataloader = val_dataloader 69 | 70 | train_dataset = VisDialDataset(config, split='train') 71 | 72 | train_dataloader = DataLoader(train_dataset, 73 | batch_size=args.batch_size, 74 | num_workers=config['solver']['cpu_workers'], 75 | shuffle=True) 76 | 77 | eval_dataset = VisDialDataset(config, split='val') 78 | 79 | eval_dataloader = DataLoader(eval_dataset, 80 | batch_size=2, 81 | num_workers=config['solver']['cpu_workers']) 82 | 83 | """MODEL INIT""" 84 | 85 | print("Move model to GPU...") 86 | device = torch.device('cuda') 87 | model = model.to(device) 88 | 89 | """LOSS FUNCTION""" 90 | disc_criterion = FinetuneLoss() 91 | 92 | """OPTIMIZER""" 93 | optimizer = Adam(model.parameters(), lr=2e-5) 94 | init_lr = args.init_lr 95 | scheduler_type = args.scheduler_type 96 | num_epochs = args.num_epochs 97 | lr_scheduler = LRScheduler(optimizer, 98 | batch_size=args.batch_size, 99 | num_samples=2064 + 2000, 100 | num_epochs=args.num_epochs, 101 | min_lr=1e-5, 102 | init_lr=args.init_lr, 103 | warmup_epochs=1, 104 | scheduler_type=args.scheduler_type, 105 | milestone_steps=config['solver']['milestone_steps'], 106 | linear_gama=config['solver']['linear_gama'] 107 | ) 108 | 109 | # ============================================================================= 110 | # SETUP BEFORE TRAINING LOOP 111 | # ============================================================================= 112 | 113 | save_path = args.save_path 114 | if not os.path.exists(save_path): 115 | os.makedirs(save_path) 116 | print(save_path) 117 | 118 | summary_writer = SummaryWriter(log_dir=save_path) 119 | checkpoint_manager = CheckpointManager(model, optimizer, save_path, config=config) 120 | sparse_metrics = SparseGTMetrics() 121 | disc_metrics = SparseGTMetrics() 122 | gen_metrics = SparseGTMetrics() 123 | ndcg = NDCG() 124 | disc_ndcg = NDCG() 125 | gen_ndcg = NDCG() 126 | 127 | if torch.cuda.device_count() > 1: 128 | print("NUMBER OF CUDA", torch.cuda.device_count()) 129 | model = nn.DataParallel(model) 130 | 131 | # ============================================================================= 132 | # TRAINING LOOP 133 | # ============================================================================= 134 | config["solver"]["training_splits"] = 'trainval' 135 | 136 | start_epoch = 0 137 | if config["solver"]["training_splits"] == "trainval": 138 | iterations = (len(train_dataset) + len(val_dataset)) // ( 139 | args.batch_size) + 1 140 | num_examples = torch.tensor(len(train_dataset) + len(val_dataset), dtype=torch.float) 141 | else: 142 | iterations = len(train_dataset) // (args.batch_size) + 1 143 | num_examples = torch.tensor(len(train_dataset), dtype=torch.float) 144 | 145 | global_iteration_step = start_epoch * iterations 146 | 147 | for epoch in range(start_epoch, config['solver']['num_epochs']): 148 | print(f"Training for epoch {epoch}:") 149 | 150 | if epoch == 6: 151 | break 152 | 153 | with tqdm(total=iterations) as pbar: 154 | if config["solver"]["training_splits"] == "trainval": 155 | combined_dataloader = itertools.chain(train_dataloader, val_dataloader) 156 | else: 157 | combined_dataloader = itertools.chain(train_dataloader) 158 | 159 | epoch_loss = torch.tensor(0.0) 160 | for i, batch in enumerate(combined_dataloader): 161 | batch = move_to_cuda(batch, device) 162 | 163 | # zero out gradients 164 | lr = lr_scheduler.step(global_iteration_step) 165 | optimizer.zero_grad() 166 | 167 | # do forward 168 | out = model(batch) 169 | 170 | # compute loss 171 | batch_loss = torch.tensor(0.0, requires_grad=True, device='cuda') 172 | if out.get('opt_scores') is not None: 173 | scores = out['opt_scores'] 174 | 175 | sparse_metrics.observe(out['opt_scores'], batch['ans_ind']) 176 | batch_loss = disc_criterion(scores, batch) 177 | 178 | # compute gradients 179 | batch_loss.backward() 180 | 181 | # update params 182 | optimizer.step() 183 | 184 | pbar.update(1) 185 | pbar.set_postfix(epoch=epoch, 186 | batch_loss=batch_loss.item()) 187 | 188 | # log metrics 189 | summary_writer.add_scalar(f'{config["config_name"]}-train/batch_loss', 190 | batch_loss.item(), global_iteration_step) 191 | 192 | # experiment.log_metric('train/lr', lr) 193 | summary_writer.add_scalar("train/batch_lr", lr, global_iteration_step) 194 | 195 | global_iteration_step += 1 196 | torch.cuda.empty_cache() 197 | 198 | epoch_loss += batch["ans"].size(0) * batch_loss.detach() 199 | 200 | if out.get('opt_scores') is not None: 201 | avg_metric_dict = {} 202 | avg_metric_dict.update(sparse_metrics.retrieve(reset=True)) 203 | 204 | for metric_name, metric_value in avg_metric_dict.items(): 205 | print(f"{metric_name}: {metric_value}") 206 | 207 | summary_writer.add_scalars(f"{config['config_name']}-train/metrics", 208 | avg_metric_dict, global_iteration_step) 209 | 210 | epoch_loss /= num_examples 211 | print(f"train/epoch_loss: {epoch_loss.item()}\n") 212 | summary_writer.add_scalar(f'{config["config_name"]}-train/epoch_loss', 213 | epoch_loss.item(), global_iteration_step) 214 | 215 | # ------------------------------------------------------------------------- 216 | # ON EPOCH END (checkpointing and validation) 217 | # ------------------------------------------------------------------------- 218 | # Validate and report automatic metrics. 219 | 220 | if True: 221 | # Switch dropout, batchnorm etc to the correct mode. 222 | model.eval() 223 | 224 | print(f"\nValidation after epoch {epoch}:") 225 | 226 | for batch in tqdm(eval_dataloader): 227 | torch.cuda.empty_cache() 228 | 229 | move_to_cuda(batch, device) 230 | 231 | with torch.no_grad(): 232 | out = model(batch) 233 | 234 | if out.get('opt_scores') is not None: 235 | scores = out['opt_scores'] 236 | disc_metrics.observe(scores, batch["ans_ind"]) 237 | 238 | if "gt_relevance" in batch: 239 | scores = scores[ 240 | torch.arange(scores.size(0)), 241 | batch["round_id"] - 1, :] 242 | 243 | disc_ndcg.observe(scores, batch["gt_relevance"]) 244 | 245 | if out.get('opts_out_scores') is not None: 246 | scores = out['opts_out_scores'] 247 | gen_metrics.observe(scores, batch["ans_ind"]) 248 | 249 | if "gt_relevance" in batch: 250 | scores = scores[ 251 | torch.arange(scores.size(0)), 252 | batch["round_id"] - 1, :] 253 | 254 | gen_ndcg.observe(scores, batch["gt_relevance"]) 255 | 256 | if out.get('opt_scores') is not None and out.get('opts_out_scores') is not None: 257 | scores = (out['opts_out_scores'] + out['opt_scores']) / 2 258 | 259 | sparse_metrics.observe(scores, batch["ans_ind"]) 260 | if "gt_relevance" in batch: 261 | scores = scores[ 262 | torch.arange(scores.size(0)), 263 | batch["round_id"] - 1, :] 264 | 265 | ndcg.observe(scores, batch["gt_relevance"]) 266 | 267 | avg_metric_dict = {} 268 | avg_metric_dict.update(sparse_metrics.retrieve(reset=True, key='avg_')) 269 | avg_metric_dict.update(ndcg.retrieve(reset=True, key='avg_')) 270 | 271 | disc_metric_dict = {} 272 | disc_metric_dict.update(disc_metrics.retrieve(reset=True, key='disc_')) 273 | disc_metric_dict.update(disc_ndcg.retrieve(reset=True, key='disc_')) 274 | 275 | gen_metric_dict = {} 276 | gen_metric_dict.update(gen_metrics.retrieve(reset=True, key='gen_')) 277 | gen_metric_dict.update(gen_ndcg.retrieve(reset=True, key='gen_')) 278 | 279 | for metric_dict in [avg_metric_dict, disc_metric_dict, gen_metric_dict]: 280 | for metric_name, metric_value in metric_dict.items(): 281 | print(f"{metric_name}: {metric_value}") 282 | 283 | summary_writer.add_scalars(f"{config['config_name']}-val/metrics", 284 | metric_dict, global_iteration_step) 285 | 286 | model.train() 287 | torch.cuda.empty_cache() 288 | 289 | # Checkpoint 290 | checkpoint_manager.step(epoch=epoch, only_best=False, 291 | metrics=disc_metric_dict, key='disc_') 292 | if epoch == 5: 293 | break 294 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | from torch import nn 6 | from tqdm import tqdm 7 | from visdial.model import get_model 8 | from torch.utils.data import DataLoader 9 | from visdial.data.dataset import VisDialDataset 10 | from visdial.metrics import SparseGTMetrics, NDCG 11 | from visdial.utils.checkpointing import CheckpointManager, load_checkpoint_from_config 12 | from visdial.utils import move_to_cuda 13 | from visdial.common.utils import check_flag 14 | from options import get_training_config_and_args 15 | from torch.utils.tensorboard import SummaryWriter 16 | from visdial.optim import Adam, LRScheduler, get_weight_decay_params 17 | 18 | config, args = get_training_config_and_args() 19 | 20 | seed = config['seed'] 21 | random.seed(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | torch.backends.cudnn.benchmark = False 27 | torch.backends.cudnn.deterministic = True 28 | os.environ['PYTHONHASHSEED'] = str(seed) 29 | 30 | print(f"CUDA number: {torch.cuda.device_count()}") 31 | 32 | """DATASET INIT""" 33 | print("Loading val dataset...") 34 | val_dataset = VisDialDataset(config, split='val') 35 | 36 | if check_flag(config['dataset'], 'v0.9'): 37 | val_dataset.dense_ann_feat_reader = None 38 | 39 | val_dataloader = DataLoader(val_dataset, 40 | batch_size=config['solver']['batch_size'] / 2 * torch.cuda.device_count(), 41 | num_workers=config['solver']['cpu_workers']) 42 | 43 | print("Loading train dataset...") 44 | if config['dataset']['overfit']: 45 | train_dataset = val_dataset 46 | train_dataloader = val_dataloader 47 | else: 48 | train_dataset = VisDialDataset(config, split='train') 49 | if check_flag(config['dataset'], 'v0.9'): 50 | train_dataset.dense_ann_feat_reader = None 51 | 52 | train_dataloader = DataLoader(train_dataset, 53 | batch_size=config['solver']['batch_size'] * torch.cuda.device_count(), 54 | num_workers=config['solver']['cpu_workers'], 55 | shuffle=True) 56 | 57 | """MODEL INIT""" 58 | print("Init model...") 59 | device = torch.device('cuda') 60 | model = get_model(config) 61 | model = model.to(device) 62 | 63 | """LOSS FUNCTION""" 64 | from visdial.loss import DiscLoss 65 | 66 | disc_criterion = DiscLoss(return_mean=True) 67 | gen_criterion = nn.CrossEntropyLoss(ignore_index=0) 68 | 69 | """OPTIMIZER""" 70 | parameters = get_weight_decay_params(model, weight_decay=config['solver']['weight_decay']) 71 | 72 | optimizer = Adam(parameters, 73 | betas=config['solver']['adam_betas'], 74 | eps=config['solver']['adam_eps'], 75 | weight_decay=config['solver']['weight_decay']) 76 | 77 | lr_scheduler = LRScheduler(optimizer, 78 | batch_size=config['solver']['batch_size'] * torch.cuda.device_count(), 79 | num_samples=config['solver']['num_samples'], 80 | num_epochs=config['solver']['num_epochs'], 81 | min_lr=config['solver']['min_lr'], 82 | init_lr=config['solver']['init_lr'], 83 | warmup_factor=config['solver']['warmup_factor'], 84 | warmup_epochs=config['solver']['warmup_epochs'], 85 | scheduler_type=config['solver']['scheduler_type'], 86 | milestone_steps=config['solver']['milestone_steps'], 87 | linear_gama=config['solver']['linear_gama'] 88 | ) 89 | 90 | # ============================================================================= 91 | # SETUP BEFORE TRAINING LOOP 92 | # ============================================================================= 93 | summary_writer = SummaryWriter(log_dir=config['callbacks']['log_dir']) 94 | 95 | checkpoint_manager = CheckpointManager(model, optimizer, config['callbacks']['save_dir'], config=config) 96 | sparse_metrics = SparseGTMetrics() 97 | disc_metrics = SparseGTMetrics() 98 | gen_metrics = SparseGTMetrics() 99 | ndcg = NDCG() 100 | disc_ndcg = NDCG() 101 | gen_ndcg = NDCG() 102 | 103 | print("Loading checkpoints...") 104 | start_epoch, model, optimizer = load_checkpoint_from_config(model, optimizer, config) 105 | 106 | if torch.cuda.device_count() > 1: 107 | model = nn.DataParallel(model) 108 | 109 | # ============================================================================= 110 | # TRAINING LOOP 111 | # ============================================================================= 112 | 113 | iterations = len(train_dataset) // (config['solver']['batch_size'] * torch.cuda.device_count()) + 1 114 | num_examples = torch.tensor(len(train_dataset), dtype=torch.float) 115 | global_iteration_step = start_epoch * iterations 116 | 117 | for epoch in range(start_epoch, config['solver']['num_epochs']): 118 | print(f"Training for epoch {epoch}:") 119 | print(f"Training for epoch {epoch}:") 120 | if check_flag(config['dataset'], 'v0.9') and epoch > 6: 121 | break 122 | 123 | epoch_loss = torch.tensor(0.0) 124 | for batch in tqdm(train_dataloader, total=iterations, unit="batch"): 125 | batch = move_to_cuda(batch, device) 126 | 127 | # zero out gradients 128 | optimizer.zero_grad() 129 | 130 | # do forward 131 | out = model(batch) 132 | 133 | # compute loss 134 | gen_loss = torch.tensor(0.0, requires_grad=True, device='cuda') 135 | disc_loss = torch.tensor(0.0, requires_grad=True, device='cuda') 136 | batch_loss = torch.tensor(0.0, requires_grad=True, device='cuda') 137 | if out.get('opt_scores') is not None: 138 | scores = out['opt_scores'].view(-1, 100) 139 | target = batch['ans_ind'].view(-1) 140 | 141 | sparse_metrics.observe(out['opt_scores'], batch['ans_ind']) 142 | disc_loss = disc_criterion(scores, target) 143 | batch_loss = batch_loss + disc_loss 144 | 145 | if out.get('ans_out_scores') is not None: 146 | scores = out['ans_out_scores'].view(-1, config['model']['txt_vocab_size']) 147 | target = batch['ans_out'].view(-1) 148 | gen_loss = gen_criterion(scores, target) 149 | batch_loss = batch_loss + gen_loss 150 | 151 | # compute gradients 152 | batch_loss.backward() 153 | 154 | # update params 155 | lr = lr_scheduler.step(global_iteration_step) 156 | optimizer.step() 157 | 158 | # logging 159 | if config['dataset']['overfit']: 160 | print("epoch={:02d}, steps={:03d}K: batch_loss:{:.03f} " 161 | "disc_loss:{:.03f} gen_loss:{:.03f} lr={:.05f}".format( 162 | epoch, int(global_iteration_step / 1000), batch_loss.item(), 163 | disc_loss.item(), gen_loss.item(), lr)) 164 | 165 | if global_iteration_step % 1000 == 0: 166 | print("epoch={:02d}, steps={:03d}K: batch_loss:{:.03f} " 167 | "disc_loss:{:.03f} gen_loss:{:.03f} lr={:.05f}".format( 168 | epoch, int(global_iteration_step / 1000), batch_loss.item(), 169 | disc_loss.item(), gen_loss.item(), lr)) 170 | 171 | summary_writer.add_scalar(config['config_name'] + "-train/batch_loss", 172 | batch_loss.item(), global_iteration_step) 173 | summary_writer.add_scalar("train/batch_lr", lr, global_iteration_step) 174 | 175 | global_iteration_step += 1 176 | torch.cuda.empty_cache() 177 | 178 | epoch_loss += batch["ans"].size(0) * batch_loss.detach() 179 | 180 | if out.get('opt_scores') is not None: 181 | avg_metric_dict = {} 182 | avg_metric_dict.update(sparse_metrics.retrieve(reset=True)) 183 | 184 | summary_writer.add_scalars(config['config_name'] + "-train/metrics", 185 | avg_metric_dict, global_iteration_step) 186 | 187 | for metric_name, metric_value in avg_metric_dict.items(): 188 | print(f"{metric_name}: {metric_value}") 189 | 190 | epoch_loss /= num_examples 191 | summary_writer.add_scalar(config['config_name'] + "-train/epoch_loss", 192 | epoch_loss.item(), global_iteration_step) 193 | 194 | # ------------------------------------------------------------------------- 195 | # ON EPOCH END (checkpointing and validation) 196 | # ------------------------------------------------------------------------- 197 | # Validate and report automatic metrics. 198 | 199 | if config['callbacks']['validate']: 200 | # Switch dropout, batchnorm etc to the correct mode. 201 | model.eval() 202 | 203 | print(f"\nValidation after epoch {epoch}:") 204 | 205 | for batch in val_dataloader: 206 | move_to_cuda(batch, device) 207 | 208 | with torch.no_grad(): 209 | out = model(batch) 210 | 211 | if out.get('opt_scores') is not None: 212 | scores = out['opt_scores'] 213 | disc_metrics.observe(scores, batch["ans_ind"]) 214 | 215 | if "gt_relevance" in batch: 216 | scores = scores[ 217 | torch.arange(scores.size(0)), 218 | batch["round_id"] - 1, :] 219 | 220 | disc_ndcg.observe(scores, batch["gt_relevance"]) 221 | 222 | if out.get('opts_out_scores') is not None: 223 | scores = out['opts_out_scores'] 224 | gen_metrics.observe(scores, batch["ans_ind"]) 225 | 226 | if "gt_relevance" in batch: 227 | scores = scores[ 228 | torch.arange(scores.size(0)), 229 | batch["round_id"] - 1, :] 230 | 231 | gen_ndcg.observe(scores, batch["gt_relevance"]) 232 | 233 | if out.get('opt_scores') is not None and out.get('opts_out_scores') is not None: 234 | scores = (out['opts_out_scores'] + out['opt_scores']) / 2 235 | 236 | sparse_metrics.observe(scores, batch["ans_ind"]) 237 | if "gt_relevance" in batch: 238 | scores = scores[ 239 | torch.arange(scores.size(0)), 240 | batch["round_id"] - 1, :] 241 | 242 | ndcg.observe(scores, batch["gt_relevance"]) 243 | 244 | avg_metric_dict = {} 245 | avg_metric_dict.update(sparse_metrics.retrieve(reset=True, key='avg_')) 246 | avg_metric_dict.update(ndcg.retrieve(reset=True, key='avg_')) 247 | 248 | disc_metric_dict = {} 249 | disc_metric_dict.update(disc_metrics.retrieve(reset=True, key='disc_')) 250 | disc_metric_dict.update(disc_ndcg.retrieve(reset=True, key='disc_')) 251 | 252 | gen_metric_dict = {} 253 | gen_metric_dict.update(gen_metrics.retrieve(reset=True, key='gen_')) 254 | gen_metric_dict.update(gen_ndcg.retrieve(reset=True, key='gen_')) 255 | 256 | for metric_dict in [avg_metric_dict, disc_metric_dict, gen_metric_dict]: 257 | for metric_name, metric_value in metric_dict.items(): 258 | print(f"{metric_name}: {metric_value}") 259 | summary_writer.add_scalars(config['config_name'] + "-val/metrics", 260 | metric_dict, global_iteration_step) 261 | 262 | model.train() 263 | torch.cuda.empty_cache() 264 | 265 | # Checkpoint 266 | if not args.overfit: 267 | if 'disc' in config['model']['decoder_type']: 268 | checkpoint_manager.step(epoch=epoch, only_best=False, metrics=disc_metric_dict, key='disc_') 269 | 270 | elif 'gen' in config['model']['decoder_type']: 271 | checkpoint_manager.step(epoch=epoch, only_best=False, metrics=gen_metric_dict, key='gen_') 272 | 273 | elif 'misc' in config['model']['decoder_type']: 274 | checkpoint_manager.step(epoch=epoch, only_best=False, metrics=disc_metric_dict, key='disc_') 275 | -------------------------------------------------------------------------------- /visdial/encoders/attn_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from visdial.common.utils import clones, check_flag 4 | 5 | 6 | class NormalSubLayer(nn.Module): 7 | """Perform Linear Projection with Dropout and RelU activation inside the for all MHAttn""" 8 | 9 | def __init__(self, hidden_size, dropout): 10 | super(NormalSubLayer, self).__init__() 11 | self.linear = nn.Sequential(nn.Linear(hidden_size * 3, hidden_size), 12 | nn.ReLU(inplace=True), 13 | nn.Dropout(p=dropout)) 14 | 15 | def forward(self, x): 16 | """x: shape [batch_size, M, hidden_size*3]""" 17 | return self.linear(x) 18 | 19 | 20 | class MultiHeadAttention(nn.Module): 21 | """This module perform MultiHeadAttention for 2 utilities X, and Y as follows: 22 | MHA_Y(X) = MHA(X, Y, Y) and 23 | MHA_X(Y) = MHA(Y, X, X). 24 | This can be done with sharing similarity matrix since 25 | X_query = X_key = X_value 26 | Y_query = Y_key = Y_value 27 | Then sim_matrix(X_query, Y_key) = sim_matrix(Y_query, X_key) 28 | Please refer to our paper and supplementary for more details. 29 | """ 30 | 31 | def __init__(self, config): 32 | super().__init__() 33 | self.config = config 34 | 35 | self.hidden_size = config['model']['hidden_size'] 36 | self.num_heads = config['model']['ca_num_attn_heads'] 37 | self.pad_size = config['model']['ca_pad_size'] 38 | self.d_h = self.hidden_size // self.num_heads 39 | 40 | self.pad_x = torch.empty(self.pad_size, self.hidden_size) 41 | self.pad_x = nn.Parameter(nn.init.kaiming_uniform_(self.pad_x)) 42 | self.pad_y = torch.empty(self.pad_size, self.hidden_size) 43 | self.pad_y = nn.Parameter(nn.init.kaiming_uniform_(self.pad_y)) 44 | 45 | self.attn_X_guided_by_Y = None 46 | self.attn_Y_guided_by_X = None 47 | 48 | def project(self, X, pad_x): 49 | """ 50 | Project X into X_query, X_key, X_value (all are X_proj) by 51 | splitting along last indexes mechanically. 52 | Note that: X_query = X_key = X_value = X_proj since W_Q = W_K = W_V. 53 | Arguments 54 | --------- 55 | X: torch.FloatTensor 56 | The input tensor with 57 | Shape [batch_size, M, hidden_size] 58 | pad_x: torch.FloatTensor 59 | The padding vectors we would like to put at the beginning of X 60 | Shape [batch_size, pad_size, hidden_size] 61 | Returns 62 | ------- 63 | X_proj: torch.FloatTensor 64 | The summarized vector of the utility (the context vector for this utility) 65 | Shape [batch_size, M + pad_size, num_heads, d_h] 66 | """ 67 | size = X.size(0), self.pad_size, self.hidden_size 68 | X = torch.cat([pad_x.unsqueeze(0).expand(*size), X], dim=1) 69 | 70 | X_proj = X.view(X.size(0), X.size(1), self.num_heads, self.d_h) 71 | return X_proj 72 | 73 | def forward(self, X, Y, mask_X, mask_Y): 74 | """ 75 | Arguments 76 | --------- 77 | X: torch.FloatTensor 78 | The input tensor of utility X 79 | Shape [batch_size, M, hidden_size] 80 | Y: torch.FloatTensor 81 | The input tensor of utility Y 82 | Shape [batch_size, N, hidden_size] 83 | mask_X: torch.LongTensor 84 | The mask of utility X where 0 denotes 85 | Shape [batch_size, M] 86 | mask_Y: torch.LongTensor 87 | The mask of utility Y where 0 denotes 88 | Shape [batch_size, N] 89 | Returns 90 | ------- 91 | A tuple of two MultiHeadAttention 92 | A_X(Y): torch.FloatTensor 93 | The attention from the source Y to X: Y_attends_in_X 94 | Shape [batch_size, M, hidden_size] 95 | A_Y(X): torch.FloatTensor 96 | The attention from the source X to Y: X_attends_in_Y 97 | Shape [batch_size, N, hidden_size] 98 | """ 99 | pad_mask = X.new_ones((X.size(0), self.pad_size)).long() 100 | 101 | mask_X = torch.cat([pad_mask, mask_X], dim=1) 102 | mask_Y = torch.cat([pad_mask, mask_Y], dim=1) 103 | M_pad, N_pad = mask_X.size(1), mask_Y.size(1) 104 | mask_X = mask_X[:, None, :, None].repeat(1, self.num_heads, 1, N_pad) 105 | mask_Y = mask_Y[:, None, None, :].repeat(1, self.num_heads, M_pad, 1) 106 | 107 | # X_proj: [bs, pad_size + M, num_heads, d_h] 108 | X_proj = self.project(X, self.pad_x) 109 | 110 | # Y_proj [bs, pad_size + N, num_heads, d_h] 111 | Y_proj = self.project(Y, self.pad_y) 112 | 113 | # (1) shape [bs, num_heads, pad_size + M, d_h] 114 | # (2) shape [bs, num_heads, d_h, pad_size + N] 115 | X_proj = X_proj.permute(0, 2, 1, 3) 116 | Y_proj = Y_proj.permute(0, 2, 3, 1) 117 | 118 | """ 119 | Note that: 120 | X_query = X_key = X_value = X_proj, 121 | Y_query = Y_key = Y_value = Y_proj 122 | Then, we have sim_matrix(X_query, Y_key) = sim_matrix(Y_query, X_key) = sim_matrix 123 | """ 124 | # shape: [bs, num_heads, pad_size + M, pad_size + N] 125 | sim_matrix = torch.matmul(X_proj, Y_proj) 126 | sim_matrix = sim_matrix.masked_fill(mask_X == 0, -1e9) 127 | sim_matrix = sim_matrix.masked_fill(mask_Y == 0, -1e9) 128 | 129 | # shape: [bs, num_heads, pad_size + M, pad_size + N] 130 | attn_X_guided_by_Y = torch.softmax(sim_matrix, dim=2) 131 | attn_Y_guided_by_X = torch.softmax(sim_matrix, dim=3) 132 | 133 | # shape [bs, num_heads, pad_size + M, d_h] 134 | X_value = X_proj 135 | # shape [bs, num_heads, pad_size + N, d_h] 136 | X_attends_in_Y = torch.matmul(attn_X_guided_by_Y.transpose(2, 3), X_value) 137 | # shape [bs, num_heads, N, d_h] 138 | X_attends_in_Y = X_attends_in_Y[:, :, self.pad_size:, :] 139 | # shape [bs, N, num_heads, d_h] 140 | X_attends_in_Y = X_attends_in_Y.permute(0, 2, 1, 3).contiguous() 141 | # shape [bs, N, num_heads, hidden_size] 142 | X_attends_in_Y = X_attends_in_Y.view(X_attends_in_Y.size(0), X_attends_in_Y.size(1), -1) 143 | 144 | # shape [bs, num_heads, pad_size + N, d_h] 145 | Y_value = Y_proj.permute(0, 1, 3, 2).contiguous() 146 | # shape [bs, num_heads, pad_size + M, d_h] 147 | Y_attends_in_X = torch.matmul(attn_Y_guided_by_X, Y_value) 148 | # shape [bs, num_heads, M, d_h] 149 | Y_attends_in_X = Y_attends_in_X[:, :, self.pad_size:, :] 150 | # shape [bs, M, num_heads, d_h] 151 | Y_attends_in_X = Y_attends_in_X.permute(0, 2, 1, 3).contiguous() 152 | # shape [bs, M, hidden_size] 153 | Y_attends_in_X = Y_attends_in_X.view(Y_attends_in_X.size(0), Y_attends_in_X.size(1), -1) 154 | 155 | # for later visualization 156 | if self.config['model']['ca_has_avg_attns']: 157 | X_value = X_value.permute(0, 2, 1, 3).contiguous() 158 | # shape [bs, pad_size + M, hidden_size] 159 | X_value = X_value.view(X_value.size(0), X_value.size(1), -1) 160 | Y_value = Y_value.permute(0, 2, 1, 3).contiguous() 161 | # shape [bs, pad_size + N, hidden_size] 162 | Y_value = Y_value.view(Y_value.size(0), Y_value.size(1), -1) 163 | attn_X_guided_by_Y = torch.mean(attn_X_guided_by_Y, dim=1) 164 | attn_Y_guided_by_X = torch.mean(attn_Y_guided_by_X, dim=1) 165 | self.attn_X_guided_by_Y = attn_X_guided_by_Y 166 | self.attn_Y_guided_by_X = attn_Y_guided_by_X 167 | # shape: [bs, pad_size + N, hidden_size] 168 | X_attends_in_Y = torch.matmul(attn_X_guided_by_Y.transpose(1, 2), X_value) 169 | # shape: [bs, pad_size + M, hidden_size] 170 | Y_attends_in_X = torch.matmul(attn_Y_guided_by_X, Y_value) 171 | # shape: [bs, N, hidden_size] 172 | X_attends_in_Y = X_attends_in_Y[:, self.pad_size:, :] 173 | # shape: [bs, M, hidden_size] 174 | Y_attends_in_X = Y_attends_in_X[:, self.pad_size:, :] 175 | return X_attends_in_Y, Y_attends_in_X 176 | 177 | 178 | class AttentionStack(nn.Module): 179 | """ 180 | The Attention Stack include of 3 blocks (i.e. 9 MHAttentions) to compute the 181 | attention from all sources to one target (including itself) 182 | Attention from X -> Y and Y -> X can be wrapped into a single MultiHeadAttention 183 | And self-attention X -> X: can be wrapped into MultiHeadAttention(X, X) 184 | """ 185 | 186 | def __init__(self, config): 187 | super(AttentionStack, self).__init__() 188 | self.config = config 189 | hidden_size = config['model']['hidden_size'] 190 | dropout = config['model']['dropout'] 191 | 192 | self.co_attns = clones(MultiHeadAttention(config), 3) 193 | if check_flag(self.config['model'], 'ca_has_self_attns'): 194 | self.self_attns = clones(MultiHeadAttention(config), 3) 195 | 196 | self.im_mlp = NormalSubLayer(hidden_size, dropout) 197 | self.qe_mlp = NormalSubLayer(hidden_size, dropout) 198 | self.hi_mlp = NormalSubLayer(hidden_size, dropout) 199 | 200 | if self.config['model']['ca_has_layer_norm']: 201 | self.im_norm = nn.LayerNorm(hidden_size) 202 | self.qe_norm = nn.LayerNorm(hidden_size) 203 | self.hi_norm = nn.LayerNorm(hidden_size) 204 | 205 | def forward(self, triples): 206 | """ 207 | Arguments 208 | --------- 209 | triples: A tuple of the following: 210 | im: torch.FloatTensor 211 | The representation of image utility 212 | Shape [batch_size x NH, K, hidden_size] 213 | qe: torch.FloatTensor 214 | The representation of question utility 215 | Shape [batch_size x NH, N, hidden_size] 216 | hi: torch.FloatTensor 217 | The representation of history utility 218 | Shape [batch_size x NH, T, hidden_size] 219 | mask_im: torch.LongTensor 220 | Shape [batch_size x NH, K] 221 | mask_qe: torch.LongTensor 222 | Shape [batch_size x NH, N] 223 | mask_hi: torch.LongTensor 224 | Shape [batch_size x NH, T] 225 | Returns 226 | ------- 227 | output : A tuples of the updated representations of inputs as the triples. 228 | """ 229 | im, qe, hi, mask_im, mask_qe, mask_hi = triples 230 | im_in_qe, qe_in_im = self.co_attns[0](im, qe, mask_im, mask_qe) 231 | im_in_hi, hi_in_im = self.co_attns[1](im, hi, mask_im, mask_hi) 232 | qe_in_hi, hi_in_qe = self.co_attns[2](qe, hi, mask_qe, mask_hi) 233 | 234 | if check_flag(self.config['model'], 'ca_has_self_attns'): 235 | im_in_im, _ = self.self_attns[0](im, im, mask_im, mask_im) 236 | hi_in_hi, _ = self.self_attns[1](hi, hi, mask_hi, mask_hi) 237 | qe_in_qe, _ = self.self_attns[2](qe, qe, mask_qe, mask_qe) 238 | a_im = self.im_mlp(torch.cat([im_in_im, qe_in_im, hi_in_im], dim=-1)) 239 | a_qe = self.qe_mlp(torch.cat([qe_in_qe, hi_in_qe, im_in_qe], dim=-1)) 240 | a_hi = self.hi_mlp(torch.cat([hi_in_hi, qe_in_hi, im_in_hi], dim=-1)) 241 | else: 242 | a_im = self.im_mlp(torch.cat([im, qe_in_im, hi_in_im], dim=-1)) 243 | a_qe = self.qe_mlp(torch.cat([qe, hi_in_qe, im_in_qe], dim=-1)) 244 | a_hi = self.hi_mlp(torch.cat([hi, qe_in_hi, im_in_hi], dim=-1)) 245 | 246 | if self.config['model']['ca_has_residual']: 247 | im = im + a_im 248 | qe = qe + a_qe 249 | hi = hi + a_hi 250 | else: 251 | im = a_im 252 | qe = a_qe 253 | hi = a_hi 254 | 255 | if self.config['model']['ca_has_layer_norm']: 256 | im = self.im_norm(im) 257 | qe = self.qe_norm(qe) 258 | hi = self.hi_norm(hi) 259 | 260 | return im, qe, hi, mask_im, mask_qe, mask_hi 261 | 262 | 263 | class AttentionStackEncoder(nn.Module): 264 | """ 265 | This provide L attention stacks in the encoder 266 | """ 267 | 268 | def __init__(self, config): 269 | super(AttentionStackEncoder, self).__init__() 270 | self.config = config 271 | 272 | num_cross_attns = self.config['model']['ca_num_attn_stacks'] 273 | 274 | # whether to share the attention weights or not 275 | if self.config['model']['ca_has_shared_attns']: 276 | layers = [AttentionStack(config)] * num_cross_attns 277 | else: 278 | layers = [AttentionStack(config) for _ in range(num_cross_attns)] 279 | 280 | self.cross_attn_encoder = nn.Sequential(*layers) 281 | 282 | def forward(self, triples): 283 | return self.cross_attn_encoder(triples) 284 | -------------------------------------------------------------------------------- /visdial/encoders/text_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from visdial.common import PositionalEmbedding, DynamicRNN, check_flag 4 | 5 | 6 | class TextEncoder(nn.Module): 7 | def __init__(self, config, hist_encoder, ques_encoder): 8 | super(TextEncoder, self).__init__() 9 | self.text_embedding = nn.Embedding(config['model']['txt_vocab_size'], 10 | config['model']['txt_embedding_size'], 11 | padding_idx=0) 12 | 13 | self.hist_encoder = hist_encoder 14 | self.ques_encoder = ques_encoder 15 | 16 | def forward(self, batch, test_mode=False): 17 | """ 18 | Arguments 19 | --------- 20 | batch: Dictionary 21 | This provides a dictionary of inputs. 22 | test_mode: Boolean 23 | Whether the forward is performed on test data 24 | Returns 25 | ------- 26 | output : A tuples of the following: 27 | hist: torch.FloatTensor 28 | The representation of history utility 29 | Shape [batch_size * NH, T, hidden_size] 30 | ques: torch.FloatTensor 31 | The representation of question utility 32 | Shape [batch_size * NH, N, hidden_size] 33 | hist_mask: torch.LongTensor 34 | The mask of history utility 35 | Shape [batch_size * NH, T] 36 | ques_mask: torch.LongTensor 37 | The mask of question utility 38 | Shape [batch_size * NH, N] 39 | """ 40 | 41 | # ques_tokens: shape [BS, num_hist, N] 42 | # hist_tokens: shape [BS, num_hist, num_rounds, N] 43 | ques_tokens = batch['ques_tokens'] 44 | hist_tokens = batch['hist_tokens'] 45 | 46 | # ques_len: shape [BS, num_hist] 47 | # hist_len: shape [BS, num_hist, num_rounds] 48 | ques_len = batch['ques_len'] 49 | hist_len = batch['hist_len'] 50 | 51 | # ques: shape [BS, num_hist, N, embedding_size(or hidden_size)] 52 | # hist: shape [BS, num_hist, num_rounds, N, embedding_size(or hidden_size)] 53 | ques = self.text_embedding(ques_tokens) 54 | hist = self.text_embedding(hist_tokens) 55 | 56 | # ques: shape [BS x NH, N, hidden_size] 57 | # hist: shape [BS x NH, num_rounds, hidden_size] 58 | # ques_mask: shape [BS x NH, N] 59 | # hist_mask: shape [BS x NH, num_rounds] 60 | ques, ques_mask = self.ques_encoder(ques, ques_len, test_mode=test_mode) 61 | hist, hist_mask = self.hist_encoder(hist, hist_len, test_mode=test_mode) 62 | return hist, ques, hist_mask, ques_mask 63 | 64 | 65 | class QuesEncoder(nn.Module): 66 | 67 | def __init__(self, config): 68 | super(QuesEncoder, self).__init__() 69 | 70 | self.config = config 71 | 72 | self.ques_linear = nn.Linear(config['model']['hidden_size'] * 2, 73 | config['model']['hidden_size']) 74 | 75 | self.ques_lstm = DynamicRNN(nn.LSTM(config['model']['txt_embedding_size'], 76 | config['model']['hidden_size'], 77 | num_layers=2, 78 | bidirectional=config['model']['txt_bidirectional'], 79 | batch_first=True)) 80 | 81 | if config['model']['txt_has_layer_norm']: 82 | self.layer_norm = nn.LayerNorm(config['model']['hidden_size']) 83 | 84 | if config['model']['txt_has_pos_embedding']: 85 | self.pos_embedding = PositionalEmbedding(config['model']['hidden_size'], 86 | config['dataset']['max_seq_len']) 87 | 88 | self.config = config 89 | 90 | def forward(self, ques, ques_len, test_mode=False): 91 | """ 92 | Arguments 93 | --------- 94 | ques: torch.FloatTensor 95 | The embedding of question tokens 96 | Shape [batch_size, num_hist, N, embedding_size] 97 | test_mode: Boolean 98 | Whether the forward is performed on test data 99 | Returns 100 | ------- 101 | output : A tuples of the following: 102 | ques: torch.FloatTensor 103 | The representation of question utility 104 | Shape [batch_size * num_hist, N, hidden_size] 105 | ques_mask: torch.LongTensor 106 | The mask of question utility 107 | Shape [batch_size x NH, N] 108 | """ 109 | # for test only 110 | if self.config['model']['test_mode'] or test_mode: 111 | # get only the last question 112 | last_idx = (ques_len > 0).sum() 113 | ques = ques[:, last_idx - 1:last_idx] 114 | ques_len = ques_len[:, last_idx - 1:last_idx] 115 | 116 | bs, num_hist, max_seq_len, embedding_size = ques.size() 117 | 118 | # shape [BS x NH, N, hidden_size] 119 | ques = ques.view(bs * num_hist, max_seq_len, embedding_size) 120 | 121 | # shape [BS * num_hist] 122 | ques_len = ques_len.view(bs * num_hist) 123 | 124 | # shape [BS x NH, N] 125 | ques_mask = torch.arange(max_seq_len, device=ques.device).repeat(bs * num_hist, 1) 126 | ques_mask = ques_mask < ques_len.unsqueeze(-1) 127 | 128 | if isinstance(self.ques_lstm, DynamicRNN): 129 | # LSTM 130 | if not self.ques_lstm.bidirectional: 131 | # shape: ques [BS x NH, N, hidden_size] 132 | ques, (_, _) = self.ques_lstm(ques, ques_len) 133 | 134 | # shape: ques [BS x NH, N, hidden_size] 135 | # shape: ques [BS x NH, N,] 136 | return ques, ques_mask.long() 137 | 138 | # BiLSTM 139 | else: 140 | # [BS x NH, SEQ, HS x 2] 141 | ques, (_, _) = self.ques_lstm(ques, ques_len) 142 | 143 | # [BS x NH, SEQ, HS] 144 | ques = self.ques_linear(ques) 145 | if self.config['model']['txt_has_pos_embedding']: 146 | ques = ques + self.pos_embedding(ques) 147 | 148 | if self.config['model']['txt_has_layer_norm']: 149 | ques = self.layer_norm(ques) 150 | 151 | return ques, ques_mask.long() 152 | 153 | 154 | class HistEncoder(nn.Module): 155 | 156 | def __init__(self, config): 157 | super(HistEncoder, self).__init__() 158 | self.config = config 159 | 160 | self.hist_linear = nn.Linear(config['model']['hidden_size'] * 2, 161 | config['model']['hidden_size']) 162 | 163 | self.hist_lstm = DynamicRNN(nn.LSTM(config['model']['txt_embedding_size'], 164 | config['model']['hidden_size'], 165 | num_layers=2, 166 | bidirectional=config['model']['txt_bidirectional'], 167 | batch_first=True)) 168 | 169 | if config['model']['txt_has_layer_norm']: 170 | self.layer_norm = nn.LayerNorm(config['model']['hidden_size']) 171 | 172 | if config['model']['txt_has_pos_embedding']: 173 | self.pos_embedding = PositionalEmbedding(config['model']['hidden_size'], 174 | max_len=10) 175 | 176 | self.config = config 177 | self.hidden_size = self.config['model']['hidden_size'] 178 | 179 | def forward(self, hist, hist_len, test_mode=False): 180 | """ 181 | Arguments 182 | --------- 183 | hist: torch.FloatTensor 184 | The embedding of question tokens 185 | Shape [batch_size, num_hist, T, embedding_size] 186 | hist_len: torch.LongTensor 187 | The length of each dialog history round 188 | test_mode: Boolean 189 | Whether the forward is performed on test data 190 | Returns 191 | ------- 192 | output : A tuples of the following: 193 | hist: torch.FloatTensor 194 | The representation of history utility 195 | Shape [batch_size x NH, T, hidden_size] 196 | hist_mask: torch.LongTensor 197 | The mask of history utility 198 | Shape [batch_size x NH, T] 199 | """ 200 | 201 | bs, num_rounds, max_seq_len, embedding_size = hist.size() 202 | 203 | if self.config['dataset']['concat_hist']: 204 | # for test only 205 | 206 | bs, num_hist, max_seq_len, embedding_size = hist.size() 207 | 208 | # shape [BS x NH, 2N, hidden_size] 209 | hist = hist.view(bs * num_hist, max_seq_len, embedding_size) 210 | 211 | # shape [BS * num_hist] 212 | hist_len = hist_len.view(bs * num_hist) 213 | 214 | # shape [BS x NH, 2N] 215 | hist_mask = torch.arange(max_seq_len, device=hist.device).repeat(bs * num_hist, 1) 216 | hist_mask = hist_mask < hist_len.unsqueeze(-1) 217 | 218 | if isinstance(self.hist_lstm, DynamicRNN): 219 | # [BS x NH, SEQ, HS x 2] 220 | hist, (_, _) = self.hist_lstm(hist, hist_len) 221 | 222 | # [BS x NH, SEQ, HS] 223 | hist = self.hist_mlp(hist) 224 | if self.config['model']['txt_has_pos_embedding']: 225 | hist = hist + self.pos_embedding(hist) 226 | 227 | if self.config['model']['txt_has_layer_norm']: 228 | hist = self.layer_norm(hist) 229 | return (hist, # shape: [BS x NH, T, hidden_size] 230 | hist_mask.long(), # shape [BS x NH, T] 231 | ) 232 | else: 233 | # shape [BS * num_rounds, 2N, hidden_size] 234 | hist = hist.view(bs * num_rounds, max_seq_len, embedding_size) 235 | 236 | # shape [BS * num_rounds] 237 | hist_len = hist_len.view(bs * num_rounds) 238 | 239 | if self.config['model']['test_mode'] or test_mode: 240 | num_hist = 1 241 | round_mask = torch.ones(bs, num_hist, num_rounds, 1, device=hist.device) 242 | hist_mask = torch.ones(bs * num_hist, num_rounds, device=hist.device) 243 | else: 244 | num_hist = 10 245 | # shape [10, 10] 246 | MASK = torch.tensor([ 247 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 248 | [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], 249 | [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], 250 | [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], 251 | [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], 252 | [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], 253 | [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 254 | [1, 1, 1, 1, 1, 1, 1, 1, 0, 0], 255 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], 256 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 257 | ], device=hist.device) 258 | 259 | # shape [BS, NH, T, 1] 260 | round_mask = MASK[None, :, :, None].repeat(bs, 1, 1, 1) 261 | 262 | hist_mask = MASK[None, :, :].repeat(bs, 1, 1).view(bs * num_hist, num_rounds) 263 | 264 | if isinstance(self.hist_lstm, DynamicRNN): 265 | 266 | if not self.hist_lstm.bidirectional: # LSTM 267 | # shape: [num_layers, BS, HS] if Not bidirectional 268 | # shape: hn = [num_layers, bs * num_rounds, hidden_size] 269 | y, (hn, cn) = self.hist_lstm(hist, hist_len) 270 | 271 | # shape: [BS * num_rounds, hidden_size] 272 | hist = hn[-1] 273 | 274 | # shape: [BS, num_rounds, hidden_size] 275 | hist = hist.view(bs, num_rounds, self.hidden_size) 276 | 277 | else: # BiLSTM 278 | 279 | # hn [num_layers x 2 (bidirectional), BS x NR, HS] 280 | y, (hn, cn) = self.hist_lstm(hist, hist_len) 281 | 282 | # shape: [2, BS x NR, HS] 283 | hn = hn[-2:] 284 | # shape: [BS x NR, HS x 2] 285 | hist = torch.cat([hn[0], hn[1]], dim=-1) 286 | 287 | # shape: [BS x NR, HS] 288 | hist = self.hist_linear(hist) 289 | 290 | # shape: [BS, NR, HS] 291 | hist = hist.view(bs, num_rounds, self.hidden_size) 292 | 293 | if self.config['model']['txt_has_pos_embedding']: 294 | hist = hist + self.pos_embedding(hist) 295 | 296 | if self.config['model']['txt_has_layer_norm']: 297 | hist = self.layer_norm(hist) 298 | 299 | # shape: [BS, NH, T, hidden_size] 300 | hist = hist[:, None, :, :].repeat(1, num_hist, 1, 1) 301 | 302 | # shape: [BS, NH, T, hidden_size] 303 | hist = hist.masked_fill(round_mask == 0, 0.0) 304 | 305 | # shape: [BS x NH, T, hidden_size] 306 | hist = hist.view(bs * num_hist, num_rounds, self.hidden_size) 307 | 308 | return (hist, # shape: [BS x NH, T, hidden_size] 309 | hist_mask.long(), # shape [BS x NH, T] 310 | ) 311 | -------------------------------------------------------------------------------- /others/generate_visdial.py: -------------------------------------------------------------------------------- 1 | # This code can be copy-paste to the Jupyter Notebook in the same folder to run. 2 | # set up Python environment: numpy for numerical routines, and matplotlib for plotting 3 | 4 | import os 5 | import sys 6 | import cv2 7 | import h5py 8 | import csv 9 | import pylab 10 | import base64 11 | import argparse 12 | import numpy as np 13 | from glob import glob 14 | from tqdm import tqdm 15 | from skimage import transform 16 | import matplotlib.pyplot as plt 17 | 18 | # set display defaults 19 | plt.rcParams['figure.figsize'] = (20, 12) # small images 20 | plt.rcParams['image.interpolation'] = 'nearest' # don't interpolate: show square pixels 21 | plt.rcParams['image.cmap'] = 'gray' # use grayscale output rather than a (potentially misleading) color heatmap 22 | 23 | # Change dir to caffe root or prototxt database paths won't work wrong 24 | os.chdir('..') 25 | os.getcwd() 26 | 27 | FIRST_STAGE = False 28 | conf_thresh = 0.2 29 | attr_thresh = 0.1 30 | 31 | NUM_BOXES = [(100, 100)] # The min boxes and the max boxes. 32 | FIELDNAMES = ['image_id', 'image_h', 'image_w', 'num_boxes', 'top_attrs', 'top_attrs_scores', 'boxes', 'features', 33 | 'st1_boxes'] 34 | 35 | # The caffe module needs to be on the Python path; 36 | # we'll add it here explicitly. 37 | sys.path.insert(0, './caffe/python/') 38 | sys.path.insert(0, './lib/') 39 | sys.path.insert(0, './tools/') 40 | 41 | import caffe 42 | # Check object extraction 43 | from fast_rcnn.config import cfg, cfg_from_file 44 | from fast_rcnn.test import im_detect, _get_blobs 45 | from fast_rcnn.nms_wrapper import nms 46 | 47 | data_path = './data/genome/1600-400-20' 48 | cfg_from_file('experiments/cfgs/faster_rcnn_end2end_resnet.yml') 49 | 50 | # Load classes 51 | CLASSES = ['__background__'] 52 | with open(os.path.join(data_path, 'objects_vocab.txt')) as f: 53 | for object in f.readlines(): 54 | CLASSES.append(object.split(',')[0].lower().strip()) 55 | 56 | # Load attributes 57 | ATTRIBUTES = ['__no_attribute__'] 58 | with open(os.path.join(data_path, 'attributes_vocab.txt')) as f: 59 | for att in f.readlines(): 60 | ATTRIBUTES.append(att.split(',')[0].lower().strip()) 61 | 62 | GPU_ID = 0 # if we have multiple GPUs, pick one 63 | caffe.set_device(GPU_ID) 64 | caffe.set_mode_gpu() 65 | net = None 66 | 67 | 68 | def get_nms_boxes(st2_scores, st1_boxes): 69 | max_cls_scores = np.zeros((st1_boxes.shape[0])) 70 | max_cls_indices = np.zeros((st1_boxes.shape[0]), dtype=int) 71 | 72 | # Keep only the best class_box for each box (each row in st2_boxes has 1601 boxes for 1 ROI) 73 | for cls_ind in range(1, st2_scores.shape[1]): 74 | cls_ind_scores = st2_scores[:, cls_ind] 75 | dets = np.hstack((st1_boxes, cls_ind_scores[:, np.newaxis])).astype(np.float32) 76 | keep = np.array(nms(dets, cfg.TEST.NMS)) 77 | max_cls_indices[keep] = np.where(cls_ind_scores[keep] > max_cls_scores[keep], 78 | cls_ind, max_cls_indices[keep]) 79 | max_cls_scores[keep] = np.where(cls_ind_scores[keep] > max_cls_scores[keep], 80 | cls_ind_scores[keep], max_cls_scores[keep]) 81 | 82 | return max_cls_scores, max_cls_indices 83 | 84 | 85 | def extract_im(net, img_path): 86 | img = cv2.imread(img_path) 87 | 88 | st2_scores, st2_boxes, st2_attr_scores, st2_rel_scores = im_detect(net, img) 89 | pool5 = net.blobs['pool5_flat'].data 90 | 91 | # unscale back to raw image space 92 | blobs, im_scales = _get_blobs(img, None) 93 | 94 | # Keep the original boxes, don't worry about the regression bbox outputs 95 | rois = net.blobs['rois'].data.copy() 96 | st1_scores = rois[:, 0] 97 | st1_boxes = rois[:, 1:5] / im_scales[0] 98 | 99 | # Keep only the best class_box of each row in st2_boxes has 1601 boxes for 1 ROI 100 | max_cls_scores, max_cls_indices = get_nms_boxes(st2_scores, st1_boxes) 101 | 102 | # For each threshold of boxes, 103 | # save (keep_box_indices, keep_box_cls_indices) 104 | keep_ind = [] 105 | for (min_boxes, max_boxes) in NUM_BOXES: 106 | keep_box_indices = np.where(max_cls_scores >= conf_thresh)[0] 107 | 108 | if len(keep_box_indices) < min_boxes: 109 | keep_box_indices = np.argsort(max_cls_scores)[::-1][:min_boxes] 110 | elif len(keep_box_indices) > max_boxes: 111 | keep_box_indices = np.argsort(max_cls_scores)[::-1][:max_boxes] 112 | 113 | # print("keep_box_indices len", len(keep_box_indices)) 114 | keep_box_cls_indices = max_cls_indices[keep_box_indices] 115 | keep_ind.append((keep_box_indices, keep_box_cls_indices)) 116 | 117 | return { 118 | "image_id": image_id_from_path(img_path), 119 | "image_h": np.size(img, 0), 120 | "image_w": np.size(img, 1), 121 | "keep_ind": keep_ind, 122 | "st2_scores": st2_scores, 123 | "st2_boxes": st2_boxes, 124 | "st2_attr_scores": st2_attr_scores, 125 | "pool5": pool5, 126 | "st1_boxes": st1_boxes 127 | } 128 | 129 | 130 | def get_topN_attrs(st2_attr_scores, keep_box_indices, topN=20): 131 | attrs = st2_attr_scores[keep_box_indices] 132 | 133 | # shape [num_boxes, topN] 134 | top_attrs = np.zeros((attrs.shape[0], topN), dtype=int) 135 | top_attrs_scores = np.zeros((attrs.shape[0], topN)) 136 | 137 | for i, box_attr in enumerate(attrs): 138 | top_attr = np.argsort(box_attr)[::-1][:topN] 139 | top_attr_score = box_attr[top_attr] 140 | top_attrs[i] = top_attr 141 | top_attrs_scores[i] = top_attr_score 142 | return top_attrs, top_attrs_scores 143 | 144 | 145 | def get_topN_attrs(st2_attr_scores, keep_box_indices, topN=20, attr_thresh=0.1): 146 | attrs = st2_attr_scores[keep_box_indices] 147 | 148 | # shape [num_boxes, topN] 149 | top_attrs = np.zeros((attrs.shape[0], topN), dtype=int) 150 | top_attrs_scores = np.zeros((attrs.shape[0], topN)) 151 | 152 | for i, box_attr in enumerate(attrs): 153 | # except __no_attribute__ 154 | top_attr = np.argsort(box_attr[:])[::-1][:topN] 155 | top_attr_score = box_attr[top_attr] 156 | top_attrs[i] = top_attr 157 | top_attrs_scores[i] = top_attr_score 158 | 159 | # No need to add 1. just ATTRIBUTE[attr_idx] 160 | # where ATTRIBUTE[0] = __no_attribute__ 161 | top_attrs = np.where(top_attrs_scores < attr_thresh, 0, top_attrs) 162 | top_attrs_scores = np.where(top_attrs == 0, 0.0, top_attrs_scores) 163 | 164 | return top_attrs, top_attrs_scores 165 | 166 | 167 | def get_cls_boxes(st2_boxes, keep_box_indices, keep_box_cls_indices): 168 | boxes = st2_boxes[keep_box_indices].reshape(-1, 1601, 4) 169 | # shape [K, 4] 170 | final_boxes = np.zeros((boxes.shape[0], 4)) 171 | 172 | for i in range(len(keep_box_cls_indices)): 173 | final_boxes[i] = boxes[i, keep_box_cls_indices[i]] 174 | 175 | return final_boxes 176 | 177 | 178 | def get_cls_indices(st2_scores, keep_box_indices): 179 | # No need to add 1. just CLASSES[attr_idx] where CLASSES[0] = __background__ 180 | # To get the class: CLASSES[cls_index] 181 | return np.argmax(st2_scores[keep_box_indices][:, 1:], axis=1) + 1 182 | 183 | 184 | def load_img_paths(dir_path): 185 | img_paths = glob(os.path.join(dir_path, "*.jpg")) 186 | return img_paths 187 | 188 | 189 | def image_id_from_path(image_path): 190 | """Given a path to an image, return its id. 191 | Parameters 192 | ---------- 193 | image_path : str 194 | Path to image, e.g.: coco_train2014/COCO_train2014/000000123456.jpg 195 | Returns 196 | ------- 197 | int 198 | Corresponding image id (123456) 199 | """ 200 | 201 | return int(image_path.split("/")[-1][-16:-4]) 202 | 203 | 204 | """ 205 | How to convert back to attr and cls to words, please 206 | see each functions and demo.ipynb. 207 | """ 208 | 209 | 210 | class Dataset: 211 | 212 | def __init__(self, num_boxes, num_images, split, topNattr): 213 | self.file_name = args.out_path 214 | 215 | self.save_h5 = h5py.File(self.file_name, 'a') 216 | self.save_h5.attrs['split'] = split 217 | 218 | self.datasets = {} 219 | min_boxes, max_boxes = num_boxes 220 | for field in ['image_id', 'image_h', 'image_w', 'num_boxes']: 221 | self.datasets[field] = self.save_h5.create_dataset(field, (num_images,), dtype='int') 222 | 223 | self.datasets['cls_indices'] = self.save_h5.create_dataset('cls_indices', (num_images, max_boxes), dtype='int') 224 | self.datasets['top_attrs'] = self.save_h5.create_dataset('top_attrs', (num_images, max_boxes, topNattr), 225 | dtype='int') 226 | self.datasets['top_attrs_scores'] = self.save_h5.create_dataset('top_attrs_scores', 227 | (num_images, max_boxes, topNattr), 228 | dtype='float32') 229 | self.datasets['boxes'] = self.save_h5.create_dataset('boxes', (num_images, max_boxes, 4), dtype='float32') 230 | self.datasets['features'] = self.save_h5.create_dataset('features', (num_images, max_boxes, 2048), 231 | dtype='float32') 232 | self.datasets['st1_boxes'] = self.save_h5.create_dataset('st1_boxes', (num_images, max_boxes, 4), 233 | dtype='float32') 234 | 235 | # for f in FIELDNAMES: 236 | # print(f, self.datasets[f].shape) 237 | 238 | self.cur_idx = 0 239 | 240 | def update(self, out): 241 | for key in out: 242 | if isinstance(out[key], np.ndarray): 243 | pass 244 | # print("key", key, out[key].shape) 245 | self.datasets[key][self.cur_idx] = out[key] 246 | self.cur_idx += 1 247 | self.save_h5.attrs['cur_idx'] = self.cur_idx 248 | 249 | def close(self): 250 | self.save_h5.close() 251 | 252 | 253 | def pad(x, max_boxes): 254 | """ 255 | :param x: [K, N] 256 | :param max_boxes: K needs to pad up to max_boxes 257 | :return: [max_boxes, N] 258 | """ 259 | num_boxes = x.shape[0] 260 | if num_boxes == max_boxes: 261 | return x 262 | if len(x.shape) == 1: 263 | return np.pad(x, (0, max_boxes - num_boxes), 'constant', constant_values=0) 264 | else: 265 | return np.pad(x, ((0, max_boxes - num_boxes), (0, 0)), 'constant', constant_values=0) 266 | 267 | 268 | if __name__ == '__main__': 269 | parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network') 270 | parser.add_argument('--split', default="demo") 271 | parser.add_argument('--topNattr', default=20, type=int) 272 | parser.add_argument('--data_path', default="data/demo") 273 | parser.add_argument('--num_images', default=30, type=int) 274 | parser.add_argument('--out_path', default=None) 275 | parser.add_argument('--prototxt', default='models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt') 276 | parser.add_argument('--weights') 277 | args = parser.parse_args() 278 | 279 | # Load model 280 | net = caffe.Net(args.prototxt, caffe.TEST, weights=args.weights) 281 | 282 | img_paths = load_img_paths(args.data_path) 283 | print("number of images", len(img_paths)) 284 | datasets = [Dataset(num_boxes=num_boxes, num_images=args.num_images, split=args.split, topNattr=args.topNattr) for 285 | num_boxes in NUM_BOXES] 286 | 287 | with tqdm(total=len(img_paths)) as pbar: 288 | for img_path in img_paths: 289 | caffe.set_mode_gpu() 290 | caffe.set_device(0) 291 | preds = extract_im(net, img_path) 292 | 293 | for i, (keep_box_indices, keep_box_cls_indices) in enumerate(preds["keep_ind"]): 294 | # shape [K, topN] 295 | # shape [K, topN] 296 | top_attrs, top_attrs_scores = get_topN_attrs(preds["st2_attr_scores"], 297 | keep_box_indices, 298 | topN=20, attr_thresh=attr_thresh) 299 | # shape [K, 4] 300 | boxes = get_cls_boxes(preds["st2_boxes"], 301 | keep_box_indices=keep_box_indices, 302 | keep_box_cls_indices=keep_box_cls_indices) 303 | 304 | # shape [K, 2048] 305 | features = preds["pool5"][keep_box_indices] 306 | 307 | # shape [K, 4] 308 | st1_boxes = preds["st1_boxes"][keep_box_indices] 309 | 310 | # shape [K, ] 311 | cls_indices = get_cls_indices(preds['st2_scores'], keep_box_indices) 312 | 313 | min_boxes, max_boxes = NUM_BOXES[i] 314 | num_boxes = len(keep_box_indices) 315 | out = { 316 | "image_id": preds["image_id"], 317 | "image_h": preds["image_h"], 318 | "image_w": preds["image_w"], 319 | "num_boxes": num_boxes, 320 | "cls_indices": pad(cls_indices, max_boxes), 321 | "top_attrs": pad(top_attrs, max_boxes), 322 | "top_attrs_scores": pad(top_attrs_scores, max_boxes), 323 | "boxes": pad(boxes, max_boxes), 324 | "features": pad(features, max_boxes), 325 | "st1_boxes": pad(st1_boxes, max_boxes), 326 | } 327 | datasets[i].update(out) 328 | 329 | pbar.update(1) 330 | 331 | for i in range(len(datasets)): 332 | datasets[i].close() 333 | -------------------------------------------------------------------------------- /visdial/data/readers.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Reader simply reads data from disk and returns it almost as is, based on 3 | a "primary key", which for the case of VisDial v1.0 dataset, is the 4 | ``image_id``. Readers should be utilized by torch ``Dataset``s. Any type of 5 | data pre-processing is not recommended in the reader, such as tokenizing words 6 | to integers, embedding tokens, or passing an image through a pre-trained CNN. 7 | 8 | Each reader must atleast implement three methods: 9 | - ``__len__`` to return the length of data this Reader can read. 10 | - ``__getitem__`` to return data based on ``image_id`` in VisDial v1.0 11 | dataset. 12 | - ``keys`` to return a list of possible ``image_id``s this Reader can 13 | provide data of. 14 | 15 | Credit: 16 | https://github.com/batra-mlp-lab/visdial-challenge-starter-pytorch/blob/master/visdialch/data/readers.py 17 | With some modification in For ImageReader 18 | """ 19 | import nltk 20 | import copy 21 | import json 22 | from typing import Dict, List, Union 23 | 24 | import h5py 25 | 26 | # A bit slow, and just splits sentences to list of words, can be doable in 27 | # `DialogsReader`. 28 | from nltk.tokenize import word_tokenize 29 | # from pytorch_pretrained_bert import BertTokenizer 30 | from tqdm import tqdm 31 | import os 32 | from nltk.tokenize import word_tokenize 33 | 34 | 35 | class DialogsReader(object): 36 | """ 37 | A simple reader for VisDial v1.0 dialog data. The json file must have the 38 | same structure as mentioned on ``https://visualdialog.org/data``. 39 | 40 | Parameters 41 | ---------- 42 | dialogs_jsonpath : str 43 | Path to json file containing VisDial v1.0 train, val or test data. 44 | """ 45 | 46 | def __init__(self, config, split='train'): 47 | if config['model']['txt_tokenizer'] == 'nlp': 48 | self.tokenize = word_tokenize 49 | elif config['model']['txt_tokenizer'] == 'bert': 50 | pass 51 | # self.tokenize = BertTokenizer.from_pretrained('bert-base-uncased').tokenize 52 | 53 | self.config = config 54 | 55 | path_json_dialogs = config['dataset'][f"{split}_json_dialog_path"] 56 | path_json_dialogs = os.path.expanduser(path_json_dialogs) 57 | 58 | with open(path_json_dialogs, "r") as visdial_file: 59 | visdial_data = json.load(visdial_file) 60 | self._split = visdial_data["split"] 61 | 62 | self.questions = visdial_data["data"]["questions"] 63 | self.answers = visdial_data["data"]["answers"] 64 | 65 | # Add empty question, answer at the end, useful for padding dialog 66 | # rounds for test. 67 | self.questions.append("") 68 | self.answers.append("") 69 | 70 | # Image_id serves as key for all three dicts here. 71 | self.captions = {} 72 | self.dialogs = {} 73 | self.num_rounds = {} 74 | 75 | for dialog_for_image in visdial_data["data"]["dialogs"]: 76 | self.captions[dialog_for_image["image_id"]] = dialog_for_image[ 77 | "caption" 78 | ] 79 | 80 | # Record original length of dialog, before padding. 81 | # 10 for train and val splits, 10 or less for test split. 82 | self.num_rounds[dialog_for_image["image_id"]] = len( 83 | dialog_for_image["dialog"] 84 | ) 85 | 86 | # Pad dialog at the end with empty question and answer pairs 87 | # (for test split). 88 | while len(dialog_for_image["dialog"]) < 10: 89 | dialog_for_image["dialog"].append( 90 | {"question": -1, "answer": -1} 91 | ) 92 | 93 | # Add empty answer /answer options if not provided 94 | # (for test split). 95 | for i in range(len(dialog_for_image["dialog"])): 96 | if "answer" not in dialog_for_image["dialog"][i]: 97 | dialog_for_image["dialog"][i]["answer"] = -1 98 | if "answer_options" not in dialog_for_image["dialog"][i]: 99 | dialog_for_image["dialog"][i]["answer_options"] = [-1] * 100 100 | 101 | self.dialogs[dialog_for_image["image_id"]] = dialog_for_image["dialog"] 102 | 103 | print(f"[{self._split}] Tokenizing questions...") 104 | for i in range(len(self.questions)): 105 | # print('len(self.questions[i])', len(self.questions[i])) 106 | self.questions[i] = self.questions[i] + "?" if len(self.questions[i]) > 0 else self.questions[i] 107 | self.questions[i] = self.do_tokenize(self.questions[i]) 108 | 109 | print(f"[{self._split}] Tokenizing answers...") 110 | for i in range(len(self.answers)): 111 | self.answers[i] = self.do_tokenize(self.answers[i]) 112 | 113 | print(f"[{self._split}] Tokenizing captions...") 114 | for image_id, caption in self.captions.items(): 115 | self.captions[image_id] = self.do_tokenize(caption) 116 | 117 | if config['model']['txt_tokenizer'] == 'bert': 118 | path_feat_questions = config['dataset'][split]['path_feat_questions'] 119 | path_feat_history = config['dataset'][split]['path_feat_history'] 120 | self.question_reader = QuestionFeatureReader(path_feat_questions) 121 | self.history_reader = HistoryFeatureReader(path_feat_history) 122 | 123 | def __len__(self): 124 | return len(self.dialogs) 125 | 126 | def __getitem__(self, image_id: int) -> Dict[str, Union[int, str, List]]: 127 | caption_for_image = self.captions[image_id] 128 | dialog_for_image = copy.deepcopy(self.dialogs[image_id]) 129 | num_rounds = self.num_rounds[image_id] 130 | 131 | # Replace question and answer indices with actual word tokens. 132 | dialog_for_image = self.replace_ids_by_tokens( 133 | dialog_for_image, 134 | keys=['question', 'answer', 'answer_options']) 135 | 136 | item = { 137 | 'image_id': image_id, 138 | 'num_rounds': num_rounds, # 10 139 | "caption": caption_for_image, 140 | "dialog": dialog_for_image, 141 | } 142 | 143 | if self.config['model']['txt_tokenizer'] == 'bert': 144 | # Replace question and answer indices with actual word tokens. 145 | dialog_for_image = self.replace_ids_by_tokens( 146 | dialog_for_image, 147 | keys=['answer', 'answer_options']) 148 | 149 | ques_feats = [] 150 | ques_masks = [] 151 | for i in range(len(dialog_for_image)): 152 | ques = self.question_reader[dialog_for_image[i]['question']] 153 | question_feature, question_mask = ques 154 | ques_feats.append(question_feature) 155 | ques_masks.append(question_mask) 156 | 157 | hist_feats = self.history_reader[image_id] 158 | 159 | bert_return = { 160 | 'ques_feats': ques_feats, # shape [10, 23, 768] 161 | 'ques_masks': ques_masks, # shape [10, 23] 162 | 'hist_feats': hist_feats, # shape [11, 768] 163 | 'dialog': dialog_for_image 164 | } 165 | item.update(bert_return) 166 | 167 | return item 168 | 169 | def do_tokenize(self, text): 170 | tokenized_text = self.tokenize(text) 171 | return tokenized_text 172 | 173 | def replace_ids_by_tokens(self, 174 | dialog_for_image, 175 | keys=['question', 'answer', 'answer_options']): 176 | for dialog_round in dialog_for_image: 177 | for key in keys: 178 | if key == 'answer_options': 179 | for i, ans_opt in enumerate(dialog_round[key]): 180 | dialog_round[key][i] = self.answers[ans_opt] 181 | elif key == 'answer': 182 | dialog_round[key] = self.answers[dialog_round[key]] 183 | elif key == 'question': 184 | dialog_round[key] = self.questions[dialog_round[key]] 185 | 186 | return dialog_for_image 187 | 188 | def keys(self) -> List[int]: 189 | return list(self.dialogs.keys()) 190 | 191 | @property 192 | def split(self): 193 | return self._split 194 | 195 | 196 | class DenseAnnotationsReader(object): 197 | """ 198 | A reader for dense annotations for val split. The json file must have the 199 | same structure as mentioned on ``https://visualdialog.org/data``. 200 | 201 | Parameters 202 | ---------- 203 | dense_annotations_jsonpath : str 204 | Path to a json file containing VisDial v1.0 205 | """ 206 | 207 | def __init__(self, dense_annotations_jsonpath: str): 208 | if '~' in dense_annotations_jsonpath: 209 | dense_annotations_jsonpath = os.path.expanduser(dense_annotations_jsonpath) 210 | 211 | with open(dense_annotations_jsonpath, "r") as visdial_file: 212 | self._visdial_data = json.load(visdial_file) 213 | self._image_ids = [ 214 | entry["image_id"] for entry in self._visdial_data 215 | ] 216 | 217 | def __len__(self): 218 | return len(self._image_ids) 219 | 220 | def __getitem__(self, image_id: int) -> Dict[str, Union[int, List]]: 221 | index = self._image_ids.index(image_id) 222 | # keys: {"image_id", "round_id", "gt_relevance"} 223 | return self._visdial_data[index] 224 | 225 | @property 226 | def split(self): 227 | # always 228 | return "val" 229 | 230 | 231 | class HistoryFeatureReader(object): 232 | def __init__(self, path_hdf_hist): 233 | self.path_hdf_hist = path_hdf_hist 234 | with h5py.File(path_hdf_hist, 'r') as features_hdf: 235 | self.image_ids = list(features_hdf['image_ids']) 236 | 237 | def __len__(self): 238 | return len(self.image_ids) 239 | 240 | def __getitem__(self, image_id): 241 | index = self.image_ids.index(image_id) 242 | with h5py.File(self.path_hdf_hist, 'r') as features_hdf: 243 | history_feature = features_hdf['features'][index] 244 | 245 | return history_feature 246 | 247 | 248 | class QuestionFeatureReader(object): 249 | 250 | def __init__(self, path_hdf_ques): 251 | self.path_hdf_ques = path_hdf_ques 252 | with h5py.File(self.path_hdf_ques, 'r') as hdf: 253 | self.num_questions = hdf.attrs['num_questions'] 254 | 255 | def __len__(self): 256 | return self.num_questions 257 | 258 | def __getitem__(self, question_id): 259 | if not os.path.isfile(self.path_hdf_ques): 260 | return None 261 | 262 | with h5py.File(self.path_hdf_ques, 'r') as hdf: 263 | question_feature = hdf['features'][question_id] 264 | question_mask = hdf['masks'][question_id] 265 | return question_feature, question_mask 266 | 267 | 268 | class ImageFeaturesHdfReader: 269 | """ 270 | Parameters 271 | ---------- 272 | features_hdfpath : str 273 | Path to an HDF file containing VisDial v1.0 train, val or test split 274 | image features. 275 | in_memory : bool 276 | Whether to load the whole HDF file in memory. Beware, these files are 277 | sometimes tens of GBs in size. Set this to true if you have sufficient 278 | RAM - trade-off between speed and memory. 279 | 280 | ['boxes', 281 | 'cls_indices', 282 | 'features', 283 | 'image_h', 284 | 'image_id', 285 | 'image_w', 286 | 'num_boxes', 287 | 'st1_boxes', 288 | 'top_attrs', 289 | 'top_attrs_scores'] 290 | """ 291 | 292 | def __init__(self, 293 | hdf_path, 294 | genome_path='datasets/genome/1600-400-20'): 295 | 296 | self.hdf_path = hdf_path 297 | genome_path = os.path.expanduser(genome_path) 298 | self.genome_path = genome_path 299 | with h5py.File(self.hdf_path, "r") as hdf: 300 | self._split = hdf.attrs["split"] 301 | self.image_id_list = list(hdf["image_id"]) 302 | self.dataset_names = list(hdf.keys()) 303 | 304 | # Load classes 305 | self.CLASSES = ['__background__'] 306 | with open(os.path.join(genome_path, 'objects_vocab.txt')) as f: 307 | for object in f.readlines(): 308 | self.CLASSES.append(object.split(',')[0].lower().strip()) 309 | 310 | # Load attributes 311 | self.ATTRIBUTES = ['__no_attribute__'] 312 | with open(os.path.join(genome_path, 'attributes_vocab.txt')) as f: 313 | for att in f.readlines(): 314 | self.ATTRIBUTES.append(att.split(',')[0].lower().strip()) 315 | 316 | def __len__(self): 317 | return len(self.image_id_list) 318 | 319 | def __getitem__(self, image_id: int): 320 | index = self.image_id_list.index(image_id) 321 | 322 | output = {} 323 | 324 | with h5py.File(self.hdf_path, 'r') as hdf: 325 | for name in self.dataset_names: 326 | if name not in ['st1_boxes']: 327 | output[name] = hdf[name][index] 328 | 329 | if 'cls_indices' in self.dataset_names: 330 | output['cls_names'] = self.get_cls_name(output['cls_indices']) 331 | if 'top_attrs' in self.dataset_names: 332 | output['top_attr_names'] = self.get_attr_names(output['top_attrs']) 333 | return output 334 | 335 | def get_cls_name(self, cls_indices): 336 | cls_names = [] 337 | for cls_idx in cls_indices: 338 | cls_names.append(self.CLASSES[cls_idx]) 339 | return cls_names 340 | 341 | def get_attr_names(self, top_attrs): 342 | top_word_attrs = [] 343 | for row_attrs in top_attrs: 344 | box_word_attrs = [] 345 | for attr_idx in row_attrs: 346 | box_word_attrs.append(self.ATTRIBUTES[attr_idx]) 347 | top_word_attrs.append(box_word_attrs) 348 | return top_word_attrs 349 | 350 | def keys(self) -> List[int]: 351 | return self.image_id_list 352 | 353 | @property 354 | def split(self): 355 | return self._split 356 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | This is the code implementation for the paper titled: "**Efficient Attention Mechanism for Visual Dialog that can Handle All the Interactions between Multiple Inputs**" 3 | (**Accepted to ECCV 2020**). 4 | 5 | Table of content 6 | ---------------------- 7 | * [Setup and Environment](#setup-and-environment) 8 | * [Download Data](#download-data) 9 | * [Training](#training) 10 | * [Evaluation](#evaluation) 11 | * [Result of Checkpoints](#result-of-checkpoints) 12 | * [Acknowledgements](#acknowledgements) 13 | 14 | If you find this code useful or use our method as the baseline for comparison, please kindly cite the paper with the following bibtex or the plain citation: 15 | 16 | ``` 17 | @inproceedings{nguyen2020efficient, 18 | title={Efficient attention mechanism for visual dialog that can handle all the interactions between multiple inputs}, 19 | author={Nguyen, Van-Quang and Suganuma, Masanori and Okatani, Takayuki}, 20 | booktitle={Computer Vision--ECCV 2020: 16th European Conference, Glasgow, UK, August 23--28, 2020, Proceedings, Part XXIV 16}, 21 | pages={223--240}, 22 | year={2020}, 23 | organization={Springer} 24 | } 25 | 26 | or as simply plain as: 27 | 28 | Nguyen, Van-Quang, Masanori Suganuma, and Takayuki Okatani. "Efficient attention mechanism for visual dialog that can handle all the interactions between multiple inputs." Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part XXIV 16. Springer International Publishing, 2020. 29 | 30 | ``` 31 | ![](https://imgur.com/download/YRhZk2M) 32 | 33 | 34 | Setup and Environment 35 | ---------------------- 36 | This code is implemented using the following environment configurations: 37 | 38 | Component | Details | 39 | ------- | ------ | 40 | Pytorch | version 1.2 | 41 | Python | version 3.7 | 42 | GPU | Tesla V100-SXM2 (16GB) | 43 | No. of GPUs | 4 | 44 | CUDA | 10.0 | 45 | GPU Driver | 410.104 | 46 | RAM | 376GB | 47 | CPU | Xeon(R) Gold 6148 CPU @ 2.40GHz| 48 | 49 | 50 | To set up the environment, we recommend you to set up a virtual environment using Anaconda. 51 | 52 | 1. Install Anaconda or Miniconda distribution based on Python3+ from their [downloads' site](https://www.anaconda.com/distribution/). 53 | 2. Clone this repository and create an environment 54 | 3. Install all the dependencies 55 | 56 | ```sh 57 | conda create -n visdial python=3.7 58 | 59 | # activate the environment and install all dependencies 60 | conda activate visdial 61 | 62 | # Install the dependencies 63 | export PROJ_ROOT='/path/to/visualdialog/' 64 | cd $PROJ_ROOT/ 65 | pip install -r requirements.txt 66 | ``` 67 | 68 | Download Data 69 | ------------- 70 | 71 | 1. Download the following `json` files for VisDial v1.0 and put them in `$PRO_ROOT/dataset/annotations/`: 72 | * For training set [here](https://www.dropbox.com/s/ix8keeudqrd8hn8/visdial_1.0_train.zip?dl=0). 73 | * For validation set [here](https://www.dropbox.com/s/ibs3a0zhw74zisc/visdial_1.0_val.zip?dl=0) 74 | as well as the [Dense answer annotations](https://www.dropbox.com/s/3knyk09ko4xekmc/visdial_1.0_val_dense_annotations.json?dl=0). 75 | * For test set [here](https://www.dropbox.com/s/o7mucbre2zm7i5n/visdial_1.0_test.zip?dl=0). 76 | 77 | 78 | 2. Download the following `json` files for VisDial v0.9 and also put them in `$PROJ_ROOT/dataset/annotations/`: 79 | * For training set [here](https://s3.amazonaws.com/visual-dialog/v0.9/visdial_0.9_train.zip). 80 | * For validation set [here](https://s3.amazonaws.com/visual-dialog/v0.9/visdial_0.9_val.zip). 81 | 82 | 83 | 2. Get the word counts for VisDial v1.0 train split [here](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/visdial_1.0_word_counts_train.json) and put it in `$PROJ_ROOT/dataset/annotations/`. They are used to build the vocabulary. 84 | 85 | 3. Get the image features. We use the extracted features for VisDial v1.0 images using a Faster-RCNN pre-trained on Visual Genome. 86 | 87 | * First download images for training set of Visdial v1.0 from COCO train2014 and val2014, which are available [here](http://cocodataset.org/#download) and also download the images for validation and test sets of Visdial v1.0 from [here](https://www.dropbox.com/s/twmtutniktom7tu/VisualDialog_val2018.zip?dl=0) 88 | and [here](https://www.dropbox.com/s/mwlrg31hx0430mt/VisualDialog_test2018.zip?dl=0). 89 | * Then follow the instruction [here](https://github.com/peteanderson80/bottom-up-attention) to extract the bottom-up-attention features for images based on the pretrained Faster-RCNN: 90 | * First, clone the code provided by the authors at https://github.com/peteanderson80/bottom-up-attention. 91 | * Second, setup the environment as [here](https://github.com/peteanderson80/bottom-up-attention). 92 | * Then, extract the features as mentioned in our paper. We provide our code for extraction; please copy the code `$PROJ_ROOT/others/generate_visdial.py` from our project to `bottom-up-attention/tools`. 93 | * Run the following command to extract: 94 | ```sh 95 | # Estimate 10 hours 96 | # Extract the image features for the training split 97 | /usr/bin/python generate_visdial.py \ 98 | --split "train" \ 99 | --topNattr 20 \ 100 | --num_images 123287 \ 101 | --data_path '/path_to_the_image_dir/trainval2014' \ 102 | --out_path '$PROJ_ROOT/datasets/bottom-up-attention/trainval_resnet101_faster_rcnn_genome_num_boxes_100.h5' \ 103 | --prototxt 'models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt' \ 104 | --weights '/path_to_bottom_up_attention_checkpoints/bottom-up-attention/resnet101_faster_rcnn_final.caffemodel' 105 | 106 | # Estimate 35 minutes 107 | # Extract the image features for the validation split 108 | /usr/bin/python generate_visdial.py \ 109 | --split "val" \ 110 | --topNattr 20 \ 111 | --num_images 2064 \ 112 | --data_path '/path_to_the_image_dir/VisualDialog_val2018' \ 113 | --out_path '/$PROJ_ROOT/datasets/bottom-up-attention/val2018_resnet101_faster_rcnn_genome_num_boxes_100.h5' \ 114 | --prototxt 'models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt' \ 115 | --weights '/path_to_bottom_up_attention_checkpoints/bottom-up-attention/resnet101_faster_rcnn_final.caffemodel' 116 | 117 | # Estimate 2 hours 118 | # Extract the image features for the test split 119 | /usr/bin/python generate_visdial.py \ 120 | --split "test" \ 121 | --topNattr 20 \ 122 | --num_images 8000 \ 123 | --data_path '/path_to_the_image_dir/VisualDialog_test2018' \ 124 | --out_path '/$PROJ_ROOT/datasets/bottom-up-attention/test2018_resnet101_faster_rcnn_genome_num_boxes_100.h5' \ 125 | --prototxt 'models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt' \ 126 | --weights '/path_to_bottom_up_attention_checkpoints/bottom-up-attention/resnet101_faster_rcnn_final.caffemodel' 127 | ``` 128 | * At the end, the directory `$PROJ_ROOT/datasets/bottom-up-attention/` should have the following files: 129 | ```sh 130 | ./trainval_resnet101_faster_rcnn_genome_100.h5 131 | ./val2018_resnet101_faster_rcnn_genome_100.h5 132 | ./test2018_resnet101_faster_rcnn_genome_100.h5 133 | ``` 134 | * In the `$PROJ_ROOT/datasets/`, we also provide the available data that you need: 135 | ``` 136 | $PROJ_ROOT/datasets/glove/embedding_Glove_840_300d.pkl 137 | $PROJ_ROOT/datasets/genome/1600-400-20/attributes_vocab.txt 138 | $PROJ_ROOT/datasets/genome/1600-400-20/objects_vocab.txt 139 | ``` 140 | 141 | Training 142 | -------- 143 | 144 | Our code supports both generative and discriminative decoders (and both of them that we call `misc`). 145 | We also provide the training script which supports Visdial v1.0 and Visdial v0.9. 146 | 147 | **Note**: If the CUDA is out of memory, please consider to decrease the `batch_size`. 148 | 149 | ### Training on Visdial v1.0 150 | To reproduce our results on Visdial v1.0, please run the following command (the other hyperparameters will be considered as default as our paper's): 151 | 152 | ```sh 153 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \ 154 | --config_name model_v10 \ 155 | --save_dir checkpoints \ 156 | --batch_size 8 \ 157 | --decoder_type misc \ 158 | --init_lr 0.001 \ 159 | --scheduler_type "LinearLR" \ 160 | --num_epochs 15 \ 161 | --num_samples 123287 \ 162 | --milestone_steps 3 5 7 9 11 13 \ 163 | --encoder_out 'img' 'ques' \ 164 | --dropout 0.1 \ 165 | --img_has_bboxes \ 166 | --ca_has_layer_norm \ 167 | --ca_num_attn_stacks 2 \ 168 | --ca_has_residual \ 169 | --ca_has_self_attns \ 170 | --txt_has_layer_norm \ 171 | --txt_has_decoder_layer_norm \ 172 | --txt_has_pos_embedding \ 173 | ``` 174 | **Note 1**: The `batch_size` is set per each GPU. If you have 4 GPUs, the number of actual `batch_size` is 32 as ours. 175 | 176 | **Note 2**: You can also train 177 | a discriminative model or a generative model by specifying `--decoder_type` as `disc` and `gen`, respectively. 178 | 179 | ### Training on Visdial v0.9 180 | To reproduce our results on Visdial v0.9, please run the following command (the other hyperparameters will be considered as default as our paper's): 181 | 182 | ```sh 183 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \ 184 | --config_name misc_v0.9 \ 185 | --save_dir checkpoints \ 186 | --v0.9 \ 187 | --batch_size 8 \ 188 | --decoder_type misc \ 189 | --init_lr 0.001 \ 190 | --scheduler_type "LinearLR" \ 191 | --num_epochs 5 \ 192 | --num_samples 123287 \ 193 | --milestone_steps 3 5 \ 194 | --encoder_out 'img' 'ques' \ 195 | --dropout 0.1 \ 196 | --img_has_bboxes \ 197 | --ca_has_layer_norm \ 198 | --ca_num_attn_stacks 2 \ 199 | --ca_has_residual \ 200 | --ca_has_self_attns \ 201 | --txt_has_layer_norm \ 202 | --txt_has_decoder_layer_norm \ 203 | --txt_has_pos_embedding \ 204 | --val_feat_img_path "datasets/bottom-up-attention/trainval_resnet101_faster_rcnn_genome_num_boxes_100.h5" \ 205 | --train_feat_img_path "datasets/bottom-up-attention/trainval_resnet101_faster_rcnn_genome_num_boxes_100.h5" \ 206 | --val_json_dialog_path "datasets/annotations/visdial_0.9_val.json" \ 207 | --train_json_dialog_path "datasets/annotations/visdial_0.9_val.json" 208 | ``` 209 | **Note 1**: You must turn the flag `--v0.9` on, then the corresponding `VisdialDataset` for Visdial v0.9 will be generated. 210 | 211 | **Note 2**: `val_json_dialog_path` is the same as `train_feat_img_path` 212 | since the v0.9 validation split is part of `trainval` split in Visdial v1.0. It will not cause any confliction 213 | since the validation split v0.9 will be generated based on the `image_ids` we get from `val_json_dialog_path`. 214 | 215 | As the original testbed, we also provide an `--overfit` flag, which can be useful for debugging. 216 | 217 | ### Saving model checkpoints 218 | The checkpoint is saved at every epoch at the directory you specify with `--save_dir`. The default directory is `checkpoint/`. 219 | 220 | ### Logging 221 | 222 | Tensorboard is used for logging training progress. Please go to `checkpoints/tensorboard` directory execute the following 223 | ```shell script 224 | tensorboard --logdir ./ --port 8008 225 | # and open `localhost:8008` in the browser. 226 | ``` 227 | 228 | ### Finetuning for Ensemble 229 | Run the following command to perform fintuning: 230 | 231 | ``` 232 | python finetune.py \ 233 | --model_path path/to/checkpoint/model_v10.pth \ 234 | --save_path path/to/saved/checkpoint \ 235 | ``` 236 | Evaluation 237 | ---------- 238 | 239 | The evaluation of a trained model checkpoint on the validation set can be done as follows: 240 | 241 | ```sh 242 | python evaluate.py \ 243 | --model_path 'checkpoints/model_v10.pth' \ 244 | --split val \ 245 | --decoder_type disc \ 246 | --device 'cuda:0' \ 247 | --output_path 'checkpoints/val_v1_disc.json' 248 | ``` 249 | **Note 1**: You can evaluate on three kinds of decoders: `disc`, `gen`, and `misc`. 250 | 251 | **Note 2**: The above script is also applicable for the `test` split by changing the value of `--split` to `test`. After that, 252 | please submit the `test_v1_disc.json` to the server for further evaluation. 253 | 254 | **Note 3**: The above script is also applicable for the evaluation on Visdial v0.9. 255 | 256 | 257 | This will generate an EvalAI submission file, and report metrics (Mean reciprocal rank, R@{1, 5, 10}, Mean rank), and Normalized Discounted Cumulative Gain (NDCG), introduced in the first Visual Dialog Challenge (in 2018). 258 | 259 | Result of Checkpoints 260 | ---------------------------------- 261 | ### The overall architecture 262 | 263 | The get the summary of the overall architecture, run the following python code: 264 | 265 | ```python 266 | import torch 267 | 268 | model = torch.load('checkpoints/model_v10.pth') 269 | print(model) 270 | ``` 271 | 272 | ### The number of the stack of attention blocks 273 | To compute the number of parameters in our proposed attention stacks, run the python code as follows: 274 | 275 | ```python 276 | import torch 277 | from visdial.utils import get_num_params 278 | 279 | model = torch.load('checkpoints/model_v10.pth') 280 | # The number of parameters per one stack 281 | print(get_num_params(model.encoder.attn_encoder.cross_attn_encoder[0])) 282 | 283 | # The number of parameters of the attention encoder 284 | print(get_num_params(model.encoder.attn_encoder)) 285 | ``` 286 | 287 | 288 | Performance on `v1.0 validation` split (trained on `v1.0` train + val): 289 | 290 | Model | R@1 | R@5 | R@10 | MeanR | MRR | NDCG | 291 | ------- | ------ | ------ | ------ | ------ | ------ | ------ | 292 | [model-v1.0] with outputs from disc | 0.4894 | 0.7865 | 0.8788 | 4.8589 | 0.6232| 0.6272 | 293 | [model-v1.0] with outputs from gen | 0.4044 | 0.6161 | 0.6971 | 14.9274| 0.5074| 0.6358 | 294 | [model-v1.0] with outputs from the avg of two | 0.4303|0.6663 | 0.7567 | 10.6030| 0.5436| 0.6575 | 295 | 296 | Performance on `v1.0 test` split (trained on `v1.0` train + val): 297 | 298 | Model | R@1 | R@5 | R@10 | MeanR | MRR | NDCG | 299 | ------- | ------ | ------ | ------ | ------ | ------ | ------ | 300 | [disc-model-v1.0] | 0.4700 | 0.7703 | 0.8775 | 4.90| 0.6065| 0.6092 | 301 | 302 | Performance on `v0.9 validation` split (trained on `v0.9` train): 303 | 304 | Model | R@1 | R@5 | R@10 | MeanR | MRR | 305 | ------- | ------ | ------ | ------ | ------ | ------ | 306 | [disc-model-v0.9] | 55.05 | 0.83.98 | 91.58 | 3.69 | 67.94 | 307 | 308 | 309 | Acknowledgements 310 | ---------------- 311 | 312 | * This evaluation code is built upon the fork of [visdial-challenge-starter-pytorch](https://github.com/batra-mlp-lab/visdial-challenge-starter-pytorch) 313 | developed by the team of researchers from Machine Learning and Perception Lab, Georgia Tech 314 | for [Visual Dialog Challenge 2019](https://visualdialog.org/challenge/2019). 315 | We would like to thank them for providing this testbed. 316 | 317 | 318 | -------------------------------------------------------------------------------- /visdial/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import os 3 | 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | from torch.utils.data import Dataset 7 | # from pytorch_pretrained_bert import BertTokenizer 8 | from visdial.data.vocabulary import Vocabulary 9 | from visdial.utils import move_to_cuda 10 | 11 | from visdial.data.readers import DialogsReader, DenseAnnotationsReader, ImageFeaturesHdfReader 12 | 13 | PADDING_IDX = 0 14 | SEP_TOKEN = 102 15 | 16 | 17 | class VisDialDataset(Dataset): 18 | """ 19 | A full representation of VisDial v1.0 (train/val/test) dataset. According 20 | to the appropriate split, it returns dictionary of question, image, 21 | history, ground truth answer, answer options, dense annotations etc. 22 | """ 23 | 24 | def __init__(self, config, split="train"): 25 | super().__init__() 26 | self.config = config 27 | self.split = split 28 | self.tokenizer = self._get_tokenizer(config) 29 | self.is_add_boundaries = self._get_is_add_boundaries(config) 30 | self.is_return_options = self._get_is_return_options(config) 31 | 32 | self.dialogs_reader = DialogsReader(config, split) 33 | self.img_feat_reader = self._get_img_feat_reader(config, split) 34 | self.dense_ann_feat_reader = self._get_dense_ann_feat_reader(config, split) 35 | self.image_ids = list(self.dialogs_reader.dialogs.keys()) 36 | self.image_ids = list(self.dialogs_reader.dialogs.keys()) 37 | 38 | if config['dataset']['overfit']: 39 | self.image_ids = self.image_ids[:64] 40 | if config['dataset']['finetune'] and split != 'test': 41 | self.image_ids = self.dense_ann_feat_reader._image_ids 42 | 43 | def __len__(self): 44 | return len(self.image_ids) 45 | 46 | def __getitem__(self, index, is_monitor=False): 47 | 48 | if is_monitor: 49 | out = self.getimage(index) 50 | res = {} 51 | for key in out: 52 | res[key] = out[key].unsqueeze(0) 53 | res = move_to_cuda(res, 'cuda:0') 54 | return res 55 | else: 56 | image_id = self.image_ids[index] 57 | out = self.getimage(image_id) 58 | return out 59 | 60 | def getimage(self, image_id, is_monitor=False): 61 | # Get image_id, which serves as a primary key for current instance. 62 | 63 | visdial_instance = self.dialogs_reader[image_id] 64 | dialog = visdial_instance['dialog'] 65 | 66 | if is_monitor: 67 | return self.monitor_output(image_id) 68 | 69 | item = dict() 70 | item['img_ids'] = torch.tensor(image_id) 71 | 72 | item['num_rounds'] = torch.tensor(visdial_instance['num_rounds']) 73 | 74 | return_elements = [ 75 | self.return_options_to_item(dialog), 76 | self.return_answers_to_item(dialog), 77 | self.return_gt_inds_to_item(dialog), 78 | self.return_gt_relev_to_item(image_id), 79 | self.return_img_feat_to_item(image_id), 80 | self.return_token_feats_to_item(visdial_instance) 81 | ] 82 | 83 | for elem in return_elements: 84 | item.update(elem) 85 | 86 | return item 87 | 88 | def _get_is_add_boundaries(self, config): 89 | return config['dataset']['is_add_boundaries'] 90 | 91 | def _get_is_return_options(self, config): 92 | return config['dataset']['is_return_options'] 93 | 94 | def _get_dense_ann_feat_reader(self, config, split): 95 | path = config['dataset'].get(f'{split}_json_dense_dialog_path', None) 96 | 97 | return DenseAnnotationsReader(os.path.expanduser(path)) if path is not None else None 98 | 99 | def _get_img_feat_reader(self, config, split): 100 | path = config['dataset'][f'{split}_feat_img_path'] 101 | path = os.path.expanduser(path) 102 | 103 | genome_path = config['dataset'].get('genome_path', None) 104 | if genome_path is None: 105 | hdf_reader = ImageFeaturesHdfReader(path) 106 | else: 107 | hdf_reader = ImageFeaturesHdfReader(path, genome_path=os.path.expanduser(genome_path)) 108 | return hdf_reader 109 | 110 | def _get_tokenizer(self, config): 111 | if config['model']['txt_tokenizer'] == 'bert': 112 | pass 113 | # return BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) 114 | else: 115 | path = config['dataset']['train_json_word_count_path'] 116 | path = os.path.expanduser(path) 117 | return Vocabulary(word_counts_path=path) 118 | 119 | def _pad_sequences(self, sequences: List[List[int]], max_seq_len=None): 120 | """Given tokenized sequences (either questions, answers or answer 121 | options, tokenized in ``__getitem__``), padding them to maximum 122 | specified sequence length. Return as a tensor of size 123 | ``(*, max_sequence_length)``. 124 | 125 | This method is only called in ``__getitem__``, chunked out separately 126 | for readability. 127 | 128 | Parameters 129 | ---------- 130 | sequences : List[List[int]] 131 | List of tokenized sequences, each sequence is typically a 132 | List[int]. 133 | 134 | Returns 135 | ------- 136 | torch.Tensor, torch.Tensor 137 | Tensor of sequences padded to max length, and length of sequences 138 | before padding. 139 | """ 140 | if max_seq_len is None: 141 | max_seq_len = self.config['dataset']['max_seq_len'] 142 | 143 | for i in range(len(sequences)): 144 | if self.is_add_boundaries: 145 | sequences[i] = sequences[i][: max_seq_len] # + 1 146 | else: 147 | sequences[i] = sequences[i][: max_seq_len] # -1 148 | 149 | sequence_lengths = [len(sequence) for sequence in sequences] 150 | 151 | PAD_INDEX = 0 152 | 153 | # Pad all sequences to max_sequence_length. 154 | maxpadded_sequences = torch.full( 155 | (len(sequences), max_seq_len), 156 | fill_value=PAD_INDEX, 157 | ) 158 | padded_sequences = pad_sequence( 159 | [torch.tensor(sequence) for sequence in sequences], 160 | batch_first=True, 161 | padding_value=PAD_INDEX, 162 | ) 163 | maxpadded_sequences[:, : padded_sequences.size(1)] = padded_sequences 164 | return maxpadded_sequences.long(), torch.tensor(sequence_lengths).long() 165 | 166 | def tokens_to_ids(self, tokens, is_caption=False): 167 | if is_caption: 168 | tokens = tokens[:self.config['dataset']['max_seq_len'] * 2] 169 | tokens = tokens[:self.config['dataset']['max_seq_len']] 170 | return self.tokenizer.convert_tokens_to_ids(tokens) 171 | 172 | def tokens_to_ids_with_boundary(self, tokens): 173 | tokens = tokens[:self.config['dataset']['max_seq_len'] - 2] 174 | tokens = [self.tokenizer.SOS_TOKEN] + tokens + [self.tokenizer.EOS_TOKEN] 175 | return self.tokens_to_ids(tokens) 176 | 177 | def convert_opt_tokens_to_ids(self, dialog): 178 | for round in range(len(dialog)): 179 | for j in range(len(dialog[round]["answer_options"])): 180 | tokens = dialog[round]["answer_options"][j] 181 | if self.is_add_boundaries: 182 | dialog[round]["answer_options"][j] = self.tokens_to_ids_with_boundary(tokens) 183 | else: 184 | dialog[round]["answer_options"][j] = self.tokens_to_ids(tokens) 185 | return dialog 186 | 187 | def do_padding(self, sequences, start=None, end=None, max_seq_len=None): 188 | sequences = [seq[start:end] for seq in sequences] 189 | sequences, seq_lens = self._pad_sequences(sequences, max_seq_len) 190 | return sequences, seq_lens 191 | 192 | def return_options_to_item(self, dialog): 193 | self.convert_opt_tokens_to_ids(dialog) 194 | ans_opts_in, ans_opts_out = [], [] 195 | ans_opts_in_len = [] 196 | ans_opts_out_len = [] 197 | ans_opts_len = [] 198 | ans_opts = [] 199 | 200 | for dialog_round in dialog: 201 | # for boundary 202 | # answer options input 203 | result = self.do_padding(dialog_round['answer_options'], end=-1) 204 | ans_opts_in.append(result[0]) 205 | ans_opts_in_len.append(result[1]) 206 | 207 | # answer options output 208 | result = self.do_padding(dialog_round['answer_options'], start=1) 209 | ans_opts_out.append(result[0]) 210 | ans_opts_out_len.append(result[1]) 211 | 212 | # for normal case 213 | result = self.do_padding(dialog_round['answer_options'], start=1, end=-1) 214 | ans_opts.append(result[0]) 215 | ans_opts_len.append(result[1]) 216 | 217 | ans_opts = torch.stack(ans_opts, dim=0) 218 | ans_opts_in = torch.stack(ans_opts_in, dim=0) 219 | ans_opts_out = torch.stack(ans_opts_out, dim=0) 220 | 221 | return { 222 | 'opts': ans_opts, 223 | 'opts_in': ans_opts_in, 224 | 'opts_out': ans_opts_out, 225 | 'opts_len': torch.stack(ans_opts_len, dim=0), 226 | 'opts_in_len': torch.stack(ans_opts_in_len, dim=0), 227 | 'opts_out_len': torch.stack(ans_opts_out_len, dim=0) 228 | } 229 | 230 | def return_answers_to_item(self, dialog): 231 | if self.split == 'test': 232 | return {} 233 | 234 | for round_idx in range(len(dialog)): 235 | tokens = [self.tokenizer.SOS_TOKEN] + dialog[round_idx]['answer'] + [self.tokenizer.EOS_TOKEN] 236 | dialog[round_idx]['answer'] = self.tokens_to_ids(tokens) 237 | 238 | round_answers = [dialog_round["answer"] for dialog_round in dialog] 239 | result = self.do_padding(round_answers, start=1, end=-1) 240 | result_in = self.do_padding(round_answers, end=-1) 241 | result_out = self.do_padding(round_answers, start=1) 242 | 243 | return { 244 | 'ans': result[0], 245 | 'ans_in': result_in[0], 246 | 'ans_out': result_out[0], 247 | 'ans_len': result[1], 248 | 'ans_in_len': result_in[1], 249 | 'ans_out_len': result_out[1], 250 | } 251 | 252 | def return_gt_inds_to_item(self, dialog): 253 | if 'test' not in self.split: 254 | answer_indices = [dialog_round['gt_index'] for dialog_round in dialog] 255 | return {'ans_ind': torch.tensor(answer_indices).long()} 256 | else: 257 | return {} 258 | 259 | def return_gt_relev_to_item(self, image_id): 260 | if self.dense_ann_feat_reader is not None: 261 | dense_annotations = self.dense_ann_feat_reader[image_id] 262 | if self.split == 'train': 263 | return { 264 | "gt_relevance": torch.tensor(dense_annotations["relevance"]).float(), 265 | "round_id": torch.tensor(dense_annotations["round_id"]).long() 266 | } 267 | else: 268 | return { 269 | "gt_relevance": torch.tensor(dense_annotations["gt_relevance"]).float(), 270 | "round_id": torch.tensor(dense_annotations["round_id"]).long() 271 | } 272 | else: 273 | return {} 274 | 275 | def _get_history(self, caption, questions, answers): 276 | # Allow double length of caption, equivalent to a concatenated QA pair. 277 | 278 | caption = caption[: self.config['dataset']['max_seq_len'] * 2] 279 | 280 | for i in range(len(questions)): 281 | questions[i] = questions[i][: self.config['dataset']['max_seq_len']] 282 | 283 | for i in range(len(answers)): 284 | if self.config['dataset']['is_add_boundaries'] and self.split != 'test': 285 | answers[i] = answers[i][1: -1] 286 | else: 287 | answers[i] = answers[i][: self.config['dataset']['max_seq_len']] 288 | 289 | # History for first round is caption, else concatenated QA pair of 290 | # previous round. 291 | history = [] 292 | history.append(caption + [self.tokenizer.EOS_INDEX]) 293 | 294 | for question, answer in zip(questions, answers): 295 | if len(question) == 0: 296 | break 297 | history.append(question + answer + [self.tokenizer.EOS_INDEX]) 298 | 299 | # Drop last entry from history (there's no eleventh question). 300 | history = history[:-1] 301 | max_history_length = self.config['dataset']['max_seq_len'] * 2 302 | round_tokens, round_lens = self.do_padding(history, max_seq_len=max_history_length) 303 | 304 | if self.config['dataset']['concat_hist']: 305 | # 10 dialog histories 306 | # 1 - caption 307 | # 2 - caption, round1 308 | # 3 - caption, round1, round2 309 | # .... 310 | # 10 caption, round1, round2, ..., round9 311 | concat_hist_tokens = [] 312 | for i in range(0, len(history)): 313 | concat_hist_tokens.append([]) 314 | for j in range(i + 1): 315 | concat_hist_tokens[i].extend(history[j]) 316 | 317 | concat_hist_tokens, concat_hist_lens = self.do_padding( 318 | concat_hist_tokens, 319 | max_seq_len=max_history_length * 10) 320 | 321 | return round_tokens, round_lens, concat_hist_tokens, concat_hist_lens 322 | else: 323 | return round_tokens, round_lens, None, None 324 | 325 | def return_token_feats_to_item(self, visdial_instance): 326 | if self.config['model']['txt_tokenizer'] == 'nlp': 327 | caption = visdial_instance["caption"] 328 | dialog = visdial_instance["dialog"] 329 | 330 | # Convert word tokens of caption, question 331 | caption = self.tokens_to_ids(caption) 332 | 333 | for i in range(len(dialog)): 334 | dialog[i]["question"] = self.tokens_to_ids(dialog[i]["question"]) 335 | if self.split == 'test': 336 | dialog[i]["answer"] = self.tokens_to_ids(dialog[i]['answer']) 337 | 338 | sequences = [dialog_round["question"] for dialog_round in dialog] 339 | ques_tokens, ques_lens = self._pad_sequences(sequences) 340 | 341 | hist_tokens, hist_lens, concat_hist_tokens, concat_hist_lens = self._get_history( 342 | caption, 343 | [dialog_round["question"] for dialog_round in dialog], 344 | [dialog_round["answer"] for dialog_round in dialog], 345 | ) 346 | 347 | if self.config['dataset']['concat_hist']: 348 | return { 349 | 'ques_tokens': ques_tokens.long(), 350 | 'hist_tokens': hist_tokens.long(), 351 | 'ques_len': ques_lens.long(), 352 | 'hist_len': hist_lens.long(), 353 | 'concat_hist_tokens': concat_hist_tokens.long(), 354 | 'concat_hist_lens': concat_hist_lens.long() 355 | } 356 | 357 | return { 358 | 'ques_tokens': ques_tokens.long(), 359 | 'hist_tokens': hist_tokens.long(), 360 | 'ques_len': ques_lens.long(), 361 | 'hist_len': hist_lens.long() 362 | } 363 | else: 364 | return {} 365 | 366 | def return_img_feat_to_item(self, image_id): 367 | # Get image features for this image_id using hdf reader. 368 | res = {} 369 | out = self.img_feat_reader[image_id] 370 | res['img_feat'] = torch.tensor(out['features']) 371 | 372 | if self.config['model']['img_has_bboxes']: 373 | res['num_boxes'] = torch.tensor(out['num_boxes']) 374 | res['img_w'] = torch.tensor(out['image_w']) 375 | res['img_h'] = torch.tensor(out['image_h']) 376 | res['boxes'] = torch.tensor(out['boxes']) 377 | 378 | if self.config['model']['img_has_attributes']: 379 | attrs = [] 380 | for box_attr in out['top_attr_names']: 381 | attrs.append(self.tokenizer.convert_tokens_to_ids(box_attr)) 382 | 383 | res['attrs'] = torch.tensor(attrs).long() 384 | res['attr_scores'] = torch.tensor(out['top_attrs_scores']) 385 | 386 | if self.config['model']['img_has_classes']: 387 | cls_ids = self.tokenizer.convert_tokens_to_ids(out['cls_names']) 388 | res['classes'] = torch.tensor(cls_ids).long() 389 | return res 390 | 391 | def monitor_output(self, image_id): 392 | visdial_instance = self.dialogs_reader[image_id] 393 | dialog = visdial_instance['dialog'] 394 | 395 | if "val" in self.split: 396 | dense_annotations = self.dense_ann_feat_reader[image_id] 397 | 398 | gt_relevance = torch.tensor(dense_annotations["gt_relevance"]).float() 399 | round_id = torch.tensor(dense_annotations["round_id"]).long() 400 | rel_ans_idx = dialog[round_id - 1]["gt_index"] 401 | caption = 'caption should be extracted' 402 | return { 403 | 'img_id': image_id, 404 | 'caption': caption, 405 | 'dialog': dialog, 406 | 'gt_relevance': gt_relevance, 407 | 'round_id': round_id, 408 | 'rel_ans_idx': rel_ans_idx 409 | } 410 | -------------------------------------------------------------------------------- /datasets/genome/1600-400-20/objects_vocab.txt: -------------------------------------------------------------------------------- 1 | yolk 2 | goal 3 | bathroom 4 | macaroni 5 | umpire 6 | toothpick 7 | alarm clock 8 | ceiling fan 9 | photos 10 | parrot 11 | tail fin 12 | birthday cake 13 | calculator 14 | catcher 15 | toilet 16 | batter 17 | stop sign,stopsign 18 | cone 19 | microwave,microwave oven 20 | skateboard ramp 21 | tea 22 | dugout 23 | products 24 | halter 25 | kettle 26 | kitchen 27 | refrigerator,fridge 28 | ostrich 29 | bathtub 30 | blinds 31 | court 32 | urinal 33 | knee pads 34 | bed 35 | flamingo 36 | giraffe 37 | helmet 38 | giraffes 39 | tennis court 40 | motorcycle 41 | laptop 42 | tea pot 43 | horse 44 | television,tv 45 | shorts 46 | manhole 47 | dishwasher 48 | jeans 49 | sail 50 | monitor 51 | man 52 | shirt 53 | car 54 | cat 55 | garage door 56 | bus 57 | radiator 58 | tights 59 | sailboat,sail boat 60 | racket,racquet 61 | plate 62 | rock wall 63 | beach 64 | trolley 65 | ocean 66 | headboard,head board 67 | tea kettle 68 | wetsuit 69 | tennis racket,tennis racquet 70 | sink 71 | train 72 | keyboard 73 | sky 74 | match 75 | train station 76 | stereo 77 | bats 78 | tennis player 79 | toilet brush 80 | lighter 81 | pepper shaker 82 | gazebo 83 | hair dryer 84 | elephant 85 | toilet seat 86 | zebra 87 | skateboard,skate board 88 | zebras 89 | floor lamp 90 | french fries 91 | woman 92 | player 93 | tower 94 | bicycle 95 | magazines 96 | christmas tree 97 | umbrella 98 | cow 99 | pants 100 | bike 101 | field 102 | living room 103 | latch 104 | bedroom 105 | grape 106 | castle 107 | table 108 | swan 109 | blender 110 | orange 111 | teddy bear 112 | net 113 | meter 114 | baseball field 115 | runway 116 | screen 117 | ski boot 118 | dog 119 | clock 120 | hair 121 | avocado 122 | highway 123 | skirt 124 | frisbee 125 | parasail 126 | desk 127 | pizza 128 | mouse 129 | sign 130 | shower curtain 131 | polar bear 132 | airplane 133 | jersey 134 | reigns 135 | hot dog,hotdog 136 | surfboard,surf board 137 | couch 138 | glass 139 | snowboard 140 | girl 141 | plane 142 | elephants 143 | oven 144 | dirt bike 145 | tail wing 146 | area rug 147 | bear 148 | washer 149 | date 150 | bow tie 151 | cows 152 | fire extinguisher 153 | bamboo 154 | wallet 155 | tail feathers 156 | truck 157 | beach chair 158 | boat 159 | tablet 160 | ceiling 161 | chandelier 162 | sheep 163 | glasses 164 | ram 165 | kite 166 | salad 167 | pillow 168 | fire hydrant,hydrant 169 | mug 170 | tarmac 171 | computer 172 | swimsuit 173 | tomato 174 | tire 175 | cauliflower 176 | fireplace 177 | snow 178 | building 179 | sandwich 180 | weather vane 181 | bird 182 | jacket 183 | chair 184 | water 185 | cats 186 | soccer ball 187 | horses 188 | drapes 189 | barn 190 | engine 191 | cake 192 | head 193 | head band 194 | skier 195 | town 196 | bath tub 197 | bowl 198 | stove 199 | tongue 200 | coffee table 201 | floor 202 | uniform 203 | ottoman 204 | broccoli 205 | olive 206 | mound 207 | pitcher 208 | food 209 | paintings 210 | traffic light 211 | parking meter 212 | bananas 213 | mountain 214 | cage 215 | hedge 216 | motorcycles 217 | wet suit 218 | radish 219 | teddy bears 220 | monitors 221 | suitcase,suit case 222 | drawers 223 | grass 224 | apple 225 | lamp 226 | goggles 227 | boy 228 | armchair 229 | ramp 230 | burner 231 | lamb 232 | cup 233 | tank top 234 | boats 235 | hat 236 | soup 237 | fence 238 | necklace 239 | visor 240 | coffee 241 | bottle 242 | stool 243 | shoe 244 | surfer 245 | stop 246 | backpack 247 | shin guard 248 | wii remote 249 | wall 250 | pizza slice 251 | home plate 252 | van 253 | packet 254 | earrings 255 | wristband 256 | tracks 257 | mitt 258 | dome 259 | snowboarder 260 | faucet 261 | toiletries 262 | ski boots 263 | room 264 | fork 265 | snow suit 266 | banana slice 267 | bench 268 | tie 269 | burners 270 | stuffed animals 271 | zoo 272 | train platform 273 | cupcake 274 | curtain 275 | ear 276 | tissue box 277 | bread 278 | scissors 279 | vase 280 | herd 281 | smoke 282 | skylight 283 | cub 284 | tail 285 | cutting board 286 | wave 287 | hedges 288 | windshield 289 | apples 290 | mirror 291 | license plate 292 | tree 293 | wheel 294 | ski pole 295 | clock tower 296 | freezer 297 | luggage 298 | skateboarder 299 | mousepad 300 | road 301 | bat 302 | toilet tank 303 | vanity 304 | neck 305 | cliff 306 | tub 307 | sprinkles 308 | dresser 309 | street 310 | wing 311 | suit 312 | veggie 313 | palm trees 314 | urinals 315 | door 316 | propeller 317 | keys 318 | skate park 319 | platform 320 | pot 321 | towel 322 | computer monitor 323 | flip flop 324 | eggs 325 | shed 326 | moped 327 | sand 328 | face 329 | scissor 330 | carts 331 | squash 332 | pillows 333 | family 334 | glove 335 | rug 336 | watch 337 | grafitti 338 | dogs 339 | scoreboard 340 | basket 341 | poster 342 | duck 343 | horns 344 | bears 345 | jeep 346 | painting 347 | lighthouse 348 | remote control 349 | toaster 350 | vegetables 351 | surfboards 352 | ducks 353 | lane 354 | carrots 355 | market 356 | paper towels 357 | island 358 | blueberries 359 | smile 360 | balloons 361 | stroller 362 | napkin 363 | towels 364 | papers 365 | person 366 | train tracks 367 | child 368 | headband 369 | pool 370 | plant 371 | harbor 372 | counter 373 | hand 374 | house 375 | donut,doughnut 376 | knot 377 | soccer player 378 | seagull 379 | bottles 380 | buses 381 | coat 382 | trees 383 | geese 384 | bun 385 | toilet bowl 386 | trunk 387 | station 388 | bikini 389 | goatee 390 | lounge chair 391 | breakfast 392 | nose 393 | moon 394 | river 395 | racer 396 | picture 397 | shaker 398 | sidewalk,side walk 399 | shutters 400 | stove top,stovetop 401 | church 402 | lampshade 403 | map 404 | shop 405 | platter 406 | airport 407 | hoodie 408 | oranges 409 | woods 410 | enclosure 411 | skatepark 412 | vases 413 | city 414 | park 415 | mailbox 416 | balloon 417 | billboard 418 | pasture 419 | portrait 420 | forehead 421 | ship 422 | cookie 423 | seaweed 424 | sofa 425 | slats 426 | tomato slice 427 | tractor 428 | bull 429 | suitcases 430 | graffiti 431 | policeman 432 | remotes 433 | pens 434 | window sill 435 | suspenders 436 | easel 437 | tray 438 | straw 439 | collar 440 | shower 441 | bag 442 | scooter 443 | tails 444 | toilet lid 445 | panda 446 | comforter 447 | outlet 448 | stems 449 | valley 450 | flag 451 | jockey 452 | gravel 453 | mouth 454 | window 455 | bridge 456 | corn 457 | mountains 458 | beer 459 | pitcher's mound 460 | palm tree 461 | crowd 462 | skis 463 | phone 464 | banana bunch 465 | tennis shoe 466 | ground 467 | carpet 468 | eye 469 | urn 470 | beak 471 | giraffe head 472 | steeple 473 | mattress 474 | baseball player 475 | wine 476 | water bottle 477 | kitten 478 | archway 479 | candle 480 | croissant 481 | tennis ball 482 | dress 483 | column 484 | utensils 485 | cell phone 486 | computer mouse 487 | cap 488 | lawn 489 | airplanes 490 | carriage 491 | snout 492 | cabinets 493 | lemons 494 | grill 495 | umbrellas 496 | meat 497 | wagon 498 | ipod 499 | bookshelf 500 | cart 501 | roof 502 | hay 503 | ski pants 504 | seat 505 | mane 506 | bikes 507 | drawer 508 | game 509 | clock face 510 | boys 511 | rider 512 | fire escape 513 | slope 514 | iphone 515 | pumpkin 516 | pan 517 | chopsticks 518 | hill 519 | uniforms 520 | cleat 521 | costume 522 | cabin 523 | police officer 524 | ears 525 | egg 526 | trash can 527 | horn 528 | arrow 529 | toothbrush 530 | carrot 531 | banana 532 | planes 533 | garden 534 | forest 535 | brocolli 536 | aircraft 537 | front window 538 | dashboard 539 | statue 540 | saucer 541 | people 542 | silverware 543 | fruit 544 | drain 545 | jet 546 | speaker 547 | eyes 548 | railway 549 | lid 550 | soap 551 | rocks 552 | office chair 553 | door knob 554 | banana peel 555 | baseball game 556 | asparagus 557 | spoon 558 | cabinet door 559 | pineapple 560 | traffic cone 561 | nightstand,night stand 562 | teapot 563 | taxi 564 | chimney 565 | lake 566 | suit jacket 567 | train engine 568 | ball 569 | wrist band 570 | pickle 571 | fruits 572 | pad 573 | dispenser 574 | bridle 575 | breast 576 | cones 577 | headlight 578 | necktie 579 | skater 580 | toilet paper 581 | skyscraper 582 | telephone 583 | ox 584 | roadway 585 | sock 586 | paddle 587 | dishes 588 | hills 589 | street sign 590 | headlights 591 | benches 592 | fuselage 593 | card 594 | napkins 595 | bush 596 | rice 597 | computer screen 598 | spokes 599 | flowers 600 | bucket 601 | rock 602 | pole 603 | pear 604 | sauce 605 | store 606 | juice 607 | knobs 608 | mustard 609 | ski 610 | stands 611 | cabinet 612 | dirt 613 | goats 614 | wine glass 615 | spectators 616 | crate 617 | pancakes 618 | kids 619 | engines 620 | shade 621 | feeder 622 | cellphone 623 | pepper 624 | blanket 625 | sunglasses 626 | train car 627 | magnet 628 | donuts,doughnuts 629 | sweater 630 | signal 631 | advertisement 632 | log 633 | vent 634 | whiskers 635 | adult 636 | arch 637 | locomotive 638 | tennis match 639 | tent 640 | motorbike 641 | magnets 642 | night 643 | marina 644 | wool 645 | vest 646 | railroad tracks 647 | stuffed bear 648 | moustache 649 | bib 650 | frame 651 | snow pants 652 | tank 653 | undershirt 654 | icons 655 | neck tie 656 | beams 657 | baseball bat 658 | safety cone 659 | paper towel 660 | bedspread 661 | can 662 | container 663 | flower 664 | vehicle 665 | tomatoes 666 | back wheel 667 | soccer field 668 | nostril 669 | suv 670 | buildings 671 | canopy 672 | flame 673 | kid 674 | baseball 675 | throw pillow 676 | belt 677 | rainbow 678 | lemon 679 | oven door 680 | tag 681 | books 682 | monument 683 | men 684 | shadow 685 | bicycles 686 | cars 687 | lamp shade 688 | pine tree 689 | bouquet 690 | toothpaste 691 | potato 692 | sinks 693 | hook 694 | switch 695 | lamp post,lamppost 696 | lapel 697 | desert 698 | knob 699 | chairs 700 | pasta 701 | feathers 702 | hole 703 | meal 704 | station wagon 705 | kites 706 | boots 707 | baby 708 | biker 709 | gate 710 | signal light 711 | headphones 712 | goat 713 | waves 714 | bumper 715 | bud 716 | logo 717 | curtains 718 | american flag 719 | yacht 720 | box 721 | baseball cap 722 | fries 723 | controller 724 | awning 725 | path 726 | front legs 727 | life jacket 728 | purse 729 | outfield 730 | pigeon 731 | toddler 732 | beard 733 | thumb 734 | water tank 735 | board 736 | parade 737 | robe 738 | newspaper 739 | wires 740 | camera 741 | pastries 742 | deck 743 | watermelon 744 | clouds 745 | deer 746 | motorcyclist 747 | kneepad 748 | sneakers 749 | women 750 | onions 751 | eyebrow 752 | gas station 753 | vane 754 | girls 755 | trash 756 | numerals 757 | knife 758 | tags 759 | light 760 | bunch 761 | outfit 762 | groom 763 | infield 764 | frosting 765 | forks 766 | entertainment center 767 | stuffed animal 768 | yard 769 | numeral 770 | ladder 771 | shoes 772 | bracelet 773 | teeth 774 | guy 775 | display case 776 | cushion 777 | post 778 | pathway 779 | tablecloth 780 | skiers 781 | trouser 782 | cloud 783 | hands 784 | produce 785 | beam 786 | ketchup 787 | paw 788 | dish 789 | raft 790 | crosswalk 791 | front wheel 792 | toast 793 | cattle 794 | players 795 | group 796 | coffee pot 797 | track 798 | cowboy hat 799 | petal 800 | eyeglasses 801 | handle 802 | table cloth 803 | jets 804 | shakers 805 | remote 806 | snowsuit 807 | bushes 808 | dessert 809 | leg 810 | eagle 811 | fire truck,firetruck 812 | game controller 813 | smartphone 814 | backsplash 815 | trains 816 | shore 817 | signs 818 | bell 819 | cupboards 820 | sweat band 821 | sack 822 | ankle 823 | coin slot 824 | bagel 825 | masts 826 | police 827 | drawing 828 | biscuit 829 | toy 830 | legs 831 | pavement 832 | outside 833 | wheels 834 | driver 835 | numbers 836 | blazer 837 | pen 838 | cabbage 839 | trucks 840 | key 841 | saddle 842 | pillow case 843 | goose 844 | label 845 | boulder 846 | pajamas 847 | wrist 848 | shelf 849 | cross 850 | coffee cup 851 | foliage 852 | lot 853 | fry 854 | air 855 | officer 856 | pepperoni 857 | cheese 858 | lady 859 | kickstand 860 | counter top 861 | veggies 862 | baseball uniform 863 | book shelf 864 | bags 865 | pickles 866 | stand 867 | netting 868 | lettuce 869 | facial hair 870 | lime 871 | animals 872 | drape 873 | boot 874 | railing 875 | end table 876 | shin guards 877 | steps 878 | trashcan 879 | tusk 880 | head light 881 | walkway 882 | cockpit 883 | tennis net 884 | animal 885 | boardwalk 886 | keypad 887 | bookcase 888 | blueberry 889 | trash bag 890 | ski poles 891 | parking lot 892 | gas tank 893 | beds 894 | fan 895 | base 896 | soap dispenser 897 | banner 898 | life vest 899 | train front 900 | word 901 | cab 902 | liquid 903 | exhaust pipe 904 | sneaker 905 | light fixture 906 | power lines 907 | curb 908 | scene 909 | buttons 910 | roman numerals 911 | muzzle 912 | sticker 913 | bacon 914 | pizzas 915 | paper 916 | feet 917 | stairs 918 | triangle 919 | plants 920 | rope 921 | beans 922 | brim 923 | beverage 924 | letters 925 | soda 926 | menu 927 | finger 928 | dvds 929 | candles 930 | picnic table 931 | wine bottle 932 | pencil 933 | tree trunk 934 | nail 935 | mantle 936 | countertop 937 | view 938 | line 939 | motor bike 940 | audience 941 | traffic sign 942 | arm 943 | pedestrian 944 | stabilizer 945 | dock 946 | doorway 947 | bedding 948 | end 949 | worker 950 | canal 951 | crane 952 | grate 953 | little girl 954 | rims 955 | passenger car 956 | plates 957 | background 958 | peel 959 | brake light 960 | roman numeral 961 | string 962 | tines 963 | turf 964 | armrest 965 | shower head 966 | leash 967 | stones 968 | stoplight 969 | handle bars 970 | front 971 | scarf 972 | band 973 | jean 974 | tennis 975 | pile 976 | doorknob 977 | foot 978 | houses 979 | windows 980 | restaurant 981 | booth 982 | cardboard box 983 | fingers 984 | mountain range 985 | bleachers 986 | rail 987 | pastry 988 | canoe 989 | sun 990 | eye glasses 991 | salt shaker 992 | number 993 | fish 994 | knee pad 995 | fur 996 | she 997 | shower door 998 | rod 999 | branches 1000 | birds 1001 | printer 1002 | sunset 1003 | median 1004 | shutter 1005 | slice 1006 | heater 1007 | prongs 1008 | bathing suit 1009 | skiier 1010 | rack 1011 | book 1012 | blade 1013 | apartment 1014 | manhole cover 1015 | stools 1016 | overhang 1017 | door handle 1018 | couple 1019 | picture frame 1020 | chicken 1021 | planter 1022 | seats 1023 | hour hand 1024 | dvd player 1025 | ski slope 1026 | french fry 1027 | bowls 1028 | top 1029 | landing gear 1030 | coffee maker 1031 | melon 1032 | computers 1033 | light switch 1034 | jar 1035 | tv stand 1036 | overalls 1037 | garage 1038 | tabletop 1039 | writing 1040 | doors 1041 | stadium 1042 | placemat 1043 | air vent 1044 | trick 1045 | sled 1046 | mast 1047 | pond 1048 | steering wheel 1049 | baseball glove 1050 | watermark 1051 | pie 1052 | sandwhich 1053 | cpu 1054 | mushroom 1055 | power pole 1056 | dirt road 1057 | handles 1058 | speakers 1059 | fender 1060 | telephone pole 1061 | strawberry 1062 | mask 1063 | children 1064 | crust 1065 | art 1066 | rim 1067 | branch 1068 | display 1069 | grasses 1070 | photo 1071 | receipt 1072 | instructions 1073 | herbs 1074 | toys 1075 | handlebars 1076 | trailer 1077 | sandal 1078 | skull 1079 | hangar 1080 | pipe 1081 | office 1082 | chest 1083 | lamps 1084 | horizon 1085 | calendar 1086 | foam 1087 | stone 1088 | bars 1089 | button 1090 | poles 1091 | heart 1092 | hose 1093 | jet engine 1094 | potatoes 1095 | rain 1096 | magazine 1097 | chain 1098 | footboard 1099 | tee shirt 1100 | design 1101 | walls 1102 | copyright 1103 | pictures 1104 | pillar 1105 | drink 1106 | barrier 1107 | boxes 1108 | chocolate 1109 | chef 1110 | slot 1111 | sweatpants 1112 | face mask 1113 | icing 1114 | wipers 1115 | circle 1116 | bin 1117 | kitty 1118 | electronics 1119 | wild 1120 | tiles 1121 | steam 1122 | lettering 1123 | bathroom sink 1124 | laptop computer 1125 | cherry 1126 | spire 1127 | conductor 1128 | sheet 1129 | slab 1130 | windshield wipers 1131 | storefront 1132 | hill side 1133 | spatula 1134 | tail light,taillight 1135 | bean 1136 | wire 1137 | intersection 1138 | pier 1139 | snow board 1140 | trunks 1141 | website 1142 | bolt 1143 | kayak 1144 | nuts 1145 | holder 1146 | turbine 1147 | stop light 1148 | olives 1149 | ball cap 1150 | burger 1151 | barrel 1152 | fans 1153 | beanie 1154 | stem 1155 | lines 1156 | traffic signal 1157 | sweatshirt 1158 | handbag 1159 | mulch 1160 | socks 1161 | landscape 1162 | soda can 1163 | shelves 1164 | ski lift 1165 | cord 1166 | vegetable 1167 | apron 1168 | blind 1169 | bracelets 1170 | stickers 1171 | traffic 1172 | strip 1173 | tennis shoes 1174 | swim trunks 1175 | hillside 1176 | sandals 1177 | concrete 1178 | lips 1179 | butter knife 1180 | words 1181 | leaves 1182 | train cars 1183 | spoke 1184 | cereal 1185 | pine trees 1186 | cooler 1187 | bangs 1188 | half 1189 | sheets 1190 | figurine 1191 | park bench 1192 | stack 1193 | second floor 1194 | motor 1195 | hand towel 1196 | wristwatch 1197 | spectator 1198 | tissues 1199 | flip flops 1200 | quilt 1201 | floret 1202 | calf 1203 | back pack 1204 | grapes 1205 | ski tracks 1206 | skin 1207 | bow 1208 | controls 1209 | dinner 1210 | baseball players 1211 | ad 1212 | ribbon 1213 | hotel 1214 | sea 1215 | cover 1216 | tarp 1217 | weather 1218 | notebook 1219 | mustache 1220 | stone wall 1221 | closet 1222 | statues 1223 | bank 1224 | skateboards 1225 | butter 1226 | dress shirt 1227 | knee 1228 | wood 1229 | laptops 1230 | cuff 1231 | hubcap 1232 | wings 1233 | range 1234 | structure 1235 | balls 1236 | tunnel 1237 | globe 1238 | utensil 1239 | dumpster 1240 | cd 1241 | floors 1242 | wrapper 1243 | folder 1244 | pocket 1245 | mother 1246 | ski goggles 1247 | posts 1248 | power line 1249 | wake 1250 | roses 1251 | train track 1252 | reflection 1253 | air conditioner 1254 | referee 1255 | barricade 1256 | baseball mitt 1257 | mouse pad 1258 | garbage can 1259 | buckle 1260 | footprints 1261 | lights 1262 | muffin 1263 | bracket 1264 | plug 1265 | taxi cab 1266 | drinks 1267 | surfers 1268 | arrows 1269 | control panel 1270 | ring 1271 | twigs 1272 | soil 1273 | skies 1274 | clock hand 1275 | caboose 1276 | playground 1277 | mango 1278 | stump 1279 | brick wall 1280 | screw 1281 | minivan 1282 | leaf 1283 | fencing 1284 | ledge 1285 | clothes 1286 | grass field 1287 | plumbing 1288 | blouse 1289 | patch 1290 | scaffolding 1291 | hamburger 1292 | utility pole 1293 | teddy 1294 | rose 1295 | skillet 1296 | cycle 1297 | cable 1298 | gloves 1299 | bark 1300 | decoration 1301 | tables 1302 | palm 1303 | wii 1304 | mountain top 1305 | shrub 1306 | hoof 1307 | celery 1308 | beads 1309 | plaque 1310 | flooring 1311 | surf 1312 | cloth 1313 | passenger 1314 | spot 1315 | plastic 1316 | knives 1317 | case 1318 | railroad 1319 | pony 1320 | muffler 1321 | hot dogs,hotdogs 1322 | stripe 1323 | scale 1324 | block 1325 | recliner 1326 | body 1327 | shades 1328 | tap 1329 | tools 1330 | cupboard 1331 | wallpaper 1332 | sculpture 1333 | surface 1334 | sedan 1335 | distance 1336 | shrubs 1337 | skiis 1338 | lift 1339 | bottom 1340 | cleats 1341 | roll 1342 | clothing 1343 | bed frame 1344 | slacks 1345 | tail lights 1346 | doll 1347 | traffic lights 1348 | symbol 1349 | strings 1350 | fixtures 1351 | short 1352 | paint 1353 | candle holder 1354 | guard rail 1355 | cyclist 1356 | tree branches 1357 | ripples 1358 | gear 1359 | waist 1360 | trash bin 1361 | onion 1362 | home 1363 | side mirror 1364 | brush 1365 | sweatband 1366 | handlebar 1367 | light pole 1368 | street lamp 1369 | pads 1370 | ham 1371 | artwork 1372 | reflector 1373 | figure 1374 | tile 1375 | mountainside 1376 | black 1377 | bricks 1378 | paper plate 1379 | stick 1380 | beef 1381 | patio 1382 | weeds 1383 | back 1384 | sausage 1385 | paws 1386 | farm 1387 | decal 1388 | harness 1389 | monkey 1390 | fence post 1391 | door frame 1392 | stripes 1393 | clocks 1394 | ponytail 1395 | toppings 1396 | strap 1397 | carton 1398 | greens 1399 | chin 1400 | lunch 1401 | name 1402 | earring 1403 | area 1404 | tshirt,t-shirt,t shirt 1405 | cream 1406 | rails 1407 | cushions 1408 | lanyard 1409 | brick 1410 | hallway 1411 | cucumber 1412 | wire fence 1413 | fern 1414 | tangerine 1415 | windowsill 1416 | pipes 1417 | package 1418 | wheelchair 1419 | chips 1420 | driveway 1421 | tattoo 1422 | side window 1423 | stairway 1424 | basin 1425 | machine 1426 | table lamp 1427 | radio 1428 | pony tail 1429 | ocean water 1430 | inside 1431 | cargo 1432 | overpass 1433 | mat 1434 | socket 1435 | flower pot 1436 | tree line 1437 | sign post 1438 | tube 1439 | dial 1440 | splash 1441 | male 1442 | lantern 1443 | lipstick 1444 | lip 1445 | tongs 1446 | ski suit 1447 | trail 1448 | passenger train 1449 | bandana 1450 | antelope 1451 | designs 1452 | tents 1453 | photograph 1454 | catcher's mitt 1455 | electrical outlet 1456 | tires 1457 | boulders 1458 | mannequin 1459 | plain 1460 | layer 1461 | mushrooms 1462 | strawberries 1463 | piece 1464 | oar 1465 | bike rack 1466 | slices 1467 | arms 1468 | fin 1469 | shadows 1470 | hood 1471 | windshield wiper 1472 | letter 1473 | dot 1474 | bus stop 1475 | railings 1476 | pebbles 1477 | mud 1478 | claws 1479 | police car 1480 | crown 1481 | meters 1482 | name tag 1483 | entrance 1484 | staircase 1485 | shrimp 1486 | ladies 1487 | peak 1488 | vines 1489 | computer keyboard 1490 | glass door 1491 | pears 1492 | pant 1493 | wine glasses 1494 | stall 1495 | asphalt 1496 | columns 1497 | sleeve 1498 | pack 1499 | cheek 1500 | baskets 1501 | land 1502 | day 1503 | blocks 1504 | courtyard 1505 | pedal 1506 | panel 1507 | seeds 1508 | balcony 1509 | yellow 1510 | disc 1511 | young man 1512 | eyebrows 1513 | crumbs 1514 | spinach 1515 | emblem 1516 | object 1517 | bar 1518 | cardboard 1519 | tissue 1520 | light post 1521 | ski jacket 1522 | seasoning 1523 | parasol 1524 | terminal 1525 | surfing 1526 | streetlight,street light 1527 | alley 1528 | cords 1529 | image 1530 | jug 1531 | antenna 1532 | puppy 1533 | berries 1534 | diamond 1535 | pans 1536 | fountain 1537 | foreground 1538 | syrup 1539 | bride 1540 | spray 1541 | license 1542 | peppers 1543 | passengers 1544 | cement 1545 | flags 1546 | shack 1547 | trough 1548 | objects 1549 | arches 1550 | streamer 1551 | pots 1552 | border 1553 | baseboard 1554 | beer bottle 1555 | wrist watch 1556 | tile floor 1557 | page 1558 | pin 1559 | items 1560 | baseline 1561 | hanger 1562 | tree branch 1563 | tusks 1564 | donkey 1565 | containers 1566 | condiments 1567 | device 1568 | envelope 1569 | parachute 1570 | mesh 1571 | hut 1572 | butterfly 1573 | salt 1574 | restroom 1575 | twig 1576 | pilot 1577 | ivy 1578 | furniture 1579 | clay 1580 | print 1581 | sandwiches 1582 | lion 1583 | shingles 1584 | pillars 1585 | vehicles 1586 | panes 1587 | shoreline 1588 | stream 1589 | control 1590 | lock 1591 | microphone 1592 | blades 1593 | towel rack 1594 | coaster 1595 | star 1596 | petals 1597 | text 1598 | feather 1599 | spots 1600 | buoy --------------------------------------------------------------------------------