├── .gitignore ├── README.md ├── cornell_movie_dialogs.lua ├── data └── .gitkeep ├── dataset.lua ├── eval.lua ├── neuralconvo.lua ├── seq2seq.lua ├── tokenizer.lua └── train.lua /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | log/* 3 | *.log 4 | 5 | # Compiled Lua sources 6 | luac.out 7 | 8 | # luarocks build files 9 | *.src.rock 10 | *.zip 11 | *.tar.gz 12 | 13 | # Object files 14 | *.o 15 | *.os 16 | *.ko 17 | *.obj 18 | *.elf 19 | 20 | # Precompiled Headers 21 | *.gch 22 | *.pch 23 | 24 | # Libraries 25 | *.lib 26 | *.a 27 | *.la 28 | *.lo 29 | *.def 30 | *.exp 31 | 32 | # Shared objects (inc. Windows DLLs) 33 | *.dll 34 | *.so 35 | *.so.* 36 | *.dylib 37 | 38 | # Executables 39 | *.exe 40 | *.out 41 | *.app 42 | *.i*86 43 | *.x86_64 44 | *.hex 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Conversational Model in Torch 2 | 3 | This is an attempt at implementing [Sequence to Sequence Learning with Neural Networks (seq2seq)](http://arxiv.org/abs/1409.3215) and reproducing the results in [A Neural Conversational Model](http://arxiv.org/abs/1506.05869) (aka the Google chatbot). 4 | 5 | The Google chatbot paper [became famous](http://www.sciencealert.com/google-s-ai-bot-thinks-the-purpose-of-life-is-to-live-forever) after cleverly answering a few philosophical questions, such as: 6 | 7 | > **Human:** What is the purpose of living? 8 | > **Machine:** To live forever. 9 | 10 | ## How it works 11 | 12 | The model is based on two [LSTM](https://en.wikipedia.org/wiki/Long_short-term_memory) layers. One for encoding the input sentence into a "thought vector", and another for decoding that vector into a response. This model is called Sequence-to-sequence or seq2seq. 13 | 14 | ![seq2seq](https://4.bp.blogspot.com/-aArS0l1pjHQ/Vjj71pKAaEI/AAAAAAAAAxE/Nvy1FSbD_Vs/s640/2TFstaticgraphic_alt-01.png) 15 | _Source: http://googleresearch.blogspot.ca/2015/11/computer-respond-to-this-email.html_ 16 | 17 | In this experiment, we train the seq2seq model with movie dialogs from the [Cornell Movie-Dialogs Corpus](http://www.mpi-sws.org/~cristian/Cornell_Movie-Dialogs_Corpus.html). The lines are shortened to the first sentence. 18 | 19 | ## Sample conversation 20 | 21 | Here's a sample conversation after training for 20 epoch with 50000 examples, using the following command: 22 | 23 | ```sh 24 | th train.lua --cuda --dataset 50000 --hiddenSize 1000 25 | ``` 26 | 27 | (Took 3 days to train on my GeForce GTX 780M.) 28 | 29 | For OpenCL, use `--opencl` instead of `--cuda`. To train on CPU, don't provide any of those two. 30 | 31 | *WARNING: I can no longer reproduce those results. The cause seems to be a change in one of dependencies. But I'm currently working on a new implementation based on [harvardnlp/seq2seq-attn](https://github.com/harvardnlp/seq2seq-attn).* 32 | 33 | > **me:** Hello? 34 | > **bot:** Hi. 35 | > 36 | > **me:** How are you? 37 | > **bot:** I'm fine. 38 | > 39 | > **me:** What's your name? 40 | > **bot:** It's hard to describe. 41 | > 42 | > **me:** How so? 43 | > **bot:** I'm not sure. 44 | > 45 | > **me:** What color is the sky? 46 | > **bot:** It's blue. 47 | > 48 | > **me:** What is your job? 49 | > **bot:** It's not that i'm a fucking werewolf! 50 | > 51 | > **me:** What is the purpose of life? 52 | > **bot:** A gift. 53 | > 54 | > **me:** Are you intelligent? 55 | > **bot:** Yes, well... 56 | > 57 | > **me:** Are you a machine? 58 | > **bot:** That's a lie. 59 | > 60 | > **me:** Are you human? 61 | > **bot:** No, i'm not. 62 | > 63 | > **me:** What are you? 64 | > **bot:** I'm not sure. 65 | > 66 | > **me:** Do you plan on taking over the world? 67 | > **bot:** No, i don't. 68 | 69 | Phew! That was close. Good thing I didn't train it on the full dataset. Please experiment responsibly. 70 | 71 | _(Disclaimer: nonsensical responses have been removed.)_ 72 | 73 | ## Installing 74 | 75 | 1. [Install Torch](http://torch.ch/docs/getting-started.html). 76 | 2. Install the following additional Lua libs: 77 | 78 | ```sh 79 | luarocks install nn 80 | luarocks install rnn 81 | luarocks install penlight 82 | ``` 83 | 84 | To train with CUDA install the latest CUDA drivers, toolkit and run: 85 | 86 | ```sh 87 | luarocks install cutorch 88 | luarocks install cunn 89 | ``` 90 | 91 | To train with opencl install the lastest Opencl torch lib: 92 | 93 | ```sh 94 | luarocks install cltorch 95 | luarocks install clnn 96 | ``` 97 | 98 | 3. Download the [Cornell Movie-Dialogs Corpus](http://www.mpi-sws.org/~cristian/Cornell_Movie-Dialogs_Corpus.html) and extract all the files into data/cornell_movie_dialogs. 99 | 100 | ## Training 101 | 102 | ```sh 103 | th train.lua [-h / options] 104 | ``` 105 | 106 | The model will be saved to `data/model.t7` after each epoch if it has improved (error decreased). 107 | 108 | ### Options (some, not all) 109 | - `--opencl` use opencl for computation (requires [torch-cl](https://github.com/hughperkins/distro-cl)) 110 | - `--cuda` use cuda for computation 111 | - `--gpu [index]` use the nth GPU for computation (eg. on a 2015 MacBook `--gpu 0` results in the Intel GPU being used while `--gpu 1` uses the far more powerful AMD GPU) 112 | - `-- dataset [size]` control the size of the dataset 113 | - `--maxEpoch [amount]` specify the number of epochs to run 114 | 115 | ## Testing 116 | 117 | To load the model and have a conversation: 118 | 119 | ```sh 120 | th eval.lua 121 | ``` 122 | 123 | ## License 124 | 125 | MIT License 126 | 127 | Copyright (c) 2016 Marc-Andre Cournoyer 128 | 129 | Permission is hereby granted, free of charge, to any person obtaining a copy 130 | of this software and associated documentation files (the "Software"), to deal 131 | in the Software without restriction, including without limitation the rights 132 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 133 | copies of the Software, and to permit persons to whom the Software is 134 | furnished to do so, subject to the following conditions: 135 | 136 | The above copyright notice and this permission notice shall be included in all 137 | copies or substantial portions of the Software. 138 | 139 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 140 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 141 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 142 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 143 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 144 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 145 | SOFTWARE. 146 | -------------------------------------------------------------------------------- /cornell_movie_dialogs.lua: -------------------------------------------------------------------------------- 1 | local CornellMovieDialogs = torch.class("neuralconvo.CornellMovieDialogs") 2 | local stringx = require "pl.stringx" 3 | local xlua = require "xlua" 4 | 5 | local function parsedLines(file, fields) 6 | local f = assert(io.open(file, 'r')) 7 | 8 | return function() 9 | local line = f:read("*line") 10 | 11 | if line == nil then 12 | f:close() 13 | return 14 | end 15 | 16 | local values = stringx.split(line, " +++$+++ ") 17 | local t = {} 18 | 19 | for i,field in ipairs(fields) do 20 | t[field] = values[i] 21 | end 22 | 23 | return t 24 | end 25 | end 26 | 27 | function CornellMovieDialogs:__init(dir) 28 | self.dir = dir 29 | end 30 | 31 | local MOVIE_LINES_FIELDS = {"lineID","characterID","movieID","character","text"} 32 | local MOVIE_CONVERSATIONS_FIELDS = {"character1ID","character2ID","movieID","utteranceIDs"} 33 | local TOTAL_LINES = 387810 34 | 35 | local function progress(c) 36 | if c % 10000 == 0 then 37 | xlua.progress(c, TOTAL_LINES) 38 | end 39 | end 40 | 41 | function CornellMovieDialogs:load() 42 | local lines = {} 43 | local conversations = {} 44 | local count = 0 45 | 46 | print("-- Parsing Cornell movie dialogs data set ...") 47 | 48 | for line in parsedLines(self.dir .. "/movie_lines.txt", MOVIE_LINES_FIELDS) do 49 | lines[line.lineID] = line 50 | line.lineID = nil 51 | -- Remove unused fields 52 | line.characterID = nil 53 | line.movieID = nil 54 | count = count + 1 55 | progress(count) 56 | end 57 | 58 | for conv in parsedLines(self.dir .. "/movie_conversations.txt", MOVIE_CONVERSATIONS_FIELDS) do 59 | local conversation = {} 60 | local lineIDs = stringx.split(conv.utteranceIDs:sub(3, -3), "', '") 61 | for i,lineID in ipairs(lineIDs) do 62 | table.insert(conversation, lines[lineID]) 63 | end 64 | table.insert(conversations, conversation) 65 | count = count + 1 66 | progress(count) 67 | end 68 | 69 | xlua.progress(TOTAL_LINES, TOTAL_LINES) 70 | 71 | return conversations 72 | end 73 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macournoyer/neuralconvo/52b648c7f9a430194855e99beb96752faea2de13/data/.gitkeep -------------------------------------------------------------------------------- /dataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Format movie dialog data as a table of line 1: 3 | 4 | { {word_ids of character1}, {word_ids of character2} } 5 | 6 | Then flips it around and get the dialog from the other character's perspective: 7 | 8 | { {word_ids of character2}, {word_ids of character1} } 9 | 10 | Also builds the vocabulary. 11 | ]]-- 12 | 13 | local DataSet = torch.class("neuralconvo.DataSet") 14 | local xlua = require "xlua" 15 | local tokenizer = require "tokenizer" 16 | local list = require "pl.List" 17 | 18 | function DataSet:__init(loader, options) 19 | options = options or {} 20 | 21 | self.examplesFilename = "data/examples.t7" 22 | 23 | -- Reject words once vocab size reaches this threshold 24 | self.maxVocabSize = options.maxVocabSize or 0 25 | 26 | -- Maximum number of words in an example sentence 27 | self.maxExampleLen = options.maxExampleLen or 25 28 | 29 | -- Load only first fews examples (approximately) 30 | self.loadFirst = options.loadFirst 31 | 32 | self.examples = {} 33 | self.word2id = {} 34 | self.id2word = {} 35 | self.wordsCount = 0 36 | 37 | self:load(loader) 38 | end 39 | 40 | function DataSet:load(loader) 41 | local filename = "data/vocab.t7" 42 | 43 | if path.exists(filename) then 44 | print("Loading vocabulary from " .. filename .. " ...") 45 | local data = torch.load(filename) 46 | self.word2id = data.word2id 47 | self.id2word = data.id2word 48 | self.wordsCount = data.wordsCount 49 | self.goToken = data.goToken 50 | self.eosToken = data.eosToken 51 | self.unknownToken = data.unknownToken 52 | self.examplesCount = data.examplesCount 53 | else 54 | print("" .. filename .. " not found") 55 | self:visit(loader:load()) 56 | print("Writing " .. filename .. " ...") 57 | torch.save(filename, { 58 | word2id = self.word2id, 59 | id2word = self.id2word, 60 | wordsCount = self.wordsCount, 61 | goToken = self.goToken, 62 | eosToken = self.eosToken, 63 | unknownToken = self.unknownToken, 64 | examplesCount = self.examplesCount 65 | }) 66 | end 67 | end 68 | 69 | function DataSet:visit(conversations) 70 | self.examples = {} 71 | 72 | -- Add magic tokens 73 | self.goToken = self:makeWordId("") -- Start of sequence 74 | self.eosToken = self:makeWordId("") -- End of sequence 75 | self.unknownToken = self:makeWordId("") -- Word dropped from vocabulary 76 | 77 | print("-- Pre-processing data") 78 | 79 | local total = self.loadFirst or #conversations * 2 80 | 81 | for i, conversation in ipairs(conversations) do 82 | if i > total then break end 83 | self:visitConversation(conversation) 84 | xlua.progress(i, total) 85 | end 86 | 87 | -- Revisit from the perspective of 2nd character 88 | for i, conversation in ipairs(conversations) do 89 | if #conversations + i > total then break end 90 | self:visitConversation(conversation, 2) 91 | xlua.progress(#conversations + i, total) 92 | end 93 | 94 | print("-- Shuffling ") 95 | newIdxs = torch.randperm(#self.examples) 96 | local sExamples = {} 97 | for i, sample in ipairs(self.examples) do 98 | sExamples[i] = self.examples[newIdxs[i]] 99 | end 100 | self.examples = sExamples 101 | 102 | self.examplesCount = #self.examples 103 | self:writeExamplesToFile() 104 | self.examples = nil 105 | 106 | collectgarbage() 107 | end 108 | 109 | function DataSet:writeExamplesToFile() 110 | print("Writing " .. self.examplesFilename .. " ...") 111 | local file = torch.DiskFile(self.examplesFilename, "w") 112 | 113 | for i, example in ipairs(self.examples) do 114 | file:writeObject(example) 115 | xlua.progress(i, #self.examples) 116 | end 117 | 118 | file:close() 119 | end 120 | 121 | function DataSet:batches(size) 122 | local file = torch.DiskFile(self.examplesFilename, "r") 123 | file:quiet() 124 | local done = false 125 | 126 | return function() 127 | if done then 128 | return 129 | end 130 | 131 | local inputSeqs,targetSeqs = {},{} 132 | local maxInputSeqLen,maxTargetOutputSeqLen = 0,0 133 | 134 | for i = 1, size do 135 | local example = file:readObject() 136 | if example == nil then 137 | done = true 138 | file:close() 139 | return examples 140 | end 141 | inputSeq,targetSeq = unpack(example) 142 | if inputSeq:size(1) > maxInputSeqLen then 143 | maxInputSeqLen = inputSeq:size(1) 144 | end 145 | if targetSeq:size(1) > maxTargetOutputSeqLen then 146 | maxTargetOutputSeqLen = targetSeq:size(1) 147 | end 148 | table.insert(inputSeqs, inputSeq) 149 | table.insert(targetSeqs, targetSeq) 150 | end 151 | 152 | local encoderInputs,decoderInputs,decoderTargets = nil,nil,nil 153 | if size == 1 then 154 | encoderInputs = torch.IntTensor(maxInputSeqLen):fill(0) 155 | decoderInputs = torch.IntTensor(maxTargetOutputSeqLen-1):fill(0) 156 | decoderTargets = torch.IntTensor(maxTargetOutputSeqLen-1):fill(0) 157 | else 158 | encoderInputs = torch.IntTensor(maxInputSeqLen,size):fill(0) 159 | decoderInputs = torch.IntTensor(maxTargetOutputSeqLen-1,size):fill(0) 160 | decoderTargets = torch.IntTensor(maxTargetOutputSeqLen-1,size):fill(0) 161 | end 162 | 163 | for samplenb = 1, #inputSeqs do 164 | for word = 1,inputSeqs[samplenb]:size(1) do 165 | eosOffset = maxInputSeqLen - inputSeqs[samplenb]:size(1) -- for left padding 166 | if size == 1 then 167 | encoderInputs[word] = inputSeqs[samplenb][word] 168 | else 169 | encoderInputs[word+eosOffset][samplenb] = inputSeqs[samplenb][word] 170 | end 171 | end 172 | end 173 | 174 | for samplenb = 1, #targetSeqs do 175 | trimmedEosToken = targetSeqs[samplenb]:sub(1,-2) 176 | for word = 1, trimmedEosToken:size(1) do 177 | if size == 1 then 178 | decoderInputs[word] = trimmedEosToken[word] 179 | else 180 | decoderInputs[word][samplenb] = trimmedEosToken[word] 181 | end 182 | end 183 | end 184 | 185 | for samplenb = 1, #targetSeqs do 186 | trimmedGoToken = targetSeqs[samplenb]:sub(2,-1) 187 | for word = 1, trimmedGoToken:size(1) do 188 | if size == 1 then 189 | decoderTargets[word] = trimmedGoToken[word] 190 | else 191 | decoderTargets[word][samplenb] = trimmedGoToken[word] 192 | end 193 | end 194 | end 195 | 196 | return encoderInputs,decoderInputs,decoderTargets 197 | end 198 | end 199 | 200 | function DataSet:visitConversation(lines, start) 201 | start = start or 1 202 | 203 | for i = start, #lines, 2 do 204 | local input = lines[i] 205 | local target = lines[i+1] 206 | 207 | if target then 208 | local inputIds = self:visitText(input.text) 209 | local targetIds = self:visitText(target.text, 2) 210 | 211 | if inputIds and targetIds then 212 | -- Revert inputs 213 | inputIds = list.reverse(inputIds) 214 | 215 | table.insert(targetIds, 1, self.goToken) 216 | table.insert(targetIds, self.eosToken) 217 | 218 | table.insert(self.examples, { torch.IntTensor(inputIds), torch.IntTensor(targetIds) }) 219 | end 220 | end 221 | end 222 | end 223 | 224 | function DataSet:visitText(text, additionalTokens) 225 | local words = {} 226 | additionalTokens = additionalTokens or 0 227 | 228 | if text == "" then 229 | return 230 | end 231 | 232 | for t, word in tokenizer.tokenize(text) do 233 | table.insert(words, self:makeWordId(word)) 234 | -- Only keep the first sentence 235 | if t == "endpunct" or #words >= self.maxExampleLen - additionalTokens then 236 | break 237 | end 238 | end 239 | 240 | if #words == 0 then 241 | return 242 | end 243 | 244 | return words 245 | end 246 | 247 | function DataSet:makeWordId(word) 248 | if self.maxVocabSize > 0 and self.wordsCount >= self.maxVocabSize then 249 | -- We've reached the maximum size for the vocab. Replace w/ unknown token 250 | return self.unknownToken 251 | end 252 | 253 | word = word:lower() 254 | 255 | local id = self.word2id[word] 256 | 257 | if not id then 258 | self.wordsCount = self.wordsCount + 1 259 | id = self.wordsCount 260 | self.id2word[id] = word 261 | self.word2id[word] = id 262 | end 263 | 264 | return id 265 | end 266 | -------------------------------------------------------------------------------- /eval.lua: -------------------------------------------------------------------------------- 1 | require 'neuralconvo' 2 | local tokenizer = require "tokenizer" 3 | local list = require "pl.List" 4 | local options = {} 5 | 6 | cmd = torch.CmdLine() 7 | cmd:text('Options:') 8 | cmd:option('--debug', false, 'show debug info') 9 | cmd:text() 10 | options = cmd:parse(arg) 11 | 12 | -- Data 13 | dataset = neuralconvo.DataSet() 14 | 15 | print("-- Loading model") 16 | model = torch.load("data/model.t7") 17 | 18 | -- Word IDs to sentence 19 | function pred2sent(wordIds, i) 20 | local words = {} 21 | i = i or 1 22 | 23 | for _, wordId in ipairs(wordIds) do 24 | local word = dataset.id2word[wordId[i]] 25 | table.insert(words, word) 26 | end 27 | 28 | return tokenizer.join(words) 29 | end 30 | 31 | function printProbabilityTable(wordIds, probabilities, num) 32 | print(string.rep("-", num * 22)) 33 | 34 | for p, wordId in ipairs(wordIds) do 35 | local line = "| " 36 | for i = 1, num do 37 | local word = dataset.id2word[wordId[i]] 38 | line = line .. string.format("%-10s(%4d%%)", word, torch.exp(probabilities[p][i]) * 100) .. " | " 39 | end 40 | print(line) 41 | end 42 | 43 | print(string.rep("-", num * 22)) 44 | end 45 | 46 | function say(text) 47 | local wordIds = {} 48 | 49 | for t, word in tokenizer.tokenize(text) do 50 | local id = dataset.word2id[word:lower()] or dataset.unknownToken 51 | table.insert(wordIds, id) 52 | end 53 | 54 | local input = torch.Tensor(list.reverse(wordIds)) 55 | local wordIds, probabilities = model:eval(input) 56 | 57 | print("neuralconvo> " .. pred2sent(wordIds)) 58 | 59 | if options.debug then 60 | printProbabilityTable(wordIds, probabilities, 4) 61 | end 62 | end 63 | 64 | print("\nType a sentence and hit enter to submit.") 65 | print("CTRL+C then enter to quit.\n") 66 | while true do 67 | io.write("you> ") 68 | io.flush() 69 | io.write(say(io.read())) 70 | end 71 | -------------------------------------------------------------------------------- /neuralconvo.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'rnn' 4 | 5 | neuralconvo = {} 6 | 7 | torch.include('neuralconvo', 'cornell_movie_dialogs.lua') 8 | torch.include('neuralconvo', 'dataset.lua') 9 | torch.include('neuralconvo', 'seq2seq.lua') 10 | 11 | return neuralconvo 12 | -------------------------------------------------------------------------------- /seq2seq.lua: -------------------------------------------------------------------------------- 1 | -- Based on https://github.com/Element-Research/rnn/blob/master/examples/encoder-decoder-coupling.lua 2 | local Seq2Seq = torch.class("neuralconvo.Seq2Seq") 3 | 4 | function Seq2Seq:__init(vocabSize, hiddenSize) 5 | self.vocabSize = assert(vocabSize, "vocabSize required at arg #1") 6 | self.hiddenSize = assert(hiddenSize, "hiddenSize required at arg #2") 7 | self:buildModel() 8 | end 9 | 10 | function Seq2Seq:buildModel() 11 | self.encoder = nn.Sequential() 12 | self.encoder:add(nn.LookupTableMaskZero(self.vocabSize, self.hiddenSize)) 13 | self.encoderLSTM = nn.FastLSTM(self.hiddenSize, self.hiddenSize):maskZero(1) 14 | self.encoder:add(nn.Sequencer(self.encoderLSTM)) 15 | self.encoder:add(nn.Select(1,-1)) 16 | 17 | self.decoder = nn.Sequential() 18 | self.decoder:add(nn.LookupTableMaskZero(self.vocabSize, self.hiddenSize)) 19 | self.decoderLSTM = nn.FastLSTM(self.hiddenSize, self.hiddenSize):maskZero(1) 20 | self.decoder:add(nn.Sequencer(self.decoderLSTM)) 21 | self.decoder:add(nn.Sequencer(nn.MaskZero(nn.Linear(self.hiddenSize, self.vocabSize),1))) 22 | self.decoder:add(nn.Sequencer(nn.MaskZero(nn.LogSoftMax(),1))) 23 | 24 | self.encoder:zeroGradParameters() 25 | self.decoder:zeroGradParameters() 26 | end 27 | 28 | function Seq2Seq:cuda() 29 | self.encoder:cuda() 30 | self.decoder:cuda() 31 | 32 | if self.criterion then 33 | self.criterion:cuda() 34 | end 35 | end 36 | 37 | function Seq2Seq:float() 38 | self.encoder:float() 39 | self.decoder:float() 40 | 41 | if self.criterion then 42 | self.criterion:float() 43 | end 44 | end 45 | 46 | function Seq2Seq:cl() 47 | self.encoder:cl() 48 | self.decoder:cl() 49 | 50 | if self.criterion then 51 | self.criterion:cl() 52 | end 53 | end 54 | 55 | function Seq2Seq:getParameters() 56 | return nn.Container():add(self.encoder):add(self.decoder):getParameters() 57 | end 58 | 59 | --[[ Forward coupling: Copy encoder cell and output to decoder LSTM ]]-- 60 | function Seq2Seq:forwardConnect(inputSeqLen) 61 | self.decoderLSTM.userPrevOutput = 62 | nn.rnn.recursiveCopy(self.decoderLSTM.userPrevOutput, self.encoderLSTM.outputs[inputSeqLen]) 63 | self.decoderLSTM.userPrevCell = 64 | nn.rnn.recursiveCopy(self.decoderLSTM.userPrevCell, self.encoderLSTM.cells[inputSeqLen]) 65 | end 66 | 67 | --[[ Backward coupling: Copy decoder gradients to encoder LSTM ]]-- 68 | function Seq2Seq:backwardConnect(inputSeqLen) 69 | self.encoderLSTM:setGradHiddenState(inputSeqLen, self.decoderLSTM:getGradHiddenState(0)) 70 | end 71 | 72 | local MAX_OUTPUT_SIZE = 20 73 | 74 | function Seq2Seq:eval(input) 75 | assert(self.goToken, "No goToken specified") 76 | assert(self.eosToken, "No eosToken specified") 77 | 78 | self.encoder:forward(input) 79 | self:forwardConnect(input:size(1)) 80 | 81 | local predictions = {} 82 | local probabilities = {} 83 | 84 | -- Forward and all of it's output recursively back to the decoder 85 | local output = {self.goToken} 86 | for i = 1, MAX_OUTPUT_SIZE do 87 | local prediction = self.decoder:forward(torch.Tensor(output))[#output] 88 | -- prediction contains the probabilities for each word IDs. 89 | -- The index of the probability is the word ID. 90 | local prob, wordIds = prediction:topk(5, 1, true, true) 91 | 92 | -- First one is the most likely. 93 | next_output = wordIds[1] 94 | table.insert(output, next_output) 95 | 96 | -- Terminate on EOS token 97 | if next_output == self.eosToken then 98 | break 99 | end 100 | 101 | table.insert(predictions, wordIds) 102 | table.insert(probabilities, prob) 103 | end 104 | 105 | self.decoder:forget() 106 | self.encoder:forget() 107 | 108 | return predictions, probabilities 109 | end 110 | -------------------------------------------------------------------------------- /tokenizer.lua: -------------------------------------------------------------------------------- 1 | local lexer = require "pl.lexer" 2 | local yield = coroutine.yield 3 | local M = {} 4 | 5 | local function word(token) 6 | return yield("word", token) 7 | end 8 | 9 | local function quote(token) 10 | return yield("quote", token) 11 | end 12 | 13 | local function space(token) 14 | return yield("space", token) 15 | end 16 | 17 | local function tag(token) 18 | return yield("tag", token) 19 | end 20 | 21 | local function punct(token) 22 | return yield("punct", token) 23 | end 24 | 25 | local function endpunct(token) 26 | return yield("endpunct", token) 27 | end 28 | 29 | local function unknown(token) 30 | return yield("unknown", token) 31 | end 32 | 33 | function M.tokenize(text) 34 | return lexer.scan(text, { 35 | { "^%s+", space }, 36 | { "^['\"]", quote }, 37 | { "^%w+", word }, 38 | { "^%-+", space }, 39 | { "^[,:;%-]", punct }, 40 | { "^%.+", endpunct }, 41 | { "^[%.%?!]", endpunct }, 42 | { "^", tag }, 43 | { "^.", unknown }, 44 | }, { [space]=true, [tag]=true }) 45 | end 46 | 47 | function M.join(words) 48 | local s = table.concat(words, " ") 49 | s = s:gsub("^%l", string.upper) 50 | s = s:gsub(" (') ", "%1") 51 | s = s:gsub(" ([,:;%-%.%?!])", "%1") 52 | return s 53 | end 54 | 55 | return M -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | require 'neuralconvo' 2 | require 'xlua' 3 | require 'optim' 4 | 5 | cmd = torch.CmdLine() 6 | cmd:text('Options:') 7 | cmd:option('--dataset', 0, 'approximate size of dataset to use (0 = all)') 8 | cmd:option('--maxVocabSize', 0, 'max number of words in the vocab (0 = no limit)') 9 | cmd:option('--cuda', false, 'use CUDA') 10 | cmd:option('--opencl', false, 'use opencl') 11 | cmd:option('--hiddenSize', 300, 'number of hidden units in LSTM') 12 | cmd:option('--learningRate', 0.001, 'learning rate at t=0') 13 | cmd:option('--gradientClipping', 5, 'clip gradients at this value') 14 | cmd:option('--momentum', 0.9, 'momentum') 15 | cmd:option('--minLR', 0.00001, 'minimum learning rate') 16 | cmd:option('--saturateEpoch', 20, 'epoch at which linear decayed LR will reach minLR') 17 | cmd:option('--maxEpoch', 50, 'maximum number of epochs to run') 18 | cmd:option('--batchSize', 10, 'mini-batch size') 19 | cmd:option('--gpu', 0, 'Zero-indexed ID of the GPU to use. Optional.') 20 | 21 | cmd:text() 22 | options = cmd:parse(arg) 23 | 24 | if options.dataset == 0 then 25 | options.dataset = nil 26 | end 27 | 28 | -- Data 29 | print("-- Loading dataset") 30 | dataset = neuralconvo.DataSet(neuralconvo.CornellMovieDialogs("data/cornell_movie_dialogs"), 31 | { 32 | loadFirst = options.dataset, 33 | maxVocabSize = options.maxVocabSize 34 | }) 35 | 36 | print("\nDataset stats:") 37 | print(" Vocabulary size: " .. dataset.wordsCount) 38 | print(" Examples: " .. dataset.examplesCount) 39 | 40 | -- Model 41 | model = neuralconvo.Seq2Seq(dataset.wordsCount, options.hiddenSize) 42 | model.goToken = dataset.goToken 43 | model.eosToken = dataset.eosToken 44 | 45 | -- Training parameters 46 | if options.batchSize > 1 then 47 | model.criterion = nn.SequencerCriterion(nn.MaskZeroCriterion(nn.ClassNLLCriterion(),1)) 48 | else 49 | model.criterion = nn.SequencerCriterion(nn.ClassNLLCriterion()) 50 | end 51 | 52 | local decayFactor = (options.minLR - options.learningRate) / options.saturateEpoch 53 | local minMeanError = nil 54 | 55 | -- Enabled CUDA 56 | if options.cuda then 57 | require 'cutorch' 58 | require 'cunn' 59 | cutorch.setDevice(options.gpu + 1) 60 | model:cuda() 61 | elseif options.opencl then 62 | require 'cltorch' 63 | require 'clnn' 64 | cltorch.setDevice(options.gpu + 1) 65 | model:cl() 66 | end 67 | 68 | -- Run the experiment 69 | local optimState = {learningRate=options.learningRate,momentum=options.momentum} 70 | for epoch = 1, options.maxEpoch do 71 | collectgarbage() 72 | 73 | local nextBatch = dataset:batches(options.batchSize) 74 | local params, gradParams = model:getParameters() 75 | 76 | -- Define optimizer 77 | local function feval(x) 78 | if x ~= params then 79 | params:copy(x) 80 | end 81 | 82 | gradParams:zero() 83 | local encoderInputs, decoderInputs, decoderTargets = nextBatch() 84 | 85 | if options.cuda then 86 | encoderInputs = encoderInputs:cuda() 87 | decoderInputs = decoderInputs:cuda() 88 | decoderTargets = decoderTargets:cuda() 89 | elseif options.opencl then 90 | encoderInputs = encoderInputs:cl() 91 | decoderInputs = decoderInputs:cl() 92 | decoderTargets = decoderTargets:cl() 93 | end 94 | 95 | -- Forward pass 96 | local encoderOutput = model.encoder:forward(encoderInputs) 97 | model:forwardConnect(encoderInputs:size(1)) 98 | local decoderOutput = model.decoder:forward(decoderInputs) 99 | local loss = model.criterion:forward(decoderOutput, decoderTargets) 100 | 101 | local avgSeqLen = nil 102 | if #decoderInputs:size() == 1 then 103 | avgSeqLen = decoderInputs:size(1) 104 | else 105 | avgSeqLen = torch.sum(torch.sign(decoderInputs)) / decoderInputs:size(2) 106 | end 107 | loss = loss / avgSeqLen 108 | 109 | -- Backward pass 110 | local dloss_doutput = model.criterion:backward(decoderOutput, decoderTargets) 111 | model.decoder:backward(decoderInputs, dloss_doutput) 112 | model:backwardConnect(encoderInputs:size(1)) 113 | model.encoder:backward(encoderInputs, encoderOutput:zero()) 114 | 115 | gradParams:clamp(-options.gradientClipping, options.gradientClipping) 116 | 117 | return loss,gradParams 118 | end 119 | 120 | -- run epoch 121 | 122 | print("\n-- Epoch " .. epoch .. " / " .. options.maxEpoch .. 123 | " (LR= " .. optimState.learningRate .. ")") 124 | print("") 125 | 126 | local errors = {} 127 | local timer = torch.Timer() 128 | 129 | for i=1, dataset.examplesCount/options.batchSize do 130 | collectgarbage() 131 | local _,tloss = optim.adam(feval, params, optimState) 132 | err = tloss[1] -- optim returns a list 133 | 134 | model.decoder:forget() 135 | model.encoder:forget() 136 | 137 | table.insert(errors,err) 138 | xlua.progress(i * options.batchSize, dataset.examplesCount) 139 | end 140 | 141 | xlua.progress(dataset.examplesCount, dataset.examplesCount) 142 | timer:stop() 143 | 144 | errors = torch.Tensor(errors) 145 | print("\n\nFinished in " .. xlua.formatTime(timer:time().real) .. 146 | " " .. (dataset.examplesCount / timer:time().real) .. ' examples/sec.') 147 | print("\nEpoch stats:") 148 | print(" Errors: min= " .. errors:min()) 149 | print(" max= " .. errors:max()) 150 | print(" median= " .. errors:median()[1]) 151 | print(" mean= " .. errors:mean()) 152 | print(" std= " .. errors:std()) 153 | print(" ppl= " .. torch.exp(errors:mean())) 154 | 155 | -- Save the model if it improved. 156 | if minMeanError == nil or errors:mean() < minMeanError then 157 | print("\n(Saving model ...)") 158 | params, gradParams = nil,nil 159 | collectgarbage() 160 | -- Model is saved as CPU 161 | model:float() 162 | torch.save("data/model.t7", model) 163 | collectgarbage() 164 | if options.cuda then 165 | model:cuda() 166 | elseif options.opencl then 167 | model:cl() 168 | end 169 | collectgarbage() 170 | minMeanError = errors:mean() 171 | end 172 | 173 | optimState.learningRate = optimState.learningRate + decayFactor 174 | optimState.learningRate = math.max(options.minLR, optimState.learningRate) 175 | end 176 | --------------------------------------------------------------------------------