├── model
└── tmp.txt
├── result
└── results_BOWIMG.zip
├── README.md
├── LinearNB.lua
├── opensource_utils.lua
├── opensource_baseline.lua
├── opensource_interactive.lua
└── opensource_base.lua
/model/tmp.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/result/results_BOWIMG.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhoubolei/VQAbaseline/HEAD/result/results_BOWIMG.zip
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Simple Baseline for Visual Question Answering
2 |
3 | We descrive a very simple bag-of-words baseline for visual question answering. The description of the baseline is in the arXiv paper http://arxiv.org/pdf/1512.02167.pdf. The code is developed by [Bolei Zhou](http://people.csail.mit.edu/bzhou/) and [Yuandong Tian](http://yuandong-tian.com/).
4 |
5 | 
6 |
7 |
8 | Demo is available at http://visualqa.csail.mit.edu/
9 |
10 | To train the model using the code, the following data of the VQA dataset are needed:
11 | - The pre-processed data of text is at http://visualqa.csail.mit.edu/data_vqa_txt.zip
12 | - The googlenet feature of all the COCO images is at http://visualqa.csail.mit.edu/data_vqa_feat.zip
13 |
14 | The pre-trained model used in the paper is at http://visualqa.csail.mit.edu/coco_qadevi_BOWIMG_bestepoch93_final.t7model. It has 55.89 on the Open-Ended and 61.69 on Multiple-Choice for the test-standard of COCO VQA dataset.
15 |
16 | Contact Bolei Zhou (zhoubolei@gmail.com) if you have any questions.
17 |
18 | Please cite our arXiv note if you use our code:
19 |
20 | B. Zhou, Y. Tian, S. Suhkbaatar, A. Szlam, R. Fergus.
21 | Simple Baseline for Visual Question Answering.
22 | arXiv:1512.02167
23 |
--------------------------------------------------------------------------------
/LinearNB.lua:
--------------------------------------------------------------------------------
1 | local LinearNB, parent = torch.class('nn.LinearNB', 'nn.Module')
2 |
3 | function LinearNB:__init(inputSize, outputSize)
4 | parent.__init(self)
5 |
6 | self.outputSize = outputSize
7 | self.inputSize = inputSize
8 | self.weight = torch.Tensor(outputSize, inputSize)
9 | self.gradWeight = torch.Tensor(outputSize, inputSize)
10 |
11 | self:reset()
12 | end
13 |
14 | function LinearNB:reset(stdv)
15 | if stdv then
16 | stdv = stdv * math.sqrt(3)
17 | else
18 | stdv = 1./math.sqrt(self.weight:size(2))
19 | end
20 | if nn.oldSeed then
21 | for i=1,self.weight:size(1) do
22 | self.weight:select(1, i):apply(function()
23 | return torch.uniform(-stdv, stdv)
24 | end)
25 | end
26 | else
27 | self.weight:uniform(-stdv, stdv)
28 | end
29 | end
30 |
31 | function LinearNB:updateOutput(input)
32 | if input:dim() == 1 then
33 | self.output:resize(self.outputSize)
34 | self.output:zero()
35 | self.output:addmv(1, self.weight, input)
36 | elseif input:dim() == 2 then
37 | local nframe = input:size(1)
38 | local nunit = self.outputSize
39 | self.output:resize(nframe, nunit):zero()
40 | if not self.addBuffer or self.addBuffer:size(1) ~= nframe then
41 | self.addBuffer = input.new(nframe):fill(1)
42 | end
43 | if nunit == 1 then
44 | -- Special case to fix output size of 1 bug:
45 | self.output:select(2,1):addmv(1, input, self.weight:select(1,1))
46 | else
47 | self.output:addmm(1, input, self.weight:t())
48 | end
49 | else
50 | error('input must be vector or matrix')
51 | end
52 |
53 | return self.output
54 | end
55 |
56 | function LinearNB:updateGradInput(input, gradOutput)
57 | if self.gradInput then
58 |
59 | local nElement = self.gradInput:nElement()
60 | self.gradInput:resizeAs(input)
61 | if self.gradInput:nElement() ~= nElement then
62 | self.gradInput:zero()
63 | end
64 | if input:dim() == 1 then
65 | self.gradInput:addmv(0, 1, self.weight:t(), gradOutput)
66 | elseif input:dim() == 2 then
67 | self.gradInput:addmm(0, 1, gradOutput, self.weight)
68 | end
69 |
70 | return self.gradInput
71 | end
72 | end
73 |
74 | function LinearNB:accGradParameters(input, gradOutput, scale)
75 | scale = scale or 1
76 |
77 | if input:dim() == 1 then
78 | self.gradWeight:addr(scale, gradOutput, input)
79 | elseif input:dim() == 2 then
80 | local nunit = self.outputSize
81 | if nunit == 1 then
82 | -- Special case to fix output size of 1 bug:
83 | self.gradWeight:select(1,1):addmv(scale,
84 | input:t(), gradOutput:select(2,1))
85 | else
86 | self.gradWeight:addmm(scale, gradOutput:t(), input)
87 | end
88 | end
89 | end
90 |
91 | -- we do not need to accumulate parameters when sharing
92 | LinearNB.sharedAccUpdateGradParameters = LinearNB.accUpdateGradParameters
93 |
94 |
--------------------------------------------------------------------------------
/opensource_utils.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 | require 'cunn'
3 |
4 | paths.dofile('LinearNB.lua')
5 |
6 | function getFreeGPU()
7 | -- select the most available GPU to train
8 | local nDevice = cutorch.getDeviceCount()
9 | local memSet = torch.Tensor(nDevice)
10 | for i=1, nDevice do
11 | local tmp, _ = cutorch.getMemoryUsage(i)
12 | memSet[i] = tmp
13 | end
14 | local _, curDeviceID = torch.max(memSet,1)
15 | return curDeviceID[1]
16 | end
17 |
18 | function build_model(opt, manager_vocab)
19 | -- function to build up baseline model
20 | local model
21 | if opt.method == 'BOW' then
22 | model = nn.Sequential()
23 | local module_tdata = nn.LinearNB(manager_vocab.nvocab_question, opt.embed_word)
24 | model:add(module_tdata)
25 | model:add(nn.Linear(opt.embed_word, manager_vocab.nvocab_answer))
26 |
27 | elseif opt.method == 'IMG' then
28 | model = nn.Sequential()
29 | model:add(nn.Linear(opt.vdim, manager_vocab.nvocab_answer))
30 |
31 | elseif opt.method == 'BOWIMG' then
32 | model = nn.Sequential()
33 | local module_tdata = nn.Sequential():add(nn.SelectTable(1)):add(nn.LinearNB(manager_vocab.nvocab_question, opt.embed_word))
34 | local module_vdata = nn.Sequential():add(nn.SelectTable(2))
35 | local cat = nn.ConcatTable():add(module_tdata):add(module_vdata)
36 | model:add(cat):add(nn.JoinTable(2))
37 | model:add(nn.LinearNB(opt.embed_word + opt.vdim, manager_vocab.nvocab_answer))
38 |
39 | else
40 | print('no such methods')
41 |
42 | end
43 |
44 | model:add(nn.LogSoftMax())
45 | local criterion = nn.ClassNLLCriterion()
46 | criterion.sizeAverage = false
47 | model:cuda()
48 | criterion:cuda()
49 |
50 | return model, criterion
51 | end
52 |
53 | function initial_params()
54 | local gpuidx = getFreeGPU()
55 | print('use GPU IDX=' .. gpuidx)
56 | cutorch.setDevice(gpuidx)
57 |
58 | local cmd = torch.CmdLine()
59 |
60 | -- parameters for general setting
61 | cmd:option('--savepath', 'model')
62 |
63 | -- parameters for the visual feature
64 | cmd:option('--vfeat', 'googlenetFC')
65 | cmd:option('--vdim', 1024)
66 |
67 | -- parameters for data pre-process
68 | cmd:option('--thresh_questionword',6, 'threshold for the word freq on question')
69 | cmd:option('--thresh_answerword', 3, 'threshold for the word freq on the answer')
70 | cmd:option('--batchsize', 100)
71 | cmd:option('--seq_length', 50)
72 |
73 | -- parameters for learning
74 | cmd:option('--uniformLR', 0, 'whether to use uniform learning rate for all the parameters')
75 | cmd:option('--epochs', 100)
76 | cmd:option('--nepoch_lr', 100)
77 | cmd:option('--decay', 1.2)
78 | cmd:option('--embed_word', 1024,'the word embedding dimension in baseline')
79 |
80 | -- parameters for universal learning rate
81 | cmd:option('--maxgradnorm', 20)
82 | cmd:option('--maxweightnorm', 2000)
83 |
84 | -- parameters for different learning rates for different layers
85 | cmd:option('--lr_wordembed', 0.8)
86 | cmd:option('--lr_other', 0.01)
87 | cmd:option('--weightClip_wordembed', 1500)
88 | cmd:option('--weightClip_other', 20)
89 |
90 | return cmd:parse(arg or {})
91 | end
92 |
93 | function loadPretrained(opt)
94 | --load the pre-trained model then evaluate on the test set then generate the csv file that could be submitted to the evaluation server
95 | local method = 'BOWIMG'
96 | local model_path = 'model/BOWIMG.t7'
97 | opt.method = method
98 |
99 | -- load pre-trained model
100 | local f_model = torch.load(model_path)
101 | local manager_vocab = f_model.manager_vocab
102 | -- Some simple fix for old models.
103 | if manager_vocab.vocab_map_question['END'] == nil then
104 | manager_vocab.vocab_map_question['END'] = -1
105 | manager_vocab.ivocab_map_question[-1] = 'END'
106 | end
107 |
108 | local model, criterion = build_model(opt, manager_vocab)
109 | local paramx, paramdx = model:getParameters()
110 | paramx:copy(f_model.paramx)
111 |
112 | return {
113 | model = model,
114 | criterion = criterion,
115 | manager_vocab = manager_vocab
116 | }
117 | end
118 |
--------------------------------------------------------------------------------
/opensource_baseline.lua:
--------------------------------------------------------------------------------
1 | require 'paths'
2 |
3 | local stringx = require 'pl.stringx'
4 | local file = require 'pl.file'
5 |
6 | paths.dofile('opensource_base.lua')
7 | paths.dofile('opensource_utils.lua')
8 |
9 | local debugger = require 'fb.debugger'
10 |
11 | function adjust_learning_rate(epoch_num, opt, config_layers)
12 | -- Every opt.nepoch_lr iterations, the learning rate is reduced.
13 | if epoch_num % opt.nepoch_lr == 0 then
14 | for j = 1, #config_layers.lr_rates do
15 | config_layers.lr_rates[j] = config_layers.lr_rates[j] / opt.decay
16 | end
17 | end
18 | end
19 |
20 | function runTrainVal()
21 | local method = 'BOWIMG'
22 | local step_trainval = true -- step for train and valiaton
23 | local step_trainall = true -- step for combining train2014 and val2014
24 | local opt = initial_params()
25 | opt.method = method
26 | opt.save = paths.concat(opt.savepath, method ..'.t7')
27 |
28 | local stat = {}
29 | -- load data inside
30 | if step_trainval then
31 | local state_train, manager_vocab = load_visualqadataset(opt, 'trainval2014_train', nil)
32 | local state_val, _ = load_visualqadataset(opt, 'trainval2014_val', manager_vocab)
33 | local model, criterion = build_model(opt, manager_vocab)
34 | local paramx, paramdx = model:getParameters()
35 | local params_current, gparams_current = model:parameters()
36 |
37 | local config_layers, grad_last = config_layer_params(opt, params_current, gparams_current, 1)
38 | -- Save variables into context so that train_epoch could use.
39 | local context = {
40 | model = model,
41 | criterion = criterion,
42 | paramx = paramx,
43 | paramdx = paramdx,
44 | params_current = params_current,
45 | gparams_current = gparams_current,
46 | config_layers = config_layers,
47 | grad_last = grad_last
48 | }
49 | print(params_current)
50 | print('start training ...')
51 | for i = 1, opt.epochs do
52 | print(method .. ' epoch '..i)
53 | train_epoch(opt, state_train, manager_vocab, context, 'train')
54 | _, _, perfs = train_epoch(opt, state_val, manager_vocab, context, 'val')
55 | -- Accumulate statistics
56 | stat[i] = {acc, perfs.most_freq, perfs.openend_overall, perfs.multiple_overall}
57 | -- Adjust the learning rate
58 | adjust_learning_rate(i, opt, config_layers)
59 | end
60 | end
61 |
62 | if step_trainall then
63 | local nEpoch_best = 1
64 | local acc_openend_best = 0
65 | if step_trainval then
66 |
67 | -- Select the best train epoch number and combine train2014 and val2014
68 | for i = 1, #stat do
69 | if stat[i][3]> acc_openend_best then
70 | nEpoch_best = i
71 | acc_openend_best = stat[i][3]
72 | end
73 | end
74 |
75 | print('best epoch number is ' .. nEpoch_best)
76 | print('best acc is ' .. acc_openend_best)
77 | else
78 | nEpoch_best = 100
79 | end
80 | -- Combine train2014 and val2014
81 | local nEpoch_trainAll = nEpoch_best
82 | local state_train, manager_vocab = load_visualqadataset(opt, 'trainval2014', nil)
83 | -- recreate the model
84 | local model, criterion = build_model(opt, manager_vocab)
85 | local paramx, paramdx = model:getParameters()
86 | local params_current, gparams_current = model:parameters()
87 |
88 | local config_layers, grad_last = config_layer_params(opt, params_current, gparams_current, 1)
89 |
90 | local context = {
91 | model = model,
92 | criterion = criterion,
93 | paramx = paramx,
94 | paramdx = paramdx,
95 | params_current = params_current,
96 | gparams_current = gparams_current,
97 | config_layers = config_layers,
98 | grad_last = grad_last
99 | }
100 | print(params_current)
101 |
102 | print('start training on all data ...')
103 | stat = {}
104 | for i=1, nEpoch_trainAll do
105 | print('epoch '..i .. '/' ..nEpoch_trainAll)
106 | _, _, perfs = train_epoch(opt, state_train, manager_vocab, context, 'train')
107 | stat[i] = {acc, perfs.most_freq, perfs.openend_overall, perfs.multiple_overall}
108 | adjust_learning_rate(i, opt, config_layers)
109 |
110 | local modelname_curr = opt.save
111 | save_model(opt, manager_vocab, context, modelname_curr)
112 | end
113 | end
114 | end
115 |
116 | function runTest()
117 | --load the pre-trained model then evaluate on the test set then generate the csv file that could be submitted to the evaluation server
118 | local opt = initial_params()
119 | local context = loadPretrained(opt)
120 | local manager_vocab = context.manager_vocab
121 |
122 | -- load test data
123 | local testSet = 'test-dev2015' --'test2015' and 'test-dev2015'
124 | local state_test, _ = load_visualqadataset(opt, testSet, manager_vocab)
125 |
126 | -- predict
127 | local pred, prob, perfs = train_epoch(opt, state_test, manager_vocab, context, 'test')
128 |
129 | -- output to csv file to be submitted to the VQA evaluation server
130 | local file_json_openend = 'result/vqa_OpenEnded_mscoco_' .. testSet .. '_'.. opt.method .. '_results.json'
131 | local file_json_multiple = 'result/vqa_MultipleChoice_mscoco_' .. testSet .. '_'.. opt.method .. '_results.json'
132 | print('output the OpenEnd prediction to JSON file...' .. file_json_openend)
133 | local choice = 0
134 | outputJSONanswer(state_test, manager_vocab, prob, file_json_openend, choice)
135 | print('output the MultipleChoice prediction to JSON file...' .. file_json_multiple)
136 | choice = 1
137 | outputJSONanswer(state_test, manager_vocab, prob, file_json_multiple, choice)
138 |
139 | collectgarbage()
140 |
141 | end
142 |
143 | runTrainVal()
144 | runTest()
145 |
--------------------------------------------------------------------------------
/opensource_interactive.lua:
--------------------------------------------------------------------------------
1 | require 'paths'
2 | require 'cunn'
3 | require 'nn'
4 |
5 | local stringx = require 'pl.stringx'
6 | local file = require 'pl.file'
7 | local debugger = require 'fb.debugger'
8 |
9 | paths.dofile('opensource_base.lua')
10 | paths.dofile('opensource_utils.lua')
11 |
12 | local rank_context = { }
13 |
14 | function rankPrepare()
15 | local opt = initial_params()
16 | rank_context.context = loadPretrained(opt)
17 |
18 | local manager_vocab = rank_context.context.manager_vocab
19 | local model = rank_context.context.model
20 |
21 | -- Question-Answer matrix:
22 | --
23 | -- The entire model is like r = M_i * v_i + M_w * M_emb * v_onehot, there is no non-linearity.
24 | -- So we could just dump M_w * M_emb from here.
25 | --
26 | -- nvocab_answer * embed_word
27 | local M_w = model.modules[3].weight[{ { }, { 1, opt.embed_word } }]
28 | -- embed_word * nvocab_question
29 | local M_emb = model.modules[1].modules[1].modules[2].weight
30 |
31 | -- nvocab_answer * nvocab_question
32 | rank_context.rankWord_M = torch.mm(M_w, M_emb)
33 |
34 | -- Image-Answer matrix: vdim * nvocab_answer
35 | rank_context.rankImage_M = model.modules[3].weight[{ { }, { opt.embed_word + 1, opt.embed_word + opt.vdim }}]:transpose(1, 2)
36 |
37 | -- load test data
38 | local testSet = 'test-dev2015' --'test2015' and 'test-dev2015'
39 | local state, _ = load_visualqadataset(opt, testSet, manager_vocab)
40 |
41 | local x_rep = rank_context.rankWord_M.new():resize(state.x_question:size(1), manager_vocab.nvocab_question):zero()
42 | for i = 1, state.x_question:size(1) do
43 | x_rep[i]:copy(bagofword(manager_vocab, state.x_question[i]))
44 | end
45 | -- nvocab_answer * #question
46 | rank_context.rankQuestion_M = torch.mm(rank_context.rankWord_M, x_rep:transpose(1, 2))
47 | rank_context.questions = state.data_question
48 |
49 | rank_context.imglist = { }
50 | rank_context.inv_imglist = { }
51 |
52 | for k, v in pairs(state.featureMap) do
53 | table.insert(rank_context.imglist, k)
54 | rank_context.inv_imglist[k] = #rank_context.imglist
55 | end
56 | rank_context.featureMap = state.featureMap
57 |
58 | -- Dump rankImg matrix.
59 | local batch = 128
60 | local mat_feat = M_w.new():resize(batch, opt.vdim):zero()
61 | local rankImage_M2 = M_w.new():resize(#rank_context.imglist, manager_vocab.nvocab_answer):zero()
62 |
63 | for i = 1, #rank_context.imglist, batch do
64 | local this_batch = math.min(batch, #rank_context.imglist - i + 1)
65 | for j = 1, this_batch do
66 | local img_name = rank_context.imglist[j + i - 1]
67 | mat_feat[j]:copy(rank_context.featureMap[img_name])
68 | end
69 | -- batch * nvocab_answer
70 | local this_output = torch.mm(mat_feat:sub(1, this_batch), rank_context.rankImage_M)
71 | -- Collect the score.
72 | rankImage_M2[{ {i, i + this_batch - 1}, { } }]:copy(this_output)
73 | end
74 | -- nvocab_answer * #rank_context.imglist
75 | rank_context.rankImage_M2 = rankImage_M2:transpose(1, 2)
76 | end
77 |
78 | function rankWord(answer, topn)
79 | local manager_vocab = rank_context.context.manager_vocab
80 | local model = rank_context.context.model
81 |
82 | -- Convert the answer to idx and return topn words
83 | -- require 'fb.debugger'.enter()
84 | if answer:sub(1, 1) == '"' then answer = answer:sub(2, -2) end
85 | local answerid = manager_vocab.vocab_map_answer[answer]
86 | if answerid == nil then return end
87 |
88 | local score = rank_context.rankWord_M[answerid]:clone():squeeze()
89 | local sortedScore, sortedIndices = score:sort(1, true)
90 |
91 | local res = { }
92 | for i = 1, topn do
93 | local w = manager_vocab.ivocab_map_question[sortedIndices[i]]
94 | table.insert(res, { word = w, score = sortedScore[i] })
95 | end
96 |
97 | return res
98 | end
99 |
100 | function rankImage(answer, topn)
101 | local manager_vocab = rank_context.context.manager_vocab
102 |
103 | -- Given the answer, rank the image most relevant to the answer.
104 | -- Convert the answer to idx and return topn words
105 | if answer:sub(1, 1) == '"' then answer = answer:sub(2, -2) end
106 | local answerid = manager_vocab.vocab_map_answer[answer]
107 | if answerid == nil then return end
108 |
109 | local score = rank_context.rankImage_M2[answerid]:clone():squeeze()
110 | local sortedScore, sortedIndices = score:sort(1, true)
111 |
112 | local res = { }
113 | for i = 1, topn do
114 | local idx = sortedIndices[i]
115 | -- Get the image link.
116 | local imgname = rank_context.imglist[idx]
117 | local id = tonumber(imgname:match("_(%d+)"))
118 | local url = "http://mscoco.org/images/" .. tostring(id)
119 | table.insert(res, { idx = idx, imgname = imgname, url = url, score = sortedScore[i] })
120 | end
121 |
122 | return res
123 | end
124 |
125 | function rankQuestion(answer, topn)
126 | local manager_vocab = rank_context.context.manager_vocab
127 |
128 | -- Given the answer, rank the image most relevant to the answer.
129 | -- Convert the answer to idx and return topn words
130 | if answer:sub(1, 1) == '"' then answer = answer:sub(2, -2) end
131 | local answerid = manager_vocab.vocab_map_answer[answer]
132 | if answerid == nil then return end
133 |
134 | local score = rank_context.rankQuestion_M[answerid]:clone():squeeze()
135 | local sortedScore, sortedIndices = score:sort(1, true)
136 |
137 | local res = { }
138 | for i = 1, topn do
139 | local idx = sortedIndices[i]
140 | table.insert(res, { idx = idx, question = rank_context.questions[idx], score = sortedScore[i] })
141 | end
142 |
143 | return res
144 | end
145 |
146 | local function smart_split(s, quotes)
147 | quotes = quotes or { ['"'] = true }
148 | local res = { }
149 | local start = 1
150 | local quote_stack = { }
151 | for i = 1, #s do
152 | local c = s:sub(i, i)
153 | if c == ' ' then
154 | if #quote_stack == 0 then
155 | table.insert(res, s:sub(start, i - 1))
156 | start = i + 1
157 | end
158 | elseif quotes[c] then
159 | if #quote_stack == 0 or c ~= quote_stack[#quote_stack] then
160 | table.insert(quote_stack, c)
161 | else
162 | table.remove(quote_stack)
163 | end
164 | end
165 | end
166 | table.insert(res, s:sub(start, -1))
167 | return res
168 | end
169 |
170 | local commands = {
171 | rankw = {
172 | -- Given answer, rank words.
173 | exec = function(tokens)
174 | local topn = tonumber(tokens[3]) or 5
175 | local res = rankWord(tokens[2], topn)
176 | local success = false
177 | local s = ""
178 | if res then
179 | for i, v in pairs(res) do
180 | s = s .. string.format("[%d]: %s (%.2f)", i, v.word, v.score) .. "\n"
181 | end
182 | success = true
183 | end
184 | return success, s
185 | end,
186 | help = "\"rankw answer 5\" will rank the word and show top5 question words that give the answer. If 5 is not given, default to top5. If the answer contains white space, use quote to separate."
187 | },
188 | rankq = {
189 | -- Given answer, rank sentence.
190 | exec = function(tokens)
191 | local topn = tonumber(tokens[3]) or 5
192 | local res = rankQuestion(tokens[2], topn)
193 | local success = false
194 | local s = ""
195 | if res then
196 | for i, v in pairs(res) do
197 | s = s .. string.format("[%d]: %s (%.2f)", i, v.question, v.score) .. "\n"
198 | end
199 | success = true
200 | end
201 | return success, s
202 | end,
203 | help = "\"rankq answer 5\" will rank the question from test set and show top5 that give the answer. If 5 is not given, default to top5. If the answer contains white space, use quote to separate."
204 | },
205 | ranki = {
206 | exec = function(tokens)
207 | local topn = tonumber(tokens[3]) or 5
208 | local res = rankImage(tokens[2], topn)
209 | local success = false
210 | local s = ""
211 | if res then
212 | for i, v in pairs(res) do
213 | s = s .. string.format("[%d]: %s (%.2f)", i, v.imgname, v.score) .. "\n"
214 | end
215 | for i, v in pairs(res) do
216 | s = s .. string.format("
", v.url)
217 | end
218 | s = s .. "\n"
219 | success = true
220 | end
221 | return success, s
222 | end,
223 | help = "\"ranki answer 5\" will rank the image and show top5 images that give the answer. If 5 is not given, default to top5. If the answer contains white space, use quote to separate."
224 | },
225 | quit = {
226 | exec = function(tokens)
227 | return true, "Bye bye!", true
228 | end,
229 | help = "Quit the interactive environment"
230 | }
231 | }
232 |
233 | local history = ""
234 |
235 | local function init_webpage()
236 | history = ""
237 | end
238 |
239 | local function write2webpage(s)
240 | local lines = stringx.split(s, "\n")
241 | for i = 1, #lines do
242 | history = history .. lines[i] .. "
\n"
243 | end
244 | end
245 |
246 | local function save_webpage(filename)
247 | io.open(filename, "w"):write(history .. ""):close()
248 | end
249 |
250 | function run_interactive()
251 | -- Run interactive environment.
252 | print("Preload the model...")
253 | rankPrepare()
254 |
255 | -- Generate help string.
256 | local help_str = "Usage: \n"
257 | for k, v in pairs(commands) do
258 | help_str = help_str .. k .. ":\n " .. commands[k].help .. "\n"
259 | end
260 | print("Ready...")
261 |
262 | init_webpage()
263 |
264 | while true do
265 | local command = io.read("*l")
266 | local tokens = smart_split(command)
267 | if #tokens > 0 then
268 | if tokens[1] == "help" then
269 | print(help_str)
270 | else
271 | local success, s, is_quit = commands[tokens[1]].exec(tokens)
272 | print(s)
273 | if success then
274 | write2webpage(command)
275 | write2webpage(s)
276 | save_webpage("/home/yuandong/public_html/webpage.html")
277 | end
278 | if is_quit then break end
279 | end
280 | end
281 | end
282 | end
283 |
284 | run_interactive()
285 |
--------------------------------------------------------------------------------
/opensource_base.lua:
--------------------------------------------------------------------------------
1 | --local debugger = require 'fb.debugger'
2 | local stringx = require 'pl.stringx'
3 | local file = require 'pl.file'
4 |
5 | -- Here we specify different learning rate and gradClip for different layers.
6 | -- This is *critical* for the performance of BOW.
7 | function config_layer_params(opt, params_current, gparams_current, IDX_wordembed)
8 | local lr_wordembed = opt.lr_wordembed
9 | local lr_other = opt.lr_other
10 | local weightClip_wordembed = opt.weightClip_wordembed
11 | local weightClip_other = opt.weightClip_other
12 |
13 | print("lr_wordembed = " .. lr_wordembed)
14 | print("lr_other = " .. lr_other)
15 | print("weightClip_wordembed = " .. weightClip_wordembed)
16 | print("weightClip_other = " .. weightClip_other)
17 |
18 | local gradientClip_dummy = 0.1
19 | local weightRegConsts_dummy = 0.000005
20 | local initialRange_dummy = 0.1
21 | local moments_dummy = 0.9
22 |
23 | -- Initialize specification of layers.
24 | local config_layers = {
25 | lr_rates = {},
26 | gradientClips = {},
27 | weightClips = {},
28 | moments = {},
29 | weightRegConsts = {},
30 | initialRange = {}
31 | }
32 |
33 | -- grad_last is used to add momentum to the gradient.
34 | local grad_last = {}
35 | if IDX_wordembed == 1 then
36 | -- assume wordembed matrix is the params_current[1]
37 | config_layers.lr_rates = {lr_wordembed}
38 | config_layers.gradientClips = {gradientClip_dummy}
39 | config_layers.weightClips = {weightClip_wordembed}
40 | config_layers.moments = {moments_dummy}
41 | config_layers.weightRegConsts = {weightRegConsts_dummy}
42 | config_layers.initialRange = {initialRange_dummy}
43 | for i = 2, #params_current do
44 | table.insert(config_layers.lr_rates, lr_other)
45 | table.insert(config_layers.moments, moments_dummy)
46 | table.insert(config_layers.gradientClips, gradientClip_dummy)
47 | table.insert(config_layers.weightClips, weightClip_other)
48 | table.insert(config_layers.weightRegConsts, weightRegConsts_dummy)
49 | table.insert(config_layers.initialRange, initialRange_dummy)
50 | end
51 |
52 | else
53 | for i = 1, #params_current do
54 | table.insert(config_layers.lr_rates, lr_other)
55 | table.insert(config_layers.moments, moments_dummy)
56 | table.insert(config_layers.gradientClips, gradientClip_dummy)
57 | table.insert(config_layers.weightClips, weightClip_other)
58 | table.insert(config_layers.weightRegConsts, weightRegConsts_dummy)
59 | table.insert(config_layers.initialRange, initialRange_dummy)
60 | end
61 | end
62 |
63 | for i=1, #gparams_current do
64 | grad_last[i] = gparams_current[i]:clone()
65 | grad_last[i]:fill(0)
66 | end
67 | return config_layers, grad_last
68 | end
69 |
70 | ---------------------------------------
71 | ---- data IO relevant functions--------
72 | ---------------------------------------
73 |
74 | function existfile(filename)
75 | local f=io.open(filename,"r")
76 | if f~=nil then io.close(f) return true else return false end
77 | end
78 |
79 | function load_filelist(fname)
80 | local data = file.read(fname)
81 | data = stringx.replace(data,'\n',' ')
82 | data = stringx.split(data)
83 | local imglist_ind = {}
84 | for i=1, #data do
85 | imglist_ind[i] = stringx.split(data[i],'.')[1]
86 | end
87 | return imglist_ind
88 | end
89 |
90 | function build_vocab(data, thresh, IDX_singleline, IDX_includeEnd)
91 | if IDX_singleline == 1 then
92 | data = stringx.split(data,'\n')
93 | else
94 | data = stringx.replace(data,'\n', ' ')
95 | data = stringx.split(data)
96 | end
97 | local countWord = {}
98 | for i=1, #data do
99 | if countWord[data[i]] == nil then
100 | countWord[data[i]] = 1
101 | else
102 | countWord[data[i]] = countWord[data[i]] + 1
103 | end
104 | end
105 | local vocab_map_ = {}
106 | local ivocab_map_ = {}
107 | local vocab_idx = 0
108 | if IDX_includeEnd==1 then
109 | vocab_idx = 1
110 | vocab_map_['NA'] = 1
111 | ivocab_map_[1] = 'NA'
112 | end
113 |
114 | for i=1, #data do
115 | if vocab_map_[data[i]]==nil then
116 | if countWord[data[i]]>=thresh then
117 | vocab_idx = vocab_idx+1
118 | vocab_map_[data[i]] = vocab_idx
119 | ivocab_map_[vocab_idx] = data[i]
120 | --print(vocab_idx..'-'.. data[i] ..'--'.. countWord[data[i]])
121 | else
122 | vocab_map_[data[i]] = vocab_map_['NA']
123 | end
124 | end
125 | end
126 | vocab_map_['END'] = -1
127 | return vocab_map_, ivocab_map_, vocab_idx
128 | end
129 |
130 | function load_visualqadataset(opt, dataType, manager_vocab)
131 | -- Change it to your path.
132 | -- local path_imglist = 'datasets/coco_dataset/allimage2014'
133 | -- All COCO images.
134 |
135 | -- VQA question/answer txt files.
136 | -- Download data_vqa_feat.zip and data_vqa_txt.zip and decompress into this folder
137 | local path_dataset = '/data/local/vqa_opensource'
138 |
139 | local prefix = 'coco_' .. dataType
140 | local filename_question = paths.concat(path_dataset, prefix .. '_question.txt')
141 | local filename_answer = paths.concat(path_dataset, prefix .. '_answer.txt')
142 | local filename_imglist = paths.concat(path_dataset, prefix .. '_imglist.txt')
143 | local filename_allanswer = paths.concat(path_dataset, prefix .. '_allanswer.txt')
144 | local filename_choice = paths.concat(path_dataset, prefix .. '_choice.txt')
145 | local filename_question_type = paths.concat(path_dataset, prefix .. '_question_type.txt')
146 | local filename_answer_type = paths.concat(path_dataset, prefix .. '_answer_type.txt')
147 | local filename_questionID = paths.concat(path_dataset, prefix .. '_questionID.txt')
148 |
149 | if existfile(filename_allanswer) then
150 | data_allanswer = file.read(filename_allanswer)
151 | data_allanswer = stringx.split(data_allanswer,'\n')
152 | end
153 | if existfile(filename_choice) then
154 | data_choice = file.read(filename_choice)
155 | data_choice = stringx.split(data_choice, '\n')
156 | end
157 | if existfile(filename_question_type) then
158 | data_question_type = file.read(filename_question_type)
159 | data_question_type = stringx.split(data_question_type,'\n')
160 | end
161 | if existfile(filename_answer_type) then
162 | data_answer_type = file.read(filename_answer_type)
163 | data_answer_type = stringx.split(data_answer_type, '\n')
164 | end
165 | if existfile(filename_questionID) then
166 | data_questionID = file.read(filename_questionID)
167 | data_questionID = stringx.split(data_questionID,'\n')
168 | end
169 |
170 | local data_answer
171 | local data_answer_split
172 | if existfile(filename_answer) then
173 | print("Load answer file = " .. filename_answer)
174 | data_answer = file.read(filename_answer)
175 | data_answer_split = stringx.split(data_answer,'\n')
176 | end
177 |
178 | print("Load question file = " .. filename_question)
179 | local data_question = file.read(filename_question)
180 | local data_question_split = stringx.split(data_question,'\n')
181 | local manager_vocab_ = {}
182 |
183 | if manager_vocab == nil then
184 | local vocab_map_answer, ivocab_map_answer, nvocab_answer = build_vocab(data_answer, opt.thresh_answerword, 1, 0)
185 | local vocab_map_question, ivocab_map_question, nvocab_question = build_vocab(data_question,opt.thresh_questionword, 0, 1)
186 | print(' no.vocab_question=' .. nvocab_question.. ', no.vocab_answer=' .. nvocab_answer)
187 | manager_vocab_ = {vocab_map_answer=vocab_map_answer, ivocab_map_answer=ivocab_map_answer, vocab_map_question=vocab_map_question, ivocab_map_question=ivocab_map_question, nvocab_answer=nvocab_answer, nvocab_question=nvocab_question}
188 | else
189 | manager_vocab_ = manager_vocab
190 | end
191 |
192 | local imglist = load_filelist(filename_imglist)
193 | local nSample = #imglist
194 | -- We can choose to run the first few answers.
195 | if nSample > #data_question_split then
196 | nSample = #data_question_split
197 | end
198 |
199 | -- Answers.
200 | local x_answer = torch.zeros(nSample):fill(-1)
201 | if opt.multipleanswer == 1 then
202 | x_answer = torch.zeros(nSample, 10)
203 | end
204 | local x_answer_num = torch.zeros(nSample)
205 |
206 | -- Convert words in answers and questions to indices into the dictionary.
207 | local x_question = torch.zeros(nSample, opt.seq_length)
208 | for i = 1, nSample do
209 | local words = stringx.split(data_question_split[i])
210 | -- Answers
211 | if existfile(filename_answer) then
212 | local answer = data_answer_split[i]
213 | if manager_vocab_.vocab_map_answer[answer] == nil then
214 | x_answer[i] = -1
215 | else
216 | x_answer[i] = manager_vocab_.vocab_map_answer[answer]
217 | end
218 | end
219 | -- Questions
220 | for j = 1, opt.seq_length do
221 | if j <= #words then
222 | if manager_vocab_.vocab_map_question[words[j]] == nil then
223 | x_question[{i, j}] = 1
224 | else
225 | x_question[{i, j}] = manager_vocab_.vocab_map_question[words[j]]
226 | end
227 | else
228 | x_question[{i, j}] = manager_vocab_.vocab_map_question['END']
229 | end
230 | end
231 | end
232 |
233 | ---------------------------
234 | -- start loading features -
235 | ---------------------------
236 | local featureMap = {}
237 | local featName = 'googlenetFCdense'
238 |
239 | print(featName)
240 |
241 | -- Possible combinations of data loading
242 | local loading_spec = {
243 | trainval2014 = { train = true, val = true, test = false },
244 | trainval2014_train = { train = true, val = true, test = false },
245 | trainval2014_val = { train = false, val = true, test = false },
246 | train2014 = { train = true, val = false, test = false },
247 | val2014 = { train = false, val = true, test = false },
248 | test2015 = { train = false, val = false, test = true }
249 | }
250 | loading_spec['test-dev2015'] = { train = false, val = false, test = true }
251 | local feature_prefixSet = {
252 | train = paths.concat(path_dataset, 'coco_train2014_' .. featName),
253 | val = paths.concat(path_dataset, 'coco_val2014_' .. featName),
254 | test = paths.concat(path_dataset,'coco_test2015_' .. featName)
255 | }
256 |
257 | for k, feature_prefix in pairs(feature_prefixSet) do
258 | -- Check if we need to load this dataset.
259 | if loading_spec[dataType][k] then
260 | local feature_imglist = torch.load(feature_prefix ..'_imglist.dat')
261 | local featureSet = torch.load(feature_prefix ..'_feat.dat')
262 | for i = 1, #feature_imglist do
263 | local feat_in = torch.squeeze(featureSet[i])
264 | featureMap[feature_imglist[i]] = feat_in
265 | end
266 | end
267 | end
268 |
269 | collectgarbage()
270 | -- Return the state.
271 | local _state = {
272 | x_question = x_question,
273 | x_answer = x_answer,
274 | x_answer_num = x_answer_num,
275 | featureMap = featureMap,
276 | data_question = data_question_split,
277 | data_answer = data_answer_split,
278 | imglist = imglist,
279 | path_imglist = path_imglist,
280 | data_allanswer = data_allanswer,
281 | data_choice = data_choice,
282 | data_question_type = data_question_type,
283 | data_answer_type = data_answer_type,
284 | data_questionID = data_questionID
285 |
286 | }
287 |
288 | return _state, manager_vocab_
289 | end
290 |
291 | --------------------------------------------
292 | -- training relevant code
293 | --------------------------------------------
294 | function save_model(opt, manager_vocab, context, path)
295 | print('saving model ' .. path)
296 | local d = {}
297 | d.paramx = context.paramx:float()
298 | d.manager_vocab = manager_vocab
299 | d.stat = stat
300 | d.config_layers = config_layers
301 | d.opt = opt
302 |
303 | torch.save(path, d)
304 | end
305 |
306 | function bagofword(manager_vocab, x_seq)
307 | -- turn the list of word index into bag of word vector
308 | local outputVector = torch.zeros(manager_vocab.nvocab_question)
309 | for i= 1, x_seq:size(1) do
310 | if x_seq[i] ~= manager_vocab.vocab_map_question['END'] then
311 | outputVector[x_seq[i]] = 1
312 | else
313 | break
314 | end
315 | end
316 | return outputVector
317 | end
318 |
319 | function add_count(t, ...)
320 | -- Input: table for counting, k1, v1, k2, v2, k3, v3
321 | -- Output: t[k1] += (v1, 1), t[k2] += (v2, 1), etc.
322 | local args = { ... }
323 | local i = 1
324 | while i < #args do
325 | local k = args[i]
326 | local v = args[i + 1]
327 | if t[k] == nil then
328 | t[k] = { v, 1 }
329 | else
330 | t[k][1] = t[k][1] + v
331 | t[k][2] = t[k][2] + 1
332 | end
333 | i = i + 2
334 | end
335 | end
336 |
337 | function compute_accuracy(t)
338 | local res = { }
339 | for k, v in pairs(t) do
340 | res[k] = v[1] / v[2]
341 | end
342 | return res
343 | end
344 |
345 | function evaluate_answer(state, manager_vocab, pred_answer, prob_answer, selectIDX)
346 | -- testing case for the VQA dataset
347 | selectIDX = selectIDX or torch.range(1, state.x_answer:size(1))
348 | local pred_answer_word = {}
349 | local gt_answer_word = state.data_answer
350 | local gt_allanswer = state.data_allanswer
351 |
352 | local perfs = { }
353 | local count_question_type = {}
354 | local count_answer_type = {}
355 |
356 | for sampleID = 1, selectIDX:size(1) do
357 | local i = selectIDX[sampleID]
358 |
359 | -- Prediction correct.
360 | if manager_vocab.ivocab_map_answer[pred_answer[i]]== gt_answer_word[i] then
361 | add_count(perfs, "most_freq", 1)
362 | else
363 | add_count(perfs, "most_freq",0)
364 | end
365 |
366 | -- Estimate using the standard criteria (min(#correct match/3, 1))
367 | -- Also estimate the mutiple choice case.
368 | local question_type = state.data_question_type[i]
369 | local answer_type = state.data_answer_type[i]
370 |
371 | -- Compute accuracy for multiple choices.
372 | local choices = stringx.split(state.data_choice[i], ',')
373 | local score_choices = torch.zeros(#choices):fill(-1000000)
374 | for j = 1, #choices do
375 | local IDX_pred = manager_vocab.vocab_map_answer[choices[j]]
376 | if IDX_pred ~= nil then
377 | local score = prob_answer[{i, IDX_pred}]
378 | if score ~= nil then
379 | score_choices[j] = score
380 | end
381 | end
382 | end
383 | local val_max, IDX_max = torch.max(score_choices, 1)
384 | local word_pred_answer_multiple = choices[IDX_max[1]]
385 | local word_pred_answer_openend = manager_vocab.ivocab_map_answer[pred_answer[i]]
386 |
387 | -- Compare the predicted answer with all gt answers from humans.
388 | if gt_allanswer then
389 | local answers = stringx.split(gt_allanswer[i], ',')
390 | -- The number of answers matched with human answers.
391 | local count_curr_openend = 0
392 | local count_curr_multiple = 0
393 | for j = 1, #answers do
394 | count_curr_openend = count_curr_openend + (word_pred_answer_openend == answers[j] and 1 or 0)
395 | count_curr_multiple = count_curr_multiple + (word_pred_answer_multiple == answers[j] and 1 or 0)
396 | end
397 |
398 | local increment = math.min(count_curr_openend * 1.0/3, 1.0)
399 | add_count(perfs, "openend_overall", increment,
400 | "openend_q_" .. question_type, increment,
401 | "openend_a_" .. answer_type, increment)
402 |
403 | increment = math.min(count_curr_multiple * 1.0/3, 1.0)
404 | add_count(perfs, "multiple_overall", increment,
405 | "multiple_q_" .. question_type, increment,
406 | "multiple_a_" .. answer_type, increment)
407 | end
408 | end
409 |
410 | -- Compute accuracy
411 | return compute_accuracy(perfs)
412 | end
413 |
414 | function outputJSONanswer(state, manager_vocab, prob, file_json, choice)
415 | -- Dump the prediction result to csv file
416 | local f_json = io.open(file_json,'w')
417 | f_json:write('[')
418 |
419 | for i = 1, prob:size(1) do
420 | local choices = stringx.split(state.data_choice[i], ',')
421 | local score_choices = torch.zeros(#choices):fill(-1000000)
422 | for j=1, #choices do
423 | local IDX_pred = manager_vocab.vocab_map_answer[choices[j]]
424 | if IDX_pred ~= nil then
425 | local score = prob[{i, IDX_pred}]
426 | if score ~= nil then
427 | score_choices[j] = score
428 | end
429 | end
430 | end
431 | local val_max,IDX_max = torch.max(score_choices,1)
432 | local val_max_open,IDX_max_open = torch.max(prob[i],1)
433 | local word_pred_answer_multiple = choices[IDX_max[1]]
434 | local word_pred_answer_openend = manager_vocab.ivocab_map_answer[IDX_max_open[1]]
435 | local answer_pred = word_pred_answer_openend
436 | if choice == 1 then
437 | answer_pred = word_pred_answer_multiple
438 | end
439 | local questionID = state.data_questionID[i]
440 | f_json:write('{"answer": "' .. answer_pred .. '","question_id": ' .. questionID .. '}')
441 | if i< prob:size(1) then
442 | f_json:write(',')
443 | end
444 | end
445 | f_json:write(']')
446 | f_json:close()
447 |
448 | end
449 |
450 | function train_epoch(opt, state, manager_vocab, context, updateIDX)
451 | -- Dump context to the local namespace.
452 | local model = context.model
453 | local criterion = context.criterion
454 | local paramx = context.paramx
455 | local paramdx = context.paramdx
456 | local params_current = context.params_current
457 | local gparams_current = context.gparams_current
458 | local config_layers = context.config_layers
459 | local grad_last = context.grad_last
460 |
461 | local loss = 0.0
462 | local N = math.ceil(state.x_question:size(1) / opt.batchsize)
463 | local prob_answer = torch.zeros(state.x_question:size(1), manager_vocab.nvocab_answer)
464 | local pred_answer = torch.zeros(state.x_question:size(1))
465 | local target = torch.zeros(opt.batchsize)
466 |
467 | local featBatch_visual = torch.zeros(opt.batchsize, opt.vdim)
468 | local featBatch_word = torch.zeros(opt.batchsize, manager_vocab.nvocab_question)
469 | local word_idx = torch.zeros(opt.batchsize, opt.seq_length)
470 |
471 | local IDXset_batch = torch.zeros(opt.batchsize)
472 | local nSample_batch = 0
473 | local count_batch = 0
474 | local nBatch = 0
475 |
476 | local randIDX = torch.randperm(state.x_question:size(1))
477 | for iii = 1, state.x_question:size(1) do
478 | local i = randIDX[iii]
479 | local first_answer = -1
480 | if updateIDX~='test' then
481 | first_answer = state.x_answer[i]
482 | end
483 | if first_answer == -1 and updateIDX == 'train' then
484 | --skip the sample with NA answer
485 | else
486 | nSample_batch = nSample_batch + 1
487 | IDXset_batch[nSample_batch] = i
488 | if updateIDX ~= 'test' then
489 | target[nSample_batch] = state.x_answer[i]
490 | end
491 | local filename = state.imglist[i]--'COCO_train2014_000000000092'
492 | local feat_visual = state.featureMap[filename]:clone()
493 | local feat_word = bagofword(manager_vocab, state.x_question[i])
494 |
495 | word_idx[nSample_batch] = state.x_question[i]
496 | featBatch_word[nSample_batch] = feat_word:clone()
497 | featBatch_visual[nSample_batch] = feat_visual:clone()
498 |
499 | while i == state.x_question:size(1) and nSample_batch< opt.batchsize do
500 | -- padding the extra sample to complete a batch for training
501 | nSample_batch = nSample_batch+1
502 | IDXset_batch[nSample_batch] = i
503 | target[nSample_batch] = first_answer
504 | featBatch_visual[nSample_batch] = feat_visual:clone()
505 | featBatch_word[nSample_batch] = feat_word:clone()
506 | word_idx[nSample_batch] = state.x_question[i]
507 | end
508 | if nSample_batch == opt.batchsize then
509 | nBatch = nBatch+1
510 | word_idx = word_idx:cuda()
511 | nSample_batch = 0
512 | target = target:cuda()
513 | featBatch_word = featBatch_word:cuda()
514 | featBatch_visual = featBatch_visual:cuda()
515 | ----------forward pass----------------------
516 | --switch between the baselines and the memn2n
517 | if opt.method == 'BOW' then
518 | input = featBatch_word
519 | elseif opt.method == 'BOWIMG' then
520 | input = {featBatch_word, featBatch_visual}
521 | elseif opt.method == 'IMG' then
522 | input = featBatch_visual
523 | else
524 | print('error baseline method \n')
525 | end
526 |
527 | local output = model:forward(input)
528 | local err = criterion:forward(output, target)
529 | local prob_batch = output:float()
530 |
531 | loss = loss + err
532 | for j = 1, opt.batchsize do
533 | prob_answer[IDXset_batch[j]] = prob_batch[j]
534 | end
535 | --------------------backforward pass
536 | if updateIDX == 'train' then
537 | model:zeroGradParameters()
538 | local df = criterion:backward(output, target)
539 | local df_model = model:backward(input, df)
540 |
541 | -------------Update the params of baseline softmax---
542 | if opt.uniformLR ~= 1 then
543 | for i=1, #params_current do
544 | local gnorm = gparams_current[i]:norm()
545 | if config_layers.gradientClips[i]>0 and gnorm > config_layers.gradientClips[i] then
546 | gparams_current[i]:mul(config_layers.gradientClips[i]/gnorm)
547 | end
548 |
549 | grad_last[i]:mul(config_layers.moments[i])
550 | local tmp = torch.mul(gparams_current[i],-config_layers.lr_rates[i])
551 | grad_last[i]:add(tmp)
552 | params_current[i]:add(grad_last[i])
553 | if config_layers.weightRegConsts[i]>0 then
554 | local a = config_layers.lr_rates[i] * config_layers.weightRegConsts[i]
555 | params_current[i]:mul(1-a)
556 | end
557 | local pnorm = params_current[i]:norm()
558 | if config_layers.weightClips[i]>0 and pnorm > config_layers.weightClips[i] then
559 | params_current[i]:mul(config_layers.weightClips[i]/pnorm)
560 | end
561 | end
562 | else
563 | local norm_dw = paramdx:norm()
564 | if norm_dw > opt.max_gradientnorm then
565 | local shrink_factor = opt.max_gradientnorm / norm_dw
566 | paramdx:mul(shrink_factor)
567 | end
568 | paramx:add(g_paramdx:mul(-opt.lr))
569 | end
570 |
571 | end
572 |
573 | --batch finished
574 | count_batch = count_batch+1
575 | if count_batch == 120 then
576 | collectgarbage()
577 | count_batch = 0
578 | end
579 | end
580 | end-- end of the pass sample with -1 answer IDX
581 | end
582 | -- 1 epoch finished
583 | local y_max, i_max = torch.max(prob_answer,2)
584 | i_max = torch.squeeze(i_max)
585 | pred_answer = i_max:clone()
586 | if updateIDX~='test' then
587 |
588 | local gtAnswer = state.x_answer:clone()
589 | gtAnswer = gtAnswer:long()
590 | local correctNum = torch.sum(torch.eq(pred_answer, gtAnswer))
591 | acc = correctNum*1.0/pred_answer:size(1)
592 | else
593 | acc = -1
594 | end
595 | print(updateIDX ..': acc (mostFreq) =' .. acc)
596 | local perfs = nil
597 | if updateIDX ~= 'test' and state.data_allanswer ~= nil then
598 | -- using the standard evalution criteria of QA virginiaTech
599 | perfs = evaluate_answer(state, manager_vocab, pred_answer, prob_answer)
600 | print(updateIDX .. ': acc.match mostfreq = ' .. perfs.most_freq)
601 | print(updateIDX .. ': acc.dataset (OpenEnd) =' .. perfs.openend_overall)
602 | print(updateIDX .. ': acc.dataset (MultipleChoice) =' .. perfs.multiple_overall)
603 | -- If you want to see more statistics. do the following:
604 | -- print(perfs)
605 | end
606 | print(updateIDX .. ' loss=' .. loss/nBatch)
607 | return pred_answer, prob_answer, perfs
608 | end
609 |
610 |
--------------------------------------------------------------------------------