├── .gitignore ├── LICENSE ├── README.md ├── cornell_movie_dialogs.lua ├── dataset.lua ├── eval.lua ├── movie_script_parser.lua ├── neuralconvo.lua ├── seq2seq.lua ├── tokenizer.lua └── train.lua /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | log/* 3 | *.log 4 | 5 | # Compiled Lua sources 6 | luac.out 7 | 8 | # luarocks build files 9 | *.src.rock 10 | *.zip 11 | *.tar.gz 12 | 13 | # Object files 14 | *.o 15 | *.os 16 | *.ko 17 | *.obj 18 | *.elf 19 | 20 | # Precompiled Headers 21 | *.gch 22 | *.pch 23 | 24 | # Libraries 25 | *.lib 26 | *.a 27 | *.la 28 | *.lo 29 | *.def 30 | *.exp 31 | 32 | # Shared objects (inc. Windows DLLs) 33 | *.dll 34 | *.so 35 | *.so.* 36 | *.dylib 37 | 38 | # Executables 39 | *.exe 40 | *.out 41 | *.app 42 | *.i*86 43 | *.x86_64 44 | *.hex 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Marc-Andre Cournoyer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Overview 2 | ============ 3 | This is an attempt at implementing [Sequence to Sequence Learning with Neural Networks (seq2seq)](http://arxiv.org/abs/1409.3215) and reproducing the results in [A Neural Conversational Model](http://arxiv.org/abs/1506.05869) (aka the Google chatbot). The model is based on two [LSTM](https://en.wikipedia.org/wiki/Long_short-term_memory) layers. One for encoding the input sentence into a "thought vector", and another for decoding that vector into a response. This model is called Sequence-to-sequence or seq2seq. This the code for 'Build a Chatbot' on [Youtube](https://www.youtube.com/watch?v=5_SAroSvC0E&feature=youtu.be) 4 | 5 | ![seq2seq](https://4.bp.blogspot.com/-aArS0l1pjHQ/Vjj71pKAaEI/AAAAAAAAAxE/Nvy1FSbD_Vs/s640/2TFstaticgraphic_alt-01.png) 6 | _Source: http://googleresearch.blogspot.ca/2015/11/computer-respond-to-this-email.html_ 7 | 8 | 9 | Dependencies 10 | ============ 11 | 12 | 1. [Install Torch](http://torch.ch/docs/getting-started.html). 13 | 2. Install the following additional Lua libs: 14 | 15 | ```sh 16 | luarocks install nn 17 | luarocks install rnn 18 | luarocks install penlight 19 | ``` 20 | 21 | To train with CUDA install the latest CUDA drivers, toolkit and run: 22 | 23 | ```sh 24 | luarocks install cutorch 25 | luarocks install cunn 26 | ``` 27 | 28 | To train with opencl install the lastest Opencl torch lib: 29 | 30 | ```sh 31 | luarocks install cltorch 32 | luarocks install clnn 33 | ``` 34 | 35 | 3. Download the [Cornell Movie-Dialogs Corpus](http://www.mpi-sws.org/~cristian/Cornell_Movie-Dialogs_Corpus.html) and extract all the files into data/cornell_movie_dialogs. 36 | 37 | 38 | Basic Usage 39 | =========== 40 | Here's a sample conversation after training for 20 epoch with 50000 examples, using the following command: 41 | 42 | ```sh 43 | th train.lua --cuda --dataset 50000 --hiddenSize 1000 44 | ``` 45 | 46 | (Took 3 days to train on my GeForce GTX 780M.) 47 | 48 | For OpenCL, use `--opencl` instead of `--cuda`. To train on CPU, don't provide any of those two. Use the `--dataset NUMBER` option to control the size of the dataset. Training on the full dataset takes about 5h for a single epoch. The model will be saved to `data/model.t7` after each epoch if it has improved (error decreased). 49 | 50 | > **me:** Hello? 51 | > **bot:** Hi. 52 | > 53 | > **me:** How are you? 54 | > **bot:** I'm fine. 55 | > 56 | > **me:** What's your name? 57 | > **bot:** It's hard to describe. 58 | > 59 | > **me:** How so? 60 | > **bot:** I'm not sure. 61 | > 62 | > **me:** What color is the sky? 63 | > **bot:** It's blue. 64 | > 65 | > **me:** What is your job? 66 | > **bot:** It's not that i'm a fucking werewolf! 67 | > 68 | > **me:** What is the purpose of life? 69 | > **bot:** A gift. 70 | > 71 | > **me:** Are you intelligent? 72 | > **bot:** Yes, well... 73 | > 74 | > **me:** Are you a machine? 75 | > **bot:** That's a lie. 76 | > 77 | > **me:** Are you human? 78 | > **bot:** No, i'm not. 79 | > 80 | > **me:** What are you? 81 | > **bot:** I'm not sure. 82 | 83 | To load the model and have a conversation: 84 | 85 | ```sh 86 | th -i eval.lua --cuda # Skip --cuda if you didn't train with it 87 | # ... 88 | th> say "Hello." 89 | ``` 90 | 91 | Credits 92 | =========== 93 | Credit for the vast majority of code here goes to [Marc-André Cournoyer](https://github.com/macournoyer). I've merely created a wrapper around all of the important functions to get people started. 94 | -------------------------------------------------------------------------------- /cornell_movie_dialogs.lua: -------------------------------------------------------------------------------- 1 | local CornellMovieDialogs = torch.class("neuralconvo.CornellMovieDialogs") 2 | local stringx = require "pl.stringx" 3 | local xlua = require "xlua" 4 | 5 | local function parsedLines(file, fields) 6 | local f = assert(io.open(file, 'r')) 7 | 8 | return function() 9 | local line = f:read("*line") 10 | 11 | if line == nil then 12 | f:close() 13 | return 14 | end 15 | 16 | local values = stringx.split(line, " +++$+++ ") 17 | local t = {} 18 | 19 | for i,field in ipairs(fields) do 20 | t[field] = values[i] 21 | end 22 | 23 | return t 24 | end 25 | end 26 | 27 | function CornellMovieDialogs:__init(dir) 28 | self.dir = dir 29 | end 30 | 31 | local MOVIE_LINES_FIELDS = {"lineID","characterID","movieID","character","text"} 32 | local MOVIE_CONVERSATIONS_FIELDS = {"character1ID","character2ID","movieID","utteranceIDs"} 33 | local TOTAL_LINES = 387810 34 | 35 | local function progress(c) 36 | if c % 10000 == 0 then 37 | xlua.progress(c, TOTAL_LINES) 38 | end 39 | end 40 | 41 | function CornellMovieDialogs:load() 42 | local lines = {} 43 | local conversations = {} 44 | local count = 0 45 | 46 | print("-- Parsing Cornell movie dialogs data set ...") 47 | 48 | for line in parsedLines(self.dir .. "/movie_lines.txt", MOVIE_LINES_FIELDS) do 49 | lines[line.lineID] = line 50 | line.lineID = nil 51 | -- Remove unused fields 52 | line.characterID = nil 53 | line.movieID = nil 54 | count = count + 1 55 | progress(count) 56 | end 57 | 58 | for conv in parsedLines(self.dir .. "/movie_conversations.txt", MOVIE_CONVERSATIONS_FIELDS) do 59 | local conversation = {} 60 | local lineIDs = stringx.split(conv.utteranceIDs:sub(3, -3), "', '") 61 | for i,lineID in ipairs(lineIDs) do 62 | table.insert(conversation, lines[lineID]) 63 | end 64 | table.insert(conversations, conversation) 65 | count = count + 1 66 | progress(count) 67 | end 68 | 69 | xlua.progress(TOTAL_LINES, TOTAL_LINES) 70 | 71 | return conversations 72 | end 73 | -------------------------------------------------------------------------------- /dataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Format movie dialog data as a table of line 1: 3 | 4 | { {word_ids of character1}, {word_ids of character2} } 5 | 6 | Then flips it around and get the dialog from the other character's perspective: 7 | 8 | { {word_ids of character2}, {word_ids of character1} } 9 | 10 | Also builds the vocabulary. 11 | ]]-- 12 | 13 | local DataSet = torch.class("neuralconvo.DataSet") 14 | local xlua = require "xlua" 15 | local tokenizer = require "tokenizer" 16 | local list = require "pl.List" 17 | 18 | function DataSet:__init(loader, options) 19 | options = options or {} 20 | 21 | self.examplesFilename = "data/examples.t7" 22 | 23 | -- Discard words with lower frequency then this 24 | self.minWordFreq = options.minWordFreq or 1 25 | 26 | -- Maximum number of words in an example sentence 27 | self.maxExampleLen = options.maxExampleLen or 25 28 | 29 | -- Load only first fews examples (approximately) 30 | self.loadFirst = options.loadFirst 31 | 32 | self.examples = {} 33 | self.word2id = {} 34 | self.id2word = {} 35 | self.wordsCount = 0 36 | 37 | self:load(loader) 38 | end 39 | 40 | function DataSet:load(loader) 41 | local filename = "data/vocab.t7" 42 | 43 | if path.exists(filename) then 44 | print("Loading vocabulary from " .. filename .. " ...") 45 | local data = torch.load(filename) 46 | self.word2id = data.word2id 47 | self.id2word = data.id2word 48 | self.wordsCount = data.wordsCount 49 | self.goToken = data.goToken 50 | self.eosToken = data.eosToken 51 | self.unknownToken = data.unknownToken 52 | self.examplesCount = data.examplesCount 53 | else 54 | print("" .. filename .. " not found") 55 | self:visit(loader:load()) 56 | print("Writing " .. filename .. " ...") 57 | torch.save(filename, { 58 | word2id = self.word2id, 59 | id2word = self.id2word, 60 | wordsCount = self.wordsCount, 61 | goToken = self.goToken, 62 | eosToken = self.eosToken, 63 | unknownToken = self.unknownToken, 64 | examplesCount = self.examplesCount 65 | }) 66 | end 67 | end 68 | 69 | function DataSet:visit(conversations) 70 | -- Table for keeping track of word frequency 71 | self.wordFreq = {} 72 | self.examples = {} 73 | 74 | -- Add magic tokens 75 | self.goToken = self:makeWordId("") -- Start of sequence 76 | self.eosToken = self:makeWordId("") -- End of sequence 77 | self.unknownToken = self:makeWordId("") -- Word dropped from vocabulary 78 | 79 | print("-- Pre-processing data") 80 | 81 | local total = self.loadFirst or #conversations * 2 82 | 83 | for i, conversation in ipairs(conversations) do 84 | if i > total then break end 85 | self:visitConversation(conversation) 86 | xlua.progress(i, total) 87 | end 88 | 89 | -- Revisit from the perspective of 2nd character 90 | for i, conversation in ipairs(conversations) do 91 | if #conversations + i > total then break end 92 | self:visitConversation(conversation, 2) 93 | xlua.progress(#conversations + i, total) 94 | end 95 | 96 | print("-- Removing low frequency words") 97 | 98 | for i, datum in ipairs(self.examples) do 99 | self:removeLowFreqWords(datum[1]) 100 | self:removeLowFreqWords(datum[2]) 101 | xlua.progress(i, #self.examples) 102 | end 103 | 104 | self.wordFreq = nil 105 | 106 | self.examplesCount = #self.examples 107 | self:writeExamplesToFile() 108 | self.examples = nil 109 | 110 | collectgarbage() 111 | end 112 | 113 | function DataSet:writeExamplesToFile() 114 | print("Writing " .. self.examplesFilename .. " ...") 115 | local file = torch.DiskFile(self.examplesFilename, "w") 116 | 117 | for i, example in ipairs(self.examples) do 118 | file:writeObject(example) 119 | xlua.progress(i, #self.examples) 120 | end 121 | 122 | file:close() 123 | end 124 | 125 | function DataSet:batches(size) 126 | local file = torch.DiskFile(self.examplesFilename, "r") 127 | file:quiet() 128 | local done = false 129 | 130 | return function() 131 | if done then 132 | return 133 | end 134 | 135 | local examples = {} 136 | 137 | for i = 1, size do 138 | local example = file:readObject() 139 | if example == nil then 140 | done = true 141 | file:close() 142 | return examples 143 | end 144 | table.insert(examples, example) 145 | end 146 | 147 | return examples 148 | end 149 | end 150 | 151 | function DataSet:removeLowFreqWords(input) 152 | for i = 1, input:size(1) do 153 | local id = input[i] 154 | local word = self.id2word[id] 155 | 156 | if word == nil then 157 | -- Already removed 158 | input[i] = self.unknownToken 159 | 160 | elseif self.wordFreq[word] < self.minWordFreq then 161 | input[i] = self.unknownToken 162 | 163 | self.word2id[word] = nil 164 | self.id2word[id] = nil 165 | self.wordsCount = self.wordsCount - 1 166 | end 167 | end 168 | end 169 | 170 | function DataSet:visitConversation(lines, start) 171 | start = start or 1 172 | 173 | for i = start, #lines, 2 do 174 | local input = lines[i] 175 | local target = lines[i+1] 176 | 177 | if target then 178 | local inputIds = self:visitText(input.text) 179 | local targetIds = self:visitText(target.text, 2) 180 | 181 | if inputIds and targetIds then 182 | -- Revert inputs 183 | inputIds = list.reverse(inputIds) 184 | 185 | table.insert(targetIds, 1, self.goToken) 186 | table.insert(targetIds, self.eosToken) 187 | 188 | table.insert(self.examples, { torch.IntTensor(inputIds), torch.IntTensor(targetIds) }) 189 | end 190 | end 191 | end 192 | end 193 | 194 | function DataSet:visitText(text, additionalTokens) 195 | local words = {} 196 | additionalTokens = additionalTokens or 0 197 | 198 | if text == "" then 199 | return 200 | end 201 | 202 | for t, word in tokenizer.tokenize(text) do 203 | table.insert(words, self:makeWordId(word)) 204 | -- Only keep the first sentence 205 | if t == "endpunct" or #words >= self.maxExampleLen - additionalTokens then 206 | break 207 | end 208 | end 209 | 210 | if #words == 0 then 211 | return 212 | end 213 | 214 | return words 215 | end 216 | 217 | function DataSet:makeWordId(word) 218 | word = word:lower() 219 | 220 | local id = self.word2id[word] 221 | 222 | if id then 223 | self.wordFreq[word] = self.wordFreq[word] + 1 224 | else 225 | self.wordsCount = self.wordsCount + 1 226 | id = self.wordsCount 227 | self.id2word[id] = word 228 | self.word2id[word] = id 229 | self.wordFreq[word] = 1 230 | end 231 | 232 | return id 233 | end 234 | -------------------------------------------------------------------------------- /eval.lua: -------------------------------------------------------------------------------- 1 | require 'neuralconvo' 2 | local tokenizer = require "tokenizer" 3 | local list = require "pl.List" 4 | local options = {} 5 | 6 | if dataset == nil then 7 | cmd = torch.CmdLine() 8 | cmd:text('Options:') 9 | cmd:option('--cuda', false, 'use CUDA. Training must be done on CUDA') 10 | cmd:option('--opencl', false, 'use OpenCL. Training must be done on OpenCL') 11 | cmd:option('--debug', false, 'show debug info') 12 | cmd:text() 13 | options = cmd:parse(arg) 14 | 15 | -- Data 16 | dataset = neuralconvo.DataSet() 17 | 18 | -- Enabled CUDA 19 | if options.cuda then 20 | require 'cutorch' 21 | require 'cunn' 22 | elseif options.opencl then 23 | require 'cltorch' 24 | require 'clnn' 25 | end 26 | end 27 | 28 | if model == nil then 29 | print("-- Loading model") 30 | model = torch.load("data/model.t7") 31 | end 32 | 33 | -- Word IDs to sentence 34 | function pred2sent(wordIds, i) 35 | local words = {} 36 | i = i or 1 37 | 38 | for _, wordId in ipairs(wordIds) do 39 | local word = dataset.id2word[wordId[i]] 40 | table.insert(words, word) 41 | end 42 | 43 | return tokenizer.join(words) 44 | end 45 | 46 | function printProbabilityTable(wordIds, probabilities, num) 47 | print(string.rep("-", num * 22)) 48 | 49 | for p, wordId in ipairs(wordIds) do 50 | local line = "| " 51 | for i = 1, num do 52 | local word = dataset.id2word[wordId[i]] 53 | line = line .. string.format("%-10s(%4d%%)", word, probabilities[p][i] * 100) .. " | " 54 | end 55 | print(line) 56 | end 57 | 58 | print(string.rep("-", num * 22)) 59 | end 60 | 61 | function say(text) 62 | local wordIds = {} 63 | 64 | for t, word in tokenizer.tokenize(text) do 65 | local id = dataset.word2id[word:lower()] or dataset.unknownToken 66 | table.insert(wordIds, id) 67 | end 68 | 69 | local input = torch.Tensor(list.reverse(wordIds)) 70 | local wordIds, probabilities = model:eval(input) 71 | 72 | print(">> " .. pred2sent(wordIds)) 73 | 74 | if options.debug then 75 | printProbabilityTable(wordIds, probabilities, 4) 76 | end 77 | end 78 | -------------------------------------------------------------------------------- /movie_script_parser.lua: -------------------------------------------------------------------------------- 1 | local Parser = torch.class("neuralconvo.MovieScriptParser") 2 | 3 | function Parser:parse(file) 4 | local f = assert(io.open(file, 'r')) 5 | self.input = f:read("*all") 6 | f:close() 7 | 8 | self.pos = 0 9 | self.match = nil 10 | 11 | -- Find start of script 12 | repeat self:acceptLine() until self:accept("
")
 13 | 
 14 |   local dialogs = {}
 15 | 
 16 |   -- Apply rules until end of script
 17 |   while not self:accept("
") and self:acceptLine() do 18 | local dialog = self:parseDialog() 19 | if dialog then 20 | table.insert(dialogs, dialog) 21 | end 22 | end 23 | 24 | return dialogs 25 | end 26 | 27 | -- Returns true if regexp matches and advance position 28 | function Parser:accept(regexp) 29 | local match = string.match(self.input, "^" .. regexp, self.pos) 30 | if match then 31 | self.pos = self.pos + #match 32 | self.match = match 33 | return true 34 | end 35 | end 36 | 37 | -- Accept anything up to the end of line 38 | function Parser:acceptLine() 39 | return self:accept(".-\n") 40 | end 41 | 42 | function Parser:acceptSep() 43 | while self:accept("") or self:accept(" +") do end 44 | return self:accept("\n") 45 | end 46 | 47 | function Parser:parseDialog() 48 | local dialogs = {} 49 | 50 | repeat 51 | local dialog = self:parseSpeech() 52 | if dialog then 53 | table.insert(dialogs, dialog) 54 | end 55 | until not self:acceptSep() 56 | 57 | if #dialogs > 0 then 58 | return dialogs 59 | end 60 | end 61 | 62 | -- Matches: 63 | -- 64 | -- NAME 65 | -- some nice text 66 | -- more text. 67 | -- 68 | -- or 69 | -- 70 | -- NAME; text 71 | function Parser:parseSpeech() 72 | local name 73 | 74 | self:accept("") 75 | self:accept("") 76 | 77 | -- Get the character name (all caps) 78 | -- TODO remove parenthesis from name 79 | if self:accept(" +") and self:accept("[A-Z][A-Z%- %.%(%)]+") then 80 | name = self.match 81 | else 82 | return 83 | end 84 | 85 | -- Handle inline dialog: `NAME; text` 86 | if self:accept(";") and self:accept("[^\n]+") then 87 | return { 88 | character = name, 89 | text = self.match 90 | } 91 | end 92 | 93 | self:accept("\n") 94 | 95 | if not self:accept("") then 96 | return 97 | end 98 | 99 | -- Get the dialog lines 100 | -- TODO remove parenthesis from text 101 | local lines = {} 102 | while self:accept(" +") do 103 | -- The actual line of dialog 104 | if self:accept("[^\n]+") then 105 | table.insert(lines, self.match) 106 | end 107 | self:accept("\n") 108 | end 109 | 110 | if #lines > 0 then 111 | return { 112 | character = name, 113 | text = table.concat(lines) 114 | } 115 | end 116 | end 117 | -------------------------------------------------------------------------------- /neuralconvo.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'rnn' 4 | 5 | neuralconvo = {} 6 | 7 | torch.include('neuralconvo', 'cornell_movie_dialogs.lua') 8 | torch.include('neuralconvo', 'dataset.lua') 9 | torch.include('neuralconvo', 'movie_script_parser.lua') 10 | torch.include('neuralconvo', 'seq2seq.lua') 11 | 12 | return neuralconvo 13 | -------------------------------------------------------------------------------- /seq2seq.lua: -------------------------------------------------------------------------------- 1 | -- Based on https://github.com/Element-Research/rnn/blob/master/examples/encoder-decoder-coupling.lua 2 | local Seq2Seq = torch.class("neuralconvo.Seq2Seq") 3 | 4 | function Seq2Seq:__init(vocabSize, hiddenSize) 5 | self.vocabSize = assert(vocabSize, "vocabSize required at arg #1") 6 | self.hiddenSize = assert(hiddenSize, "hiddenSize required at arg #2") 7 | 8 | self:buildModel() 9 | end 10 | 11 | function Seq2Seq:buildModel() 12 | self.encoder = nn.Sequential() 13 | self.encoder:add(nn.LookupTable(self.vocabSize, self.hiddenSize)) 14 | self.encoder:add(nn.SplitTable(1, 2)) 15 | self.encoderLSTM = nn.LSTM(self.hiddenSize, self.hiddenSize) 16 | self.encoder:add(nn.Sequencer(self.encoderLSTM)) 17 | self.encoder:add(nn.SelectTable(-1)) 18 | 19 | self.decoder = nn.Sequential() 20 | self.decoder:add(nn.LookupTable(self.vocabSize, self.hiddenSize)) 21 | self.decoder:add(nn.SplitTable(1, 2)) 22 | self.decoderLSTM = nn.LSTM(self.hiddenSize, self.hiddenSize) 23 | self.decoder:add(nn.Sequencer(self.decoderLSTM)) 24 | self.decoder:add(nn.Sequencer(nn.Linear(self.hiddenSize, self.vocabSize))) 25 | self.decoder:add(nn.Sequencer(nn.LogSoftMax())) 26 | 27 | self.encoder:zeroGradParameters() 28 | self.decoder:zeroGradParameters() 29 | end 30 | 31 | function Seq2Seq:cuda() 32 | self.encoder:cuda() 33 | self.decoder:cuda() 34 | 35 | if self.criterion then 36 | self.criterion:cuda() 37 | end 38 | end 39 | 40 | function Seq2Seq:cl() 41 | self.encoder:cl() 42 | self.decoder:cl() 43 | 44 | if self.criterion then 45 | self.criterion:cl() 46 | end 47 | end 48 | 49 | --[[ Forward coupling: Copy encoder cell and output to decoder LSTM ]]-- 50 | function Seq2Seq:forwardConnect(inputSeqLen) 51 | self.decoderLSTM.userPrevOutput = 52 | nn.rnn.recursiveCopy(self.decoderLSTM.userPrevOutput, self.encoderLSTM.outputs[inputSeqLen]) 53 | self.decoderLSTM.userPrevCell = 54 | nn.rnn.recursiveCopy(self.decoderLSTM.userPrevCell, self.encoderLSTM.cells[inputSeqLen]) 55 | end 56 | 57 | --[[ Backward coupling: Copy decoder gradients to encoder LSTM ]]-- 58 | function Seq2Seq:backwardConnect() 59 | self.encoderLSTM.userNextGradCell = 60 | nn.rnn.recursiveCopy(self.encoderLSTM.userNextGradCell, self.decoderLSTM.userGradPrevCell) 61 | self.encoderLSTM.gradPrevOutput = 62 | nn.rnn.recursiveCopy(self.encoderLSTM.gradPrevOutput, self.decoderLSTM.userGradPrevOutput) 63 | end 64 | 65 | function Seq2Seq:train(input, target) 66 | local encoderInput = input 67 | local decoderInput = target:sub(1, -2) 68 | local decoderTarget = target:sub(2, -1) 69 | 70 | -- Forward pass 71 | local encoderOutput = self.encoder:forward(encoderInput) 72 | self:forwardConnect(encoderInput:size(1)) 73 | local decoderOutput = self.decoder:forward(decoderInput) 74 | local Edecoder = self.criterion:forward(decoderOutput, decoderTarget) 75 | 76 | if Edecoder ~= Edecoder then -- Exist early on bad error 77 | return Edecoder 78 | end 79 | 80 | -- Backward pass 81 | local gEdec = self.criterion:backward(decoderOutput, decoderTarget) 82 | self.decoder:backward(decoderInput, gEdec) 83 | self:backwardConnect() 84 | self.encoder:backward(encoderInput, encoderOutput:zero()) 85 | 86 | self.encoder:updateGradParameters(self.momentum) 87 | self.decoder:updateGradParameters(self.momentum) 88 | self.decoder:updateParameters(self.learningRate) 89 | self.encoder:updateParameters(self.learningRate) 90 | self.encoder:zeroGradParameters() 91 | self.decoder:zeroGradParameters() 92 | 93 | self.decoder:forget() 94 | self.encoder:forget() 95 | 96 | return Edecoder 97 | end 98 | 99 | local MAX_OUTPUT_SIZE = 20 100 | 101 | function Seq2Seq:eval(input) 102 | assert(self.goToken, "No goToken specified") 103 | assert(self.eosToken, "No eosToken specified") 104 | 105 | self.encoder:forward(input) 106 | self:forwardConnect(input:size(1)) 107 | 108 | local predictions = {} 109 | local probabilities = {} 110 | 111 | -- Forward and all of it's output recursively back to the decoder 112 | local output = {self.goToken} 113 | for i = 1, MAX_OUTPUT_SIZE do 114 | local prediction = self.decoder:forward(torch.Tensor(output))[#output] 115 | -- prediction contains the probabilities for each word IDs. 116 | -- The index of the probability is the word ID. 117 | local prob, wordIds = prediction:topk(5, 1, true, true) 118 | 119 | -- First one is the most likely. 120 | next_output = wordIds[1] 121 | table.insert(output, next_output) 122 | 123 | -- Terminate on EOS token 124 | if next_output == self.eosToken then 125 | break 126 | end 127 | 128 | table.insert(predictions, wordIds) 129 | table.insert(probabilities, prob) 130 | end 131 | 132 | self.decoder:forget() 133 | self.encoder:forget() 134 | 135 | return predictions, probabilities 136 | end 137 | -------------------------------------------------------------------------------- /tokenizer.lua: -------------------------------------------------------------------------------- 1 | local lexer = require "pl.lexer" 2 | local yield = coroutine.yield 3 | local M = {} 4 | 5 | local function word(token) 6 | return yield("word", token) 7 | end 8 | 9 | local function quote(token) 10 | return yield("quote", token) 11 | end 12 | 13 | local function space(token) 14 | return yield("space", token) 15 | end 16 | 17 | local function tag(token) 18 | return yield("tag", token) 19 | end 20 | 21 | local function punct(token) 22 | return yield("punct", token) 23 | end 24 | 25 | local function endpunct(token) 26 | return yield("endpunct", token) 27 | end 28 | 29 | local function unknown(token) 30 | return yield("unknown", token) 31 | end 32 | 33 | function M.tokenize(text) 34 | return lexer.scan(text, { 35 | { "^%s+", space }, 36 | { "^['\"]", quote }, 37 | { "^%w+", word }, 38 | { "^%-+", space }, 39 | { "^[,:;%-]", punct }, 40 | { "^%.+", endpunct }, 41 | { "^[%.%?!]", endpunct }, 42 | { "^", tag }, 43 | { "^.", unknown }, 44 | }, { [space]=true, [tag]=true }) 45 | end 46 | 47 | function M.join(words) 48 | local s = table.concat(words, " ") 49 | s = s:gsub("^%l", string.upper) 50 | s = s:gsub(" (') ", "%1") 51 | s = s:gsub(" ([,:;%-%.%?!])", "%1") 52 | return s 53 | end 54 | 55 | return M -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | require 'neuralconvo' 2 | require 'xlua' 3 | 4 | cmd = torch.CmdLine() 5 | cmd:text('Options:') 6 | cmd:option('--dataset', 0, 'approximate size of dataset to use (0 = all)') 7 | cmd:option('--minWordFreq', 1, 'minimum frequency of words kept in vocab') 8 | cmd:option('--cuda', false, 'use CUDA') 9 | cmd:option('--opencl', false, 'use opencl') 10 | cmd:option('--hiddenSize', 300, 'number of hidden units in LSTM') 11 | cmd:option('--learningRate', 0.05, 'learning rate at t=0') 12 | cmd:option('--momentum', 0.9, 'momentum') 13 | cmd:option('--minLR', 0.00001, 'minimum learning rate') 14 | cmd:option('--saturateEpoch', 20, 'epoch at which linear decayed LR will reach minLR') 15 | cmd:option('--maxEpoch', 50, 'maximum number of epochs to run') 16 | cmd:option('--batchSize', 1000, 'number of examples to load at once') 17 | 18 | cmd:text() 19 | options = cmd:parse(arg) 20 | 21 | if options.dataset == 0 then 22 | options.dataset = nil 23 | end 24 | 25 | -- Data 26 | print("-- Loading dataset") 27 | dataset = neuralconvo.DataSet(neuralconvo.CornellMovieDialogs("data/cornell_movie_dialogs"), 28 | { 29 | loadFirst = options.dataset, 30 | minWordFreq = options.minWordFreq 31 | }) 32 | 33 | print("\nDataset stats:") 34 | print(" Vocabulary size: " .. dataset.wordsCount) 35 | print(" Examples: " .. dataset.examplesCount) 36 | 37 | -- Model 38 | model = neuralconvo.Seq2Seq(dataset.wordsCount, options.hiddenSize) 39 | model.goToken = dataset.goToken 40 | model.eosToken = dataset.eosToken 41 | 42 | -- Training parameters 43 | model.criterion = nn.SequencerCriterion(nn.ClassNLLCriterion()) 44 | model.learningRate = options.learningRate 45 | model.momentum = options.momentum 46 | local decayFactor = (options.minLR - options.learningRate) / options.saturateEpoch 47 | local minMeanError = nil 48 | 49 | -- Enabled CUDA 50 | if options.cuda then 51 | require 'cutorch' 52 | require 'cunn' 53 | model:cuda() 54 | elseif options.opencl then 55 | require 'cltorch' 56 | require 'clnn' 57 | model:cl() 58 | end 59 | 60 | 61 | -- Run the experiment 62 | 63 | for epoch = 1, options.maxEpoch do 64 | print("\n-- Epoch " .. epoch .. " / " .. options.maxEpoch) 65 | print("") 66 | 67 | local errors = torch.Tensor(dataset.examplesCount):fill(0) 68 | local timer = torch.Timer() 69 | 70 | local i = 1 71 | for examples in dataset:batches(options.batchSize) do 72 | collectgarbage() 73 | 74 | for _, example in ipairs(examples) do 75 | local input, target = unpack(example) 76 | 77 | if options.cuda then 78 | input = input:cuda() 79 | target = target:cuda() 80 | elseif options.opencl then 81 | input = input:cl() 82 | target = target:cl() 83 | end 84 | 85 | local err = model:train(input, target) 86 | 87 | -- Check if error is NaN. If so, it's probably a bug. 88 | if err ~= err then 89 | error("Invalid error! Exiting.") 90 | end 91 | 92 | errors[i] = err 93 | xlua.progress(i, dataset.examplesCount) 94 | i = i + 1 95 | end 96 | end 97 | 98 | timer:stop() 99 | 100 | print("\nFinished in " .. xlua.formatTime(timer:time().real) .. " " .. (dataset.examplesCount / timer:time().real) .. ' examples/sec.') 101 | print("\nEpoch stats:") 102 | print(" LR= " .. model.learningRate) 103 | print(" Errors: min= " .. errors:min()) 104 | print(" max= " .. errors:max()) 105 | print(" median= " .. errors:median()[1]) 106 | print(" mean= " .. errors:mean()) 107 | print(" std= " .. errors:std()) 108 | 109 | -- Save the model if it improved. 110 | if minMeanError == nil or errors:mean() < minMeanError then 111 | print("\n(Saving model ...)") 112 | torch.save("data/model.t7", model) 113 | minMeanError = errors:mean() 114 | end 115 | 116 | model.learningRate = model.learningRate + decayFactor 117 | model.learningRate = math.max(options.minLR, model.learningRate) 118 | end 119 | 120 | -- Load testing script 121 | require "eval" 122 | --------------------------------------------------------------------------------