├── model └── tmp.txt ├── result └── results_BOWIMG.zip ├── README.md ├── LinearNB.lua ├── opensource_utils.lua ├── opensource_baseline.lua ├── opensource_interactive.lua └── opensource_base.lua /model/tmp.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /result/results_BOWIMG.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhoubolei/VQAbaseline/HEAD/result/results_BOWIMG.zip -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple Baseline for Visual Question Answering 2 | 3 | We descrive a very simple bag-of-words baseline for visual question answering. The description of the baseline is in the arXiv paper http://arxiv.org/pdf/1512.02167.pdf. The code is developed by [Bolei Zhou](http://people.csail.mit.edu/bzhou/) and [Yuandong Tian](http://yuandong-tian.com/). 4 | 5 | ![results](http://visualqa.csail.mit.edu/example.jpg) 6 | 7 | 8 | Demo is available at http://visualqa.csail.mit.edu/ 9 | 10 | To train the model using the code, the following data of the VQA dataset are needed: 11 | - The pre-processed data of text is at http://visualqa.csail.mit.edu/data_vqa_txt.zip 12 | - The googlenet feature of all the COCO images is at http://visualqa.csail.mit.edu/data_vqa_feat.zip 13 | 14 | The pre-trained model used in the paper is at http://visualqa.csail.mit.edu/coco_qadevi_BOWIMG_bestepoch93_final.t7model. It has 55.89 on the Open-Ended and 61.69 on Multiple-Choice for the test-standard of COCO VQA dataset. 15 | 16 | Contact Bolei Zhou (zhoubolei@gmail.com) if you have any questions. 17 | 18 | Please cite our arXiv note if you use our code: 19 | 20 | B. Zhou, Y. Tian, S. Suhkbaatar, A. Szlam, R. Fergus. 21 | Simple Baseline for Visual Question Answering. 22 | arXiv:1512.02167 23 | -------------------------------------------------------------------------------- /LinearNB.lua: -------------------------------------------------------------------------------- 1 | local LinearNB, parent = torch.class('nn.LinearNB', 'nn.Module') 2 | 3 | function LinearNB:__init(inputSize, outputSize) 4 | parent.__init(self) 5 | 6 | self.outputSize = outputSize 7 | self.inputSize = inputSize 8 | self.weight = torch.Tensor(outputSize, inputSize) 9 | self.gradWeight = torch.Tensor(outputSize, inputSize) 10 | 11 | self:reset() 12 | end 13 | 14 | function LinearNB:reset(stdv) 15 | if stdv then 16 | stdv = stdv * math.sqrt(3) 17 | else 18 | stdv = 1./math.sqrt(self.weight:size(2)) 19 | end 20 | if nn.oldSeed then 21 | for i=1,self.weight:size(1) do 22 | self.weight:select(1, i):apply(function() 23 | return torch.uniform(-stdv, stdv) 24 | end) 25 | end 26 | else 27 | self.weight:uniform(-stdv, stdv) 28 | end 29 | end 30 | 31 | function LinearNB:updateOutput(input) 32 | if input:dim() == 1 then 33 | self.output:resize(self.outputSize) 34 | self.output:zero() 35 | self.output:addmv(1, self.weight, input) 36 | elseif input:dim() == 2 then 37 | local nframe = input:size(1) 38 | local nunit = self.outputSize 39 | self.output:resize(nframe, nunit):zero() 40 | if not self.addBuffer or self.addBuffer:size(1) ~= nframe then 41 | self.addBuffer = input.new(nframe):fill(1) 42 | end 43 | if nunit == 1 then 44 | -- Special case to fix output size of 1 bug: 45 | self.output:select(2,1):addmv(1, input, self.weight:select(1,1)) 46 | else 47 | self.output:addmm(1, input, self.weight:t()) 48 | end 49 | else 50 | error('input must be vector or matrix') 51 | end 52 | 53 | return self.output 54 | end 55 | 56 | function LinearNB:updateGradInput(input, gradOutput) 57 | if self.gradInput then 58 | 59 | local nElement = self.gradInput:nElement() 60 | self.gradInput:resizeAs(input) 61 | if self.gradInput:nElement() ~= nElement then 62 | self.gradInput:zero() 63 | end 64 | if input:dim() == 1 then 65 | self.gradInput:addmv(0, 1, self.weight:t(), gradOutput) 66 | elseif input:dim() == 2 then 67 | self.gradInput:addmm(0, 1, gradOutput, self.weight) 68 | end 69 | 70 | return self.gradInput 71 | end 72 | end 73 | 74 | function LinearNB:accGradParameters(input, gradOutput, scale) 75 | scale = scale or 1 76 | 77 | if input:dim() == 1 then 78 | self.gradWeight:addr(scale, gradOutput, input) 79 | elseif input:dim() == 2 then 80 | local nunit = self.outputSize 81 | if nunit == 1 then 82 | -- Special case to fix output size of 1 bug: 83 | self.gradWeight:select(1,1):addmv(scale, 84 | input:t(), gradOutput:select(2,1)) 85 | else 86 | self.gradWeight:addmm(scale, gradOutput:t(), input) 87 | end 88 | end 89 | end 90 | 91 | -- we do not need to accumulate parameters when sharing 92 | LinearNB.sharedAccUpdateGradParameters = LinearNB.accUpdateGradParameters 93 | 94 | -------------------------------------------------------------------------------- /opensource_utils.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cunn' 3 | 4 | paths.dofile('LinearNB.lua') 5 | 6 | function getFreeGPU() 7 | -- select the most available GPU to train 8 | local nDevice = cutorch.getDeviceCount() 9 | local memSet = torch.Tensor(nDevice) 10 | for i=1, nDevice do 11 | local tmp, _ = cutorch.getMemoryUsage(i) 12 | memSet[i] = tmp 13 | end 14 | local _, curDeviceID = torch.max(memSet,1) 15 | return curDeviceID[1] 16 | end 17 | 18 | function build_model(opt, manager_vocab) 19 | -- function to build up baseline model 20 | local model 21 | if opt.method == 'BOW' then 22 | model = nn.Sequential() 23 | local module_tdata = nn.LinearNB(manager_vocab.nvocab_question, opt.embed_word) 24 | model:add(module_tdata) 25 | model:add(nn.Linear(opt.embed_word, manager_vocab.nvocab_answer)) 26 | 27 | elseif opt.method == 'IMG' then 28 | model = nn.Sequential() 29 | model:add(nn.Linear(opt.vdim, manager_vocab.nvocab_answer)) 30 | 31 | elseif opt.method == 'BOWIMG' then 32 | model = nn.Sequential() 33 | local module_tdata = nn.Sequential():add(nn.SelectTable(1)):add(nn.LinearNB(manager_vocab.nvocab_question, opt.embed_word)) 34 | local module_vdata = nn.Sequential():add(nn.SelectTable(2)) 35 | local cat = nn.ConcatTable():add(module_tdata):add(module_vdata) 36 | model:add(cat):add(nn.JoinTable(2)) 37 | model:add(nn.LinearNB(opt.embed_word + opt.vdim, manager_vocab.nvocab_answer)) 38 | 39 | else 40 | print('no such methods') 41 | 42 | end 43 | 44 | model:add(nn.LogSoftMax()) 45 | local criterion = nn.ClassNLLCriterion() 46 | criterion.sizeAverage = false 47 | model:cuda() 48 | criterion:cuda() 49 | 50 | return model, criterion 51 | end 52 | 53 | function initial_params() 54 | local gpuidx = getFreeGPU() 55 | print('use GPU IDX=' .. gpuidx) 56 | cutorch.setDevice(gpuidx) 57 | 58 | local cmd = torch.CmdLine() 59 | 60 | -- parameters for general setting 61 | cmd:option('--savepath', 'model') 62 | 63 | -- parameters for the visual feature 64 | cmd:option('--vfeat', 'googlenetFC') 65 | cmd:option('--vdim', 1024) 66 | 67 | -- parameters for data pre-process 68 | cmd:option('--thresh_questionword',6, 'threshold for the word freq on question') 69 | cmd:option('--thresh_answerword', 3, 'threshold for the word freq on the answer') 70 | cmd:option('--batchsize', 100) 71 | cmd:option('--seq_length', 50) 72 | 73 | -- parameters for learning 74 | cmd:option('--uniformLR', 0, 'whether to use uniform learning rate for all the parameters') 75 | cmd:option('--epochs', 100) 76 | cmd:option('--nepoch_lr', 100) 77 | cmd:option('--decay', 1.2) 78 | cmd:option('--embed_word', 1024,'the word embedding dimension in baseline') 79 | 80 | -- parameters for universal learning rate 81 | cmd:option('--maxgradnorm', 20) 82 | cmd:option('--maxweightnorm', 2000) 83 | 84 | -- parameters for different learning rates for different layers 85 | cmd:option('--lr_wordembed', 0.8) 86 | cmd:option('--lr_other', 0.01) 87 | cmd:option('--weightClip_wordembed', 1500) 88 | cmd:option('--weightClip_other', 20) 89 | 90 | return cmd:parse(arg or {}) 91 | end 92 | 93 | function loadPretrained(opt) 94 | --load the pre-trained model then evaluate on the test set then generate the csv file that could be submitted to the evaluation server 95 | local method = 'BOWIMG' 96 | local model_path = 'model/BOWIMG.t7' 97 | opt.method = method 98 | 99 | -- load pre-trained model 100 | local f_model = torch.load(model_path) 101 | local manager_vocab = f_model.manager_vocab 102 | -- Some simple fix for old models. 103 | if manager_vocab.vocab_map_question['END'] == nil then 104 | manager_vocab.vocab_map_question['END'] = -1 105 | manager_vocab.ivocab_map_question[-1] = 'END' 106 | end 107 | 108 | local model, criterion = build_model(opt, manager_vocab) 109 | local paramx, paramdx = model:getParameters() 110 | paramx:copy(f_model.paramx) 111 | 112 | return { 113 | model = model, 114 | criterion = criterion, 115 | manager_vocab = manager_vocab 116 | } 117 | end 118 | -------------------------------------------------------------------------------- /opensource_baseline.lua: -------------------------------------------------------------------------------- 1 | require 'paths' 2 | 3 | local stringx = require 'pl.stringx' 4 | local file = require 'pl.file' 5 | 6 | paths.dofile('opensource_base.lua') 7 | paths.dofile('opensource_utils.lua') 8 | 9 | local debugger = require 'fb.debugger' 10 | 11 | function adjust_learning_rate(epoch_num, opt, config_layers) 12 | -- Every opt.nepoch_lr iterations, the learning rate is reduced. 13 | if epoch_num % opt.nepoch_lr == 0 then 14 | for j = 1, #config_layers.lr_rates do 15 | config_layers.lr_rates[j] = config_layers.lr_rates[j] / opt.decay 16 | end 17 | end 18 | end 19 | 20 | function runTrainVal() 21 | local method = 'BOWIMG' 22 | local step_trainval = true -- step for train and valiaton 23 | local step_trainall = true -- step for combining train2014 and val2014 24 | local opt = initial_params() 25 | opt.method = method 26 | opt.save = paths.concat(opt.savepath, method ..'.t7') 27 | 28 | local stat = {} 29 | -- load data inside 30 | if step_trainval then 31 | local state_train, manager_vocab = load_visualqadataset(opt, 'trainval2014_train', nil) 32 | local state_val, _ = load_visualqadataset(opt, 'trainval2014_val', manager_vocab) 33 | local model, criterion = build_model(opt, manager_vocab) 34 | local paramx, paramdx = model:getParameters() 35 | local params_current, gparams_current = model:parameters() 36 | 37 | local config_layers, grad_last = config_layer_params(opt, params_current, gparams_current, 1) 38 | -- Save variables into context so that train_epoch could use. 39 | local context = { 40 | model = model, 41 | criterion = criterion, 42 | paramx = paramx, 43 | paramdx = paramdx, 44 | params_current = params_current, 45 | gparams_current = gparams_current, 46 | config_layers = config_layers, 47 | grad_last = grad_last 48 | } 49 | print(params_current) 50 | print('start training ...') 51 | for i = 1, opt.epochs do 52 | print(method .. ' epoch '..i) 53 | train_epoch(opt, state_train, manager_vocab, context, 'train') 54 | _, _, perfs = train_epoch(opt, state_val, manager_vocab, context, 'val') 55 | -- Accumulate statistics 56 | stat[i] = {acc, perfs.most_freq, perfs.openend_overall, perfs.multiple_overall} 57 | -- Adjust the learning rate 58 | adjust_learning_rate(i, opt, config_layers) 59 | end 60 | end 61 | 62 | if step_trainall then 63 | local nEpoch_best = 1 64 | local acc_openend_best = 0 65 | if step_trainval then 66 | 67 | -- Select the best train epoch number and combine train2014 and val2014 68 | for i = 1, #stat do 69 | if stat[i][3]> acc_openend_best then 70 | nEpoch_best = i 71 | acc_openend_best = stat[i][3] 72 | end 73 | end 74 | 75 | print('best epoch number is ' .. nEpoch_best) 76 | print('best acc is ' .. acc_openend_best) 77 | else 78 | nEpoch_best = 100 79 | end 80 | -- Combine train2014 and val2014 81 | local nEpoch_trainAll = nEpoch_best 82 | local state_train, manager_vocab = load_visualqadataset(opt, 'trainval2014', nil) 83 | -- recreate the model 84 | local model, criterion = build_model(opt, manager_vocab) 85 | local paramx, paramdx = model:getParameters() 86 | local params_current, gparams_current = model:parameters() 87 | 88 | local config_layers, grad_last = config_layer_params(opt, params_current, gparams_current, 1) 89 | 90 | local context = { 91 | model = model, 92 | criterion = criterion, 93 | paramx = paramx, 94 | paramdx = paramdx, 95 | params_current = params_current, 96 | gparams_current = gparams_current, 97 | config_layers = config_layers, 98 | grad_last = grad_last 99 | } 100 | print(params_current) 101 | 102 | print('start training on all data ...') 103 | stat = {} 104 | for i=1, nEpoch_trainAll do 105 | print('epoch '..i .. '/' ..nEpoch_trainAll) 106 | _, _, perfs = train_epoch(opt, state_train, manager_vocab, context, 'train') 107 | stat[i] = {acc, perfs.most_freq, perfs.openend_overall, perfs.multiple_overall} 108 | adjust_learning_rate(i, opt, config_layers) 109 | 110 | local modelname_curr = opt.save 111 | save_model(opt, manager_vocab, context, modelname_curr) 112 | end 113 | end 114 | end 115 | 116 | function runTest() 117 | --load the pre-trained model then evaluate on the test set then generate the csv file that could be submitted to the evaluation server 118 | local opt = initial_params() 119 | local context = loadPretrained(opt) 120 | local manager_vocab = context.manager_vocab 121 | 122 | -- load test data 123 | local testSet = 'test-dev2015' --'test2015' and 'test-dev2015' 124 | local state_test, _ = load_visualqadataset(opt, testSet, manager_vocab) 125 | 126 | -- predict 127 | local pred, prob, perfs = train_epoch(opt, state_test, manager_vocab, context, 'test') 128 | 129 | -- output to csv file to be submitted to the VQA evaluation server 130 | local file_json_openend = 'result/vqa_OpenEnded_mscoco_' .. testSet .. '_'.. opt.method .. '_results.json' 131 | local file_json_multiple = 'result/vqa_MultipleChoice_mscoco_' .. testSet .. '_'.. opt.method .. '_results.json' 132 | print('output the OpenEnd prediction to JSON file...' .. file_json_openend) 133 | local choice = 0 134 | outputJSONanswer(state_test, manager_vocab, prob, file_json_openend, choice) 135 | print('output the MultipleChoice prediction to JSON file...' .. file_json_multiple) 136 | choice = 1 137 | outputJSONanswer(state_test, manager_vocab, prob, file_json_multiple, choice) 138 | 139 | collectgarbage() 140 | 141 | end 142 | 143 | runTrainVal() 144 | runTest() 145 | -------------------------------------------------------------------------------- /opensource_interactive.lua: -------------------------------------------------------------------------------- 1 | require 'paths' 2 | require 'cunn' 3 | require 'nn' 4 | 5 | local stringx = require 'pl.stringx' 6 | local file = require 'pl.file' 7 | local debugger = require 'fb.debugger' 8 | 9 | paths.dofile('opensource_base.lua') 10 | paths.dofile('opensource_utils.lua') 11 | 12 | local rank_context = { } 13 | 14 | function rankPrepare() 15 | local opt = initial_params() 16 | rank_context.context = loadPretrained(opt) 17 | 18 | local manager_vocab = rank_context.context.manager_vocab 19 | local model = rank_context.context.model 20 | 21 | -- Question-Answer matrix: 22 | -- 23 | -- The entire model is like r = M_i * v_i + M_w * M_emb * v_onehot, there is no non-linearity. 24 | -- So we could just dump M_w * M_emb from here. 25 | -- 26 | -- nvocab_answer * embed_word 27 | local M_w = model.modules[3].weight[{ { }, { 1, opt.embed_word } }] 28 | -- embed_word * nvocab_question 29 | local M_emb = model.modules[1].modules[1].modules[2].weight 30 | 31 | -- nvocab_answer * nvocab_question 32 | rank_context.rankWord_M = torch.mm(M_w, M_emb) 33 | 34 | -- Image-Answer matrix: vdim * nvocab_answer 35 | rank_context.rankImage_M = model.modules[3].weight[{ { }, { opt.embed_word + 1, opt.embed_word + opt.vdim }}]:transpose(1, 2) 36 | 37 | -- load test data 38 | local testSet = 'test-dev2015' --'test2015' and 'test-dev2015' 39 | local state, _ = load_visualqadataset(opt, testSet, manager_vocab) 40 | 41 | local x_rep = rank_context.rankWord_M.new():resize(state.x_question:size(1), manager_vocab.nvocab_question):zero() 42 | for i = 1, state.x_question:size(1) do 43 | x_rep[i]:copy(bagofword(manager_vocab, state.x_question[i])) 44 | end 45 | -- nvocab_answer * #question 46 | rank_context.rankQuestion_M = torch.mm(rank_context.rankWord_M, x_rep:transpose(1, 2)) 47 | rank_context.questions = state.data_question 48 | 49 | rank_context.imglist = { } 50 | rank_context.inv_imglist = { } 51 | 52 | for k, v in pairs(state.featureMap) do 53 | table.insert(rank_context.imglist, k) 54 | rank_context.inv_imglist[k] = #rank_context.imglist 55 | end 56 | rank_context.featureMap = state.featureMap 57 | 58 | -- Dump rankImg matrix. 59 | local batch = 128 60 | local mat_feat = M_w.new():resize(batch, opt.vdim):zero() 61 | local rankImage_M2 = M_w.new():resize(#rank_context.imglist, manager_vocab.nvocab_answer):zero() 62 | 63 | for i = 1, #rank_context.imglist, batch do 64 | local this_batch = math.min(batch, #rank_context.imglist - i + 1) 65 | for j = 1, this_batch do 66 | local img_name = rank_context.imglist[j + i - 1] 67 | mat_feat[j]:copy(rank_context.featureMap[img_name]) 68 | end 69 | -- batch * nvocab_answer 70 | local this_output = torch.mm(mat_feat:sub(1, this_batch), rank_context.rankImage_M) 71 | -- Collect the score. 72 | rankImage_M2[{ {i, i + this_batch - 1}, { } }]:copy(this_output) 73 | end 74 | -- nvocab_answer * #rank_context.imglist 75 | rank_context.rankImage_M2 = rankImage_M2:transpose(1, 2) 76 | end 77 | 78 | function rankWord(answer, topn) 79 | local manager_vocab = rank_context.context.manager_vocab 80 | local model = rank_context.context.model 81 | 82 | -- Convert the answer to idx and return topn words 83 | -- require 'fb.debugger'.enter() 84 | if answer:sub(1, 1) == '"' then answer = answer:sub(2, -2) end 85 | local answerid = manager_vocab.vocab_map_answer[answer] 86 | if answerid == nil then return end 87 | 88 | local score = rank_context.rankWord_M[answerid]:clone():squeeze() 89 | local sortedScore, sortedIndices = score:sort(1, true) 90 | 91 | local res = { } 92 | for i = 1, topn do 93 | local w = manager_vocab.ivocab_map_question[sortedIndices[i]] 94 | table.insert(res, { word = w, score = sortedScore[i] }) 95 | end 96 | 97 | return res 98 | end 99 | 100 | function rankImage(answer, topn) 101 | local manager_vocab = rank_context.context.manager_vocab 102 | 103 | -- Given the answer, rank the image most relevant to the answer. 104 | -- Convert the answer to idx and return topn words 105 | if answer:sub(1, 1) == '"' then answer = answer:sub(2, -2) end 106 | local answerid = manager_vocab.vocab_map_answer[answer] 107 | if answerid == nil then return end 108 | 109 | local score = rank_context.rankImage_M2[answerid]:clone():squeeze() 110 | local sortedScore, sortedIndices = score:sort(1, true) 111 | 112 | local res = { } 113 | for i = 1, topn do 114 | local idx = sortedIndices[i] 115 | -- Get the image link. 116 | local imgname = rank_context.imglist[idx] 117 | local id = tonumber(imgname:match("_(%d+)")) 118 | local url = "http://mscoco.org/images/" .. tostring(id) 119 | table.insert(res, { idx = idx, imgname = imgname, url = url, score = sortedScore[i] }) 120 | end 121 | 122 | return res 123 | end 124 | 125 | function rankQuestion(answer, topn) 126 | local manager_vocab = rank_context.context.manager_vocab 127 | 128 | -- Given the answer, rank the image most relevant to the answer. 129 | -- Convert the answer to idx and return topn words 130 | if answer:sub(1, 1) == '"' then answer = answer:sub(2, -2) end 131 | local answerid = manager_vocab.vocab_map_answer[answer] 132 | if answerid == nil then return end 133 | 134 | local score = rank_context.rankQuestion_M[answerid]:clone():squeeze() 135 | local sortedScore, sortedIndices = score:sort(1, true) 136 | 137 | local res = { } 138 | for i = 1, topn do 139 | local idx = sortedIndices[i] 140 | table.insert(res, { idx = idx, question = rank_context.questions[idx], score = sortedScore[i] }) 141 | end 142 | 143 | return res 144 | end 145 | 146 | local function smart_split(s, quotes) 147 | quotes = quotes or { ['"'] = true } 148 | local res = { } 149 | local start = 1 150 | local quote_stack = { } 151 | for i = 1, #s do 152 | local c = s:sub(i, i) 153 | if c == ' ' then 154 | if #quote_stack == 0 then 155 | table.insert(res, s:sub(start, i - 1)) 156 | start = i + 1 157 | end 158 | elseif quotes[c] then 159 | if #quote_stack == 0 or c ~= quote_stack[#quote_stack] then 160 | table.insert(quote_stack, c) 161 | else 162 | table.remove(quote_stack) 163 | end 164 | end 165 | end 166 | table.insert(res, s:sub(start, -1)) 167 | return res 168 | end 169 | 170 | local commands = { 171 | rankw = { 172 | -- Given answer, rank words. 173 | exec = function(tokens) 174 | local topn = tonumber(tokens[3]) or 5 175 | local res = rankWord(tokens[2], topn) 176 | local success = false 177 | local s = "" 178 | if res then 179 | for i, v in pairs(res) do 180 | s = s .. string.format("[%d]: %s (%.2f)", i, v.word, v.score) .. "\n" 181 | end 182 | success = true 183 | end 184 | return success, s 185 | end, 186 | help = "\"rankw answer 5\" will rank the word and show top5 question words that give the answer. If 5 is not given, default to top5. If the answer contains white space, use quote to separate." 187 | }, 188 | rankq = { 189 | -- Given answer, rank sentence. 190 | exec = function(tokens) 191 | local topn = tonumber(tokens[3]) or 5 192 | local res = rankQuestion(tokens[2], topn) 193 | local success = false 194 | local s = "" 195 | if res then 196 | for i, v in pairs(res) do 197 | s = s .. string.format("[%d]: %s (%.2f)", i, v.question, v.score) .. "\n" 198 | end 199 | success = true 200 | end 201 | return success, s 202 | end, 203 | help = "\"rankq answer 5\" will rank the question from test set and show top5 that give the answer. If 5 is not given, default to top5. If the answer contains white space, use quote to separate." 204 | }, 205 | ranki = { 206 | exec = function(tokens) 207 | local topn = tonumber(tokens[3]) or 5 208 | local res = rankImage(tokens[2], topn) 209 | local success = false 210 | local s = "" 211 | if res then 212 | for i, v in pairs(res) do 213 | s = s .. string.format("[%d]: %s (%.2f)", i, v.imgname, v.score) .. "\n" 214 | end 215 | for i, v in pairs(res) do 216 | s = s .. string.format("", v.url) 217 | end 218 | s = s .. "\n" 219 | success = true 220 | end 221 | return success, s 222 | end, 223 | help = "\"ranki answer 5\" will rank the image and show top5 images that give the answer. If 5 is not given, default to top5. If the answer contains white space, use quote to separate." 224 | }, 225 | quit = { 226 | exec = function(tokens) 227 | return true, "Bye bye!", true 228 | end, 229 | help = "Quit the interactive environment" 230 | } 231 | } 232 | 233 | local history = "" 234 | 235 | local function init_webpage() 236 | history = "" 237 | end 238 | 239 | local function write2webpage(s) 240 | local lines = stringx.split(s, "\n") 241 | for i = 1, #lines do 242 | history = history .. lines[i] .. "
\n" 243 | end 244 | end 245 | 246 | local function save_webpage(filename) 247 | io.open(filename, "w"):write(history .. ""):close() 248 | end 249 | 250 | function run_interactive() 251 | -- Run interactive environment. 252 | print("Preload the model...") 253 | rankPrepare() 254 | 255 | -- Generate help string. 256 | local help_str = "Usage: \n" 257 | for k, v in pairs(commands) do 258 | help_str = help_str .. k .. ":\n " .. commands[k].help .. "\n" 259 | end 260 | print("Ready...") 261 | 262 | init_webpage() 263 | 264 | while true do 265 | local command = io.read("*l") 266 | local tokens = smart_split(command) 267 | if #tokens > 0 then 268 | if tokens[1] == "help" then 269 | print(help_str) 270 | else 271 | local success, s, is_quit = commands[tokens[1]].exec(tokens) 272 | print(s) 273 | if success then 274 | write2webpage(command) 275 | write2webpage(s) 276 | save_webpage("/home/yuandong/public_html/webpage.html") 277 | end 278 | if is_quit then break end 279 | end 280 | end 281 | end 282 | end 283 | 284 | run_interactive() 285 | -------------------------------------------------------------------------------- /opensource_base.lua: -------------------------------------------------------------------------------- 1 | --local debugger = require 'fb.debugger' 2 | local stringx = require 'pl.stringx' 3 | local file = require 'pl.file' 4 | 5 | -- Here we specify different learning rate and gradClip for different layers. 6 | -- This is *critical* for the performance of BOW. 7 | function config_layer_params(opt, params_current, gparams_current, IDX_wordembed) 8 | local lr_wordembed = opt.lr_wordembed 9 | local lr_other = opt.lr_other 10 | local weightClip_wordembed = opt.weightClip_wordembed 11 | local weightClip_other = opt.weightClip_other 12 | 13 | print("lr_wordembed = " .. lr_wordembed) 14 | print("lr_other = " .. lr_other) 15 | print("weightClip_wordembed = " .. weightClip_wordembed) 16 | print("weightClip_other = " .. weightClip_other) 17 | 18 | local gradientClip_dummy = 0.1 19 | local weightRegConsts_dummy = 0.000005 20 | local initialRange_dummy = 0.1 21 | local moments_dummy = 0.9 22 | 23 | -- Initialize specification of layers. 24 | local config_layers = { 25 | lr_rates = {}, 26 | gradientClips = {}, 27 | weightClips = {}, 28 | moments = {}, 29 | weightRegConsts = {}, 30 | initialRange = {} 31 | } 32 | 33 | -- grad_last is used to add momentum to the gradient. 34 | local grad_last = {} 35 | if IDX_wordembed == 1 then 36 | -- assume wordembed matrix is the params_current[1] 37 | config_layers.lr_rates = {lr_wordembed} 38 | config_layers.gradientClips = {gradientClip_dummy} 39 | config_layers.weightClips = {weightClip_wordembed} 40 | config_layers.moments = {moments_dummy} 41 | config_layers.weightRegConsts = {weightRegConsts_dummy} 42 | config_layers.initialRange = {initialRange_dummy} 43 | for i = 2, #params_current do 44 | table.insert(config_layers.lr_rates, lr_other) 45 | table.insert(config_layers.moments, moments_dummy) 46 | table.insert(config_layers.gradientClips, gradientClip_dummy) 47 | table.insert(config_layers.weightClips, weightClip_other) 48 | table.insert(config_layers.weightRegConsts, weightRegConsts_dummy) 49 | table.insert(config_layers.initialRange, initialRange_dummy) 50 | end 51 | 52 | else 53 | for i = 1, #params_current do 54 | table.insert(config_layers.lr_rates, lr_other) 55 | table.insert(config_layers.moments, moments_dummy) 56 | table.insert(config_layers.gradientClips, gradientClip_dummy) 57 | table.insert(config_layers.weightClips, weightClip_other) 58 | table.insert(config_layers.weightRegConsts, weightRegConsts_dummy) 59 | table.insert(config_layers.initialRange, initialRange_dummy) 60 | end 61 | end 62 | 63 | for i=1, #gparams_current do 64 | grad_last[i] = gparams_current[i]:clone() 65 | grad_last[i]:fill(0) 66 | end 67 | return config_layers, grad_last 68 | end 69 | 70 | --------------------------------------- 71 | ---- data IO relevant functions-------- 72 | --------------------------------------- 73 | 74 | function existfile(filename) 75 | local f=io.open(filename,"r") 76 | if f~=nil then io.close(f) return true else return false end 77 | end 78 | 79 | function load_filelist(fname) 80 | local data = file.read(fname) 81 | data = stringx.replace(data,'\n',' ') 82 | data = stringx.split(data) 83 | local imglist_ind = {} 84 | for i=1, #data do 85 | imglist_ind[i] = stringx.split(data[i],'.')[1] 86 | end 87 | return imglist_ind 88 | end 89 | 90 | function build_vocab(data, thresh, IDX_singleline, IDX_includeEnd) 91 | if IDX_singleline == 1 then 92 | data = stringx.split(data,'\n') 93 | else 94 | data = stringx.replace(data,'\n', ' ') 95 | data = stringx.split(data) 96 | end 97 | local countWord = {} 98 | for i=1, #data do 99 | if countWord[data[i]] == nil then 100 | countWord[data[i]] = 1 101 | else 102 | countWord[data[i]] = countWord[data[i]] + 1 103 | end 104 | end 105 | local vocab_map_ = {} 106 | local ivocab_map_ = {} 107 | local vocab_idx = 0 108 | if IDX_includeEnd==1 then 109 | vocab_idx = 1 110 | vocab_map_['NA'] = 1 111 | ivocab_map_[1] = 'NA' 112 | end 113 | 114 | for i=1, #data do 115 | if vocab_map_[data[i]]==nil then 116 | if countWord[data[i]]>=thresh then 117 | vocab_idx = vocab_idx+1 118 | vocab_map_[data[i]] = vocab_idx 119 | ivocab_map_[vocab_idx] = data[i] 120 | --print(vocab_idx..'-'.. data[i] ..'--'.. countWord[data[i]]) 121 | else 122 | vocab_map_[data[i]] = vocab_map_['NA'] 123 | end 124 | end 125 | end 126 | vocab_map_['END'] = -1 127 | return vocab_map_, ivocab_map_, vocab_idx 128 | end 129 | 130 | function load_visualqadataset(opt, dataType, manager_vocab) 131 | -- Change it to your path. 132 | -- local path_imglist = 'datasets/coco_dataset/allimage2014' 133 | -- All COCO images. 134 | 135 | -- VQA question/answer txt files. 136 | -- Download data_vqa_feat.zip and data_vqa_txt.zip and decompress into this folder 137 | local path_dataset = '/data/local/vqa_opensource' 138 | 139 | local prefix = 'coco_' .. dataType 140 | local filename_question = paths.concat(path_dataset, prefix .. '_question.txt') 141 | local filename_answer = paths.concat(path_dataset, prefix .. '_answer.txt') 142 | local filename_imglist = paths.concat(path_dataset, prefix .. '_imglist.txt') 143 | local filename_allanswer = paths.concat(path_dataset, prefix .. '_allanswer.txt') 144 | local filename_choice = paths.concat(path_dataset, prefix .. '_choice.txt') 145 | local filename_question_type = paths.concat(path_dataset, prefix .. '_question_type.txt') 146 | local filename_answer_type = paths.concat(path_dataset, prefix .. '_answer_type.txt') 147 | local filename_questionID = paths.concat(path_dataset, prefix .. '_questionID.txt') 148 | 149 | if existfile(filename_allanswer) then 150 | data_allanswer = file.read(filename_allanswer) 151 | data_allanswer = stringx.split(data_allanswer,'\n') 152 | end 153 | if existfile(filename_choice) then 154 | data_choice = file.read(filename_choice) 155 | data_choice = stringx.split(data_choice, '\n') 156 | end 157 | if existfile(filename_question_type) then 158 | data_question_type = file.read(filename_question_type) 159 | data_question_type = stringx.split(data_question_type,'\n') 160 | end 161 | if existfile(filename_answer_type) then 162 | data_answer_type = file.read(filename_answer_type) 163 | data_answer_type = stringx.split(data_answer_type, '\n') 164 | end 165 | if existfile(filename_questionID) then 166 | data_questionID = file.read(filename_questionID) 167 | data_questionID = stringx.split(data_questionID,'\n') 168 | end 169 | 170 | local data_answer 171 | local data_answer_split 172 | if existfile(filename_answer) then 173 | print("Load answer file = " .. filename_answer) 174 | data_answer = file.read(filename_answer) 175 | data_answer_split = stringx.split(data_answer,'\n') 176 | end 177 | 178 | print("Load question file = " .. filename_question) 179 | local data_question = file.read(filename_question) 180 | local data_question_split = stringx.split(data_question,'\n') 181 | local manager_vocab_ = {} 182 | 183 | if manager_vocab == nil then 184 | local vocab_map_answer, ivocab_map_answer, nvocab_answer = build_vocab(data_answer, opt.thresh_answerword, 1, 0) 185 | local vocab_map_question, ivocab_map_question, nvocab_question = build_vocab(data_question,opt.thresh_questionword, 0, 1) 186 | print(' no.vocab_question=' .. nvocab_question.. ', no.vocab_answer=' .. nvocab_answer) 187 | manager_vocab_ = {vocab_map_answer=vocab_map_answer, ivocab_map_answer=ivocab_map_answer, vocab_map_question=vocab_map_question, ivocab_map_question=ivocab_map_question, nvocab_answer=nvocab_answer, nvocab_question=nvocab_question} 188 | else 189 | manager_vocab_ = manager_vocab 190 | end 191 | 192 | local imglist = load_filelist(filename_imglist) 193 | local nSample = #imglist 194 | -- We can choose to run the first few answers. 195 | if nSample > #data_question_split then 196 | nSample = #data_question_split 197 | end 198 | 199 | -- Answers. 200 | local x_answer = torch.zeros(nSample):fill(-1) 201 | if opt.multipleanswer == 1 then 202 | x_answer = torch.zeros(nSample, 10) 203 | end 204 | local x_answer_num = torch.zeros(nSample) 205 | 206 | -- Convert words in answers and questions to indices into the dictionary. 207 | local x_question = torch.zeros(nSample, opt.seq_length) 208 | for i = 1, nSample do 209 | local words = stringx.split(data_question_split[i]) 210 | -- Answers 211 | if existfile(filename_answer) then 212 | local answer = data_answer_split[i] 213 | if manager_vocab_.vocab_map_answer[answer] == nil then 214 | x_answer[i] = -1 215 | else 216 | x_answer[i] = manager_vocab_.vocab_map_answer[answer] 217 | end 218 | end 219 | -- Questions 220 | for j = 1, opt.seq_length do 221 | if j <= #words then 222 | if manager_vocab_.vocab_map_question[words[j]] == nil then 223 | x_question[{i, j}] = 1 224 | else 225 | x_question[{i, j}] = manager_vocab_.vocab_map_question[words[j]] 226 | end 227 | else 228 | x_question[{i, j}] = manager_vocab_.vocab_map_question['END'] 229 | end 230 | end 231 | end 232 | 233 | --------------------------- 234 | -- start loading features - 235 | --------------------------- 236 | local featureMap = {} 237 | local featName = 'googlenetFCdense' 238 | 239 | print(featName) 240 | 241 | -- Possible combinations of data loading 242 | local loading_spec = { 243 | trainval2014 = { train = true, val = true, test = false }, 244 | trainval2014_train = { train = true, val = true, test = false }, 245 | trainval2014_val = { train = false, val = true, test = false }, 246 | train2014 = { train = true, val = false, test = false }, 247 | val2014 = { train = false, val = true, test = false }, 248 | test2015 = { train = false, val = false, test = true } 249 | } 250 | loading_spec['test-dev2015'] = { train = false, val = false, test = true } 251 | local feature_prefixSet = { 252 | train = paths.concat(path_dataset, 'coco_train2014_' .. featName), 253 | val = paths.concat(path_dataset, 'coco_val2014_' .. featName), 254 | test = paths.concat(path_dataset,'coco_test2015_' .. featName) 255 | } 256 | 257 | for k, feature_prefix in pairs(feature_prefixSet) do 258 | -- Check if we need to load this dataset. 259 | if loading_spec[dataType][k] then 260 | local feature_imglist = torch.load(feature_prefix ..'_imglist.dat') 261 | local featureSet = torch.load(feature_prefix ..'_feat.dat') 262 | for i = 1, #feature_imglist do 263 | local feat_in = torch.squeeze(featureSet[i]) 264 | featureMap[feature_imglist[i]] = feat_in 265 | end 266 | end 267 | end 268 | 269 | collectgarbage() 270 | -- Return the state. 271 | local _state = { 272 | x_question = x_question, 273 | x_answer = x_answer, 274 | x_answer_num = x_answer_num, 275 | featureMap = featureMap, 276 | data_question = data_question_split, 277 | data_answer = data_answer_split, 278 | imglist = imglist, 279 | path_imglist = path_imglist, 280 | data_allanswer = data_allanswer, 281 | data_choice = data_choice, 282 | data_question_type = data_question_type, 283 | data_answer_type = data_answer_type, 284 | data_questionID = data_questionID 285 | 286 | } 287 | 288 | return _state, manager_vocab_ 289 | end 290 | 291 | -------------------------------------------- 292 | -- training relevant code 293 | -------------------------------------------- 294 | function save_model(opt, manager_vocab, context, path) 295 | print('saving model ' .. path) 296 | local d = {} 297 | d.paramx = context.paramx:float() 298 | d.manager_vocab = manager_vocab 299 | d.stat = stat 300 | d.config_layers = config_layers 301 | d.opt = opt 302 | 303 | torch.save(path, d) 304 | end 305 | 306 | function bagofword(manager_vocab, x_seq) 307 | -- turn the list of word index into bag of word vector 308 | local outputVector = torch.zeros(manager_vocab.nvocab_question) 309 | for i= 1, x_seq:size(1) do 310 | if x_seq[i] ~= manager_vocab.vocab_map_question['END'] then 311 | outputVector[x_seq[i]] = 1 312 | else 313 | break 314 | end 315 | end 316 | return outputVector 317 | end 318 | 319 | function add_count(t, ...) 320 | -- Input: table for counting, k1, v1, k2, v2, k3, v3 321 | -- Output: t[k1] += (v1, 1), t[k2] += (v2, 1), etc. 322 | local args = { ... } 323 | local i = 1 324 | while i < #args do 325 | local k = args[i] 326 | local v = args[i + 1] 327 | if t[k] == nil then 328 | t[k] = { v, 1 } 329 | else 330 | t[k][1] = t[k][1] + v 331 | t[k][2] = t[k][2] + 1 332 | end 333 | i = i + 2 334 | end 335 | end 336 | 337 | function compute_accuracy(t) 338 | local res = { } 339 | for k, v in pairs(t) do 340 | res[k] = v[1] / v[2] 341 | end 342 | return res 343 | end 344 | 345 | function evaluate_answer(state, manager_vocab, pred_answer, prob_answer, selectIDX) 346 | -- testing case for the VQA dataset 347 | selectIDX = selectIDX or torch.range(1, state.x_answer:size(1)) 348 | local pred_answer_word = {} 349 | local gt_answer_word = state.data_answer 350 | local gt_allanswer = state.data_allanswer 351 | 352 | local perfs = { } 353 | local count_question_type = {} 354 | local count_answer_type = {} 355 | 356 | for sampleID = 1, selectIDX:size(1) do 357 | local i = selectIDX[sampleID] 358 | 359 | -- Prediction correct. 360 | if manager_vocab.ivocab_map_answer[pred_answer[i]]== gt_answer_word[i] then 361 | add_count(perfs, "most_freq", 1) 362 | else 363 | add_count(perfs, "most_freq",0) 364 | end 365 | 366 | -- Estimate using the standard criteria (min(#correct match/3, 1)) 367 | -- Also estimate the mutiple choice case. 368 | local question_type = state.data_question_type[i] 369 | local answer_type = state.data_answer_type[i] 370 | 371 | -- Compute accuracy for multiple choices. 372 | local choices = stringx.split(state.data_choice[i], ',') 373 | local score_choices = torch.zeros(#choices):fill(-1000000) 374 | for j = 1, #choices do 375 | local IDX_pred = manager_vocab.vocab_map_answer[choices[j]] 376 | if IDX_pred ~= nil then 377 | local score = prob_answer[{i, IDX_pred}] 378 | if score ~= nil then 379 | score_choices[j] = score 380 | end 381 | end 382 | end 383 | local val_max, IDX_max = torch.max(score_choices, 1) 384 | local word_pred_answer_multiple = choices[IDX_max[1]] 385 | local word_pred_answer_openend = manager_vocab.ivocab_map_answer[pred_answer[i]] 386 | 387 | -- Compare the predicted answer with all gt answers from humans. 388 | if gt_allanswer then 389 | local answers = stringx.split(gt_allanswer[i], ',') 390 | -- The number of answers matched with human answers. 391 | local count_curr_openend = 0 392 | local count_curr_multiple = 0 393 | for j = 1, #answers do 394 | count_curr_openend = count_curr_openend + (word_pred_answer_openend == answers[j] and 1 or 0) 395 | count_curr_multiple = count_curr_multiple + (word_pred_answer_multiple == answers[j] and 1 or 0) 396 | end 397 | 398 | local increment = math.min(count_curr_openend * 1.0/3, 1.0) 399 | add_count(perfs, "openend_overall", increment, 400 | "openend_q_" .. question_type, increment, 401 | "openend_a_" .. answer_type, increment) 402 | 403 | increment = math.min(count_curr_multiple * 1.0/3, 1.0) 404 | add_count(perfs, "multiple_overall", increment, 405 | "multiple_q_" .. question_type, increment, 406 | "multiple_a_" .. answer_type, increment) 407 | end 408 | end 409 | 410 | -- Compute accuracy 411 | return compute_accuracy(perfs) 412 | end 413 | 414 | function outputJSONanswer(state, manager_vocab, prob, file_json, choice) 415 | -- Dump the prediction result to csv file 416 | local f_json = io.open(file_json,'w') 417 | f_json:write('[') 418 | 419 | for i = 1, prob:size(1) do 420 | local choices = stringx.split(state.data_choice[i], ',') 421 | local score_choices = torch.zeros(#choices):fill(-1000000) 422 | for j=1, #choices do 423 | local IDX_pred = manager_vocab.vocab_map_answer[choices[j]] 424 | if IDX_pred ~= nil then 425 | local score = prob[{i, IDX_pred}] 426 | if score ~= nil then 427 | score_choices[j] = score 428 | end 429 | end 430 | end 431 | local val_max,IDX_max = torch.max(score_choices,1) 432 | local val_max_open,IDX_max_open = torch.max(prob[i],1) 433 | local word_pred_answer_multiple = choices[IDX_max[1]] 434 | local word_pred_answer_openend = manager_vocab.ivocab_map_answer[IDX_max_open[1]] 435 | local answer_pred = word_pred_answer_openend 436 | if choice == 1 then 437 | answer_pred = word_pred_answer_multiple 438 | end 439 | local questionID = state.data_questionID[i] 440 | f_json:write('{"answer": "' .. answer_pred .. '","question_id": ' .. questionID .. '}') 441 | if i< prob:size(1) then 442 | f_json:write(',') 443 | end 444 | end 445 | f_json:write(']') 446 | f_json:close() 447 | 448 | end 449 | 450 | function train_epoch(opt, state, manager_vocab, context, updateIDX) 451 | -- Dump context to the local namespace. 452 | local model = context.model 453 | local criterion = context.criterion 454 | local paramx = context.paramx 455 | local paramdx = context.paramdx 456 | local params_current = context.params_current 457 | local gparams_current = context.gparams_current 458 | local config_layers = context.config_layers 459 | local grad_last = context.grad_last 460 | 461 | local loss = 0.0 462 | local N = math.ceil(state.x_question:size(1) / opt.batchsize) 463 | local prob_answer = torch.zeros(state.x_question:size(1), manager_vocab.nvocab_answer) 464 | local pred_answer = torch.zeros(state.x_question:size(1)) 465 | local target = torch.zeros(opt.batchsize) 466 | 467 | local featBatch_visual = torch.zeros(opt.batchsize, opt.vdim) 468 | local featBatch_word = torch.zeros(opt.batchsize, manager_vocab.nvocab_question) 469 | local word_idx = torch.zeros(opt.batchsize, opt.seq_length) 470 | 471 | local IDXset_batch = torch.zeros(opt.batchsize) 472 | local nSample_batch = 0 473 | local count_batch = 0 474 | local nBatch = 0 475 | 476 | local randIDX = torch.randperm(state.x_question:size(1)) 477 | for iii = 1, state.x_question:size(1) do 478 | local i = randIDX[iii] 479 | local first_answer = -1 480 | if updateIDX~='test' then 481 | first_answer = state.x_answer[i] 482 | end 483 | if first_answer == -1 and updateIDX == 'train' then 484 | --skip the sample with NA answer 485 | else 486 | nSample_batch = nSample_batch + 1 487 | IDXset_batch[nSample_batch] = i 488 | if updateIDX ~= 'test' then 489 | target[nSample_batch] = state.x_answer[i] 490 | end 491 | local filename = state.imglist[i]--'COCO_train2014_000000000092' 492 | local feat_visual = state.featureMap[filename]:clone() 493 | local feat_word = bagofword(manager_vocab, state.x_question[i]) 494 | 495 | word_idx[nSample_batch] = state.x_question[i] 496 | featBatch_word[nSample_batch] = feat_word:clone() 497 | featBatch_visual[nSample_batch] = feat_visual:clone() 498 | 499 | while i == state.x_question:size(1) and nSample_batch< opt.batchsize do 500 | -- padding the extra sample to complete a batch for training 501 | nSample_batch = nSample_batch+1 502 | IDXset_batch[nSample_batch] = i 503 | target[nSample_batch] = first_answer 504 | featBatch_visual[nSample_batch] = feat_visual:clone() 505 | featBatch_word[nSample_batch] = feat_word:clone() 506 | word_idx[nSample_batch] = state.x_question[i] 507 | end 508 | if nSample_batch == opt.batchsize then 509 | nBatch = nBatch+1 510 | word_idx = word_idx:cuda() 511 | nSample_batch = 0 512 | target = target:cuda() 513 | featBatch_word = featBatch_word:cuda() 514 | featBatch_visual = featBatch_visual:cuda() 515 | ----------forward pass---------------------- 516 | --switch between the baselines and the memn2n 517 | if opt.method == 'BOW' then 518 | input = featBatch_word 519 | elseif opt.method == 'BOWIMG' then 520 | input = {featBatch_word, featBatch_visual} 521 | elseif opt.method == 'IMG' then 522 | input = featBatch_visual 523 | else 524 | print('error baseline method \n') 525 | end 526 | 527 | local output = model:forward(input) 528 | local err = criterion:forward(output, target) 529 | local prob_batch = output:float() 530 | 531 | loss = loss + err 532 | for j = 1, opt.batchsize do 533 | prob_answer[IDXset_batch[j]] = prob_batch[j] 534 | end 535 | --------------------backforward pass 536 | if updateIDX == 'train' then 537 | model:zeroGradParameters() 538 | local df = criterion:backward(output, target) 539 | local df_model = model:backward(input, df) 540 | 541 | -------------Update the params of baseline softmax--- 542 | if opt.uniformLR ~= 1 then 543 | for i=1, #params_current do 544 | local gnorm = gparams_current[i]:norm() 545 | if config_layers.gradientClips[i]>0 and gnorm > config_layers.gradientClips[i] then 546 | gparams_current[i]:mul(config_layers.gradientClips[i]/gnorm) 547 | end 548 | 549 | grad_last[i]:mul(config_layers.moments[i]) 550 | local tmp = torch.mul(gparams_current[i],-config_layers.lr_rates[i]) 551 | grad_last[i]:add(tmp) 552 | params_current[i]:add(grad_last[i]) 553 | if config_layers.weightRegConsts[i]>0 then 554 | local a = config_layers.lr_rates[i] * config_layers.weightRegConsts[i] 555 | params_current[i]:mul(1-a) 556 | end 557 | local pnorm = params_current[i]:norm() 558 | if config_layers.weightClips[i]>0 and pnorm > config_layers.weightClips[i] then 559 | params_current[i]:mul(config_layers.weightClips[i]/pnorm) 560 | end 561 | end 562 | else 563 | local norm_dw = paramdx:norm() 564 | if norm_dw > opt.max_gradientnorm then 565 | local shrink_factor = opt.max_gradientnorm / norm_dw 566 | paramdx:mul(shrink_factor) 567 | end 568 | paramx:add(g_paramdx:mul(-opt.lr)) 569 | end 570 | 571 | end 572 | 573 | --batch finished 574 | count_batch = count_batch+1 575 | if count_batch == 120 then 576 | collectgarbage() 577 | count_batch = 0 578 | end 579 | end 580 | end-- end of the pass sample with -1 answer IDX 581 | end 582 | -- 1 epoch finished 583 | local y_max, i_max = torch.max(prob_answer,2) 584 | i_max = torch.squeeze(i_max) 585 | pred_answer = i_max:clone() 586 | if updateIDX~='test' then 587 | 588 | local gtAnswer = state.x_answer:clone() 589 | gtAnswer = gtAnswer:long() 590 | local correctNum = torch.sum(torch.eq(pred_answer, gtAnswer)) 591 | acc = correctNum*1.0/pred_answer:size(1) 592 | else 593 | acc = -1 594 | end 595 | print(updateIDX ..': acc (mostFreq) =' .. acc) 596 | local perfs = nil 597 | if updateIDX ~= 'test' and state.data_allanswer ~= nil then 598 | -- using the standard evalution criteria of QA virginiaTech 599 | perfs = evaluate_answer(state, manager_vocab, pred_answer, prob_answer) 600 | print(updateIDX .. ': acc.match mostfreq = ' .. perfs.most_freq) 601 | print(updateIDX .. ': acc.dataset (OpenEnd) =' .. perfs.openend_overall) 602 | print(updateIDX .. ': acc.dataset (MultipleChoice) =' .. perfs.multiple_overall) 603 | -- If you want to see more statistics. do the following: 604 | -- print(perfs) 605 | end 606 | print(updateIDX .. ' loss=' .. loss/nBatch) 607 | return pred_answer, prob_answer, perfs 608 | end 609 | 610 | --------------------------------------------------------------------------------