├── Decoders ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── decoder1.cpython-36.pyc │ └── decoder1_attention.cpython-36.pyc ├── decoder1.py └── decoder1_attention.py ├── LICENSE ├── README.md ├── attention.py ├── base_model.py ├── classifier.py ├── data └── features_faster_rcnn_x101_train_v0.9_imgid.json ├── dataset.py ├── dmrm_overview.png ├── eval_v0.9.py ├── eval_v1.0.py ├── fc.py ├── language_model.py ├── main_v0.9.py ├── main_v1.0.py ├── misc ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── dataLoader.cpython-36.pyc │ ├── dataLoader_v1.cpython-36.pyc │ ├── focalloss.cpython-36.pyc │ ├── readers.cpython-36.pyc │ └── utils.cpython-36.pyc ├── dataLoader.py ├── dataLoader_v1.py ├── focalloss.py ├── readers.py └── utils.py ├── requirements.txt ├── script ├── create_glove.py └── prepro.py ├── train.py └── utils.py /Decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phellonchen/DMRM/4e530d6d6ac8e0f311baa186fd439421b76d6f70/Decoders/__init__.py -------------------------------------------------------------------------------- /Decoders/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phellonchen/DMRM/4e530d6d6ac8e0f311baa186fd439421b76d6f70/Decoders/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /Decoders/__pycache__/decoder1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phellonchen/DMRM/4e530d6d6ac8e0f311baa186fd439421b76d6f70/Decoders/__pycache__/decoder1.cpython-36.pyc -------------------------------------------------------------------------------- /Decoders/__pycache__/decoder1_attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phellonchen/DMRM/4e530d6d6ac8e0f311baa186fd439421b76d6f70/Decoders/__pycache__/decoder1_attention.cpython-36.pyc -------------------------------------------------------------------------------- /Decoders/decoder1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import pdb 5 | import math 6 | import numpy as np 7 | import torch.nn.functional as F 8 | # from misc.share_Linear import share_Linear 9 | from misc.utils import l2_norm 10 | from classifier import SimpleClassifier 11 | 12 | class _netG(nn.Module): 13 | def __init__(self, args): 14 | super(_netG, self).__init__() 15 | 16 | self.ninp = args.ninp 17 | self.nhid = args.nhid 18 | self.nlayers = args.nlayers 19 | self.dropout = args.dropout 20 | 21 | self.rnn = getattr(nn, 'LSTM')(self.ninp, self.nhid, self.nlayers, bidirectional=True, dropout=self.dropout, batch_first=True) 22 | 23 | self.rnn_type = 'LSTM' 24 | 25 | self.decoder =SimpleClassifier(self.nhid, self.nhid*2 , args.vocab_size, self.dropout) 26 | self.d = args.dropout 27 | self.beta = 3 28 | self.vocab_size = args.vocab_size 29 | # self.init_weights() 30 | 31 | def init_weights(self): 32 | 33 | self.decoder.weight = nn.init.xavier_uniform(self.decoder.weight) 34 | self.decoder.bias.data.fill_(0) 35 | 36 | def forward(self, emb, hidden): 37 | 38 | output, hidden1 = self.rnn(emb, hidden) 39 | output = output[:,:,:self.nhid] 40 | output = F.dropout(output, self.d, training=self.training) 41 | decoded = self.decoder(output.view(emb.size(0)*emb.size(1), self.nhid)) 42 | logprob = F.log_softmax(self.beta * decoded, 1) 43 | 44 | return logprob.view(emb.size(0), emb.size(1), -1), hidden1 45 | 46 | def init_hidden(self, bsz): 47 | weight = next(self.parameters()).data 48 | if self.rnn_type == 'LSTM': 49 | return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()), 50 | Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())) 51 | else: 52 | return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()) 53 | 54 | def sample_beam(self, netW, input, hidden_state, opt={}): 55 | beam_size = opt.get('beam_size', 10) 56 | batch_size = input.size(1) 57 | 58 | # assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 59 | seq_all = torch.LongTensor(self.seq_length, batch_size, beam_size).zero_() 60 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 61 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 62 | # lets process every image independently for now, for simplicity 63 | 64 | self.done_beams = [[] for _ in range(batch_size)] 65 | for k in range(batch_size): 66 | # copy the hidden state for beam_size time. 67 | state = [] 68 | for state_tmp in hidden_state: 69 | state.append(state_tmp[:, k, :].view(1, 1, -1).expand(1, beam_size, self.nhid).clone()) 70 | 71 | state = tuple(state) 72 | 73 | beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() 74 | beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() 75 | beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam 76 | for t in range(self.seq_length + 1): 77 | if t == 0: # input 78 | it = input.data.resize_(1, beam_size).fill_(self.vocab_size) 79 | xt = netW(Variable(it, requires_grad=False)) 80 | else: 81 | """perform a beam merge. that is, 82 | for every previous beam we now many new possibilities to branch out 83 | we need to resort our beams to maintain the loop invariant of keeping 84 | the top beam_size most likely sequences.""" 85 | logprobsf = logprobs.float() # lets go to CPU for more efficiency in indexing operations 86 | ys, ix = torch.sort(logprobsf, 1, 87 | True) # sorted array of logprobs along each previous beam (last true = descending) 88 | candidates = [] 89 | cols = min(beam_size, ys.size(1)) 90 | rows = beam_size 91 | if t == 1: # at first time step only the first beam is active 92 | rows = 1 93 | for cc in range(cols): # for each column (word, essentially) 94 | for qq in range(rows): # for each beam expansion 95 | # compute logprob of expanding beam q with word in (sorted) position c 96 | local_logprob = ys[qq, cc] 97 | if beam_seq[t - 2, qq] == self.vocab_size: 98 | local_logprob.data.fill_(-9999) 99 | 100 | candidate_logprob = beam_logprobs_sum[qq] + local_logprob 101 | candidates.append({'c': ix.data[qq, cc], 'q': qq, 'p': candidate_logprob.data[0], 102 | 'r': local_logprob.data[0]}) 103 | 104 | candidates = sorted(candidates, key=lambda x: -x['p']) 105 | 106 | # construct new beams 107 | new_state = [_.clone() for _ in state] 108 | if t > 1: 109 | # well need these as reference when we fork beams around 110 | beam_seq_prev = beam_seq[:t - 1].clone() 111 | beam_seq_logprobs_prev = beam_seq_logprobs[:t - 1].clone() 112 | for vix in range(beam_size): 113 | v = candidates[vix] 114 | # fork beam index q into index vix 115 | if t > 1: 116 | beam_seq[:t - 1, vix] = beam_seq_prev[:, v['q']] 117 | beam_seq_logprobs[:t - 1, vix] = beam_seq_logprobs_prev[:, v['q']] 118 | 119 | # rearrange recurrent states 120 | for state_ix in range(len(new_state)): 121 | # copy over state in previous beam q to new beam at vix 122 | new_state[state_ix][0, vix] = state[state_ix][0, v['q']] # dimension one is time step 123 | 124 | # append new end terminal at the end of this beam 125 | beam_seq[t - 1, vix] = v['c'] # c'th word is the continuation 126 | beam_seq_logprobs[t - 1, vix] = v['r'] # the raw logprob here 127 | beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam 128 | 129 | if v['c'] == self.vocab_size or t == self.seq_length: 130 | # END token special case here, or we reached the end. 131 | # add the beam to a set of done beams 132 | self.done_beams[k].append({'seq': beam_seq[:, vix].clone(), 133 | 'logps': beam_seq_logprobs[:, vix].clone(), 134 | 'p': beam_logprobs_sum[vix] 135 | }) 136 | 137 | # encode as vectors 138 | it = beam_seq[t - 1].view(1, -1) 139 | xt = netW(Variable(it.cuda())) 140 | 141 | if t >= 1: 142 | state = new_state 143 | 144 | output, state = self.rnn(xt, state) 145 | 146 | output = F.dropout(output, self.d, training=self.training) 147 | decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2))) 148 | logprobs = F.log_softmax(self.beta * decoded) 149 | 150 | self.done_beams[k] = sorted(self.done_beams[k], key=lambda x: -x['p']) 151 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 152 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 153 | for ii in range(beam_size): 154 | seq_all[:, k, ii] = self.done_beams[k][ii]['seq'] 155 | 156 | # return the samples and their log likelihoods 157 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 158 | 159 | def sample(self, netW, input, state, opt={}): 160 | sample_max = opt.get('sample_max', 1) 161 | beam_size = opt.get('beam_size', 1) 162 | temperature = opt.get('temperature', 1.0) 163 | seq_length = opt.get('seq_length', 9) 164 | self.seq_length = seq_length 165 | 166 | if beam_size > 1: 167 | return self.sample_beam(netW, input, state, opt) 168 | 169 | batch_size = input.size(1) 170 | seq = [] 171 | seqLogprobs = [] 172 | for t in range(self.seq_length + 1): 173 | if t == 0: # input 174 | it = input.data 175 | elif sample_max: 176 | sampleLogprobs, it = torch.max(logprobs.data, 1) 177 | it = it.view(-1).long() 178 | else: 179 | if temperature == 1.0: 180 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 181 | else: 182 | # scale logprobs by temperature 183 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 184 | it = torch.multinomial(prob_prev, 1).cuda() 185 | sampleLogprobs = logprobs.gather(1, Variable(it, 186 | requires_grad=False)) # gather the logprobs at sampled positions 187 | it = it.view(-1).long() # and flatten indices for downstream processing 188 | 189 | xt = netW(Variable(it.view(-1, 1), requires_grad=False)) 190 | 191 | if t >= 1: 192 | seq.append(it) # seq[t] the input of t+2 time step 193 | seqLogprobs.append(sampleLogprobs.view(-1)) 194 | it = torch.unsqueeze(it, 0) 195 | 196 | output, state = self.rnn(xt, state) 197 | output = F.dropout(output, self.d, training=self.training) 198 | decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2))) 199 | logprobs = F.log_softmax(self.beta * decoded, 1) 200 | 201 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | -------------------------------------------------------------------------------- /Decoders/decoder1_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | from classifier import SimpleClassifier 6 | 7 | class _netG(nn.Module): 8 | def __init__(self, args): 9 | super(_netG, self).__init__() 10 | 11 | self.ninp = args.ninp 12 | self.nhid = args.nhid 13 | self.nlayers = args.nlayers 14 | self.dropout = args.dropout 15 | self.rnn = getattr(nn, 'LSTM')(self.ninp, self.nhid, self.nlayers, bidirectional=False, dropout=self.dropout, batch_first=True) 16 | self.rnn_type = 'LSTM' 17 | 18 | self.decoder =SimpleClassifier(self.nhid*2, self.nhid*4, args.vocab_size, self.dropout) 19 | self.d = args.dropout 20 | self.beta = 3 21 | self.vocab_size = args.vocab_size 22 | # self.init_weights() 23 | self.w_q = nn.Linear(self.nhid*2, self.nhid) 24 | self.ans_q = nn.Linear(self.nhid, self.nhid) 25 | self.Wa_q = nn.Linear(self.nhid, 1) 26 | 27 | self.w_h = nn.Linear(self.nhid*2, self.nhid) 28 | self.ans_h = nn.Linear(self.nhid, self.nhid) 29 | self.Wa_h = nn.Linear(self.nhid, 1) 30 | 31 | self.w_i = nn.Linear(self.nhid*2, self.nhid) 32 | self.ans_i = nn.Linear(self.nhid, self.nhid) 33 | self.Wa_i = nn.Linear(self.nhid, 1) 34 | 35 | self.concat = nn.Linear(self.nhid*3, self.nhid) 36 | # self.fusion = nn.Linear(self.nhid*2, self.nhid*2) 37 | 38 | def init_weights(self): 39 | self.decoder.weight = nn.init.xavier_uniform(self.decoder.weight) 40 | self.decoder.bias.data.fill_(0) 41 | 42 | def forward(self, emb, question, history, image, hidden): 43 | ques_length = question.size(1) 44 | his_length = history.size(1) 45 | img_length = image.size(1) 46 | batch_size, ans_length, _ = emb.size() 47 | question = question.contiguous() 48 | seqLogprobs = [] 49 | for index in range(ans_length): 50 | input_ans = emb[:, index, :].unsqueeze(1) 51 | output, hidden = self.rnn(input_ans, hidden) 52 | input_ans = output.squeeze(1) 53 | ques_emb = self.w_q(question.view(-1, 2*self.nhid)).view(-1, ques_length, self.nhid) 54 | input_ans_q = self.ans_q(input_ans).view(-1, 1, self.nhid) 55 | atten_emb_q = F.tanh(ques_emb + input_ans_q.expand_as(ques_emb)) 56 | ques_atten_weight = F.softmax(self.Wa_q(F.dropout(atten_emb_q, self.d, training=self.training).view(-1, self.nhid)).view(-1, ques_length), 1) 57 | ques_attn_feat = torch.bmm(ques_atten_weight.view(-1, 1, ques_length), ques_emb.view(-1,ques_length, self.nhid)) 58 | 59 | input_ans_h = self.ans_h(input_ans).view(-1, 1, self.nhid) 60 | his_emb = self.w_h(history.view(-1, 2* self.nhid)).view(-1, his_length, self.nhid) 61 | atten_emb_h = F.tanh(his_emb + input_ans_h.expand_as(his_emb)) 62 | his_atten_weight = F.softmax(self.Wa_h(F.dropout(atten_emb_h, self.d, training=self.training).view(-1, self.nhid)).view(-1, his_length), 1) 63 | his_attn_feat = torch.bmm(his_atten_weight.view(-1, 1, his_length), his_emb.view(-1, his_length, self.nhid)) 64 | 65 | input_ans_i = self.ans_i(input_ans).view(-1, 1, self.nhid) 66 | img_emb = self.w_i(image.view(-1, 2* self.nhid)).view(-1, img_length, self.nhid) 67 | atten_emb_i = F.tanh(img_emb + input_ans_i.expand_as(img_emb)) 68 | img_atten_weight = F.softmax(self.Wa_i(F.dropout(atten_emb_i, self.d, training=self.training).view(-1, self.nhid)).view(-1, img_length), 1) 69 | img_attn_feat = torch.bmm(img_atten_weight.view(-1, 1, img_length), img_emb.view(-1, img_length, self.nhid)) 70 | 71 | concat_feat = torch.cat((ques_attn_feat.view(-1, self.nhid), his_attn_feat.view(-1, self.nhid), img_attn_feat.view(-1, self.nhid)),1) 72 | concat_feat = F.tanh(self.concat(F.dropout(concat_feat, self.d, training=self.training))) 73 | fusion_feat = torch.cat((output.squeeze(1), concat_feat),1) 74 | 75 | fusion_feat = F.dropout(fusion_feat, self.d, training=self.training) 76 | decoded = self.decoder(fusion_feat.view(-1, 2*self.nhid)) 77 | logprob = F.log_softmax(self.beta * decoded, 1) 78 | seqLogprobs.append(logprob) 79 | 80 | return torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1).contiguous(), hidden 81 | 82 | def init_hidden(self, bsz): 83 | weight = next(self.parameters()).data 84 | if self.rnn_type == 'LSTM': 85 | return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()), 86 | Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())) 87 | else: 88 | return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()) 89 | 90 | def sample_beam(self, netW, input, hidden_state, opt={}): 91 | beam_size = opt.get('beam_size', 10) 92 | batch_size = input.size(1) 93 | 94 | # assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 95 | seq_all = torch.LongTensor(self.seq_length, batch_size, beam_size).zero_() 96 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 97 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 98 | # lets process every image independently for now, for simplicity 99 | 100 | self.done_beams = [[] for _ in range(batch_size)] 101 | for k in range(batch_size): 102 | # copy the hidden state for beam_size time. 103 | state = [] 104 | for state_tmp in hidden_state: 105 | state.append(state_tmp[:, k, :].view(1, 1, -1).expand(1, beam_size, self.nhid).clone()) 106 | 107 | state = tuple(state) 108 | 109 | beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() 110 | beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() 111 | beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam 112 | for t in range(self.seq_length + 1): 113 | if t == 0: # input 114 | it = input.data.resize_(1, beam_size).fill_(self.vocab_size) 115 | xt = netW(Variable(it, requires_grad=False)) 116 | else: 117 | """perform a beam merge. that is, 118 | for every previous beam we now many new possibilities to branch out 119 | we need to resort our beams to maintain the loop invariant of keeping 120 | the top beam_size most likely sequences.""" 121 | logprobsf = logprobs.float() # lets go to CPU for more efficiency in indexing operations 122 | ys, ix = torch.sort(logprobsf, 1, 123 | True) # sorted array of logprobs along each previous beam (last true = descending) 124 | candidates = [] 125 | cols = min(beam_size, ys.size(1)) 126 | rows = beam_size 127 | if t == 1: # at first time step only the first beam is active 128 | rows = 1 129 | for cc in range(cols): # for each column (word, essentially) 130 | for qq in range(rows): # for each beam expansion 131 | # compute logprob of expanding beam q with word in (sorted) position c 132 | local_logprob = ys[qq, cc] 133 | if beam_seq[t - 2, qq] == self.vocab_size: 134 | local_logprob.data.fill_(-9999) 135 | 136 | candidate_logprob = beam_logprobs_sum[qq] + local_logprob 137 | candidates.append({'c': ix.data[qq, cc], 'q': qq, 'p': candidate_logprob.data[0], 138 | 'r': local_logprob.data[0]}) 139 | 140 | candidates = sorted(candidates, key=lambda x: -x['p']) 141 | 142 | # construct new beams 143 | new_state = [_.clone() for _ in state] 144 | if t > 1: 145 | # well need these as reference when we fork beams around 146 | beam_seq_prev = beam_seq[:t - 1].clone() 147 | beam_seq_logprobs_prev = beam_seq_logprobs[:t - 1].clone() 148 | for vix in range(beam_size): 149 | v = candidates[vix] 150 | # fork beam index q into index vix 151 | if t > 1: 152 | beam_seq[:t - 1, vix] = beam_seq_prev[:, v['q']] 153 | beam_seq_logprobs[:t - 1, vix] = beam_seq_logprobs_prev[:, v['q']] 154 | 155 | # rearrange recurrent states 156 | for state_ix in range(len(new_state)): 157 | # copy over state in previous beam q to new beam at vix 158 | new_state[state_ix][0, vix] = state[state_ix][0, v['q']] # dimension one is time step 159 | 160 | # append new end terminal at the end of this beam 161 | beam_seq[t - 1, vix] = v['c'] # c'th word is the continuation 162 | beam_seq_logprobs[t - 1, vix] = v['r'] # the raw logprob here 163 | beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam 164 | 165 | if v['c'] == self.vocab_size or t == self.seq_length: 166 | # END token special case here, or we reached the end. 167 | # add the beam to a set of done beams 168 | self.done_beams[k].append({'seq': beam_seq[:, vix].clone(), 169 | 'logps': beam_seq_logprobs[:, vix].clone(), 170 | 'p': beam_logprobs_sum[vix] 171 | }) 172 | 173 | # encode as vectors 174 | it = beam_seq[t - 1].view(1, -1) 175 | xt = netW(Variable(it.cuda())) 176 | 177 | if t >= 1: 178 | state = new_state 179 | 180 | output, state = self.rnn(xt, state) 181 | 182 | output = F.dropout(output, self.d, training=self.training) 183 | decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2))) 184 | logprobs = F.log_softmax(self.beta * decoded) 185 | 186 | self.done_beams[k] = sorted(self.done_beams[k], key=lambda x: -x['p']) 187 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 188 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 189 | for ii in range(beam_size): 190 | seq_all[:, k, ii] = self.done_beams[k][ii]['seq'] 191 | 192 | # return the samples and their log likelihoods 193 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 194 | 195 | def sample(self, netW, input, state, opt={}): 196 | sample_max = opt.get('sample_max', 1) 197 | beam_size = opt.get('beam_size', 1) 198 | temperature = opt.get('temperature', 1.0) 199 | seq_length = opt.get('seq_length', 9) 200 | self.seq_length = seq_length 201 | 202 | if beam_size > 1: 203 | return self.sample_beam(netW, input, state, opt) 204 | 205 | batch_size = input.size(1) 206 | seq = [] 207 | seqLogprobs = [] 208 | for t in range(self.seq_length + 1): 209 | if t == 0: # input 210 | it = input.data 211 | elif sample_max: 212 | sampleLogprobs, it = torch.max(logprobs.data, 1) 213 | it = it.view(-1).long() 214 | else: 215 | if temperature == 1.0: 216 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 217 | else: 218 | # scale logprobs by temperature 219 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 220 | it = torch.multinomial(prob_prev, 1).cuda() 221 | sampleLogprobs = logprobs.gather(1, Variable(it, 222 | requires_grad=False)) # gather the logprobs at sampled positions 223 | it = it.view(-1).long() # and flatten indices for downstream processing 224 | 225 | xt = netW(Variable(it.view(-1, 1), requires_grad=False)) 226 | 227 | if t >= 1: 228 | seq.append(it) # seq[t] the input of t+2 time step 229 | seqLogprobs.append(sampleLogprobs.view(-1)) 230 | it = torch.unsqueeze(it, 0) 231 | 232 | output, state = self.rnn(xt, state) 233 | output = F.dropout(output, self.d, training=self.training) 234 | decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2))) 235 | logprobs = F.log_softmax(self.beta * decoded, 1) 236 | 237 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 paper-coder 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | DMRM: A Dual-channel Multi-hop Reasoning Model for Visual Dialog 2 | ======================================================================== 3 | 4 | Pytorch Implementation for the paper: 5 | 6 | **[DMRM: A Dual-channel Multi-hop Reasoning Model for Visual Dialog][11]**
7 | Feilong Chen, Fandong Meng, Jiaming Xu, Peng Li, Bo Xu,and Jie Zhou
8 | In AAAI 2020 9 | 10 | 11 | 12 | 13 | 14 | Setup and Dependencies 15 | ---------------------- 16 | This code is implemented using **PyTorch v0.3.0** with **CUDA 8 and CuDNN 7**.
17 | It is recommended to set up this source code using Anaconda or Miniconda.
18 | 19 | 1. Install Anaconda or Miniconda distribution based on **Python 3.6+** from their [downloads' site][2]. 20 | 2. Clone this repository and create an environment: 21 | 22 | ```sh 23 | git clone https://github.com/phellonchen/DMRM.git 24 | conda create -n dmrm_visdial python=3.6 25 | 26 | # activate the environment and install all dependencies 27 | conda activate dmrm_visdial 28 | cd $PROJECT_ROOT/ 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | Download Features 33 | ---------------------- 34 | 1. Download the VisDial dialog json files from [here](https://visualdialog.org/data) and keep it under `$PROJECT_ROOT/data` directory, for default arguments to work effectively. 35 | 36 | 2. We used the Faster-RCNN pre-trained with Visual Genome as image features. Download the image features below, and put each feature under `$PROJECT_ROOT/data` directory. 37 | * [`features_faster_rcnn_x101_train.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_faster_rcnn_x101_train.h5): Bottom-up features of 36 proposals from images of `train` split. 38 | * [`features_faster_rcnn_x101_val.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_faster_rcnn_x101_val.h5): Bottom-up features of 36 proposals from images of `val` split. 39 | * [`features_faster_rcnn_x101_test.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_faster_rcnn_x101_test.h5): Bottom-up features of 36 proposals from images of `test` split. 40 | 41 | 3. Download the GloVe pretrained word vectors from [here][9], and keep `glove.6B.300d.txt` under `$PROJECT_ROOT/data` directory. 42 | 43 | Data preprocessing & Word embedding initialization 44 | ---------------------- 45 | ```sh 46 | # data preprocessing 47 | cd $PROJECT_ROOT/script/ 48 | python prepro.py 49 | 50 | # Word embedding vector initialization (GloVe) 51 | cd $PROJECT_ROOT/script/ 52 | python create_glove.py 53 | ``` 54 | 55 | Training 56 | -------- 57 | Simple run 58 | ```sh 59 | python main_v0.9.py or python main_v1.0.py 60 | ``` 61 | 62 | ### Saving model checkpoints 63 | Our model save model checkpoints at every epoch and undate the best one. You can change it by editing the `train.py`. 64 | 65 | ### Logging 66 | Logging data `$PROJECT_ROOT/save_models/time/log.txt` shows epoch, loss, and learning rate. 67 | 68 | Evaluation 69 | -------- 70 | Evaluation of a trained model checkpoint can be evaluated as follows: 71 | ```sh 72 | python eval_v0.9.py or python eval_v1.0.py 73 | ``` 74 | 75 | Results 76 | -------- 77 | Performance on `v0.9 val-std` (trained on `v0.9` train): 78 | 79 | Model | MRR | R@1 | R@5 | R@10 | Mean | 80 | ------- | ------ | ------ | ------ | ------ | ------ | 81 | DMRM | 55.96 | 46.20 | 66.02 | 72.43 | 13.15 | 82 | 83 | Performance on `v1.0 val-std` (trained on `v1.0` train): 84 | 85 | Model | MRR | R@1 | R@5 | R@10 | Mean | 86 | ------- | ------ | ------ | ------ | ------ | ------ | 87 | DMRM | 50.16 | 40.15 | 60.02 | 67.21 | 15.19 | 88 | 89 | If you find this repository useful, please consider citing our work: 90 | ``` 91 | @inproceedings{chen2020dmrm, 92 | title={DMRM: A dual-channel multi-hop reasoning model for visual dialog}, 93 | author={Chen, Feilong and Meng, Fandong and Xu, Jiaming and Li, Peng and Xu, Bo and Zhou, Jie}, 94 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 95 | volume={34}, 96 | number={05}, 97 | pages={7504--7511}, 98 | year={2020} 99 | } 100 | ``` 101 | License 102 | -------- 103 | MIT License 104 | 105 | [1]: https://arxiv.org/abs/1902.09368 106 | [2]: https://conda.io/docs/user-guide/install/download.html 107 | [3]: https://drive.google.com/file/d/1NYlSSikwEAqpJDsNGqOxgc0ZOkpQtom9/view?usp=sharing 108 | [4]: https://drive.google.com/file/d/1QSi0Lr4XKdQ2LdoS1taS6P9IBVAKRntF/view?usp=sharing 109 | [5]: https://drive.google.com/file/d/1NI5TNKKhqm6ggpB2CK4k8yKiYQE3efW6/view?usp=sharing 110 | [6]: https://drive.google.com/file/d/1nTBaLziRIVkKAqFtQ-YIbXew2tYMUOSZ/view?usp=sharing 111 | [7]: https://drive.google.com/file/d/1BXWPV3k-HxlTw_k3-kTV6JhWrdzXsT7W/view?usp=sharing 112 | [8]: https://drive.google.com/file/d/1_32kGhd6wKzQLqfmqJzIHubfZwe9nhFy/view?usp=sharing 113 | [9]: http://nlp.stanford.edu/data/glove.6B.zip 114 | [10]: https://evalai.cloudcv.org/web/challenges/challenge-page/161/overview 115 | [11]: https://arxiv.org/abs/1912.08360 116 | -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.weight_norm import weight_norm 4 | from fc import FCNet 5 | 6 | 7 | class Attention(nn.Module): 8 | def __init__(self, v_dim, q_dim, num_hid): 9 | super(Attention, self).__init__() 10 | self.nonlinear = FCNet([v_dim + q_dim, num_hid]) 11 | self.linear = weight_norm(nn.Linear(num_hid, 1), dim=None) 12 | 13 | def forward(self, v, q): 14 | """ 15 | v: [batch, k, vdim] 16 | q: [batch, qdim] 17 | """ 18 | logits = self.logits(v, q) 19 | w = nn.functional.softmax(logits, 1) 20 | return w 21 | 22 | def logits(self, v, q): 23 | num_objs = v.size(1) 24 | q = q.unsqueeze(1).repeat(1, num_objs, 1) 25 | vq = torch.cat((v, q), 2) 26 | joint_repr = self.nonlinear(vq) 27 | logits = self.linear(joint_repr) 28 | return logits 29 | 30 | 31 | class NewAttention(nn.Module): 32 | def __init__(self, v_dim, q_dim, num_hid, dropout=0.5): 33 | super(NewAttention, self).__init__() 34 | 35 | self.v_proj = FCNet([v_dim, num_hid]) 36 | self.q_proj = FCNet([q_dim, num_hid]) 37 | self.dropout = nn.Dropout(dropout) 38 | self.linear = weight_norm(nn.Linear(q_dim, 1), dim=None) 39 | 40 | def forward(self, v, q): 41 | """ 42 | v: [batch, k, vdim] 43 | q: [batch, qdim] 44 | """ 45 | logits = self.logits(v, q) 46 | w = nn.functional.softmax(logits, 1) 47 | return w 48 | 49 | def logits(self, v, q): 50 | batch, k, _ = v.size() 51 | v_proj = self.v_proj(v) # [batch, k, qdim] 52 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1) 53 | joint_repr = v_proj * q_proj 54 | joint_repr = self.dropout(joint_repr) 55 | logits = self.linear(joint_repr) 56 | return logits 57 | -------------------------------------------------------------------------------- /base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attention import Attention, NewAttention 4 | from language_model import WordEmbedding, QuestionEmbedding, QuestionEmbedding2 5 | from classifier import SimpleClassifier 6 | from fc import FCNet 7 | from Decoders.decoder1 import _netG as netG 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | from misc.utils import LayerNorm 11 | class BaseModel2(nn.Module): 12 | def __init__(self, w_emb, q_emb, h_emb, v_att, h_att, q_net, v_net, h_net, qih_att, qhi_att, qih_net, qhi_net, 13 | decoder, args, qhih_att, qihi_att): 14 | super(BaseModel2, self).__init__() 15 | self.ninp = args.ninp 16 | self.w_emb = w_emb 17 | self.q_emb = q_emb 18 | self.h_emb = h_emb 19 | self.decoder = decoder 20 | self.img_embed = nn.Linear(args.img_feat_size, 2 * args.nhid) 21 | self.w1 = nn.Linear(args.nhid*2, args.nhid*2) 22 | self.w2 = nn.Linear(args.nhid*2, args.nhid*2) 23 | self.track_1 = v_att 24 | self.locate_1 = h_att 25 | self.locate_2 = qih_att 26 | self.track_2 = qhi_att 27 | self.locate_3 = qhih_att 28 | self.track_3 = qihi_att 29 | self.q_net = q_net 30 | self.v_net = v_net 31 | self.h_net = h_net 32 | self.qih_net = qih_net 33 | self.qhi_net = qhi_net 34 | self.fc1 = nn.Linear(args.nhid * 4, self.ninp) 35 | self.dropout = args.dropout 36 | self.vocab_size = args.vocab_size 37 | # self.fch = FCNet([args.nhid * 2, args.nhid * 2]) 38 | # self.layernorm = LayerNorm(args.nhid*2) 39 | 40 | def forward(self, image, question, history, answer, tans, rnd, Training=True, sampling=False): 41 | 42 | # prepare I, Q, H 43 | image = self.img_embed(image) 44 | w_emb = self.w_emb(question) 45 | q_emb, ques_hidden = self.q_emb(w_emb) # [batch, q_dim] 46 | 47 | hw_emb = self.w_emb(history) 48 | h_emb, _ = self.h_emb(hw_emb) # [batch * rnd, h_dim] 49 | h_emb = h_emb.view(-1, rnd, h_emb.size(1)) 50 | 51 | # cap & image 52 | # qc_att = self.v_att(image, h_emb[:, 0, :]) 53 | # qc_emb = (qc_att * image).sum(1) 54 | # qc_emb = self.fch(qc_emb * q_emb) 55 | 56 | # question & image --> qi 57 | qv_att = self.track_1(image, q_emb) 58 | qv_emb = (qv_att * image).sum(1) # [batch, v_dim] 59 | 60 | # question & history --> qh 61 | qh_att = self.locate_1(h_emb, q_emb) 62 | qh_emb = (qh_att * h_emb).sum(1) # [batch, h_dim] 63 | # qh_emb = self.fch(qh_emb+q_emb) 64 | # qh_emb = self.layernorm(qh_emb+h_emb[:,0,:]) 65 | 66 | # qh & image --> qhi 67 | qhi_att = self.track_2(image, qh_emb) 68 | qhi_emb = (qhi_att * image).sum(1) # [batch, v_dim] 69 | 70 | # qi & history --> qih 71 | qih_att = self.locate_2(h_emb, qv_emb) 72 | qih_emb = (qih_att * h_emb).sum(1) # [batch, h_dim] 73 | 74 | q_re = self.q_net(q_emb) 75 | qih_emb = self.h_net(qih_emb) 76 | qih_emb = q_re * qih_emb 77 | 78 | qhi_emb = self.v_net(qhi_emb) 79 | qhi_emb = q_re * qhi_emb 80 | 81 | # qih & i --> qihi 82 | qihi_att = self.track_3(image, qih_emb) 83 | qihi_emb = (qihi_att * image).sum(1) 84 | 85 | # qhi & his --> qhih 86 | qhih_att = self.locate_3(h_emb, qhi_emb) 87 | qhih_emb = (qhih_att * h_emb).sum(1) 88 | 89 | q_repr = self.q_net(q_emb) 90 | qhi_repr = self.qhi_net(qihi_emb) 91 | qqhi_joint_repr = q_repr * qhi_repr 92 | 93 | qih_repr = self.qih_net(qhih_emb) 94 | qqih_joint_repr = q_repr * qih_repr 95 | 96 | joint_repr = torch.cat([self.w1(qqhi_joint_repr), self.w2(qqih_joint_repr)], 1) # [batch, h_dim * 2 97 | joint_repr = F.tanh(self.fc1(F.dropout(joint_repr, self.dropout, training=self.training))) 98 | 99 | _, ques_hidden = self.decoder(joint_repr.view(-1, 1, self.ninp), ques_hidden) 100 | 101 | if sampling: 102 | batch_size, _, _ = image.size() 103 | sample_ans_input = Variable(torch.LongTensor(batch_size, 1).fill_(2).cuda()) 104 | sample_opt = {'beam_size': 1} 105 | seq, seqLogprobs = self.decoder.sample(self.w_emb, sample_ans_input, ques_hidden, sample_opt) 106 | sample_ans = self.w_emb(Variable(seq)) 107 | ans_emb = self.w_emb(tans) 108 | sample_ans = torch.cat([w_emb, joint_repr.view(batch_size, -1, self.ninp),sample_ans], 1) 109 | ans_emb = torch.cat([w_emb, joint_repr.view(batch_size, -1, self.ninp), ans_emb], 1) 110 | return sample_ans, ans_emb 111 | 112 | if not Training: 113 | batch_size, _, hid_size = image.size() 114 | hid_size = int(hid_size / 2) 115 | hidden_replicated = [] 116 | for hid in ques_hidden: 117 | hidden_replicated.append(hid.view(2, batch_size, 1,hid_size).expand(2, 118 | batch_size, 100, hid_size).clone().view(2, -1, hid_size)) 119 | hidden_replicated = tuple(hidden_replicated) 120 | ques_hidden = hidden_replicated 121 | 122 | emb = self.w_emb(answer) 123 | pred, _ = self.decoder(emb, ques_hidden) 124 | return pred 125 | 126 | 127 | def build_baseline0_newatt2(args, num_hid): 128 | w_emb = WordEmbedding(args.vocab_size, args.ninp, 0.0) 129 | q_emb = QuestionEmbedding2(args.ninp, num_hid, args.nlayers, True, 0.0) 130 | h_emb = QuestionEmbedding2(args.ninp, num_hid, args.nlayers, True, 0.0) 131 | v_att = NewAttention(args.nhid*2, q_emb.num_hid*2, num_hid*2) 132 | h_att = NewAttention(args.nhid*2, q_emb.num_hid*2, num_hid*2) 133 | qih_att = NewAttention(args.nhid*2, q_emb.num_hid*2, num_hid*2) 134 | qhi_att = NewAttention(args.nhid*2, q_emb.num_hid*2, num_hid*2) 135 | q_net = FCNet([q_emb.num_hid*2, num_hid*2]) 136 | v_net = FCNet([args.nhid*2, num_hid*2]) 137 | h_net = FCNet([args.nhid*2, num_hid*2]) 138 | qih_net = FCNet([args.nhid*2, num_hid*2]) 139 | qhi_net = FCNet([args.nhid*2, num_hid*2]) 140 | qhih_att = NewAttention(args.nhid*2, q_emb.num_hid*2, num_hid*2) 141 | qihi_att = NewAttention(args.nhid*2, q_emb.num_hid*2, num_hid*2) 142 | 143 | decoder = netG(args) 144 | return BaseModel2(w_emb, q_emb, h_emb, v_att, h_att, q_net, v_net, h_net, qih_att, qhi_att, qih_net, qhi_net, 145 | decoder, args, qhih_att, qihi_att) 146 | 147 | class attflat(nn.Module): 148 | def __init__(self, args): 149 | super(attflat, self).__init__() 150 | self.mlp = FCNet([args.nhid * 2, args.nhid, 1]) 151 | self.fc = nn.Linear(args.nhid*2, args.nhid*2) 152 | 153 | def forward(self, x): 154 | batch_size, q_len, nhid = x.size() 155 | att = self.mlp(x.view(-1, nhid)) 156 | att = F.softmax(att, dim=1) 157 | x_atted = (att.view(batch_size, q_len, -1) * x.view(batch_size, q_len, -1)).sum(1) 158 | x_atted = self.fc(x_atted) 159 | 160 | return x_atted 161 | -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils.weight_norm import weight_norm 3 | 4 | 5 | class SimpleClassifier(nn.Module): 6 | def __init__(self, in_dim, hid_dim, out_dim, dropout): 7 | super(SimpleClassifier, self).__init__() 8 | layers = [ 9 | weight_norm(nn.Linear(in_dim, hid_dim), dim=None), 10 | nn.ReLU(), 11 | nn.Dropout(dropout, inplace=True), 12 | weight_norm(nn.Linear(hid_dim, out_dim), dim=None) 13 | ] 14 | self.main = nn.Sequential(*layers) 15 | 16 | def forward(self, x): 17 | logits = self.main(x) 18 | return logits 19 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function 2 | # import os 3 | # import json 4 | # import cPickle 5 | # import numpy as np 6 | # import utils 7 | # import h5py 8 | # import torch 9 | # from torch.utils.data import Dataset 10 | # 11 | # 12 | # class Dictionary(object): 13 | # def __init__(self, word2idx=None, idx2word=None): 14 | # if word2idx is None: 15 | # word2idx = {} 16 | # if idx2word is None: 17 | # idx2word = [] 18 | # self.word2idx = word2idx 19 | # self.idx2word = idx2word 20 | # 21 | # @property 22 | # def ntoken(self): 23 | # return len(self.word2idx) 24 | # 25 | # @property 26 | # def padding_idx(self): 27 | # return len(self.word2idx) 28 | # 29 | # def tokenize(self, sentence, add_word): 30 | # sentence = sentence.lower() 31 | # sentence = sentence.replace(',', '').replace('?', '').replace('\'s', ' \'s') 32 | # words = sentence.split() 33 | # tokens = [] 34 | # if add_word: 35 | # for w in words: 36 | # tokens.append(self.add_word(w)) 37 | # else: 38 | # for w in words: 39 | # tokens.append(self.word2idx[w]) 40 | # return tokens 41 | # 42 | # def dump_to_file(self, path): 43 | # cPickle.dump([self.word2idx, self.idx2word], open(path, 'wb')) 44 | # print('dictionary dumped to %s' % path) 45 | # 46 | # @classmethod 47 | # def load_from_file(cls, path): 48 | # print('loading dictionary from %s' % path) 49 | # word2idx, idx2word = cPickle.load(open(path, 'rb')) 50 | # d = cls(word2idx, idx2word) 51 | # return d 52 | # 53 | # def add_word(self, word): 54 | # if word not in self.word2idx: 55 | # self.idx2word.append(word) 56 | # self.word2idx[word] = len(self.idx2word) - 1 57 | # return self.word2idx[word] 58 | # 59 | # def __len__(self): 60 | # return len(self.idx2word) 61 | # 62 | # 63 | # def _create_entry(img, question, answer): 64 | # answer.pop('image_id') 65 | # answer.pop('question_id') 66 | # entry = { 67 | # 'question_id' : question['question_id'], 68 | # 'image_id' : question['image_id'], 69 | # 'image' : img, 70 | # 'question' : question['question'], 71 | # 'answer' : answer} 72 | # return entry 73 | # 74 | # 75 | # def _load_dataset(dataroot, name, img_id2val): 76 | # """Load entries 77 | # 78 | # img_id2val: dict {img_id -> val} val can be used to retrieve image or features 79 | # dataroot: root path of dataset 80 | # name: 'train', 'val' 81 | # """ 82 | # question_path = os.path.join( 83 | # dataroot, 'v2_OpenEnded_mscoco_%s2014_questions.json' % name) 84 | # questions = sorted(json.load(open(question_path))['questions'], 85 | # key=lambda x: x['question_id']) 86 | # answer_path = os.path.join(dataroot, 'cache', '%s_target.pkl' % name) 87 | # answers = cPickle.load(open(answer_path, 'rb')) 88 | # answers = sorted(answers, key=lambda x: x['question_id']) 89 | # 90 | # utils.assert_eq(len(questions), len(answers)) 91 | # entries = [] 92 | # for question, answer in zip(questions, answers): 93 | # utils.assert_eq(question['question_id'], answer['question_id']) 94 | # utils.assert_eq(question['image_id'], answer['image_id']) 95 | # img_id = question['image_id'] 96 | # entries.append(_create_entry(img_id2val[img_id], question, answer)) 97 | # 98 | # return entries 99 | # 100 | # 101 | # class VQAFeatureDataset(Dataset): 102 | # def __init__(self, name, dictionary, dataroot='data'): 103 | # super(VQAFeatureDataset, self).__init__() 104 | # assert name in ['train', 'val'] 105 | # 106 | # ans2label_path = os.path.join(dataroot, 'cache', 'trainval_ans2label.pkl') 107 | # label2ans_path = os.path.join(dataroot, 'cache', 'trainval_label2ans.pkl') 108 | # self.ans2label = cPickle.load(open(ans2label_path, 'rb')) 109 | # self.label2ans = cPickle.load(open(label2ans_path, 'rb')) 110 | # self.num_ans_candidates = len(self.ans2label) 111 | # 112 | # self.dictionary = dictionary 113 | # 114 | # self.img_id2idx = cPickle.load( 115 | # open(os.path.join(dataroot, '%s36_imgid2idx.pkl' % name))) 116 | # print('loading features from h5 file') 117 | # h5_path = os.path.join(dataroot, '%s36.hdf5' % name) 118 | # with h5py.File(h5_path, 'r') as hf: 119 | # self.features = np.array(hf.get('image_features')) 120 | # self.spatials = np.array(hf.get('spatial_features')) 121 | # 122 | # self.entries = _load_dataset(dataroot, name, self.img_id2idx) 123 | # 124 | # self.tokenize() 125 | # self.tensorize() 126 | # self.v_dim = self.features.size(2) 127 | # self.s_dim = self.spatials.size(2) 128 | # 129 | # def tokenize(self, max_length=14): 130 | # """Tokenizes the questions. 131 | # 132 | # This will add q_token in each entry of the dataset. 133 | # -1 represent nil, and should be treated as padding_idx in embedding 134 | # """ 135 | # for entry in self.entries: 136 | # tokens = self.dictionary.tokenize(entry['question'], False) 137 | # tokens = tokens[:max_length] 138 | # if len(tokens) < max_length: 139 | # # Note here we pad in front of the sentence 140 | # padding = [self.dictionary.padding_idx] * (max_length - len(tokens)) 141 | # tokens = padding + tokens 142 | # utils.assert_eq(len(tokens), max_length) 143 | # entry['q_token'] = tokens 144 | # 145 | # def tensorize(self): 146 | # self.features = torch.from_numpy(self.features) 147 | # self.spatials = torch.from_numpy(self.spatials) 148 | # 149 | # for entry in self.entries: 150 | # question = torch.from_numpy(np.array(entry['q_token'])) 151 | # entry['q_token'] = question 152 | # 153 | # answer = entry['answer'] 154 | # labels = np.array(answer['labels']) 155 | # scores = np.array(answer['scores'], dtype=np.float32) 156 | # if len(labels): 157 | # labels = torch.from_numpy(labels) 158 | # scores = torch.from_numpy(scores) 159 | # entry['answer']['labels'] = labels 160 | # entry['answer']['scores'] = scores 161 | # else: 162 | # entry['answer']['labels'] = None 163 | # entry['answer']['scores'] = None 164 | # 165 | # def __getitem__(self, index): 166 | # entry = self.entries[index] 167 | # features = self.features[entry['image']] 168 | # spatials = self.spatials[entry['image']] 169 | # 170 | # question = entry['q_token'] 171 | # answer = entry['answer'] 172 | # labels = answer['labels'] 173 | # scores = answer['scores'] 174 | # target = torch.zeros(self.num_ans_candidates) 175 | # if labels is not None: 176 | # target.scatter_(0, labels, scores) 177 | # 178 | # return features, spatials, question, target 179 | # 180 | # def __len__(self): 181 | # return len(self.entries) 182 | -------------------------------------------------------------------------------- /dmrm_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phellonchen/DMRM/4e530d6d6ac8e0f311baa186fd439421b76d6f70/dmrm_overview.png -------------------------------------------------------------------------------- /eval_v0.9.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | 7 | # from dataset import Dictionary, VQAFeatureDataset 8 | import base_model 9 | from train import evaluate 10 | import utils 11 | import dataLoader_bttd as dl 12 | import time 13 | import os 14 | # os.environ["CUDA_VISIBLE_DEVICES"] = '0' 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--epochs', type=int, default=100) 18 | parser.add_argument('--model', type=str, default='baseline0_newatt2') 19 | parser.add_argument('--output', type=str, default='saved_models/') 20 | parser.add_argument('--batch_size', type=int, default=256) 21 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 22 | parser.add_argument('--gpuid', type=int, default=0) 23 | parser.add_argument('--input_img_h5', default='data/features_faster_rcnn_x101_train.h5', 24 | help='path to dataset, now hdf5 file') 25 | parser.add_argument('--input_imgid', default='data/features_faster_rcnn_x101_train_v0.9_imgid.json', help='path to dataset, now hdf5 file') 26 | 27 | parser.add_argument('--input_ques_h5', default='data/visdial_data.h5', 28 | help='path to dataset, now hdf5 file') 29 | parser.add_argument('--input_json', default='data/visdial_params.json', 30 | help='path to dataset, now hdf5 file') 31 | parser.add_argument('--model_path', default='', 32 | help='path to model, now pth file') 33 | parser.add_argument('--img_feat_size', type=int, default=2048, help='input batch size') 34 | parser.add_argument('--ninp', type=int, default=300, help='size of word embeddings') 35 | parser.add_argument('--nhid', type=int, default=512, help='humber of hidden units per layer') 36 | parser.add_argument('--nlayers', type=int, default=1, help='number of layers') 37 | parser.add_argument('--dropout', type=int, default=0.5, help='number of layers') 38 | parser.add_argument('--negative_sample', type=int, default=20, help='folder to output images and model checkpoints') 39 | parser.add_argument('--neg_batch_sample', type=int, default=30, 40 | help='folder to output images and model checkpoints') 41 | parser.add_argument('--num_val', default=1000, help='number of image split out as validation set.') 42 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=6) 43 | parser.add_argument('--lr', type=float, default=0.002, help='learning rate for, default=0.00005') 44 | 45 | args = parser.parse_args() 46 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpuid) 47 | return args 48 | 49 | 50 | if __name__ == '__main__': 51 | args = parse_args() 52 | 53 | torch.manual_seed(args.seed) 54 | torch.cuda.manual_seed(args.seed) 55 | torch.backends.cudnn.benchmark = True 56 | 57 | batch_size = args.batch_size 58 | 59 | eval_dset = dl.validate(input_img_h5=args.input_img_h5, input_imgid=args.input_imgid, input_ques_h5=args.input_ques_h5, 60 | input_json=args.input_json, negative_sample=args.negative_sample, 61 | num_val=args.num_val, data_split='test') 62 | 63 | eval_loader = torch.utils.data.DataLoader(eval_dset, batch_size=5, 64 | shuffle=False, num_workers=int(args.workers)) 65 | 66 | args.vocab_size = eval_dset.vocab_size 67 | args.ques_length = eval_dset.ques_length 68 | args.ans_length = eval_dset.ans_length + 1 69 | args.his_length = eval_dset.ques_length + eval_dset.ans_length 70 | args.seq_length = args.ques_length 71 | constructor = 'build_%s' % args.model 72 | vocab_size = eval_dset.vocab_size 73 | model = getattr(base_model, constructor)(args, args.nhid).cuda() 74 | model = nn.DataParallel(model).cuda() 75 | checkpoint = torch.load(args.model_path) 76 | # model_dict = model.state_dict() 77 | # keys = [] 78 | # for k, v in checkpoint['model'].items(): 79 | # keys.append(k) 80 | # i = 0 81 | # for k, v in model_dict.items(): 82 | # #if v.size() == checkpoint['model'][keys[i]].size(): 83 | # # print(k, ',', keys[i]) 84 | # model_dict[k] = checkpoint['model'][keys[i]] 85 | # i = i + 1 86 | # model.load_state_dict(model_dict) 87 | model.load_state_dict(checkpoint['model']) 88 | model.eval() 89 | print('Evaluating ... ') 90 | start_time = time.time() 91 | rank_all = evaluate(model, eval_loader, args, True) 92 | R1 = np.sum(np.array(rank_all) == 1) / float(len(rank_all)) 93 | R5 = np.sum(np.array(rank_all) <= 5) / float(len(rank_all)) 94 | R10 = np.sum(np.array(rank_all) <= 10) / float(len(rank_all)) 95 | ave = np.sum(np.array(rank_all)) / float(len(rank_all)) 96 | mrr = np.sum(1 / (np.array(rank_all, dtype='float'))) / float(len(rank_all)) 97 | #save_path = checkpoint['args'].save_path 98 | #logger = utils.Logger(os.path.join(save_path, 'eval-log.txt')) 99 | #logger.write('mrr: %f R1: %f R5 %f R10 %f Mean %f time: %.2f' % (mrr, R1, R5, R10, ave, time.time() - start_time)) 100 | print('mrr: %f R1: %f R5 %f R10 %f Mean %f time: %.2f' % (mrr, R1, R5, R10, ave, time.time() - start_time)) 101 | -------------------------------------------------------------------------------- /eval_v1.0.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | 7 | # from dataset import Dictionary, VQAFeatureDataset 8 | import base_model 9 | from train import evaluate 10 | import utils 11 | import misc.dataLoader as dl 12 | import time 13 | import os 14 | # os.environ["CUDA_VISIBLE_DEVICES"] = '0' 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--epochs', type=int, default=100) 18 | parser.add_argument('--model', type=str, default='baseline0_newatt2') 19 | parser.add_argument('--output', type=str, default='saved_models/') 20 | parser.add_argument('--batch_size', type=int, default=256) 21 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 22 | parser.add_argument('--gpuid', type=int, default=0) 23 | parser.add_argument('--input_img_h5', default='data/features_faster_rcnn_x101_train.h5.h5', 24 | help='path to dataset, now hdf5 file') 25 | parser.add_argument('--input_ques_h5', default='data/visdial_data_v1.0.h5', 26 | help='path to dataset, now hdf5 file') 27 | parser.add_argument('--input_json', default='data/visdial_params_v1.0.json', 28 | help='path to dataset, now hdf5 file') 29 | parser.add_argument('--model_path', default='', 30 | help='path to model, now pth file') 31 | parser.add_argument('--img_feat_size', type=int, default=512, help='input batch size') 32 | parser.add_argument('--ninp', type=int, default=300, help='size of word embeddings') 33 | parser.add_argument('--nhid', type=int, default=512, help='humber of hidden units per layer') 34 | parser.add_argument('--nlayers', type=int, default=1, help='number of layers') 35 | parser.add_argument('--dropout', type=int, default=0.5, help='number of layers') 36 | parser.add_argument('--negative_sample', type=int, default=20, help='folder to output images and model checkpoints') 37 | parser.add_argument('--neg_batch_sample', type=int, default=30, 38 | help='folder to output images and model checkpoints') 39 | parser.add_argument('--num_val', default=1000, help='number of image split out as validation set.') 40 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=6) 41 | parser.add_argument('--lr', type=float, default=0.002, help='learning rate for, default=0.00005') 42 | 43 | args = parser.parse_args() 44 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpuid) 45 | 46 | return args 47 | 48 | 49 | if __name__ == '__main__': 50 | args = parse_args() 51 | 52 | torch.manual_seed(args.seed) 53 | torch.cuda.manual_seed(args.seed) 54 | torch.backends.cudnn.benchmark = True 55 | 56 | batch_size = args.batch_size 57 | 58 | eval_dset = dl.validate(input_img_h5=args.input_img_h5, input_ques_h5=args.input_ques_h5, 59 | input_json=args.input_json, negative_sample=args.negative_sample, 60 | num_val=args.num_val, data_split='test') 61 | 62 | eval_loader = torch.utils.data.DataLoader(eval_dset, batch_size=5, 63 | shuffle=False, num_workers=int(args.workers)) 64 | 65 | args.vocab_size = eval_dset.vocab_size 66 | args.ques_length = eval_dset.ques_length 67 | args.ans_length = eval_dset.ans_length + 1 68 | args.his_length = eval_dset.ques_length + eval_dset.ans_length 69 | args.seq_length = args.ques_length 70 | constructor = 'build_%s' % args.model 71 | vocab_size = eval_dset.vocab_size 72 | model = getattr(base_model, constructor)(args, args.nhid).cuda() 73 | model = nn.DataParallel(model).cuda() 74 | checkpoint = torch.load(args.model_path) 75 | 76 | # model_dict = model.state_dict() 77 | # keys = [] 78 | # for k, v in checkpoint['model'].items(): 79 | # keys.append(k) 80 | # i = 0 81 | # for k, v in model_dict.items(): 82 | # #if v.size() == checkpoint['model'][keys[i]].size(): 83 | # # print(k, ',', keys[i]) 84 | # model_dict[k] = checkpoint['model'][keys[i]] 85 | # i = i + 1 86 | # model.load_state_dict(model_dict) 87 | model.load_state_dict(checkpoint['model']) 88 | model.eval() 89 | print('Evaluating ... ') 90 | start_time = time.time() 91 | rank_all = evaluate(model, eval_loader, args, True) 92 | R1 = np.sum(np.array(rank_all) == 1) / float(len(rank_all)) 93 | R5 = np.sum(np.array(rank_all) <= 5) / float(len(rank_all)) 94 | R10 = np.sum(np.array(rank_all) <= 10) / float(len(rank_all)) 95 | ave = np.sum(np.array(rank_all)) / float(len(rank_all)) 96 | mrr = np.sum(1 / (np.array(rank_all, dtype='float'))) / float(len(rank_all)) 97 | 98 | print('mrr: %f R1: %f R5 %f R10 %f Mean %f time: %.2f' % (mrr, R1, R5, R10, ave, time.time() - start_time)) 99 | -------------------------------------------------------------------------------- /fc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.nn as nn 3 | from torch.nn.utils.weight_norm import weight_norm 4 | 5 | 6 | class FCNet(nn.Module): 7 | """Simple class for non-linear fully connect network 8 | """ 9 | def __init__(self, dims): 10 | super(FCNet, self).__init__() 11 | 12 | layers = [] 13 | for i in range(len(dims)-2): 14 | in_dim = dims[i] 15 | out_dim = dims[i+1] 16 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) 17 | layers.append(nn.ReLU()) 18 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) 19 | layers.append(nn.ReLU()) 20 | 21 | self.main = nn.Sequential(*layers) 22 | 23 | def forward(self, x): 24 | return self.main(x) 25 | 26 | 27 | if __name__ == '__main__': 28 | fc1 = FCNet([10, 20, 10]) 29 | print(fc1) 30 | 31 | print('============') 32 | fc2 = FCNet([10, 20]) 33 | print(fc2) 34 | -------------------------------------------------------------------------------- /language_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | 7 | class WordEmbedding(nn.Module): 8 | """Word Embedding 9 | 10 | The ntoken-th dim is used for padding_idx, which agrees *implicitly* 11 | with the definition in Dictionary. 12 | """ 13 | def __init__(self, ntoken, emb_dim, dropout): 14 | super(WordEmbedding, self).__init__() 15 | self.emb = nn.Embedding(ntoken+1, emb_dim, padding_idx=0) 16 | self.dropout = nn.Dropout(dropout) 17 | self.ntoken = ntoken 18 | self.emb_dim = emb_dim 19 | 20 | def init_embedding(self, np_file): 21 | weight_init = torch.from_numpy(np.load(np_file)) 22 | assert weight_init.shape == (self.ntoken, self.emb_dim) 23 | self.emb.weight.data[:self.ntoken] = weight_init 24 | 25 | 26 | 27 | def forward(self, x): 28 | emb = self.emb(x) 29 | emb = self.dropout(emb) 30 | return emb 31 | 32 | 33 | class QuestionEmbedding(nn.Module): 34 | def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, rnn_type='LSTM'): 35 | """Module for question embedding 36 | """ 37 | super(QuestionEmbedding, self).__init__() 38 | assert rnn_type == 'LSTM' or rnn_type == 'GRU' 39 | rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU 40 | 41 | self.rnn = rnn_cls( 42 | in_dim, num_hid, nlayers, 43 | bidirectional=bidirect, 44 | dropout=dropout, 45 | batch_first=True) 46 | 47 | self.in_dim = in_dim 48 | self.num_hid = num_hid 49 | self.nlayers = nlayers 50 | self.rnn_type = rnn_type 51 | self.ndirections = 1 + int(bidirect) 52 | self.init_weights() 53 | 54 | def init_hidden(self, batch): 55 | # just to get the type of tensor 56 | weight = next(self.parameters()).data 57 | hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid) 58 | if self.rnn_type == 'LSTM': 59 | return (Variable(weight.new(*hid_shape).zero_()), 60 | Variable(weight.new(*hid_shape).zero_())) 61 | else: 62 | return Variable(weight.new(*hid_shape).zero_()) 63 | 64 | 65 | def forward(self, x): 66 | # x: [batch, sequence, in_dim] 67 | batch = x.size(0) 68 | hidden = self.init_hidden(batch) 69 | self.rnn.flatten_parameters() 70 | output, hidden = self.rnn(x, hidden) 71 | 72 | if self.ndirections == 1: 73 | return output[:, -1] 74 | 75 | forward_ = output[:, -1, :self.num_hid] 76 | backward = output[:, 0, self.num_hid:] 77 | return torch.cat((forward_, backward), dim=1) 78 | 79 | def forward_all(self, x): 80 | # x: [batch, sequence, in_dim] 81 | batch = x.size(0) 82 | hidden = self.init_hidden(batch) 83 | self.rnn.flatten_parameters() 84 | output, hidden = self.rnn(x, hidden) 85 | return output 86 | 87 | class QuestionEmbedding2(nn.Module): 88 | def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, rnn_type='LSTM'): 89 | """Module for question embedding 90 | """ 91 | super(QuestionEmbedding2, self).__init__() 92 | assert rnn_type == 'LSTM' or rnn_type == 'GRU' 93 | rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU 94 | 95 | self.rnn = rnn_cls( 96 | in_dim, num_hid, nlayers, 97 | bidirectional=bidirect, 98 | dropout=dropout, 99 | batch_first=True) 100 | 101 | self.in_dim = in_dim 102 | self.num_hid = num_hid 103 | self.nlayers = nlayers 104 | self.rnn_type = rnn_type 105 | self.ndirections = 1 + int(bidirect) 106 | 107 | def init_hidden(self, batch): 108 | # just to get the type of tensor 109 | weight = next(self.parameters()).data 110 | hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid) 111 | if self.rnn_type == 'LSTM': 112 | return (Variable(weight.new(*hid_shape).zero_()), 113 | Variable(weight.new(*hid_shape).zero_())) 114 | else: 115 | return Variable(weight.new(*hid_shape).zero_()) 116 | 117 | def forward(self, x): 118 | # x: [batch, sequence, in_dim] 119 | batch = x.size(0) 120 | hidden = self.init_hidden(batch) 121 | self.rnn.flatten_parameters() 122 | output, hidden = self.rnn(x, hidden) 123 | 124 | if self.ndirections == 1: 125 | return output[:, -1], hidden 126 | 127 | forward_ = output[:, -1, :self.num_hid] 128 | backward = output[:, 0, self.num_hid:] 129 | return torch.cat([forward_, backward], dim=1), hidden 130 | 131 | def forward_all(self, x): 132 | # x: [batch, sequence, in_dim] 133 | batch = x.size(0) 134 | hidden = self.init_hidden(batch) 135 | self.rnn.flatten_parameters() 136 | output, hidden = self.rnn(x, hidden) 137 | return output, hidden 138 | -------------------------------------------------------------------------------- /main_v0.9.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | import random 6 | import base_model as base_model 7 | from train import train 8 | import misc.dataLoader as dl 9 | import os 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--epochs', type=int, default=100) 14 | parser.add_argument('--model', type=str, default='baseline0_newatt2') 15 | parser.add_argument('--output', type=str, default='saved_models/') 16 | parser.add_argument('--batch_size', type=int, default=128) 17 | parser.add_argument('--gpuid', type=int, default=0) 18 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 19 | 20 | parser.add_argument('--input_img_h5', default='data/features_faster_rcnn_x101_train.h5', 21 | help='path to dataset, now hdf5 file') 22 | parser.add_argument('--input_imgid', default='data/features_faster_rcnn_x101_train_v0.9_imgid.json', 23 | help='path to dataset, now hdf5 file') 24 | parser.add_argument('--input_ques_h5', default='data/visdial_data.h5', 25 | help='path to dataset, now hdf5 file') 26 | parser.add_argument('--input_json', default='data/visdial_params.json', 27 | help='path to dataset, now hdf5 file') 28 | 29 | parser.add_argument('--img_feat_size', type=int, default=2048, help='input batch size') 30 | parser.add_argument('--ninp', type=int, default=300, help='size of word embeddings') 31 | parser.add_argument('--nhid', type=int, default=512, help='humber of hidden units per layer') 32 | parser.add_argument('--nlayers', type=int, default=1, help='number of layers') 33 | parser.add_argument('--dropout', type=int, default=0.5, help='number of layers') 34 | parser.add_argument('--negative_sample', type=int, default=20, help='folder to output images and model checkpoints') 35 | parser.add_argument('--neg_batch_sample', type=int, default=30, 36 | help='folder to output images and model checkpoints') 37 | parser.add_argument('--num_val', default=1000, help='number of image split out as validation set.') 38 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=6) 39 | parser.add_argument('--lr', type=float, default=0.0005, help='learning rate for, default=0.00005') 40 | parser.add_argument('--beta1', type=float, default=0.8, help='beta1 for adam. default=0.5') 41 | parser.add_argument('--margin', type=float, default=2, help='number of epochs to train for') 42 | args = parser.parse_args() 43 | 44 | return args 45 | 46 | 47 | if __name__ == '__main__': 48 | args = parse_args() 49 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpuid) 50 | args.seed = random.randint(1, 10000) 51 | torch.manual_seed(args.seed) 52 | torch.cuda.manual_seed(args.seed) 53 | torch.backends.cudnn.benchmark = True 54 | 55 | batch_size = args.batch_size 56 | 57 | train_dataset = dl.train(input_img_h5=args.input_img_h5, input_imgid=args.input_imgid, input_ques_h5=args.input_ques_h5, 58 | input_json=args.input_json, negative_sample=args.negative_sample, 59 | num_val=args.num_val, data_split='train') 60 | 61 | eval_dateset = dl.validate(input_img_h5=args.input_img_h5, input_imgid=args.input_imgid, input_ques_h5=args.input_ques_h5, 62 | input_json=args.input_json, negative_sample=args.negative_sample, 63 | num_val=args.num_val, data_split='val') 64 | 65 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, 66 | shuffle=True, num_workers=int(args.workers)) 67 | 68 | eval_loader = torch.utils.data.DataLoader(eval_dateset, batch_size=5, 69 | shuffle=False, num_workers=int(args.workers)) 70 | 71 | args.vocab_size = train_dataset.vocab_size 72 | args.ques_length = train_dataset.ques_length 73 | args.ans_length = train_dataset.ans_length + 1 74 | args.his_length = train_dataset.ques_length + train_dataset.ans_length 75 | args.seq_length = args.ans_length 76 | constructor = 'build_%s' % args.model 77 | vocab_size = train_dataset.vocab_size 78 | model = getattr(base_model, constructor)(args, args.nhid).cuda() 79 | model.w_emb.init_embedding('data/glove6b_init_300d.npy') 80 | 81 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 82 | print("Training params: ", num_params) 83 | model = nn.DataParallel(model).cuda() 84 | model = train(model, train_loader, eval_loader, args) 85 | 86 | -------------------------------------------------------------------------------- /main_v1.0.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | 7 | # from dataset import Dictionary, VQAFeatureDataset 8 | import base_model as base_model 9 | from train import train 10 | import utils 11 | import misc.dataLoader_v1 as dl 12 | import os 13 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--epochs', type=int, default=100) 17 | parser.add_argument('--model', type=str, default='baseline0_newatt2') 18 | parser.add_argument('--output', type=str, default='saved_models/') 19 | parser.add_argument('--batch_size', type=int, default=128) 20 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 21 | 22 | 23 | parser.add_argument('--input_img_h5', default='data/features_faster_rcnn_x101_train.h5', 24 | help='path to dataset, now hdf5 file') 25 | parser.add_argument('--input_img_val_h5', default='data/features_faster_rcnn_x101_val.h5', 26 | help='path to dataset, now hdf5 file') 27 | parser.add_argument('--input_ques_h5', default='data/visdial_data_v1.0.h5', 28 | help='path to dataset, now hdf5 file') 29 | parser.add_argument('--input_json', default='data/visdial_params_v1.0.json', 30 | help='path to dataset, now hdf5 file') 31 | 32 | parser.add_argument('--img_feat_size', type=int, default=512, help='input batch size') 33 | parser.add_argument('--ninp', type=int, default=300, help='size of word embeddings') 34 | parser.add_argument('--nhid', type=int, default=512, help='humber of hidden units per layer') 35 | parser.add_argument('--nlayers', type=int, default=1, help='number of layers') 36 | parser.add_argument('--dropout', type=int, default=0.5, help='number of layers') 37 | parser.add_argument('--negative_sample', type=int, default=20, help='folder to output images and model checkpoints') 38 | parser.add_argument('--neg_batch_sample', type=int, default=30, 39 | help='folder to output images and model checkpoints') 40 | parser.add_argument('--num_val', default=1000, help='number of image split out as validation set.') 41 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=6) 42 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate for, default=0.00005') 43 | parser.add_argument('--beta1', type=float, default=0.8, help='beta1 for adam. default=0.5') 44 | args = parser.parse_args() 45 | 46 | args.input_encoding_size = args.ninp 47 | args.rnn_size = args.nhid 48 | args.num_layers = args.nlayers 49 | args.drop_prob_lm = args.dropout 50 | args.fc_feat_size = args.img_feat_size 51 | args.att_feat_size = args.img_feat_size 52 | args.att_hid_size = args.img_feat_size 53 | return args 54 | 55 | 56 | if __name__ == '__main__': 57 | args = parse_args() 58 | 59 | torch.manual_seed(args.seed) 60 | torch.cuda.manual_seed(args.seed) 61 | torch.backends.cudnn.benchmark = True 62 | 63 | batch_size = args.batch_size 64 | 65 | train_dset = dl.train(input_img_h5=args.input_img_h5, input_ques_h5=args.input_ques_h5, 66 | input_json=args.input_json, negative_sample=args.negative_sample, 67 | num_val=args.num_val, data_split='train') 68 | 69 | eval_dset = dl.validate(input_img_h5=args.input_img_val_h5, input_ques_h5=args.input_ques_h5, 70 | input_json=args.input_json, negative_sample=args.negative_sample, 71 | num_val=args.num_val, data_split='val') 72 | 73 | train_loader = torch.utils.data.DataLoader(train_dset, batch_size=args.batch_size, 74 | shuffle=True, num_workers=int(args.workers)) 75 | 76 | eval_loader = torch.utils.data.DataLoader(eval_dset, batch_size=5, 77 | shuffle=False, num_workers=int(args.workers)) 78 | 79 | args.vocab_size = train_dset.vocab_size 80 | args.ques_length = train_dset.ques_length 81 | args.ans_length = train_dset.ans_length + 1 82 | args.his_length = train_dset.ques_length + train_dset.ans_length 83 | args.seq_length = args.ans_length 84 | constructor = 'build_%s' % args.model 85 | vocab_size = train_dset.vocab_size 86 | model = getattr(base_model, constructor)(args, args.nhid).cuda() 87 | model.w_emb.init_embedding('data/glove6b_init_300d_v1.0.npy') 88 | 89 | model = nn.DataParallel(model).cuda() 90 | model = train(model, train_loader, eval_loader, args) 91 | 92 | 93 | # dis_model, model = train_D(model, train_loader, args) 94 | 95 | 96 | # train_RL(model, train_loader, eval_loader, args) 97 | -------------------------------------------------------------------------------- /misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phellonchen/DMRM/4e530d6d6ac8e0f311baa186fd439421b76d6f70/misc/__init__.py -------------------------------------------------------------------------------- /misc/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phellonchen/DMRM/4e530d6d6ac8e0f311baa186fd439421b76d6f70/misc/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /misc/__pycache__/dataLoader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phellonchen/DMRM/4e530d6d6ac8e0f311baa186fd439421b76d6f70/misc/__pycache__/dataLoader.cpython-36.pyc -------------------------------------------------------------------------------- /misc/__pycache__/dataLoader_v1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phellonchen/DMRM/4e530d6d6ac8e0f311baa186fd439421b76d6f70/misc/__pycache__/dataLoader_v1.cpython-36.pyc -------------------------------------------------------------------------------- /misc/__pycache__/focalloss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phellonchen/DMRM/4e530d6d6ac8e0f311baa186fd439421b76d6f70/misc/__pycache__/focalloss.cpython-36.pyc -------------------------------------------------------------------------------- /misc/__pycache__/readers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phellonchen/DMRM/4e530d6d6ac8e0f311baa186fd439421b76d6f70/misc/__pycache__/readers.cpython-36.pyc -------------------------------------------------------------------------------- /misc/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phellonchen/DMRM/4e530d6d6ac8e0f311baa186fd439421b76d6f70/misc/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /misc/dataLoader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torch 4 | import numpy as np 5 | import h5py 6 | import json 7 | import pdb 8 | import random 9 | from misc.utils import repackage_hidden, clip_gradient, adjust_learning_rate, decode_txt 10 | from misc.readers import ImageFeaturesHdfReader 11 | from torch.nn.functional import normalize 12 | 13 | 14 | class train(data.Dataset): # torch wrapper 15 | def __init__(self, input_img_h5, input_imgid, input_ques_h5, input_json, negative_sample, num_val, data_split): 16 | 17 | print(('DataLoader loading: %s' % data_split)) 18 | print(('Loading image feature from %s' % input_img_h5)) 19 | 20 | if data_split == 'test': 21 | split = 'val' 22 | else: 23 | split = 'train' # train and val split both corresponding to 'train' 24 | 25 | f = json.load(open(input_json, 'r')) 26 | self.itow = f['itow'] 27 | self.wtoi = f['wtoi'] 28 | self.img_info = f['img_' + split] 29 | 30 | # get the data split. 31 | total_num = len(self.img_info) 32 | if data_split == 'train': 33 | s = 0 34 | # e = int((total_num) * 1) 35 | e = int((total_num - num_val) * 1) 36 | # e = 1000 37 | elif data_split == 'val': 38 | s = total_num - num_val 39 | e = total_num 40 | else: 41 | s = 0 42 | e = total_num 43 | 44 | self.img_info = self.img_info[s:e] 45 | 46 | print(('%s number of data: %d' % (data_split, e - s))) 47 | self.hdf_reader = ImageFeaturesHdfReader( 48 | input_img_h5, False) 49 | 50 | self.imgid = json.load(open(input_imgid, 'r'))['imgid'][s:e] 51 | print(('Loading txt from %s' % input_ques_h5)) 52 | f = h5py.File(input_ques_h5, 'r') 53 | self.ques = f['ques_' + split][s:e] 54 | self.ans = f['ans_' + split][s:e] 55 | self.cap = f['cap_' + split][s:e] 56 | 57 | self.ques_len = f['ques_len_' + split][s:e] 58 | self.ans_len = f['ans_len_' + split][s:e] 59 | self.cap_len = f['cap_len_' + split][s:e] 60 | 61 | self.ans_ids = f['ans_index_' + split][s:e] 62 | self.opt_ids = f['opt_' + split][s:e] 63 | self.opt_list = f['opt_list_' + split][:] 64 | self.opt_len = f['opt_len_' + split][:] 65 | f.close() 66 | 67 | self.ques_length = self.ques.shape[2] 68 | self.ans_length = self.ans.shape[2] 69 | self.his_length = self.ques_length + self.ans_length 70 | 71 | self.vocab_size = len(self.itow) 72 | 73 | print(('Vocab Size: %d' % self.vocab_size)) 74 | self.split = split 75 | self.rnd = 10 76 | self.negative_sample = negative_sample 77 | 78 | def __getitem__(self, index): 79 | # get the image 80 | img_id = self.img_info[index]['imgId'] 81 | img = self.hdf_reader[img_id] 82 | img = torch.from_numpy(img) 83 | 84 | img = normalize(img, dim=0, p=2) 85 | # get the history 86 | his = np.zeros((self.rnd, self.his_length)) 87 | his[0, self.his_length - self.cap_len[index]:] = self.cap[index, :self.cap_len[index]] 88 | 89 | ques = np.zeros((self.rnd, self.ques_length)) 90 | ans = np.zeros((self.rnd, self.ans_length + 1)) 91 | ans_target = np.zeros((self.rnd, self.ans_length + 1)) 92 | ques_ori = np.zeros((self.rnd, self.ques_length)) 93 | 94 | opt_ans = np.zeros((self.rnd, self.negative_sample, self.ans_length + 1)) 95 | ans_len = np.zeros((self.rnd)) 96 | opt_ans_len = np.zeros((self.rnd, self.negative_sample)) 97 | 98 | ans_idx = np.zeros((self.rnd)) 99 | opt_ans_idx = np.zeros((self.rnd, self.negative_sample)) 100 | 101 | for i in range(self.rnd): 102 | # get the index 103 | q_len = self.ques_len[index, i] 104 | a_len = self.ans_len[index, i] 105 | qa_len = q_len + a_len 106 | 107 | if i + 1 < self.rnd: 108 | his[i + 1, self.his_length - qa_len:self.his_length - a_len] = self.ques[index, i, :q_len] 109 | his[i + 1, self.his_length - a_len:] = self.ans[index, i, :a_len] 110 | 111 | ques[i, self.ques_length - q_len:] = self.ques[index, i, :q_len] 112 | 113 | ques_ori[i, :q_len] = self.ques[index, i, :q_len] 114 | ans[i, 1:a_len + 1] = self.ans[index, i, :a_len] 115 | ans[i, 0] = self.wtoi[''] 116 | 117 | ans_target[i, :a_len] = self.ans[index, i, :a_len] 118 | ans_target[i, a_len] = self.wtoi[''] 119 | ans_len[i] = self.ans_len[index, i] 120 | 121 | opt_ids = self.opt_ids[index, i] # since python start from 0 122 | # random select the negative samples. 123 | ans_idx[i] = opt_ids[self.ans_ids[index, i]] 124 | # exclude the gt index. 125 | opt_ids = np.delete(opt_ids, ans_idx[i], 0) 126 | random.shuffle(opt_ids) 127 | for j in range(self.negative_sample): 128 | ids = opt_ids[j] 129 | opt_ans_idx[i, j] = ids 130 | 131 | opt_len = self.opt_len[ids] 132 | 133 | opt_ans_len[i, j] = opt_len 134 | opt_ans[i, j, :opt_len] = self.opt_list[ids, :opt_len] 135 | opt_ans[i, j, opt_len] = self.wtoi[''] 136 | 137 | his = torch.from_numpy(his) 138 | ques = torch.from_numpy(ques) 139 | ans = torch.from_numpy(ans) 140 | ans_target = torch.from_numpy(ans_target) 141 | ques_ori = torch.from_numpy(ques_ori) 142 | ans_len = torch.from_numpy(ans_len) 143 | opt_ans_len = torch.from_numpy(opt_ans_len) 144 | opt_ans = torch.from_numpy(opt_ans) 145 | ans_idx = torch.from_numpy(ans_idx) 146 | opt_ans_idx = torch.from_numpy(opt_ans_idx) 147 | return img, img_id, his, ques, ans, ans_target, ans_len, ans_idx, ques_ori, \ 148 | opt_ans, opt_ans_len, opt_ans_idx 149 | 150 | def __len__(self): 151 | return self.ques.shape[0] 152 | 153 | 154 | class validate(data.Dataset): # torch wrapper 155 | def __init__(self, input_img_h5, input_imgid, input_ques_h5, input_json, negative_sample, num_val, data_split): 156 | 157 | print(('DataLoader loading: %s' % data_split)) 158 | print(('Loading image feature from %s' % input_img_h5)) 159 | 160 | if data_split == 'test': 161 | split = 'val' 162 | else: 163 | split = 'train' # train and val split both corresponding to 'train' 164 | 165 | f = json.load(open(input_json, 'r')) 166 | self.itow = f['itow'] 167 | self.wtoi = f['wtoi'] 168 | self.img_info = f['img_' + split] 169 | 170 | # get the data split. 171 | total_num = len(self.img_info) 172 | if data_split == 'train': 173 | s = 0 174 | e = total_num - num_val 175 | elif data_split == 'val': 176 | s = total_num - num_val 177 | e = total_num 178 | else: 179 | s = 0 180 | e = total_num 181 | 182 | self.img_info = self.img_info[s:e] 183 | print(('%s number of data: %d' % (data_split, e - s))) 184 | self.imgid = json.load(open(input_imgid, 'r'))['imgid'][s:e] 185 | self.hdf_reader = ImageFeaturesHdfReader( 186 | input_img_h5, False) 187 | 188 | print(('Loading txt from %s' % input_ques_h5)) 189 | f = h5py.File(input_ques_h5, 'r') 190 | self.ques = f['ques_' + split][s:e] 191 | self.ans = f['ans_' + split][s:e] 192 | self.cap = f['cap_' + split][s:e] 193 | 194 | self.ques_len = f['ques_len_' + split][s:e] 195 | self.ans_len = f['ans_len_' + split][s:e] 196 | self.cap_len = f['cap_len_' + split][s:e] 197 | 198 | self.ans_ids = f['ans_index_' + split][s:e] 199 | self.opt_ids = f['opt_' + split][s:e] 200 | self.opt_list = f['opt_list_' + split][:] 201 | self.opt_len = f['opt_len_' + split][:] 202 | f.close() 203 | 204 | self.ques_length = self.ques.shape[2] 205 | self.ans_length = self.ans.shape[2] 206 | self.his_length = self.ques_length + self.ans_length 207 | 208 | self.vocab_size = len(self.itow) 209 | 210 | print(('Vocab Size: %d' % self.vocab_size)) 211 | self.split = split 212 | self.rnd = 10 213 | self.negative_sample = negative_sample 214 | 215 | def __getitem__(self, index): 216 | 217 | # get the image 218 | img_id = self.img_info[index]['imgId'] 219 | img = self.hdf_reader[img_id] 220 | img = torch.from_numpy(img) 221 | img = normalize(img, dim=0, p=2) 222 | 223 | # get the history 224 | his = np.zeros((self.rnd, self.his_length)) 225 | his[0, self.his_length - self.cap_len[index]:] = self.cap[index, :self.cap_len[index]] 226 | 227 | ques = np.zeros((self.rnd, self.ques_length)) 228 | ans = np.zeros((self.rnd, self.ans_length + 1)) 229 | ans_target = np.zeros((self.rnd, self.ans_length + 1)) 230 | quesL = np.zeros((self.rnd, self.ques_length)) 231 | 232 | opt_ans = np.zeros((self.rnd, 100, self.ans_length + 1)) 233 | ans_ids = np.zeros(self.rnd) 234 | opt_ans_target = np.zeros((self.rnd, 100, self.ans_length + 1)) 235 | 236 | ans_len = np.zeros((self.rnd)) 237 | opt_ans_len = np.zeros((self.rnd, 100)) 238 | 239 | for i in range(self.rnd): 240 | # get the index 241 | q_len = self.ques_len[index, i] 242 | a_len = self.ans_len[index, i] 243 | qa_len = q_len + a_len 244 | 245 | if i + 1 < self.rnd: 246 | ques_ans = np.concatenate([self.ques[index, i, :q_len], self.ans[index, i, :a_len]]) 247 | his[i + 1, self.his_length - qa_len:] = ques_ans 248 | 249 | ques[i, self.ques_length - q_len:] = self.ques[index, i, :q_len] 250 | quesL[i, :q_len] = self.ques[index, i, :q_len] 251 | ans[i, 1:a_len + 1] = self.ans[index, i, :a_len] 252 | ans[i, 0] = self.wtoi[''] 253 | 254 | ans_target[i, :a_len] = self.ans[index, i, :a_len] 255 | ans_target[i, a_len] = self.wtoi[''] 256 | 257 | ans_ids[i] = self.ans_ids[index, i] # since python start from 0 258 | opt_ids = self.opt_ids[index, i] # since python start from 0 259 | ans_len[i] = self.ans_len[index, i] 260 | ans_idx = self.ans_ids[index, i] 261 | 262 | for j, ids in enumerate(opt_ids): 263 | opt_len = self.opt_len[ids] 264 | opt_ans[i, j, 1:opt_len + 1] = self.opt_list[ids, :opt_len] 265 | opt_ans[i, j, 0] = self.wtoi[''] 266 | 267 | opt_ans_target[i, j, :opt_len] = self.opt_list[ids, :opt_len] 268 | opt_ans_target[i, j, opt_len] = self.wtoi[''] 269 | opt_ans_len[i, j] = opt_len 270 | 271 | opt_ans = torch.from_numpy(opt_ans) 272 | opt_ans_target = torch.from_numpy(opt_ans_target) 273 | ans_ids = torch.from_numpy(ans_ids) 274 | 275 | his = torch.from_numpy(his) 276 | ques = torch.from_numpy(ques) 277 | ans = torch.from_numpy(ans) 278 | ans_target = torch.from_numpy(ans_target) 279 | quesL = torch.from_numpy(quesL) 280 | 281 | ans_len = torch.from_numpy(ans_len) 282 | opt_ans_len = torch.from_numpy(opt_ans_len) 283 | 284 | return img, img_id, his, ques, ans, ans_target, quesL, opt_ans, \ 285 | opt_ans_target, ans_ids, ans_len, opt_ans_len 286 | 287 | def __len__(self): 288 | return self.ques.shape[0] 289 | -------------------------------------------------------------------------------- /misc/dataLoader_v1.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torch 4 | import numpy as np 5 | import h5py 6 | import json 7 | import pdb 8 | import random 9 | from misc.utils import repackage_hidden, clip_gradient, adjust_learning_rate, decode_txt 10 | from misc.readers import ImageFeaturesHdfReader 11 | 12 | class train(data.Dataset): # torch wrapper 13 | def __init__(self, input_img_h5, input_ques_h5, input_json, negative_sample, num_val, data_split): 14 | 15 | print(('DataLoader loading: %s' % data_split)) 16 | print(('Loading image feature from %s' % input_img_h5)) 17 | 18 | if data_split == 'test': 19 | split = 'val' 20 | else: 21 | split = 'train' # train and val split both corresponding to 'train' 22 | 23 | f = json.load(open(input_json, 'r')) 24 | self.itow = f['itow'] 25 | self.wtoi = f['wtoi'] 26 | self.img_info = f['img_' + split] 27 | 28 | # get the data split. 29 | total_num = len(self.img_info) 30 | if data_split == 'train': 31 | s = 0 32 | e = int((total_num) * 1) 33 | # e = int((total_num - num_val) * 1) 34 | # e = 1000 35 | elif data_split == 'val': 36 | s = total_num - num_val 37 | e = total_num 38 | else: 39 | s = 0 40 | e = total_num 41 | 42 | self.img_info = self.img_info[s:e] 43 | 44 | print(('%s number of data: %d' % (data_split, e - s))) 45 | # load the data. 46 | # f = h5py.File(input_img_h5, 'r') 47 | # # self.imgs = f['images_' + split][s:e] 48 | # self.imgs = f["features"][s:e] 49 | # f.close() 50 | self.hdf_reader = ImageFeaturesHdfReader( 51 | input_img_h5, False 52 | ) 53 | 54 | print(('Loading txt from %s' % input_ques_h5)) 55 | f = h5py.File(input_ques_h5, 'r') 56 | self.ques = f['ques_' + split][s:e] 57 | self.ans = f['ans_' + split][s:e] 58 | self.cap = f['cap_' + split][s:e] 59 | 60 | self.ques_len = f['ques_len_' + split][s:e] 61 | self.ans_len = f['ans_len_' + split][s:e] 62 | self.cap_len = f['cap_len_' + split][s:e] 63 | 64 | self.ans_ids = f['ans_index_' + split][s:e] 65 | self.opt_ids = f['opt_' + split][s:e] 66 | self.opt_list = f['opt_list_' + split][:] 67 | self.opt_len = f['opt_len_' + split][:] 68 | f.close() 69 | 70 | self.ques_length = self.ques.shape[2] 71 | self.ans_length = self.ans.shape[2] 72 | self.his_length = self.ques_length + self.ans_length 73 | 74 | # self.itow['0'] = '' 75 | # self.itow[str(len(self.itow))] = '' 76 | # # self.itow[str(len(self.itow))] = '' 77 | self.vocab_size = len(self.itow) 78 | 79 | print(('Vocab Size: %d' % self.vocab_size)) 80 | self.split = split 81 | self.rnd = 10 82 | self.negative_sample = negative_sample 83 | 84 | def __getitem__(self, index): 85 | # get the image 86 | # img = torch.from_numpy(self.imgs[index]) 87 | # img_id = self.img_info[index]['imgId'] 88 | img_id = self.img_info[index] 89 | image_features = self.hdf_reader[img_id] 90 | # image_features = torch.tensor(image_features) 91 | img = torch.from_numpy(image_features) 92 | # get the history 93 | his = np.zeros((self.rnd, self.his_length)) 94 | his[0, self.his_length - self.cap_len[index]:] = self.cap[index, :self.cap_len[index]] 95 | 96 | ques = np.zeros((self.rnd, self.ques_length)) 97 | ans = np.zeros((self.rnd, self.ans_length + 1)) 98 | ans_target = np.zeros((self.rnd, self.ans_length + 1)) 99 | ques_ori = np.zeros((self.rnd, self.ques_length)) 100 | 101 | opt_ans = np.zeros((self.rnd, self.negative_sample, self.ans_length + 1)) 102 | ans_len = np.zeros((self.rnd)) 103 | opt_ans_len = np.zeros((self.rnd, self.negative_sample)) 104 | 105 | ans_idx = np.zeros((self.rnd)) 106 | opt_ans_idx = np.zeros((self.rnd, self.negative_sample)) 107 | 108 | for i in range(self.rnd): 109 | # get the index 110 | q_len = self.ques_len[index, i] 111 | a_len = self.ans_len[index, i] 112 | qa_len = q_len + a_len 113 | 114 | if i + 1 < self.rnd: 115 | his[i + 1, self.his_length - qa_len:self.his_length - a_len] = self.ques[index, i, :q_len] 116 | his[i + 1, self.his_length - a_len:] = self.ans[index, i, :a_len] 117 | 118 | ques[i, self.ques_length - q_len:] = self.ques[index, i, :q_len] 119 | 120 | ques_ori[i, :q_len] = self.ques[index, i, :q_len] 121 | ans[i, 1:a_len + 1] = self.ans[index, i, :a_len] 122 | ans[i, 0] = self.wtoi[''] 123 | 124 | ans_target[i, :a_len] = self.ans[index, i, :a_len] 125 | ans_target[i, a_len] = self.wtoi[''] 126 | ans_len[i] = self.ans_len[index, i] 127 | 128 | opt_ids = self.opt_ids[index, i] # since python start from 0 129 | # random select the negative samples. 130 | ans_idx[i] = opt_ids[self.ans_ids[index, i]] 131 | # exclude the gt index. 132 | opt_ids = np.delete(opt_ids, ans_idx[i], 0) 133 | random.shuffle(opt_ids) 134 | for j in range(self.negative_sample): 135 | ids = opt_ids[j] 136 | opt_ans_idx[i, j] = ids 137 | 138 | opt_len = self.opt_len[ids] 139 | 140 | opt_ans_len[i, j] = opt_len 141 | opt_ans[i, j, :opt_len] = self.opt_list[ids, :opt_len] 142 | opt_ans[i, j, opt_len] = self.wtoi[''] 143 | 144 | his = torch.from_numpy(his) 145 | ques = torch.from_numpy(ques) 146 | ans = torch.from_numpy(ans) 147 | ans_target = torch.from_numpy(ans_target) 148 | ques_ori = torch.from_numpy(ques_ori) 149 | ans_len = torch.from_numpy(ans_len) 150 | opt_ans_len = torch.from_numpy(opt_ans_len) 151 | opt_ans = torch.from_numpy(opt_ans) 152 | ans_idx = torch.from_numpy(ans_idx) 153 | opt_ans_idx = torch.from_numpy(opt_ans_idx) 154 | return img, img_id, his, ques, ans, ans_target, ans_len, ans_idx, ques_ori, \ 155 | opt_ans, opt_ans_len, opt_ans_idx 156 | 157 | def __len__(self): 158 | return self.ques.shape[0] 159 | 160 | 161 | class validate(data.Dataset): # torch wrapper 162 | def __init__(self, input_img_h5, input_ques_h5, input_json, negative_sample, num_val, data_split): 163 | 164 | print(('DataLoader loading: %s' % data_split)) 165 | print(('Loading image feature from %s' % input_img_h5)) 166 | 167 | # if data_split == 'test': 168 | # split = 'test' 169 | # else: 170 | # split = 'val' # train and val split both corresponding to 'train' 171 | split = data_split 172 | f = json.load(open(input_json, 'r')) 173 | self.itow = f['itow'] 174 | self.wtoi = f['wtoi'] 175 | self.img_info = f['img_' + split] 176 | 177 | # get the data split. 178 | total_num = len(self.img_info) 179 | if data_split == 'train': 180 | s = 0 181 | e = total_num - num_val 182 | elif data_split == 'val': 183 | s = 0 184 | e = total_num 185 | else: 186 | s = 0 187 | e = total_num 188 | 189 | self.img_info = self.img_info[s:e] 190 | print(('%s number of data: %d' % (data_split, e - s))) 191 | 192 | # load the data. 193 | # f = h5py.File(input_img_h5, 'r') 194 | # self.imgs = f["features"][s:e] 195 | self.hdf_reader = ImageFeaturesHdfReader( 196 | input_img_h5, False 197 | ) 198 | # self.imgs = f['images_'+split][s:e] 199 | # f.close() 200 | 201 | print(('Loading txt from %s' % input_ques_h5)) 202 | f = h5py.File(input_ques_h5, 'r') 203 | self.ques = f['ques_' + split][s:e] 204 | self.ans = f['ans_' + split][s:e] 205 | self.cap = f['cap_' + split][s:e] 206 | 207 | self.ques_len = f['ques_len_' + split][s:e] 208 | self.ans_len = f['ans_len_' + split][s:e] 209 | self.cap_len = f['cap_len_' + split][s:e] 210 | 211 | self.ans_ids = f['ans_index_' + split][s:e] 212 | self.opt_ids = f['opt_' + split][s:e] 213 | self.opt_list = f['opt_list_' + split][:] 214 | self.opt_len = f['opt_len_' + split][:] 215 | f.close() 216 | 217 | self.ques_length = self.ques.shape[2] 218 | self.ans_length = self.ans.shape[2] 219 | self.his_length = self.ques_length + self.ans_length 220 | 221 | # self.itow['0'] = '' 222 | # self.itow[str(len(self.itow))] = '' 223 | self.vocab_size = len(self.itow) 224 | 225 | print(('Vocab Size: %d' % self.vocab_size)) 226 | self.split = split 227 | self.rnd = 10 228 | self.negative_sample = negative_sample 229 | 230 | def __getitem__(self, index): 231 | 232 | # get the image 233 | img_id = self.img_info[index] 234 | image_features = self.hdf_reader[img_id] 235 | # image_features = torch.tensor(image_features) 236 | img = torch.from_numpy(image_features) 237 | # get the history 238 | his = np.zeros((self.rnd, self.his_length)) 239 | his[0, self.his_length - self.cap_len[index]:] = self.cap[index, :self.cap_len[index]] 240 | 241 | ques = np.zeros((self.rnd, self.ques_length)) 242 | ans = np.zeros((self.rnd, self.ans_length + 1)) 243 | ans_target = np.zeros((self.rnd, self.ans_length + 1)) 244 | quesL = np.zeros((self.rnd, self.ques_length)) 245 | 246 | opt_ans = np.zeros((self.rnd, 100, self.ans_length + 1)) 247 | ans_ids = np.zeros(self.rnd) 248 | opt_ans_target = np.zeros((self.rnd, 100, self.ans_length + 1)) 249 | 250 | ans_len = np.zeros((self.rnd)) 251 | opt_ans_len = np.zeros((self.rnd, 100)) 252 | 253 | for i in range(self.rnd): 254 | # get the index 255 | q_len = self.ques_len[index, i] 256 | a_len = self.ans_len[index, i] 257 | qa_len = q_len + a_len 258 | 259 | if i + 1 < self.rnd: 260 | ques_ans = np.concatenate([self.ques[index, i, :q_len], self.ans[index, i, :a_len]]) 261 | his[i + 1, self.his_length - qa_len:] = ques_ans 262 | 263 | ques[i, self.ques_length - q_len:] = self.ques[index, i, :q_len] 264 | quesL[i, :q_len] = self.ques[index, i, :q_len] 265 | ans[i, 1:a_len + 1] = self.ans[index, i, :a_len] 266 | ans[i, 0] = self.wtoi[''] 267 | 268 | ans_target[i, :a_len] = self.ans[index, i, :a_len] 269 | ans_target[i, a_len] = self.wtoi[''] 270 | 271 | ans_ids[i] = self.ans_ids[index, i] # since python start from 0 272 | opt_ids = self.opt_ids[index, i] # since python start from 0 273 | ans_len[i] = self.ans_len[index, i] 274 | ans_idx = self.ans_ids[index, i] 275 | 276 | for j, ids in enumerate(opt_ids): 277 | opt_len = self.opt_len[ids] 278 | opt_ans[i, j, 1:opt_len + 1] = self.opt_list[ids, :opt_len] 279 | opt_ans[i, j, 0] = self.wtoi[''] 280 | 281 | opt_ans_target[i, j, :opt_len] = self.opt_list[ids, :opt_len] 282 | opt_ans_target[i, j, opt_len] = self.wtoi[''] 283 | opt_ans_len[i, j] = opt_len 284 | 285 | opt_ans = torch.from_numpy(opt_ans) 286 | opt_ans_target = torch.from_numpy(opt_ans_target) 287 | ans_ids = torch.from_numpy(ans_ids) 288 | 289 | his = torch.from_numpy(his) 290 | ques = torch.from_numpy(ques) 291 | ans = torch.from_numpy(ans) 292 | ans_target = torch.from_numpy(ans_target) 293 | quesL = torch.from_numpy(quesL) 294 | 295 | ans_len = torch.from_numpy(ans_len) 296 | opt_ans_len = torch.from_numpy(opt_ans_len) 297 | 298 | return img, img_id, his, ques, ans, ans_target, quesL, opt_ans, \ 299 | opt_ans_target, ans_ids, ans_len, opt_ans_len 300 | 301 | def __len__(self): 302 | return self.ques.shape[0] -------------------------------------------------------------------------------- /misc/focalloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import random 7 | 8 | class FocalLoss(nn.Module): 9 | def __init__(self, gamma=2, alpha=0.75, size_average=True): 10 | super(FocalLoss, self).__init__() 11 | self.gamma = gamma 12 | self.alpha = alpha 13 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) 14 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 15 | self.size_average = size_average 16 | 17 | def forward(self, input, target): 18 | 19 | target = target.view(-1,1) 20 | 21 | # logpt = F.log_softmax(input, 1) 22 | logpt = input 23 | mask = target.data.gt(0) 24 | if isinstance(input, Variable): 25 | mask = Variable(mask, volatile=input.volatile) 26 | logpt = logpt.gather(1,target) 27 | logpt = torch.masked_select(logpt, mask) 28 | logpt = logpt.view(-1) 29 | pt = Variable(logpt.data.exp()) 30 | 31 | if self.alpha is not None: 32 | if self.alpha.type()!=input.data.type(): 33 | self.alpha = self.alpha.type_as(input.data) 34 | at = self.alpha.gather(0,target.data.view(-1)) 35 | logpt = logpt * Variable(at) 36 | 37 | loss = -1 * (1-pt)**self.gamma * logpt 38 | if self.size_average: return loss.mean() 39 | else: return loss.sum() 40 | 41 | class nPairLoss(nn.Module): 42 | """ 43 | Given the right, fake, wrong, wrong_sampled embedding, use the N Pair Loss 44 | objective (which is an extension to the triplet loss) 45 | 46 | Loss = log(1+exp(feat*wrong - feat*right + feat*fake - feat*right)) + L2 norm. 47 | 48 | Improved Deep Metric Learning with Multi-class N-pair Loss Objective (NIPS) 49 | """ 50 | def __init__(self, ninp, margin): 51 | super(nPairLoss, self).__init__() 52 | self.ninp = ninp 53 | self.margin = np.log(margin) 54 | 55 | def forward(self, feat, right, wrong, batch_wrong, fake=None, fake_diff_mask=None): 56 | 57 | num_wrong = wrong.size(1) 58 | batch_size = feat.size(0) 59 | 60 | feat = feat.view(-1, self.ninp, 1) 61 | right_dis = torch.bmm(right.view(-1, 1, self.ninp), feat) 62 | wrong_dis = torch.bmm(wrong, feat) 63 | batch_wrong_dis = torch.bmm(batch_wrong, feat) 64 | 65 | wrong_score = torch.sum(torch.exp(wrong_dis - right_dis.expand_as(wrong_dis)),1) \ 66 | + torch.sum(torch.exp(batch_wrong_dis - right_dis.expand_as(batch_wrong_dis)),1) 67 | 68 | loss_dis = torch.sum(torch.log(wrong_score + 1)) 69 | loss_norm = right.norm() + feat.norm() + wrong.norm() + batch_wrong.norm() 70 | 71 | if fake: 72 | fake_dis = torch.bmm(fake.view(-1, 1, self.ninp), feat) 73 | fake_score = torch.masked_select(torch.exp(fake_dis - right_dis), fake_diff_mask) 74 | 75 | margin_score = F.relu(torch.log(fake_score + 1) - self.margin) 76 | loss_fake = torch.sum(margin_score) 77 | loss_dis += loss_fake 78 | loss_norm += fake.norm() 79 | 80 | loss = (loss_dis + 0.1 * loss_norm) / batch_size 81 | if fake: 82 | return loss, loss_fake.data[0] / batch_size 83 | else: 84 | return loss 85 | 86 | def sample_batch_neg(answerIdx, negAnswerIdx, sample_idx, num_sample): 87 | """ 88 | input: 89 | answerIdx: batch_size 90 | negAnswerIdx: batch_size x opt.negative_sample 91 | 92 | output: 93 | sample_idx = batch_size x num_sample 94 | """ 95 | 96 | batch_size = answerIdx.size(0) 97 | num_neg = negAnswerIdx.size(0) * negAnswerIdx.size(1) 98 | negAnswerIdx = negAnswerIdx.clone().view(-1) 99 | for b in range(batch_size): 100 | gt_idx = answerIdx[b] 101 | for n in range(num_sample): 102 | while True: 103 | rand = int(random.random() * num_neg) 104 | neg_idx = negAnswerIdx[rand] 105 | if gt_idx != neg_idx: 106 | sample_idx.data[b, n] = rand 107 | break 108 | 109 | class G_loss(nn.Module): 110 | """ 111 | Generator loss: 112 | minimize right feature and fake feature L2 norm. 113 | maximinze the fake feature and wrong feature. 114 | """ 115 | def __init__(self, ninp): 116 | super(G_loss, self).__init__() 117 | self.ninp = ninp 118 | 119 | def forward(self, feat, right, fake): 120 | 121 | batch_size = feat.size(0) 122 | fake = fake.view(batch_size, -1, self.ninp) 123 | num_fake = fake.size(1) 124 | feat = feat.view(batch_size, 1, self.ninp).repeat(1, num_fake, 1) 125 | right = right.view(batch_size, 1, self.ninp).repeat(1, num_fake, 1) 126 | 127 | feat = feat.view(-1, self.ninp, 1) 128 | #wrong_dis = torch.bmm(wrong, feat) 129 | #batch_wrong_dis = torch.bmm(batch_wrong, feat) 130 | fake_dis = torch.bmm(fake.view(-1, 1, self.ninp), feat) 131 | right_dis = torch.bmm(right.view(-1, 1, self.ninp), feat) 132 | 133 | fake_score = torch.exp(right_dis - fake_dis) 134 | loss_fake = torch.sum(torch.log(fake_score + 1)) 135 | 136 | loss_norm = feat.norm() + fake.norm() + right.norm() 137 | loss = (loss_fake + 0.1 * loss_norm ) / batch_size / num_fake 138 | 139 | return loss, loss_fake.data[0]/batch_size -------------------------------------------------------------------------------- /misc/readers.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Reader simply reads data from disk and returns it almost as is, based on 3 | a "primary key", which for the case of VisDial v1.0 dataset, is the 4 | ``image_id``. Readers should be utilized by torch ``Dataset``s. Any type of 5 | data pre-processing is not recommended in the reader, such as tokenizing words 6 | to integers, embedding tokens, or passing an image through a pre-trained CNN. 7 | 8 | Each reader must atleast implement three methods: 9 | - ``__len__`` to return the length of data this Reader can read. 10 | - ``__getitem__`` to return data based on ``image_id`` in VisDial v1.0 11 | dataset. 12 | - ``keys`` to return a list of possible ``image_id``s this Reader can 13 | provide data of. 14 | """ 15 | 16 | import copy 17 | import json 18 | import multiprocessing as mp 19 | from typing import Any, Dict, List, Optional, Set, Union 20 | 21 | import h5py 22 | 23 | # A bit slow, and just splits sentences to list of words, can be doable in 24 | # `DialogsReader`. 25 | from nltk.tokenize import word_tokenize 26 | from tqdm import tqdm 27 | 28 | 29 | class DialogsReader(object): 30 | """ 31 | A simple reader for VisDial v1.0 dialog data. The json file must have the 32 | same structure as mentioned on ``https://visualdialog.org/data``. 33 | 34 | Parameters 35 | ---------- 36 | dialogs_jsonpath : str 37 | Path to json file containing VisDial v1.0 train, val or test data. 38 | num_examples: int, optional (default = None) 39 | Process first ``num_examples`` from the split. Useful to speed up while 40 | debugging. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | dialogs_jsonpath: str, 46 | num_examples: Optional[int] = None, 47 | num_workers: int = 1, 48 | ): 49 | with open(dialogs_jsonpath, "r") as visdial_file: 50 | visdial_data = json.load(visdial_file) 51 | self._split = visdial_data["split"] 52 | 53 | # Maintain questions and answers as a dict instead of list because 54 | # they are referenced by index in dialogs. We drop elements from 55 | # these in "overfit" mode to save time (tokenization is slow). 56 | self.questions = { 57 | i: question for i, question in 58 | enumerate(visdial_data["data"]["questions"]) 59 | } 60 | self.answers = { 61 | i: answer for i, answer in 62 | enumerate(visdial_data["data"]["answers"]) 63 | } 64 | 65 | # Add empty question, answer - useful for padding dialog rounds 66 | # for test split. 67 | self.questions[-1] = "" 68 | self.answers[-1] = "" 69 | 70 | # ``image_id``` serves as key for all three dicts here. 71 | self.captions: Dict[int, Any] = {} 72 | self.dialogs: Dict[int, Any] = {} 73 | self.num_rounds: Dict[int, Any] = {} 74 | 75 | all_dialogs = visdial_data["data"]["dialogs"] 76 | 77 | # Retain only first ``num_examples`` dialogs if specified. 78 | if num_examples is not None: 79 | all_dialogs = all_dialogs[:num_examples] 80 | 81 | for _dialog in all_dialogs: 82 | 83 | self.captions[_dialog["image_id"]] = _dialog["caption"] 84 | 85 | # Record original length of dialog, before padding. 86 | # 10 for train and val splits, 10 or less for test split. 87 | self.num_rounds[_dialog["image_id"]] = len(_dialog["dialog"]) 88 | 89 | # Pad dialog at the end with empty question and answer pairs 90 | # (for test split). 91 | while len(_dialog["dialog"]) < 10: 92 | _dialog["dialog"].append({"question": -1, "answer": -1}) 93 | 94 | # Add empty answer (and answer options) if not provided 95 | # (for test split). We use "-1" as a key for empty questions 96 | # and answers. 97 | for i in range(len(_dialog["dialog"])): 98 | if "answer" not in _dialog["dialog"][i]: 99 | _dialog["dialog"][i]["answer"] = -1 100 | if "answer_options" not in _dialog["dialog"][i]: 101 | _dialog["dialog"][i]["answer_options"] = [-1] * 100 102 | 103 | self.dialogs[_dialog["image_id"]] = _dialog["dialog"] 104 | 105 | # If ``num_examples`` is specified, collect questions and answers 106 | # included in those examples, and drop the rest to save time while 107 | # tokenizing. Collecting these should be fast because num_examples 108 | # during debugging are generally small. 109 | if num_examples is not None: 110 | questions_included: Set[int] = set() 111 | answers_included: Set[int] = set() 112 | 113 | for _dialog in self.dialogs.values(): 114 | for _dialog_round in _dialog: 115 | questions_included.add(_dialog_round["question"]) 116 | answers_included.add(_dialog_round["answer"]) 117 | for _answer_option in _dialog_round["answer_options"]: 118 | answers_included.add(_answer_option) 119 | 120 | self.questions = { 121 | i: self.questions[i] for i in questions_included 122 | } 123 | self.answers = { 124 | i: self.answers[i] for i in answers_included 125 | } 126 | 127 | self._multiprocess_tokenize(num_workers) 128 | 129 | def _multiprocess_tokenize(self, num_workers: int): 130 | """ 131 | Tokenize captions, questions and answers in parallel processes. This 132 | method uses multiprocessing module internally. 133 | 134 | Since questions, answers and captions are dicts - and multiprocessing 135 | map utilities operate on lists, we convert these to lists first and 136 | then back to dicts. 137 | 138 | Parameters 139 | ---------- 140 | num_workers: int 141 | Number of workers (processes) to run in parallel. 142 | """ 143 | 144 | # While displaying progress bar through tqdm, specify total number of 145 | # sequences to tokenize, because tqdm won't know in case of pool.imap 146 | with mp.Pool(num_workers) as pool: 147 | print(f"[{self._split}] Tokenizing questions...") 148 | _question_tuples = self.questions.items() 149 | _question_indices = [t[0] for t in _question_tuples] 150 | _questions = list( 151 | tqdm( 152 | pool.imap(word_tokenize, [t[1] for t in _question_tuples]), 153 | total=len(self.questions) 154 | ) 155 | ) 156 | self.questions = { 157 | i: question + ["?"] for i, question in 158 | zip(_question_indices, _questions) 159 | } 160 | # Delete variables to free memory. 161 | del _question_tuples, _question_indices, _questions 162 | 163 | print(f"[{self._split}] Tokenizing answers...") 164 | _answer_tuples = self.answers.items() 165 | _answer_indices = [t[0] for t in _answer_tuples] 166 | _answers = list( 167 | tqdm( 168 | pool.imap(word_tokenize, [t[1] for t in _answer_tuples]), 169 | total=len(self.answers) 170 | ) 171 | ) 172 | self.answers = { 173 | i: answer + ["?"] for i, answer in 174 | zip(_answer_indices, _answers) 175 | } 176 | # Delete variables to free memory. 177 | del _answer_tuples, _answer_indices, _answers 178 | 179 | print(f"[{self._split}] Tokenizing captions...") 180 | # Convert dict to separate lists of image_ids and captions. 181 | _caption_tuples = self.captions.items() 182 | _image_ids = [t[0] for t in _caption_tuples] 183 | _captions = list( 184 | tqdm( 185 | pool.imap(word_tokenize, [t[1] for t in _caption_tuples]), 186 | total=(len(_caption_tuples)) 187 | ) 188 | ) 189 | # Convert tokenized captions back to a dict. 190 | self.captions = {i: c for i, c in zip(_image_ids, _captions)} 191 | 192 | def __len__(self): 193 | return len(self.dialogs) 194 | 195 | def __getitem__(self, image_id: int) -> Dict[str, Union[int, str, List]]: 196 | caption_for_image = self.captions[image_id] 197 | dialog = copy.copy(self.dialogs[image_id]) 198 | num_rounds = self.num_rounds[image_id] 199 | 200 | # Replace question and answer indices with actual word tokens. 201 | for i in range(len(dialog)): 202 | dialog[i]["question"] = self.questions[ 203 | dialog[i]["question"] 204 | ] 205 | dialog[i]["answer"] = self.answers[ 206 | dialog[i]["answer"] 207 | ] 208 | for j, answer_option in enumerate( 209 | dialog[i]["answer_options"] 210 | ): 211 | dialog[i]["answer_options"][j] = self.answers[ 212 | answer_option 213 | ] 214 | 215 | return { 216 | "image_id": image_id, 217 | "caption": caption_for_image, 218 | "dialog": dialog, 219 | "num_rounds": num_rounds, 220 | } 221 | 222 | def keys(self) -> List[int]: 223 | return list(self.dialogs.keys()) 224 | 225 | @property 226 | def split(self): 227 | return self._split 228 | 229 | 230 | class DenseAnnotationsReader(object): 231 | """ 232 | A reader for dense annotations for val split. The json file must have the 233 | same structure as mentioned on ``https://visualdialog.org/data``. 234 | 235 | Parameters 236 | ---------- 237 | dense_annotations_jsonpath : str 238 | Path to a json file containing VisDial v1.0 239 | """ 240 | 241 | def __init__(self, dense_annotations_jsonpath: str): 242 | with open(dense_annotations_jsonpath, "r") as visdial_file: 243 | self._visdial_data = json.load(visdial_file) 244 | self._image_ids = [ 245 | entry["image_id"] for entry in self._visdial_data 246 | ] 247 | 248 | def __len__(self): 249 | return len(self._image_ids) 250 | 251 | def __getitem__(self, image_id: int) -> Dict[str, Union[int, List]]: 252 | index = self._image_ids.index(image_id) 253 | # keys: {"image_id", "round_id", "gt_relevance"} 254 | return self._visdial_data[index] 255 | 256 | @property 257 | def split(self): 258 | # always 259 | return "val" 260 | 261 | 262 | class ImageFeaturesHdfReader(object): 263 | """ 264 | A reader for HDF files containing pre-extracted image features. A typical 265 | HDF file is expected to have a column named "image_id", and another column 266 | named "features". 267 | 268 | Example of an HDF file: 269 | ``` 270 | visdial_train_faster_rcnn_bottomup_features.h5 271 | |--- "image_id" [shape: (num_images, )] 272 | |--- "features" [shape: (num_images, num_proposals, feature_size)] 273 | +--- .attrs ("split", "train") 274 | ``` 275 | Refer ``$PROJECT_ROOT/data/extract_bottomup.py`` script for more details 276 | about HDF structure. 277 | 278 | Parameters 279 | ---------- 280 | features_hdfpath : str 281 | Path to an HDF file containing VisDial v1.0 train, val or test split 282 | image features. 283 | in_memory : bool 284 | Whether to load the whole HDF file in memory. Beware, these files are 285 | sometimes tens of GBs in size. Set this to true if you have sufficient 286 | RAM - trade-off between speed and memory. 287 | """ 288 | 289 | def __init__(self, features_hdfpath: str, in_memory: bool = False): 290 | self.features_hdfpath = features_hdfpath 291 | self._in_memory = in_memory 292 | 293 | with h5py.File(self.features_hdfpath, "r") as features_hdf: 294 | self._split = features_hdf.attrs["split"] 295 | self._image_id_list = list(features_hdf["image_id"]) 296 | # "features" is List[np.ndarray] if the dataset is loaded in-memory 297 | # If not loaded in memory, then list of None. 298 | self.features = [None] * len(self._image_id_list) 299 | 300 | def __len__(self): 301 | return len(self._image_id_list) 302 | 303 | def __getitem__(self, image_id: int): 304 | index = self._image_id_list.index(image_id) 305 | if self._in_memory: 306 | # Load features during first epoch, all not loaded together as it 307 | # has a slow start. 308 | if self.features[index] is not None: 309 | image_id_features = self.features[index] 310 | else: 311 | with h5py.File(self.features_hdfpath, "r") as features_hdf: 312 | image_id_features = features_hdf["features"][index] 313 | self.features[index] = image_id_features 314 | else: 315 | # Read chunk from file everytime if not loaded in memory. 316 | with h5py.File(self.features_hdfpath, "r") as features_hdf: 317 | image_id_features = features_hdf["features"][index] 318 | 319 | return image_id_features 320 | 321 | def keys(self) -> List[int]: 322 | return self._image_id_list 323 | 324 | @property 325 | def split(self): 326 | return self._split 327 | -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import random 6 | from collections import OrderedDict 7 | import numpy as np 8 | #from ciderD.ciderD import CiderD 9 | """ 10 | Some utility Functions. 11 | """ 12 | 13 | # CiderD_scorer = None 14 | #CiderD_scorer = CiderD(df='corpus') 15 | 16 | 17 | def repackage_hidden_volatile(h): 18 | if type(h) == Variable: 19 | return Variable(h.data, volatile=True) 20 | else: 21 | return tuple(repackage_hidden_volatile(v) for v in h) 22 | 23 | def repackage_hidden(h, batch_size): 24 | """Wraps hidden states in new Variables, to detach them from their history.""" 25 | if type(h) == Variable: 26 | return Variable(h.data.resize_(h.size(0), batch_size, h.size(2)).zero_()) 27 | else: 28 | return tuple(repackage_hidden(v, batch_size) for v in h) 29 | 30 | def clip_gradient(model): 31 | """Computes a gradient clipping coefficient based on gradient norm.""" 32 | totalnorm = 0 33 | for p in model.parameters(): 34 | p.grad.data.clamp_(-5, 5) 35 | 36 | def adjust_learning_rate(optimizer, epoch, lr): 37 | """Sets the learning rate to the initial LR decayed by 0.5 every 20 epochs""" 38 | if epoch < 20: 39 | lr = lr * (0.5 ** (epoch // 5)) 40 | if epoch < 30 and epoch >= 20: 41 | lr = 0.0001 42 | if epoch >= 30: 43 | lr = 0.00001 44 | for param_group in optimizer.param_groups: 45 | param_group['lr'] = lr 46 | return lr 47 | 48 | def decode_txt(itow, x): 49 | """Function to show decode the text.""" 50 | out = [] 51 | for b in range(x.size(1)): 52 | txt = '' 53 | for t in range(x.size(0)): 54 | idx = x[t,b] 55 | if idx == 0 or idx == 3: 56 | break 57 | txt += itow[str(int(idx))] 58 | txt += ' ' 59 | out.append(txt) 60 | 61 | return out 62 | 63 | def decode_txt_ques(itow, x): 64 | """Function to show decode the text.""" 65 | out = [] 66 | for b in range(x.size(1)): 67 | txt = '' 68 | for t in range(x.size(0)): 69 | idx = x[t,b] 70 | if idx == 3: 71 | break 72 | if idx == 0: 73 | continue 74 | txt += itow[str(int(idx))] 75 | if t != (x.size(0) -1): 76 | txt += ' ' 77 | out.append(txt) 78 | 79 | return out 80 | 81 | def l2_norm(input): 82 | """ 83 | input: feature that need to normalize. 84 | output: normalziaed feature. 85 | """ 86 | input_size = input.size() 87 | buffer = torch.pow(input, 2) 88 | 89 | normp = torch.sum(buffer, 1).add_(1e-10) 90 | norm = torch.sqrt(normp) 91 | _output = torch.div(input, norm.view(-1, 1).expand_as(input)) 92 | output = _output.view(input_size) 93 | 94 | return output 95 | 96 | 97 | def sample_batch_neg(answerIdx, negAnswerIdx, sample_idx, num_sample): 98 | """ 99 | input: 100 | answerIdx: batch_size 101 | negAnswerIdx: batch_size x opt.negative_sample 102 | 103 | output: 104 | sample_idx = batch_size x num_sample 105 | """ 106 | 107 | batch_size = answerIdx.size(0) 108 | num_neg = negAnswerIdx.size(0) * negAnswerIdx.size(1) 109 | negAnswerIdx = negAnswerIdx.clone().view(-1) 110 | for b in range(batch_size): 111 | gt_idx = answerIdx[b] 112 | for n in range(num_sample): 113 | while True: 114 | rand = int(random.random() * num_neg) 115 | neg_idx = negAnswerIdx[rand] 116 | if gt_idx != neg_idx: 117 | sample_idx.data[b, n] = rand 118 | break 119 | 120 | 121 | def to_contiguous(tensor): 122 | if tensor.is_contiguous(): 123 | return tensor 124 | else: 125 | return tensor.contiguous() 126 | 127 | class RewardCriterion(nn.Module): 128 | def __init__(self): 129 | super(RewardCriterion, self).__init__() 130 | 131 | def forward(self, input, seq, reward): 132 | #print(input[0]) 133 | input = to_contiguous(input).view(-1) 134 | reward = to_contiguous(reward).view(-1) 135 | mask = (seq>0).float() 136 | mask = to_contiguous(torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)).view(-1) 137 | output = - input * reward * Variable(mask) 138 | output = torch.sum(output) / torch.sum(mask) 139 | 140 | return output 141 | 142 | def RLreward(input, target, reward): 143 | batch_size = input.size(0) 144 | seq_len = input.size(1) 145 | inp = input.permute(1, 0, 2) # seq_len x batch_size 146 | target = target.permute(1, 0) # seq_len x batch_size 147 | 148 | loss = 0 149 | for i in range(seq_len): 150 | # TODO: should h be detached from graph (.detach())? 151 | for j in range(batch_size): 152 | loss += -inp[i][j][target.data[i][j]] * reward[j] # log(P(y_t|Y_1:Y_{t-1})) * Q 153 | 154 | return loss / batch_size 155 | 156 | def array_to_str(arr): 157 | out = [] 158 | for i in range(len(arr)): 159 | out.append(str(arr[i])) 160 | if arr[i] == 0: 161 | break 162 | return out 163 | 164 | 165 | 166 | def init_cider_scorer(cached_tokens): 167 | global CiderD_scorer 168 | CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens) 169 | 170 | def get_self_critical_reward(netG, netW, sample_ans_input, ques_hidden, gen_result, ans_input, itows): 171 | batch_size = gen_result.size(0) # batch_size = sample_size * seq_per_img 172 | #seq_per_img = batch_size // len(data['gts']) 173 | 174 | # get greedy decoding baseline 175 | greedy_res, _ = netG.sample(netW, sample_ans_input, ques_hidden) 176 | ans_sample_txt = decode_txt(itows, greedy_res.t()) 177 | #print('greedy_ans: %s' % (ans_sample_txt)) 178 | res1 = OrderedDict() 179 | res2 = OrderedDict() 180 | # 181 | gen_result = gen_result.cpu().numpy() 182 | greedy_res = greedy_res.cpu().numpy() 183 | ans_input = ans_input.cpu().numpy() 184 | for i in range(batch_size): 185 | res1[i] = array_to_str(gen_result[i]) 186 | for i in range(batch_size): 187 | res2[i] = array_to_str(greedy_res[i]) 188 | # 189 | gts = OrderedDict() 190 | for i in range(len(ans_input)): 191 | gts[i] = array_to_str(ans_input[i]) 192 | # 193 | # # _, scores = Bleu(4).compute_score(gts, res) 194 | # # scores = np.array(scores[3]) 195 | # res = [{'image_id': i, 'caption': res[i]} for i in range(2 * batch_size)] 196 | # gts = {i: gts[i % batch_size] for i in range(2 * batch_size)} 197 | 198 | 199 | from nltk.translate import bleu 200 | from nltk.translate.bleu_score import SmoothingFunction 201 | smoothie = SmoothingFunction().method4 202 | scores = [] 203 | for i in range(len(gen_result)): 204 | score = bleu(gts[i], res1[i], weights=(0.5, 0.5)) 205 | # if score != 0: 206 | # print i , ': ' , score 207 | scores.append(score) 208 | for i in range(len(greedy_res)): 209 | score = bleu(gts[i], res2[i], weights=(0.5, 0.5)) 210 | scores.append(score) 211 | # scores = bleu(gts, res, smoothing_function=smoothie) 212 | 213 | # _, scores = CiderD_scorer.compute_score(gts, res) 214 | # print('Cider scores:', _) 215 | scores = np.array(scores) 216 | scores = scores[:batch_size] - scores[batch_size:] 217 | 218 | rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) 219 | 220 | return rewards 221 | 222 | class LayerNorm(nn.Module): 223 | """ 224 | Layer Normalization 225 | """ 226 | def __init__(self, features, eps=1e-6): 227 | super().__init__() 228 | self.gamma = nn.Parameter(torch.ones(features)) 229 | self.beta = nn.Parameter(torch.zeros(features)) 230 | self.eps = eps 231 | 232 | def forward(self, x): 233 | mean = x.mean(-1, keepdim=True) 234 | std = x.std(-1, keepdim=True) 235 | return self.gamma * (x - mean) / (std + self.eps) + self.beta -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backcall==0.1.0 2 | certifi==2018.11.29 3 | cffi==1.12.1 4 | cycler==0.10.0 5 | decorator==4.3.2 6 | h5py==2.8.0 7 | ipython==6.4.0 8 | ipython-genutils==0.2.0 9 | jedi==0.13.3 10 | kiwisolver==1.1.0 11 | matplotlib==3.0.3 12 | mkl-fft==1.0.6 13 | mkl-random==1.0.1 14 | nltk==3.4.5 15 | numpy==1.14.5 16 | pandas==0.24.2 17 | parso==0.3.4 18 | patsy==0.5.1 19 | pexpect==4.6.0 20 | pickleshare==0.7.5 21 | Pillow==6.0.0 22 | prompt-toolkit==1.0.15 23 | protobuf==3.6.0 24 | ptyprocess==0.6.0 25 | pycparser==2.19 26 | Pygments==2.3.1 27 | pyparsing==2.4.0 28 | python-dateutil==2.8.0 29 | pytz==2019.1 30 | scipy==1.1.0 31 | seaborn==0.9.0 32 | setproctitle==1.1.10 33 | simplegeneric==0.8.1 34 | six==1.11.0 35 | statsmodels==0.9.0 36 | torch==0.3.1 37 | tornado==6.0.2 38 | tqdm==4.31.1 39 | traitlets==4.3.2 40 | wcwidth==0.1.7 41 | -------------------------------------------------------------------------------- /script/create_glove.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import json 5 | import numpy as np 6 | import argparse 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | 9 | def create_glove_embedding_init(idx2word, glove_file): 10 | word2emb = {} 11 | with open(glove_file, 'r') as f: 12 | entries = f.readlines() 13 | emb_dim = len(entries[0].split(' ')) - 1 14 | print('embedding dim is %d' % emb_dim) 15 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 16 | 17 | for entry in entries: 18 | vals = entry.split(' ') 19 | word = vals[0] 20 | vals = list(map(float, vals[1:])) 21 | word2emb[word] = np.array(vals) 22 | for idx, word in enumerate(idx2word): 23 | if word not in word2emb: 24 | continue 25 | weights[idx] = word2emb[word] 26 | return weights, word2emb 27 | 28 | 29 | if __name__ == '__main__': 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--input_json', default='../data/visdial_params_v1.0.json', 33 | help='path to dataset, now hdf5 file') 34 | args = parser.parse_args() 35 | 36 | f = json.load(open(args.input_json, 'r')) 37 | itow = f['itow'] 38 | emb_dim = 300 39 | glove_file = '../data/glove.6B.%dd.txt' % emb_dim 40 | weights, word2emb = create_glove_embedding_init(itow, glove_file) 41 | np.save('../data/glove6b_init_%dd_v1.0.npy' % emb_dim, weights) 42 | -------------------------------------------------------------------------------- /script/prepro.py: -------------------------------------------------------------------------------- 1 | import json 2 | import h5py 3 | import argparse 4 | import numpy as np 5 | from nltk.tokenize import word_tokenize 6 | import nltk 7 | #from unidecode import unidecode 8 | import os 9 | import re 10 | import pdb 11 | 12 | def tokenize(sentence): 13 | return [i for i in re.split(r"([-.\"',:? !\$#@~()*&\^%;\[\]/\\\+<>\n=])", sentence) if i!='' and i!=' ' and i!='\n']; 14 | 15 | def tokenize_data(data): 16 | ''' 17 | Tokenize captions, questions and answers 18 | Also maintain word count if required 19 | ''' 20 | ques_toks, ans_toks, caption_toks = [], [], [] 21 | 22 | print(data['split']) 23 | print('Tokenizing captions...') 24 | for i in data['data']['dialogs']: 25 | caption = word_tokenize(i['caption']) 26 | caption_toks.append(caption) 27 | 28 | print('Tokenizing questions...') 29 | for i in data['data']['questions']: 30 | ques_tok = word_tokenize(i + '?') 31 | ques_toks.append(ques_tok) 32 | 33 | print('Tokenizing answers...') 34 | for i in data['data']['answers']: 35 | ans_tok = word_tokenize(i) 36 | ans_toks.append(ans_tok) 37 | 38 | return ques_toks, ans_toks, caption_toks 39 | 40 | def build_vocab(data, ques_toks, ans_toks, caption_toks, params): 41 | count_thr = args.word_count_threshold 42 | 43 | i = 0 44 | counts = {} 45 | for imgs in data['data']['dialogs']: 46 | caption = caption_toks[i] 47 | i += 1 48 | 49 | for w in caption: 50 | counts[w] = counts.get(w, 0) + 1 51 | 52 | for dialog in imgs['dialog']: 53 | question = ques_toks[dialog['question']] 54 | answer = ans_toks[dialog['answer']] 55 | 56 | for w in question + answer: 57 | counts[w] = counts.get(w, 0) + 1 58 | 59 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True) 60 | print('top words and their counts:') 61 | print('\n'.join(map(str,cw[:20]))) 62 | 63 | # print some stats 64 | total_words = sum(counts.values()) 65 | print('total words:', total_words) 66 | bad_words = [w for w,n in counts.items() if n <= count_thr] 67 | vocab = [w for w,n in counts.items() if n > count_thr] 68 | bad_count = sum(counts[w] for w in bad_words) 69 | print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) 70 | print('number of words in vocab would be %d' % (len(vocab), )) 71 | print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) 72 | 73 | print('inserting the special UNK token') 74 | vocab.append('UNK') 75 | 76 | return vocab 77 | 78 | 79 | def encode_vocab(questions, answers, captions, wtoi): 80 | 81 | ques_idx, ans_idx, cap_idx = [], [], [] 82 | 83 | for txt in captions: 84 | ind = [wtoi.get(w, wtoi['']) for w in txt] 85 | cap_idx.append(ind) 86 | 87 | for txt in questions: 88 | ind = [wtoi.get(w, wtoi['']) for w in txt] 89 | ques_idx.append(ind) 90 | 91 | for txt in answers: 92 | ind = [wtoi.get(w, wtoi['']) for w in txt] 93 | ans_idx.append(ind) 94 | 95 | return ques_idx, ans_idx, cap_idx 96 | 97 | 98 | def create_mats(data, ques_idx, ans_idx, cap_idx, params, test_type=False): 99 | N = len(data['data']['dialogs']) 100 | num_round = 10 101 | max_ques_len = params.max_ques_len 102 | max_cap_len = params.max_cap_len 103 | max_ans_len = params.max_ans_len 104 | 105 | captions = np.zeros([N, max_cap_len], dtype='uint32') 106 | questions = np.zeros([N, num_round, max_ques_len], dtype='uint32') 107 | answers = np.zeros([N, num_round, max_ans_len], dtype='uint32') 108 | 109 | answer_index = np.zeros([N, num_round], dtype='uint32') 110 | 111 | caption_len = np.zeros(N, dtype='uint32') 112 | question_len = np.zeros([N, num_round], dtype='uint32') 113 | answer_len = np.zeros([N, num_round], dtype='uint32') 114 | options = np.zeros([N, num_round, 100], dtype='uint32') 115 | 116 | image_list = [] 117 | for i, img in enumerate(data['data']['dialogs']): 118 | image_id = img['image_id'] 119 | image_list.append(image_id) 120 | 121 | cap_len = len(cap_idx[i][0:max_cap_len]) 122 | caption_len[i] = cap_len 123 | captions[i][0:cap_len] = cap_idx[i][0:cap_len] 124 | 125 | 126 | for j, dialog in enumerate(img['dialog']): 127 | 128 | ques_len = len(ques_idx[dialog['question']][0:max_ques_len]) 129 | question_len[i,j] = ques_len 130 | questions[i,j,:ques_len] = ques_idx[dialog['question']][0:ques_len] 131 | 132 | if not test_type: 133 | ans_len = len(ans_idx[dialog['answer']][0:max_ans_len]) 134 | answer_len[i,j] = ans_len 135 | answers[i,j,:ans_len] = ans_idx[dialog['answer']][0:ans_len] 136 | options[i,j] = dialog['answer_options'] 137 | answer_index[i,j] = dialog['gt_index'] 138 | 139 | options_list = np.zeros([len(ans_idx), max_ans_len], dtype='uint32') 140 | options_len = np.zeros(len(ans_idx), dtype='uint32') 141 | 142 | for i, ans in enumerate(ans_idx): 143 | options_len[i] = len(ans[0:max_ans_len]) 144 | options_list[i][0:options_len[i]] = ans[0:max_ans_len] 145 | 146 | return captions, caption_len, questions, question_len, answers, answer_len, \ 147 | options, options_list, options_len, answer_index, image_list 148 | 149 | def img_info(image_list, split): 150 | out = [] 151 | for i,imgId in enumerate(image_list): 152 | jimg = {} 153 | jimg['imgId'] = imgId 154 | file_name = 'COCO_%s_%012d.jpg' %(split, imgId) 155 | jimg['path'] = os.path.join(split, file_name) 156 | out.append(jimg) 157 | 158 | return out 159 | 160 | def get_image_ids(data): 161 | image_ids = [dialog['image_id'] for dialog in data['data']['dialogs']] 162 | return image_ids 163 | 164 | if __name__ == "__main__": 165 | 166 | parser = argparse.ArgumentParser() 167 | 168 | # Input files 169 | parser.add_argument('-input_json_train', default = '../data/visdial_1.0_train.json', help='Input `train` json file') 170 | parser.add_argument('-input_json_val', default='../data/visdial_1.0_val.json', help='Input `val` json file') 171 | parser.add_argument('-input_json_test',default='../data/visdial_1.0_test.json', help='Input `val` json file') 172 | # Output files 173 | parser.add_argument('-output_json', default='data/visdial_params_v1.0.json', help='Output json file') 174 | parser.add_argument('-output_h5', default='data/visdial_data_v1.0.h5', help='Output hdf5 file') 175 | # Options 176 | parser.add_argument('-max_ques_len', default=16, type=int, help='Max length of questions') 177 | parser.add_argument('-max_ans_len', default=8, type=int, help='Max length of answers') 178 | parser.add_argument('-max_cap_len', default=24, type=int, help='Max length of captions') 179 | parser.add_argument('-word_count_threshold', default=5, type=int, help='Min threshold of word count to include in vocabulary') 180 | 181 | args = parser.parse_args() 182 | 183 | print('Reading json...') 184 | data_train = json.load(open(args.input_json_train, 'r')) 185 | data_val = json.load(open(args.input_json_val, 'r')) 186 | data_test = json.load(open(args.input_json_test, 'r')) 187 | 188 | ques_tok_train, ans_tok_train, cap_tok_train = tokenize_data(data_train) 189 | ques_tok_val, ans_tok_val, cap_tok_val = tokenize_data(data_val) 190 | ques_tok_test, ans_tok_test, cap_tok_test = tokenize_data(data_test) 191 | 192 | vocab = build_vocab(data_train, ques_tok_train, ans_tok_train, cap_tok_train, args) 193 | 194 | word2index = {'': 0, '': 1, '': 2, '': 3} 195 | for vo in vocab: 196 | if word2index.get(vo) is None: 197 | word2index[vo] = len(word2index) 198 | index2word = {v: k for k, v in word2index.items()} 199 | itow = index2word # a 1-indexed vocab translation table 200 | wtoi = word2index # inverse table 201 | 202 | # itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 203 | # wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 204 | 205 | print('Encoding based on vocabulary...') 206 | ques_idx_train, ans_idx_train, cap_idx_train = encode_vocab(ques_tok_train, ans_tok_train, cap_tok_train, wtoi) 207 | ques_idx_val, ans_idx_val, cap_idx_val = encode_vocab(ques_tok_val, ans_tok_val, cap_tok_val, wtoi) 208 | ques_idx_test, ans_idx_test, cap_idx_test = encode_vocab(ques_tok_test, ans_tok_test, cap_tok_test, wtoi) 209 | 210 | cap_train, cap_len_train, ques_train, ques_len_train, ans_train, ans_len_train, \ 211 | opt_train, opt_list_train, opt_len_train, ans_index_train, image_list_train = \ 212 | create_mats(data_train, ques_idx_train, ans_idx_train, cap_idx_train, args) 213 | 214 | cap_val, cap_len_val, ques_val, ques_len_val, ans_val, ans_len_val, \ 215 | opt_val, opt_list_val, opt_len_val, ans_index_val, image_list_val = \ 216 | create_mats(data_val, ques_idx_val, ans_idx_val, cap_idx_val, args) 217 | 218 | cap_test, cap_len_test, ques_test, ques_len_test, ans_test, ans_len_test, \ 219 | opt_test, opt_list_test, opt_len_test, ans_index_test, image_list_test = \ 220 | create_mats(data_test, ques_idx_test, ans_idx_test, cap_idx_test, args, test_type=True) 221 | 222 | print('Saving hdf5...') 223 | f = h5py.File(args.output_h5, 'w') 224 | f.create_dataset('ques_train', dtype='uint32', data=ques_train) 225 | f.create_dataset('ques_len_train', dtype='uint32', data=ques_len_train) 226 | f.create_dataset('ans_train', dtype='uint32', data=ans_train) 227 | f.create_dataset('ans_len_train', dtype='uint32', data=ans_len_train) 228 | f.create_dataset('ans_index_train', dtype='uint32', data=ans_index_train) 229 | f.create_dataset('cap_train', dtype='uint32', data=cap_train) 230 | f.create_dataset('cap_len_train', dtype='uint32', data=cap_len_train) 231 | f.create_dataset('opt_train', dtype='uint32', data=opt_train) 232 | f.create_dataset('opt_len_train', dtype='uint32', data=opt_len_train) 233 | f.create_dataset('opt_list_train', dtype='uint32', data=opt_list_train) 234 | 235 | f.create_dataset('ques_val', dtype='uint32', data=ques_val) 236 | f.create_dataset('ques_len_val', dtype='uint32', data=ques_len_val) 237 | f.create_dataset('ans_val', dtype='uint32', data=ans_val) 238 | f.create_dataset('ans_len_val', dtype='uint32', data=ans_len_val) 239 | f.create_dataset('ans_index_val', dtype='uint32', data=ans_index_val) 240 | f.create_dataset('cap_val', dtype='uint32', data=cap_val) 241 | f.create_dataset('cap_len_val', dtype='uint32', data=cap_len_val) 242 | f.create_dataset('opt_val', dtype='uint32', data=opt_val) 243 | f.create_dataset('opt_len_val', dtype='uint32', data=opt_len_val) 244 | f.create_dataset('opt_list_val', dtype='uint32', data=opt_list_val) 245 | 246 | f.create_dataset('ques_test', dtype='uint32', data=ques_test) 247 | f.create_dataset('ques_len_test', dtype='uint32', data=ques_len_test) 248 | f.create_dataset('ans_test', dtype='uint32', data=ans_test) 249 | f.create_dataset('ans_len_test', dtype='uint32', data=ans_len_test) 250 | f.create_dataset('ans_index_test', dtype='uint32', data=ans_index_test) 251 | f.create_dataset('cap_test', dtype='uint32', data=cap_test) 252 | f.create_dataset('cap_len_test', dtype='uint32', data=cap_len_test) 253 | f.create_dataset('opt_test', dtype='uint32', data=opt_test) 254 | f.create_dataset('opt_len_test', dtype='uint32', data=opt_len_test) 255 | f.create_dataset('opt_list_test', dtype='uint32', data=opt_list_test) 256 | f.close() 257 | 258 | out = {} 259 | out['itow'] = itow # encode the (1-indexed) vocab 260 | out['wtoi'] = wtoi 261 | out['img_train'] = get_image_ids(data_train) 262 | out['img_val'] = get_image_ids(data_val) 263 | out['img_test'] = get_image_ids(data_test) 264 | json.dump(out, open(args.output_json, 'w')) 265 | 266 | 267 | #data_val_toks, ques_val_inds, ans_val_inds = encode_vocab(data_val_toks, ques_val_toks, ans_val_toks, word2ind) 268 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import utils 6 | from torch.autograd import Variable 7 | import datetime 8 | import numpy as np 9 | import sys 10 | 11 | def LangMCriterion(input, target): 12 | target = target.view(-1, 1) 13 | logprob_select = torch.gather(input, 1, target) 14 | 15 | mask = target.data.gt(0) # generate the mask 16 | if isinstance(input, Variable): 17 | mask = Variable(mask, volatile=input.volatile) 18 | out = torch.masked_select(logprob_select, mask) 19 | loss = -torch.sum(out) # get the average loss. 20 | return loss 21 | 22 | def train(model, train_loader, eval_loader, args): 23 | t = datetime.datetime.now() 24 | cur_time = '%s-%s-%s-%s-%s' % (t.year, t.month, t.day, t.hour, t.minute) 25 | save_path = os.path.join(args.output, cur_time) 26 | args.save_path = save_path 27 | utils.create_dir(save_path) 28 | optim = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) 29 | 30 | lr_default = args.lr if eval_loader is not None else 7e-4 31 | lr_decay_step = 2 32 | lr_decay_rate = .25 33 | lr_decay_epochs = range(10, 30, lr_decay_step) if eval_loader is not None else range(10, 20, lr_decay_step) 34 | gradual_warmup_steps = [0.5 * lr_default, 1.0 * lr_default, 1.5 * lr_default, 2.0 * lr_default] 35 | 36 | logger = utils.Logger(os.path.join(save_path, 'log.txt')) 37 | for arg in vars(args): 38 | logger.write('{:<20}: {}'.format(arg, getattr(args, arg))) 39 | best_eval_score = 0 40 | model.train() 41 | start_time = time.time() 42 | best_cnt = 0 43 | 44 | print('Training ... ') 45 | for epoch in range(args.epochs): 46 | 47 | total_loss = 0 48 | count = 0 49 | train_score = 0 50 | t = time.time() 51 | train_iter = iter(train_loader) 52 | 53 | # TODO: get learning rate 54 | # lr = adjust_learning_rate(optim, epoch, args.lr) 55 | if epoch < 4: 56 | optim.param_groups[0]['lr'] = gradual_warmup_steps[epoch] 57 | lr = optim.param_groups[0]['lr'] 58 | elif epoch in lr_decay_epochs: 59 | optim.param_groups[0]['lr'] *= lr_decay_rate 60 | lr = optim.param_groups[0]['lr'] 61 | else: 62 | lr = optim.param_groups[0]['lr'] 63 | iter_step = 0 64 | for i in range(len(train_loader)): 65 | average_loss_tmp = 0 66 | count_tmp = 0 67 | train_data = next(train_iter) 68 | image, image_id, history, question, answer, answerT, ans_len, ans_idx, ques_ori, opt, opt_len, opt_idx = train_data 69 | batch_size = question.size(0) 70 | image = image.view(image.size(0), -1, args.img_feat_size) 71 | img_input = Variable(image).cuda() 72 | for rnd in range(10): 73 | ques = question[:, rnd, :] 74 | his = history[:, :rnd + 1, :].clone().view(-1, args.his_length) 75 | ans = answer[:, rnd, :] 76 | tans = answerT[:, rnd, :] 77 | opt_ans = opt[:, rnd, :].clone().view(-1, args.ans_length) 78 | 79 | ques = Variable(ques).cuda().long() 80 | his = Variable(his).cuda().long() 81 | ans = Variable(ans).cuda().long() 82 | tans = Variable(tans).cuda().long() 83 | opt_ans = Variable(opt_ans).cuda().long() 84 | 85 | pred = model(img_input, ques, his, ans, tans, rnd + 1) 86 | loss = LangMCriterion(pred.view(-1, args.vocab_size), tans) 87 | loss = loss / torch.sum(tans.data.gt(0)) 88 | 89 | loss.backward() 90 | nn.utils.clip_grad_norm(model.parameters(), 0.25) 91 | optim.step() 92 | model.zero_grad() 93 | 94 | average_loss_tmp += loss.data[0] 95 | total_loss += loss.data[0] 96 | count += 1 97 | count_tmp += 1 98 | sys.stdout.write('Training: Epoch {:d} Step {:d}/{:d} \r'.format(epoch + 1, i + 1, len(train_loader))) 99 | if (i+1) % 50 == 0: 100 | average_loss_tmp /= count_tmp 101 | print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}".format(i + 1, len(train_loader), epoch + 1, average_loss_tmp, lr)) 102 | iter_step += 1 103 | total_loss /= count 104 | logger.write('Epoch %d : learningRate %4f train loss %4f Time: %3f' % (epoch + 1, lr, total_loss, time.time() - start_time)) 105 | model.eval() 106 | print('Evaluating ... ') 107 | start_time = time.time() 108 | rank_all = evaluate(model, eval_loader, args) 109 | R1 = np.sum(np.array(rank_all) == 1) / float(len(rank_all)) 110 | R5 = np.sum(np.array(rank_all) <= 5) / float(len(rank_all)) 111 | R10 = np.sum(np.array(rank_all) <= 10) / float(len(rank_all)) 112 | ave = np.sum(np.array(rank_all)) / float(len(rank_all)) 113 | mrr = np.sum(1 / (np.array(rank_all, dtype='float'))) / float(len(rank_all)) 114 | logger.write('Epoch %d: mrr: %f R1: %f R5 %f R10 %f Mean %f time: %.2f' % (epoch + 1, mrr, R1, R5, R10, ave, time.time()-start_time)) 115 | 116 | eval_score = mrr 117 | 118 | model_path = os.path.join(save_path, 'model_epoch_%d.pth' % (epoch + 1)) 119 | torch.save({'epoch': epoch, 120 | 'args': args, 121 | 'model': model.state_dict()}, model_path) 122 | 123 | if eval_score > best_eval_score: 124 | model_path = os.path.join(save_path, 'best_model.pth') 125 | torch.save({'epoch': epoch, 126 | 'args': args, 127 | 'model': model.state_dict()}, model_path) 128 | best_eval_score = eval_score 129 | best_cnt = 0 130 | else: 131 | best_cnt = best_cnt + 1 132 | if best_cnt > 10: 133 | break 134 | return model 135 | 136 | 137 | def evaluate(model, eval_loader, args, Eval=False): 138 | rank_all_tmp = [] 139 | eval_iter = iter(eval_loader) 140 | step = 0 141 | for i in range(len(eval_loader)): 142 | eval_data = next(eval_iter) 143 | image, image_id, history, question, answer, answerT, questionL, opt_answer, \ 144 | opt_answerT, answer_ids, answerLen, opt_answerLen = eval_data 145 | 146 | image = image.view(image.size(0), -1, args.img_feat_size) 147 | img_input = Variable(image).cuda() 148 | batch_size = question.size(0) 149 | for rnd in range(10): 150 | ques, tans = question[:, rnd, :], opt_answerT[:, rnd, :].clone().view(-1, args.ans_length) 151 | his = history[:, :rnd + 1, :].clone().view(-1, args.his_length) 152 | ans = opt_answer[:, rnd, :, :].clone().view(-1, args.ans_length) 153 | gt_id = answer_ids[:, rnd] 154 | 155 | ques = Variable(ques).cuda().long() 156 | tans = Variable(tans).cuda().long() 157 | his = Variable(his).cuda().long() 158 | ans = Variable(ans).cuda().long() 159 | gt_index = Variable(gt_id).cuda().long() 160 | 161 | pred = model(img_input, ques, his, ans, tans, rnd + 1, Training=False) 162 | logprob = - pred.permute(1, 0, 2).contiguous().view(-1, args.vocab_size) 163 | 164 | logprob_select = torch.gather(logprob, 1, tans.t().contiguous().view(-1, 1)) 165 | 166 | mask = tans.t().data.eq(0) # generate the mask 167 | if isinstance(logprob, Variable): 168 | mask = Variable(mask, volatile=logprob.volatile) 169 | logprob_select.masked_fill_(mask.view_as(logprob_select), 0) 170 | 171 | prob = logprob_select.view(args.ans_length, -1, 100).sum(0).view(-1, 100) 172 | 173 | for b in range(batch_size): 174 | gt_index.data[b] = gt_index.data[b] + b * 100 175 | 176 | gt_score = prob.view(-1).index_select(0, gt_index) 177 | sort_score, sort_idx = torch.sort(prob, 1) 178 | 179 | count = sort_score.lt(gt_score.view(-1, 1).expand_as(sort_score)) 180 | rank = count.sum(1) + 1 181 | 182 | rank_all_tmp += list(rank.view(-1).data.cpu().numpy()) 183 | step += 1 184 | if Eval: 185 | sys.stdout.write('Evaluating: {:d}/{:d} \r'.format(i, len(eval_loader))) 186 | if (i+1) % 50 == 0: 187 | R1 = np.sum(np.array(rank_all_tmp) == 1) / float(len(rank_all_tmp)) 188 | R5 = np.sum(np.array(rank_all_tmp) <= 5) / float(len(rank_all_tmp)) 189 | R10 = np.sum(np.array(rank_all_tmp) <= 10) / float(len(rank_all_tmp)) 190 | ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp)) 191 | mrr = np.sum(1 / (np.array(rank_all_tmp, dtype='float'))) / float(len(rank_all_tmp)) 192 | print('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' % (i+1, len(eval_loader), mrr, R1, R5, R10, ave)) 193 | 194 | return rank_all_tmp 195 | 196 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import errno 4 | import os 5 | import numpy as np 6 | from PIL import Image 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | EPS = 1e-7 12 | 13 | 14 | def assert_eq(real, expected): 15 | assert real == expected, '%s (true) vs %s (expected)' % (real, expected) 16 | 17 | 18 | def assert_array_eq(real, expected): 19 | assert (np.abs(real-expected) < EPS).all(), \ 20 | '%s (true) vs %s (expected)' % (real, expected) 21 | 22 | 23 | def load_folder(folder, suffix): 24 | imgs = [] 25 | for f in sorted(os.listdir(folder)): 26 | if f.endswith(suffix): 27 | imgs.append(os.path.join(folder, f)) 28 | return imgs 29 | 30 | 31 | def load_imageid(folder): 32 | images = load_folder(folder, 'jpg') 33 | img_ids = set() 34 | for img in images: 35 | img_id = int(img.split('/')[-1].split('.')[0].split('_')[-1]) 36 | img_ids.add(img_id) 37 | return img_ids 38 | 39 | 40 | def pil_loader(path): 41 | with open(path, 'rb') as f: 42 | with Image.open(f) as img: 43 | return img.convert('RGB') 44 | 45 | 46 | def weights_init(m): 47 | """custom weights initialization.""" 48 | cname = m.__class__ 49 | if cname == nn.Linear or cname == nn.Conv2d or cname == nn.ConvTranspose2d: 50 | m.weight.data.normal_(0.0, 0.02) 51 | elif cname == nn.BatchNorm2d: 52 | m.weight.data.normal_(1.0, 0.02) 53 | m.bias.data.fill_(0) 54 | else: 55 | print('%s is not initialized.' % cname) 56 | 57 | 58 | def init_net(net, net_file): 59 | if net_file: 60 | net.load_state_dict(torch.load(net_file)) 61 | else: 62 | net.apply(weights_init) 63 | 64 | 65 | def create_dir(path): 66 | if not os.path.exists(path): 67 | try: 68 | os.makedirs(path) 69 | except OSError as exc: 70 | if exc.errno != errno.EEXIST: 71 | raise 72 | 73 | 74 | class Logger(object): 75 | def __init__(self, output_name): 76 | dirname = os.path.dirname(output_name) 77 | if not os.path.exists(dirname): 78 | os.mkdir(dirname) 79 | 80 | self.log_file = open(output_name, 'w') 81 | self.infos = {} 82 | 83 | def append(self, key, val): 84 | vals = self.infos.setdefault(key, []) 85 | vals.append(val) 86 | 87 | def log(self, extra_msg=''): 88 | msgs = [extra_msg] 89 | for key, vals in self.infos.iteritems(): 90 | msgs.append('%s %.6f' % (key, np.mean(vals))) 91 | msg = '\n'.join(msgs) 92 | self.log_file.write(msg + '\n') 93 | self.log_file.flush() 94 | self.infos = {} 95 | return msg 96 | 97 | def write(self, msg): 98 | self.log_file.write(msg + '\n') 99 | self.log_file.flush() 100 | print(msg) 101 | --------------------------------------------------------------------------------