├── .gitignore ├── 163qr.jpeg ├── a.lua ├── a.png ├── b.png ├── c.png ├── cornell_movie_dialogs.lua ├── cornell_movie_dialogs2.lua ├── data ├── .gitkeep ├── fate2.jpeg ├── qq2.jpeg └── qqun.png ├── dataset.lua ├── eval-server.lua ├── eval.lua ├── lstm_text_generation.py ├── material └── stub ├── movie_script_parser.lua ├── neuralconvo.lua ├── readme.md ├── run_server.sh ├── seq2seq.lua ├── tokenizer.lua └── train.lua /.gitignore: -------------------------------------------------------------------------------- 1 | log/* 2 | *.log 3 | 4 | # Compiled Lua sources 5 | luac.out 6 | 7 | # luarocks build files 8 | *.src.rock 9 | *.zip 10 | *.tar.gz 11 | 12 | # Object files 13 | *.o 14 | *.os 15 | *.ko 16 | *.obj 17 | *.elf 18 | 19 | # Precompiled Headers 20 | *.gch 21 | *.pch 22 | 23 | # Libraries 24 | *.lib 25 | *.a 26 | *.la 27 | *.lo 28 | *.def 29 | *.exp 30 | 31 | # Shared objects (inc. Windows DLLs) 32 | *.dll 33 | *.so 34 | *.so.* 35 | *.dylib 36 | 37 | # Executables 38 | *.exe 39 | *.out 40 | *.app 41 | *.i*86 42 | *.x86_64 43 | *.hex 44 | 45 | *.conv 46 | -------------------------------------------------------------------------------- /163qr.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aceimnorstuvwxz/chatbot-zh-torch7/0c0751aed0c163c1a2f4b3226780d8b6b5684712/163qr.jpeg -------------------------------------------------------------------------------- /a.lua: -------------------------------------------------------------------------------- 1 | --- Lexical scanner for creating a sequence of tokens from text. 2 | -- `lexer.scan(s)` returns an iterator over all tokens found in the 3 | -- string `s`. This iterator returns two values, a token type string 4 | -- (such as 'string' for quoted string, 'iden' for identifier) and the value of the 5 | -- token. 6 | -- 7 | -- Versions specialized for Lua and C are available; these also handle block comments 8 | -- and classify keywords as 'keyword' tokens. For example: 9 | -- 10 | -- > s = 'for i=1,n do' 11 | -- > for t,v in lexer.lua(s) do print(t,v) end 12 | -- keyword for 13 | -- iden i 14 | -- = = 15 | -- number 1 16 | -- , , 17 | -- iden n 18 | -- keyword do 19 | -- 20 | -- See the Guide for further @{06-data.md.Lexical_Scanning|discussion} 21 | -- @module pl.lexer 22 | 23 | local yield,wrap = coroutine.yield,coroutine.wrap 24 | local strfind = string.find 25 | local strsub = string.sub 26 | local append = table.insert 27 | 28 | local function assert_arg(idx,val,tp) 29 | if type(val) ~= tp then 30 | error("argument "..idx.." must be "..tp, 2) 31 | end 32 | end 33 | 34 | local lexer = {} 35 | 36 | local NUMBER1 = '^[%+%-]?%d+%.?%d*[eE][%+%-]?%d+' 37 | local NUMBER2 = '^[%+%-]?%d+%.?%d*' 38 | local NUMBER3 = '^0x[%da-fA-F]+' 39 | local NUMBER4 = '^%d+%.?%d*[eE][%+%-]?%d+' 40 | local NUMBER5 = '^%d+%.?%d*' 41 | local IDEN = '^[%a_][%w_]*' 42 | local WSPACE = '^%s+' 43 | local STRING1 = "^(['\"])%1" -- empty string 44 | local STRING2 = [[^(['"])(\*)%2%1]] 45 | local STRING3 = [[^(['"]).-[^\](\*)%2%1]] 46 | local CHAR1 = "^''" 47 | local CHAR2 = [[^'(\*)%1']] 48 | local CHAR3 = [[^'.-[^\](\*)%1']] 49 | local PREPRO = '^#.-[^\\]\n' 50 | 51 | local plain_matches,lua_matches,cpp_matches,lua_keyword,cpp_keyword 52 | 53 | local function tdump(tok) 54 | return yield(tok,tok) 55 | end 56 | 57 | local function ndump(tok,options) 58 | if options and options.number then 59 | tok = tonumber(tok) 60 | end 61 | return yield("number",tok) 62 | end 63 | 64 | -- regular strings, single or double quotes; usually we want them 65 | -- without the quotes 66 | local function sdump(tok,options) 67 | if options and options.string then 68 | tok = tok:sub(2,-2) 69 | end 70 | return yield("string",tok) 71 | end 72 | 73 | -- long Lua strings need extra work to get rid of the quotes 74 | local function sdump_l(tok,options,findres) 75 | if options and options.string then 76 | local quotelen = 3 77 | if findres[3] then 78 | quotelen = quotelen + findres[3]:len() 79 | end 80 | tok = tok:sub(quotelen, -quotelen) 81 | if tok:sub(1, 1) == "\n" then 82 | tok = tok:sub(2) 83 | end 84 | end 85 | return yield("string",tok) 86 | end 87 | 88 | local function chdump(tok,options) 89 | if options and options.string then 90 | tok = tok:sub(2,-2) 91 | end 92 | return yield("char",tok) 93 | end 94 | 95 | local function cdump(tok) 96 | return yield('comment',tok) 97 | end 98 | 99 | local function wsdump (tok) 100 | return yield("space",tok) 101 | end 102 | 103 | local function pdump (tok) 104 | return yield('prepro',tok) 105 | end 106 | 107 | local function plain_vdump(tok) 108 | return yield("iden",tok) 109 | end 110 | 111 | local function lua_vdump(tok) 112 | if lua_keyword[tok] then 113 | return yield("keyword",tok) 114 | else 115 | return yield("iden",tok) 116 | end 117 | end 118 | 119 | local function cpp_vdump(tok) 120 | if cpp_keyword[tok] then 121 | return yield("keyword",tok) 122 | else 123 | return yield("iden",tok) 124 | end 125 | end 126 | 127 | --- create a plain token iterator from a string or file-like object. 128 | -- @tparam string|file s a string or a file-like object with `:read()` method returning lines. 129 | -- @tab matches an optional match table - array of token descriptions. 130 | -- A token is described by a `{pattern, action}` pair, where `pattern` should match 131 | -- token body and `action` is a function called when a token of described type is found. 132 | -- @tab[opt] filter a table of token types to exclude, by default `{space=true}` 133 | -- @tab[opt] options a table of options; by default, `{number=true,string=true}`, 134 | -- which means convert numbers and strip string quotes. 135 | function lexer.scan(s,matches,filter,options) 136 | local file = type(s) ~= 'string' and s 137 | filter = filter or {space=true} 138 | options = options or {number=true,string=true} 139 | if filter then 140 | if filter.space then filter[wsdump] = true end 141 | if filter.comments then 142 | filter[cdump] = true 143 | end 144 | end 145 | if not matches then 146 | if not plain_matches then 147 | plain_matches = { 148 | {WSPACE,wsdump}, 149 | {NUMBER3,ndump}, 150 | {IDEN,plain_vdump}, 151 | {NUMBER1,ndump}, 152 | {NUMBER2,ndump}, 153 | {STRING1,sdump}, 154 | {STRING2,sdump}, 155 | {STRING3,sdump}, 156 | {'^.',tdump} 157 | } 158 | end 159 | matches = plain_matches 160 | end 161 | local function lex() 162 | local line_nr = 0 163 | local next_line = file and file:read() 164 | local sz = file and 0 or #s 165 | local idx = 1 166 | 167 | while true do 168 | if idx > sz then 169 | if file then 170 | if not next_line then return end 171 | s = next_line 172 | line_nr = line_nr + 1 173 | next_line = file:read() 174 | if next_line then 175 | s = s .. '\n' 176 | end 177 | idx, sz = 1, #s 178 | else 179 | return 180 | end 181 | end 182 | 183 | for _,m in ipairs(matches) do 184 | local pat = m[1] 185 | local fun = m[2] 186 | local findres = {strfind(s,pat,idx)} 187 | local i1, i2 = findres[1], findres[2] 188 | if i1 then 189 | local tok = strsub(s,i1,i2) 190 | idx = i2 + 1 191 | local res 192 | if not (filter and filter[fun]) then 193 | lexer.finished = idx > sz 194 | res = fun(tok, options, findres) 195 | end 196 | if res then 197 | local tp = type(res) 198 | -- insert a token list 199 | if tp == 'table' then 200 | yield('','') 201 | for _,t in ipairs(res) do 202 | yield(t[1],t[2]) 203 | end 204 | elseif tp == 'string' then -- or search up to some special pattern 205 | i1,i2 = strfind(s,res,idx) 206 | if i1 then 207 | tok = strsub(s,i1,i2) 208 | idx = i2 + 1 209 | yield('',tok) 210 | else 211 | yield('','') 212 | idx = sz + 1 213 | end 214 | else 215 | yield(line_nr,idx) 216 | end 217 | end 218 | 219 | break 220 | end 221 | end 222 | end 223 | end 224 | return wrap(lex) 225 | end 226 | 227 | local function isstring (s) 228 | return type(s) == 'string' 229 | end 230 | 231 | --- insert tokens into a stream. 232 | -- @param tok a token stream 233 | -- @param a1 a string is the type, a table is a token list and 234 | -- a function is assumed to be a token-like iterator (returns type & value) 235 | -- @string a2 a string is the value 236 | function lexer.insert (tok,a1,a2) 237 | if not a1 then return end 238 | local ts 239 | if isstring(a1) and isstring(a2) then 240 | ts = {{a1,a2}} 241 | elseif type(a1) == 'function' then 242 | ts = {} 243 | for t,v in a1() do 244 | append(ts,{t,v}) 245 | end 246 | else 247 | ts = a1 248 | end 249 | tok(ts) 250 | end 251 | 252 | --- get everything in a stream upto a newline. 253 | -- @param tok a token stream 254 | -- @return a string 255 | function lexer.getline (tok) 256 | local t,v = tok('.-\n') 257 | return v 258 | end 259 | 260 | --- get current line number. 261 | -- Only available if the input source is a file-like object. 262 | -- @param tok a token stream 263 | -- @return the line number and current column 264 | function lexer.lineno (tok) 265 | return tok(0) 266 | end 267 | 268 | --- get the rest of the stream. 269 | -- @param tok a token stream 270 | -- @return a string 271 | function lexer.getrest (tok) 272 | local t,v = tok('.+') 273 | return v 274 | end 275 | 276 | --- get the Lua keywords as a set-like table. 277 | -- So `res["and"]` etc would be `true`. 278 | -- @return a table 279 | function lexer.get_keywords () 280 | if not lua_keyword then 281 | lua_keyword = { 282 | ["and"] = true, ["break"] = true, ["do"] = true, 283 | ["else"] = true, ["elseif"] = true, ["end"] = true, 284 | ["false"] = true, ["for"] = true, ["function"] = true, 285 | ["if"] = true, ["in"] = true, ["local"] = true, ["nil"] = true, 286 | ["not"] = true, ["or"] = true, ["repeat"] = true, 287 | ["return"] = true, ["then"] = true, ["true"] = true, 288 | ["until"] = true, ["while"] = true 289 | } 290 | end 291 | return lua_keyword 292 | end 293 | 294 | --- create a Lua token iterator from a string or file-like object. 295 | -- Will return the token type and value. 296 | -- @string s the string 297 | -- @tab[opt] filter a table of token types to exclude, by default `{space=true,comments=true}` 298 | -- @tab[opt] options a table of options; by default, `{number=true,string=true}`, 299 | -- which means convert numbers and strip string quotes. 300 | function lexer.lua(s,filter,options) 301 | filter = filter or {space=true,comments=true} 302 | lexer.get_keywords() 303 | if not lua_matches then 304 | lua_matches = { 305 | {WSPACE,wsdump}, 306 | {NUMBER3,ndump}, 307 | {IDEN,lua_vdump}, 308 | {NUMBER4,ndump}, 309 | {NUMBER5,ndump}, 310 | {STRING1,sdump}, 311 | {STRING2,sdump}, 312 | {STRING3,sdump}, 313 | {'^%-%-%[(=*)%[.-%]%1%]',cdump}, 314 | {'^%-%-.-\n',cdump}, 315 | {'^%[(=*)%[.-%]%1%]',sdump_l}, 316 | {'^==',tdump}, 317 | {'^~=',tdump}, 318 | {'^<=',tdump}, 319 | {'^>=',tdump}, 320 | {'^%.%.%.',tdump}, 321 | {'^%.%.',tdump}, 322 | {'^.',tdump} 323 | } 324 | end 325 | return lexer.scan(s,lua_matches,filter,options) 326 | end 327 | 328 | --- create a C/C++ token iterator from a string or file-like object. 329 | -- Will return the token type type and value. 330 | -- @string s the string 331 | -- @tab[opt] filter a table of token types to exclude, by default `{space=true,comments=true}` 332 | -- @tab[opt] options a table of options; by default, `{number=true,string=true}`, 333 | -- which means convert numbers and strip string quotes. 334 | function lexer.cpp(s,filter,options) 335 | filter = filter or {space=true,comments=true} 336 | if not cpp_keyword then 337 | cpp_keyword = { 338 | ["class"] = true, ["break"] = true, ["do"] = true, ["sizeof"] = true, 339 | ["else"] = true, ["continue"] = true, ["struct"] = true, 340 | ["false"] = true, ["for"] = true, ["public"] = true, ["void"] = true, 341 | ["private"] = true, ["protected"] = true, ["goto"] = true, 342 | ["if"] = true, ["static"] = true, ["const"] = true, ["typedef"] = true, 343 | ["enum"] = true, ["char"] = true, ["int"] = true, ["bool"] = true, 344 | ["long"] = true, ["float"] = true, ["true"] = true, ["delete"] = true, 345 | ["double"] = true, ["while"] = true, ["new"] = true, 346 | ["namespace"] = true, ["try"] = true, ["catch"] = true, 347 | ["switch"] = true, ["case"] = true, ["extern"] = true, 348 | ["return"] = true,["default"] = true,['unsigned'] = true,['signed'] = true, 349 | ["union"] = true, ["volatile"] = true, ["register"] = true,["short"] = true, 350 | } 351 | end 352 | if not cpp_matches then 353 | cpp_matches = { 354 | {WSPACE,wsdump}, 355 | {PREPRO,pdump}, 356 | {NUMBER3,ndump}, 357 | {IDEN,cpp_vdump}, 358 | {NUMBER4,ndump}, 359 | {NUMBER5,ndump}, 360 | {CHAR1,chdump}, 361 | {CHAR2,chdump}, 362 | {CHAR3,chdump}, 363 | {STRING1,sdump}, 364 | {STRING2,sdump}, 365 | {STRING3,sdump}, 366 | {'^//.-\n',cdump}, 367 | {'^/%*.-%*/',cdump}, 368 | {'^==',tdump}, 369 | {'^!=',tdump}, 370 | {'^<=',tdump}, 371 | {'^>=',tdump}, 372 | {'^->',tdump}, 373 | {'^&&',tdump}, 374 | {'^||',tdump}, 375 | {'^%+%+',tdump}, 376 | {'^%-%-',tdump}, 377 | {'^%+=',tdump}, 378 | {'^%-=',tdump}, 379 | {'^%*=',tdump}, 380 | {'^/=',tdump}, 381 | {'^|=',tdump}, 382 | {'^%^=',tdump}, 383 | {'^::',tdump}, 384 | {'^.',tdump} 385 | } 386 | end 387 | return lexer.scan(s,cpp_matches,filter,options) 388 | end 389 | 390 | --- get a list of parameters separated by a delimiter from a stream. 391 | -- @param tok the token stream 392 | -- @string[opt=')'] endtoken end of list. Can be '\n' 393 | -- @string[opt=','] delim separator 394 | -- @return a list of token lists. 395 | function lexer.get_separated_list(tok,endtoken,delim) 396 | endtoken = endtoken or ')' 397 | delim = delim or ',' 398 | local parm_values = {} 399 | local level = 1 -- used to count ( and ) 400 | local tl = {} 401 | local function tappend (tl,t,val) 402 | val = val or t 403 | append(tl,{t,val}) 404 | end 405 | local is_end 406 | if endtoken == '\n' then 407 | is_end = function(t,val) 408 | return t == 'space' and val:find '\n' 409 | end 410 | else 411 | is_end = function (t) 412 | return t == endtoken 413 | end 414 | end 415 | local token,value 416 | while true do 417 | token,value=tok() 418 | if not token then return nil,'EOS' end -- end of stream is an error! 419 | if is_end(token,value) and level == 1 then 420 | append(parm_values,tl) 421 | break 422 | elseif token == '(' then 423 | level = level + 1 424 | tappend(tl,'(') 425 | elseif token == ')' then 426 | level = level - 1 427 | if level == 0 then -- finished with parm list 428 | append(parm_values,tl) 429 | break 430 | else 431 | tappend(tl,')') 432 | end 433 | elseif token == delim and level == 1 then 434 | append(parm_values,tl) -- a new parm 435 | tl = {} 436 | else 437 | tappend(tl,token,value) 438 | end 439 | end 440 | return parm_values,{token,value} 441 | end 442 | 443 | --- get the next non-space token from the stream. 444 | -- @param tok the token stream. 445 | function lexer.skipws (tok) 446 | local t,v = tok() 447 | while t == 'space' do 448 | t,v = tok() 449 | end 450 | return t,v 451 | end 452 | 453 | local skipws = lexer.skipws 454 | 455 | --- get the next token, which must be of the expected type. 456 | -- Throws an error if this type does not match! 457 | -- @param tok the token stream 458 | -- @string expected_type the token type 459 | -- @bool no_skip_ws whether we should skip whitespace 460 | function lexer.expecting (tok,expected_type,no_skip_ws) 461 | assert_arg(1,tok,'function') 462 | assert_arg(2,expected_type,'string') 463 | local t,v 464 | if no_skip_ws then 465 | t,v = tok() 466 | else 467 | t,v = skipws(tok) 468 | end 469 | if t ~= expected_type then error ("expecting "..expected_type,2) end 470 | return v 471 | end 472 | 473 | return lexer 474 | -------------------------------------------------------------------------------- /a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aceimnorstuvwxz/chatbot-zh-torch7/0c0751aed0c163c1a2f4b3226780d8b6b5684712/a.png -------------------------------------------------------------------------------- /b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aceimnorstuvwxz/chatbot-zh-torch7/0c0751aed0c163c1a2f4b3226780d8b6b5684712/b.png -------------------------------------------------------------------------------- /c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aceimnorstuvwxz/chatbot-zh-torch7/0c0751aed0c163c1a2f4b3226780d8b6b5684712/c.png -------------------------------------------------------------------------------- /cornell_movie_dialogs.lua: -------------------------------------------------------------------------------- 1 | local CornellMovieDialogs = torch.class("neuralconvo.CornellMovieDialogs") 2 | local stringx = require "pl.stringx" 3 | local xlua = require "xlua" 4 | 5 | function CornellMovieDialogs:__init(dir) 6 | self.dir = dir 7 | end 8 | 9 | 10 | function CornellMovieDialogs:load() 11 | local lines = {} 12 | local conversations = {} 13 | local count = 1 14 | 15 | print("-- Parsing Cornell movie dialogs data set ...") 16 | 17 | 18 | local f = assert(io.open('../xiaohuangji50w_fenciA.conv', 'r')) 19 | 20 | while true do 21 | local line = f:read("*line") 22 | if line == nil then 23 | f:close() 24 | break 25 | end 26 | 27 | lines[count] = line 28 | count = count + 1 29 | end 30 | 31 | print("Total lines = "..count) 32 | local tmpconv = nil 33 | 34 | local TOTAL = #lines 35 | local count = 0 36 | 37 | for i, line in ipairs(lines) do 38 | --print(i..' '..line) 39 | if string.sub(line, 0, 1) == "E" then 40 | 41 | if tmpconv ~= nil then 42 | --print('new conv'..#tmpconv) 43 | table.insert(conversations, tmpconv) 44 | end 45 | --print('e make the tmpconv') 46 | tmpconv = {} 47 | 48 | end 49 | 50 | if string.sub(line, 0, 1) == "M" then 51 | --print('insert into conv') 52 | local tmpl = string.sub(line, 3, #line) 53 | --print(tmpl) 54 | table.insert(tmpconv, tmpl) 55 | end 56 | 57 | count = count + 1 58 | if count%1000 == 0 then 59 | xlua.progress(count, TOTAL) 60 | end 61 | end 62 | 63 | return conversations 64 | end 65 | -------------------------------------------------------------------------------- /cornell_movie_dialogs2.lua: -------------------------------------------------------------------------------- 1 | local CornellMovieDialogs = torch.class("neuralconvo.CornellMovieDialogs") 2 | local stringx = require "pl.stringx" 3 | local xlua = require "xlua" 4 | 5 | local function parsedLines(file, fields) 6 | local f = assert(io.open(file, 'r')) 7 | 8 | return function() 9 | local line = f:read("*line") 10 | 11 | if line == nil then 12 | f:close() 13 | return 14 | end 15 | 16 | local values = stringx.split(line, " +++$+++ ") 17 | local t = {} 18 | 19 | for i,field in ipairs(fields) do 20 | t[field] = values[i] 21 | end 22 | 23 | return t 24 | end 25 | end 26 | 27 | function CornellMovieDialogs:__init(dir) 28 | self.dir = dir 29 | end 30 | 31 | local MOVIE_LINES_FIELDS = {"lineID","characterID","movieID","character","text"} 32 | local MOVIE_CONVERSATIONS_FIELDS = {"character1ID","character2ID","movieID","utteranceIDs"} 33 | local TOTAL_LINES = 387810 34 | 35 | local function progress(c) 36 | if c % 10000 == 0 then 37 | xlua.progress(c, TOTAL_LINES) 38 | end 39 | end 40 | 41 | function CornellMovieDialogs:load() 42 | local lines = {} 43 | local conversations = {} 44 | local count = 0 45 | 46 | print("-- Parsing Cornell movie dialogs data set ...") 47 | 48 | for line in parsedLines(self.dir .. "/movie_lines.txt", MOVIE_LINES_FIELDS) do 49 | lines[line.lineID] = line 50 | line.lineID = nil 51 | -- Remove unused fields 52 | line.characterID = nil 53 | line.movieID = nil 54 | count = count + 1 55 | progress(count) 56 | end 57 | 58 | for conv in parsedLines(self.dir .. "/movie_conversations.txt", MOVIE_CONVERSATIONS_FIELDS) do 59 | local conversation = {} 60 | local lineIDs = stringx.split(conv.utteranceIDs:sub(3, -3), "', '") 61 | for i,lineID in ipairs(lineIDs) do 62 | table.insert(conversation, lines[lineID]) 63 | end 64 | table.insert(conversations, conversation) 65 | count = count + 1 66 | progress(count) 67 | end 68 | 69 | xlua.progress(TOTAL_LINES, TOTAL_LINES) 70 | 71 | return conversations 72 | end 73 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aceimnorstuvwxz/chatbot-zh-torch7/0c0751aed0c163c1a2f4b3226780d8b6b5684712/data/.gitkeep -------------------------------------------------------------------------------- /data/fate2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aceimnorstuvwxz/chatbot-zh-torch7/0c0751aed0c163c1a2f4b3226780d8b6b5684712/data/fate2.jpeg -------------------------------------------------------------------------------- /data/qq2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aceimnorstuvwxz/chatbot-zh-torch7/0c0751aed0c163c1a2f4b3226780d8b6b5684712/data/qq2.jpeg -------------------------------------------------------------------------------- /data/qqun.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aceimnorstuvwxz/chatbot-zh-torch7/0c0751aed0c163c1a2f4b3226780d8b6b5684712/data/qqun.png -------------------------------------------------------------------------------- /dataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Format movie dialog data as a table of line 1: 3 | 4 | { {word_ids of character1}, {word_ids of character2} } 5 | 6 | Then flips it around and get the dialog from the other character's perspective: 7 | 8 | { {word_ids of character2}, {word_ids of character1} } 9 | 10 | Also builds the vocabulary. 11 | ]]-- 12 | 13 | local DataSet = torch.class("neuralconvo.DataSet") 14 | local xlua = require "xlua" 15 | local tokenizer = require "tokenizer" 16 | local list = require "pl.List" 17 | 18 | function DataSet:__init(loader, options) 19 | options = options or {} 20 | 21 | self.examplesFilename = "data/examples.t7" 22 | 23 | -- Discard words with lower frequency then this 24 | self.minWordFreq = options.minWordFreq or 1 25 | 26 | -- Maximum number of words in an example sentence 27 | self.maxExampleLen = options.maxExampleLen or 25 28 | 29 | -- Load only first fews examples (approximately) 30 | self.loadFirst = options.loadFirst 31 | 32 | self.examples = {} 33 | self.word2id = {} 34 | self.id2word = {} 35 | self.wordsCount = 0 36 | 37 | self:load(loader) 38 | end 39 | 40 | function DataSet:load(loader) 41 | local filename = "data/vocab.t7" 42 | 43 | if path.exists(filename) then 44 | --if false then 45 | print("Loading vocabulary from " .. filename .. " ...") 46 | local data = torch.load(filename) 47 | self.word2id = data.word2id 48 | self.id2word = data.id2word 49 | self.wordsCount = data.wordsCount 50 | self.goToken = data.goToken 51 | self.eosToken = data.eosToken 52 | self.unknownToken = data.unknownToken 53 | self.examplesCount = data.examplesCount 54 | --print(self.word2id) 55 | else 56 | print("" .. filename .. " not found") 57 | self:visit(loader:load()) 58 | print("Writing " .. filename .. " ...") 59 | torch.save(filename, { 60 | word2id = self.word2id, 61 | id2word = self.id2word, 62 | wordsCount = self.wordsCount, 63 | goToken = self.goToken, 64 | eosToken = self.eosToken, 65 | unknownToken = self.unknownToken, 66 | examplesCount = self.examplesCount 67 | }) 68 | end 69 | end 70 | 71 | function DataSet:visit(conversations) 72 | -- Table for keeping track of word frequency 73 | self.wordFreq = {} 74 | self.examples = {} 75 | 76 | -- Add magic tokens 77 | self.goToken = self:makeWordId("") -- Start of sequence 78 | self.eosToken = self:makeWordId("") -- End of sequence 79 | self.unknownToken = self:makeWordId("") -- Word dropped from vocabulary 80 | 81 | print("-- Pre-processing data") 82 | 83 | local total = self.loadFirst or #conversations * 2 84 | 85 | for i, conversation in ipairs(conversations) do 86 | --print(i) 87 | if i > total then break end 88 | self:visitConversation(conversation) 89 | xlua.progress(i, total) 90 | end 91 | 92 | -- Revisit from the perspective of 2nd character 93 | for i, conversation in ipairs(conversations) do 94 | --print(i) 95 | if #conversations + i > total then break end 96 | self:visitConversation(conversation, 2) 97 | xlua.progress(#conversations + i, total) 98 | end 99 | 100 | print("-- Removing low frequency words") 101 | print("sfsgsdfgdf") 102 | 103 | for i, datum in ipairs(self.examples) do 104 | self:removeLowFreqWords(datum[1]) 105 | self:removeLowFreqWords(datum[2]) 106 | xlua.progress(i, #self.examples) 107 | end 108 | 109 | self.wordFreq = nil 110 | 111 | self.examplesCount = #self.examples 112 | self:writeExamplesToFile() 113 | self.examples = nil 114 | 115 | collectgarbage() 116 | end 117 | 118 | function DataSet:writeExamplesToFile() 119 | print("Writing " .. self.examplesFilename .. " ...") 120 | local file = torch.DiskFile(self.examplesFilename, "w") 121 | 122 | for i, example in ipairs(self.examples) do 123 | file:writeObject(example) 124 | xlua.progress(i, #self.examples) 125 | end 126 | 127 | file:close() 128 | end 129 | 130 | function DataSet:batches(size) 131 | local file = torch.DiskFile(self.examplesFilename, "r") 132 | file:quiet() 133 | local done = false 134 | 135 | return function() 136 | if done then 137 | return 138 | end 139 | 140 | local inputSeqs,targetSeqs = {},{} 141 | local maxInputSeqLen,maxTargetOutputSeqLen = 0,0 142 | 143 | for i = 1, size do 144 | local example = file:readObject() 145 | if example == nil then 146 | done = true 147 | file:close() 148 | return examples 149 | end 150 | inputSeq,targetSeq = unpack(example) 151 | if inputSeq:size(1) > maxInputSeqLen then 152 | maxInputSeqLen = inputSeq:size(1) 153 | end 154 | if targetSeq:size(1) > maxTargetOutputSeqLen then 155 | maxTargetOutputSeqLen = targetSeq:size(1) 156 | end 157 | table.insert(inputSeqs, inputSeq) 158 | table.insert(targetSeqs, targetSeq) 159 | end 160 | 161 | local encoderInputs,decoderInputs,decoderTargets = nil,nil,nil 162 | if size == 1 then 163 | encoderInputs = torch.IntTensor(maxInputSeqLen):fill(0) 164 | decoderInputs = torch.IntTensor(maxTargetOutputSeqLen-1):fill(0) 165 | decoderTargets = torch.IntTensor(maxTargetOutputSeqLen-1):fill(0) 166 | else 167 | encoderInputs = torch.IntTensor(maxInputSeqLen,size):fill(0) 168 | decoderInputs = torch.IntTensor(maxTargetOutputSeqLen-1,size):fill(0) 169 | decoderTargets = torch.IntTensor(maxTargetOutputSeqLen-1,size):fill(0) 170 | end 171 | 172 | for samplenb = 1, #inputSeqs do 173 | for word = 1,inputSeqs[samplenb]:size(1) do 174 | eosOffset = maxInputSeqLen - inputSeqs[samplenb]:size(1) -- for left padding 175 | if size == 1 then 176 | encoderInputs[word] = inputSeqs[samplenb][word] 177 | else 178 | encoderInputs[word+eosOffset][samplenb] = inputSeqs[samplenb][word] 179 | end 180 | end 181 | end 182 | 183 | for samplenb = 1, #targetSeqs do 184 | trimmedEosToken = targetSeqs[samplenb]:sub(1,-2) 185 | for word = 1, trimmedEosToken:size(1) do 186 | if size == 1 then 187 | decoderInputs[word] = trimmedEosToken[word] 188 | else 189 | decoderInputs[word][samplenb] = trimmedEosToken[word] 190 | end 191 | end 192 | end 193 | 194 | for samplenb = 1, #targetSeqs do 195 | trimmedGoToken = targetSeqs[samplenb]:sub(2,-1) 196 | for word = 1, trimmedGoToken:size(1) do 197 | if size == 1 then 198 | decoderTargets[word] = trimmedGoToken[word] 199 | else 200 | decoderTargets[word][samplenb] = trimmedGoToken[word] 201 | end 202 | end 203 | end 204 | 205 | return encoderInputs,decoderInputs,decoderTargets 206 | end 207 | end 208 | 209 | function DataSet:removeLowFreqWords(input) 210 | for i = 1, input:size(1) do 211 | local id = input[i] 212 | local word = self.id2word[id] 213 | 214 | if word == nil then 215 | -- Already removed 216 | input[i] = self.unknownToken 217 | 218 | elseif self.wordFreq[word] < self.minWordFreq then 219 | input[i] = self.unknownToken 220 | 221 | self.word2id[word] = nil 222 | self.id2word[id] = nil 223 | self.wordsCount = self.wordsCount - 1 224 | end 225 | end 226 | end 227 | 228 | function DataSet:visitConversation(lines, start) 229 | start = start or 1 230 | 231 | --print("conv lines "..#lines) 232 | 233 | for i = start, #lines, 2 do 234 | local input = lines[i] 235 | local target = lines[i+1] 236 | 237 | if target then 238 | local inputIds = self:visitText(input) 239 | local targetIds = self:visitText(target, 2) 240 | 241 | if inputIds and targetIds then 242 | -- Revert inputs 243 | inputIds = list.reverse(inputIds) 244 | 245 | table.insert(targetIds, 1, self.goToken) 246 | table.insert(targetIds, self.eosToken) 247 | 248 | table.insert(self.examples, { torch.IntTensor(inputIds), torch.IntTensor(targetIds) }) 249 | end 250 | end 251 | end 252 | end 253 | 254 | function DataSet:visitText(text, additionalTokens) 255 | local words = {} 256 | additionalTokens = additionalTokens or 0 257 | 258 | if text == "" or text == nil then 259 | print "zero text" 260 | return 261 | end 262 | 263 | --print(text) 264 | local values = stringx.split(text, "/") 265 | for i, word in ipairs(values) do 266 | --print("spword:"..word) 267 | table.insert(words, self:makeWordId(word)) 268 | if #words >= self.maxExampleLen - additionalTokens then 269 | break 270 | end 271 | end 272 | 273 | --[[ 274 | for t, word in tokenizer.tokenize(text) do 275 | print(word) 276 | table.insert(words, self:makeWordId(word)) 277 | -- Only keep the first sentence 278 | if t == "endpunct" or #words >= self.maxExampleLen - additionalTokens then 279 | break 280 | end 281 | end 282 | ]]-- 283 | if #words == 0 then 284 | return 285 | end 286 | 287 | return words 288 | end 289 | 290 | function DataSet:makeWordId(word) 291 | --word = word:lower() 292 | --print(word) 293 | local id = self.word2id[word] 294 | 295 | if id then 296 | self.wordFreq[word] = self.wordFreq[word] + 1 297 | --print("more freq > 1") 298 | else 299 | --print("to dict word = "..word) 300 | self.wordsCount = self.wordsCount + 1 301 | id = self.wordsCount 302 | self.id2word[id] = word 303 | self.word2id[word] = id 304 | self.wordFreq[word] = 1 305 | end 306 | 307 | return id 308 | end 309 | -------------------------------------------------------------------------------- /eval-server.lua: -------------------------------------------------------------------------------- 1 | require 'neuralconvo' 2 | local tokenizer = require "tokenizer" 3 | local list = require "pl.List" 4 | local options = {} 5 | 6 | if dataset == nil then 7 | cmd = torch.CmdLine() 8 | cmd:text('Options:') 9 | cmd:option('--cuda', false, 'use CUDA. Training must be done on CUDA') 10 | cmd:option('--opencl', false, 'use OpenCL. Training must be done on OpenCL') 11 | cmd:option('--debug', false, 'show debug info') 12 | cmd:text() 13 | options = cmd:parse(arg) 14 | 15 | -- Data 16 | dataset = neuralconvo.DataSet() 17 | 18 | -- Enabled CUDA 19 | if options.cuda then 20 | require 'cutorch' 21 | require 'cunn' 22 | elseif options.opencl then 23 | require 'cltorch' 24 | require 'clnn' 25 | end 26 | end 27 | 28 | if model == nil then 29 | print("-- Loading model") 30 | model = torch.load("data/model.t7") 31 | end 32 | 33 | -- Word IDs to sentence 34 | function pred2sent(wordIds, i) 35 | local words = {} 36 | i = i or 1 37 | 38 | for _, wordId in ipairs(wordIds) do 39 | local word = dataset.id2word[wordId[i]] 40 | --print(wordId[i]..word) 41 | table.insert(words, word) 42 | end 43 | 44 | return tokenizer.join(words) 45 | end 46 | 47 | function printProbabilityTable(wordIds, probabilities, num) 48 | print(string.rep("-", num * 22)) 49 | 50 | for p, wordId in ipairs(wordIds) do 51 | local line = "| " 52 | for i = 1, num do 53 | local word = dataset.id2word[wordId[i]] 54 | line = line .. string.format("%-10s(%4d%%)", word, probabilities[p][i] * 100) .. " | " 55 | end 56 | print(line) 57 | end 58 | 59 | print(string.rep("-", num * 22)) 60 | end 61 | 62 | function say(text) 63 | local wordIds = {} 64 | 65 | 66 | 67 | --print(text) 68 | local values = {} 69 | for w in text:gmatch("[\33-\127\192-\255]+[\128-\191]*") do 70 | table.insert(values, w) 71 | end 72 | 73 | for i, word in ipairs(values) do 74 | local id = dataset.word2id[word] or dataset.unknownToken 75 | --print(i.." "..word.." "..id) 76 | 77 | table.insert(wordIds, id) 78 | 79 | end 80 | 81 | --[[ 82 | for t, word in tokenizer.tokenize(text) do 83 | local id = dataset.word2id[word:lower()] or dataset.unknownToken 84 | table.insert(wordIds, id) 85 | end 86 | ]]-- 87 | 88 | local input = torch.Tensor(list.reverse(wordIds)) 89 | local wordIds, probabilities = model:eval(input) 90 | 91 | local ret = pred2sent(wordIds) 92 | print(">> " .. ret) 93 | 94 | if options.debug then 95 | printProbabilityTable(wordIds, probabilities, 4) 96 | end 97 | 98 | return ret 99 | 100 | end 101 | 102 | 103 | --[[ http server using ASyNC]]-- 104 | 105 | function unescape (s) 106 | s = string.gsub(s, "+", " ") 107 | s = string.gsub(s, "%%(%x%x)", function (h) 108 | return string.char(tonumber(h, 16)) 109 | end) 110 | return s 111 | end 112 | 113 | 114 | local async = require 'async' 115 | require('pl.text').format_operator() 116 | 117 | async.http.listen('http://0.0.0.0:8082/', function(req,res) 118 | print('request:',req) 119 | local resp 120 | 121 | if req.url.path == '/' and req.url.query ~= nil and #req.url.query > 0 then 122 | 123 | local text_in = unescape(req.url.query) 124 | print(text_in) 125 | local ret = say(text_in) 126 | resp = [[${data}]] % {data = ret} 127 | 128 | else 129 | resp = 'Oops~ This is a wrong place, please goto here!' 130 | 131 | end 132 | 133 | -- if req.url.path == '/test' then 134 | -- resp = [[ 135 | --

You requested route /test

136 | -- ]] 137 | -- else 138 | -- -- Produce a random story: 139 | -- resp = [[ 140 | --

From my server

141 | --

It's working!

142 | --

Randomly generated number: ${number}

143 | --

A variable in the global scope: ${ret}

144 | -- ]] % { 145 | -- number = math.random(), 146 | -- ret = ret 147 | -- } 148 | -- end 149 | 150 | res(resp, {['Content-Type']='text/html; charset=UTF-8'}) 151 | end) 152 | 153 | print('server listening to port 8082') 154 | 155 | async.go() -------------------------------------------------------------------------------- /eval.lua: -------------------------------------------------------------------------------- 1 | require 'neuralconvo' 2 | local tokenizer = require "tokenizer" 3 | local list = require "pl.List" 4 | local options = {} 5 | 6 | if dataset == nil then 7 | cmd = torch.CmdLine() 8 | cmd:text('Options:') 9 | cmd:option('--cuda', false, 'use CUDA. Training must be done on CUDA') 10 | cmd:option('--opencl', false, 'use OpenCL. Training must be done on OpenCL') 11 | cmd:option('--debug', false, 'show debug info') 12 | cmd:text() 13 | options = cmd:parse(arg) 14 | 15 | -- Data 16 | dataset = neuralconvo.DataSet() 17 | 18 | -- Enabled CUDA 19 | if options.cuda then 20 | require 'cutorch' 21 | require 'cunn' 22 | elseif options.opencl then 23 | require 'cltorch' 24 | require 'clnn' 25 | end 26 | end 27 | 28 | if model == nil then 29 | print("-- Loading model") 30 | model = torch.load("data/model.t7") 31 | end 32 | 33 | -- Word IDs to sentence 34 | function pred2sent(wordIds, i) 35 | local words = {} 36 | i = i or 1 37 | 38 | for _, wordId in ipairs(wordIds) do 39 | local word = dataset.id2word[wordId[i]] 40 | --print(wordId[i]..word) 41 | table.insert(words, word) 42 | end 43 | 44 | return tokenizer.join(words) 45 | end 46 | 47 | function printProbabilityTable(wordIds, probabilities, num) 48 | print(string.rep("-", num * 22)) 49 | 50 | for p, wordId in ipairs(wordIds) do 51 | local line = "| " 52 | for i = 1, num do 53 | local word = dataset.id2word[wordId[i]] 54 | line = line .. string.format("%-10s(%4d%%)", word, probabilities[p][i] * 100) .. " | " 55 | end 56 | print(line) 57 | end 58 | 59 | print(string.rep("-", num * 22)) 60 | end 61 | 62 | 63 | function say(text) 64 | local wordIds = {} 65 | 66 | 67 | 68 | --print(text) 69 | local values = {} 70 | for w in text:gmatch("[\33-\127\192-\255]+[\128-\191]*") do 71 | table.insert(values, w) 72 | end 73 | -- print(values) 74 | for i, word in ipairs(values) do 75 | local id = dataset.word2id[word] or dataset.unknownToken 76 | -- print(i.." "..word.." "..id) 77 | 78 | table.insert(wordIds, id) 79 | 80 | end 81 | 82 | --[[ 83 | for t, word in tokenizer.tokenize(text) do 84 | local id = dataset.word2id[word:lower()] or dataset.unknownToken 85 | table.insert(wordIds, id) 86 | end 87 | ]]-- 88 | 89 | local input = torch.Tensor(list.reverse(wordIds)) 90 | local wordIds, probabilities = model:eval(input) 91 | 92 | 93 | print(">> " .. pred2sent(wordIds)) 94 | 95 | if options.debug then 96 | printProbabilityTable(wordIds, probabilities, 4) 97 | end 98 | end 99 | -------------------------------------------------------------------------------- /lstm_text_generation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | '''(C) 2016 rust 6 | 7 | Example script to generate text from Nietzsche's writings. 8 | 9 | At least 20 epochs are required before the generated text 10 | starts sounding coherent. 11 | 12 | It is recommended to run this script on GPU, as recurrent 13 | networks are quite computationally intensive. 14 | 15 | If you try this script on new data, make sure your corpus 16 | has at least ~100k characters. ~1M is better. 17 | 18 | 19 | 20 | ''' 21 | 22 | from __future__ import print_function 23 | from keras.models import Sequential 24 | from keras.layers import Dense, Activation, Dropout 25 | from keras.layers import LSTM 26 | from keras.utils.data_utils import get_file 27 | import numpy as np 28 | import random 29 | import sys 30 | 31 | 32 | #pandas 安装出错,所以直接从csv中解析 33 | text_lines = open('000001.csv').readlines()[1:] 34 | print ('day count = ' + str(len(text_lines))) 35 | 36 | print (text_lines[0]) 37 | 38 | sources_all = [] 39 | targets_all = [] 40 | 41 | for line in reversed(text_lines): 42 | #print(line) 43 | lw = line.split(',') 44 | S = [float(lw[3]), float(lw[5])/10000, float(lw[6])/10000] 45 | if len(sources_all) == 0: 46 | S.append(0.00001) 47 | S.append(0.00001) 48 | S.append(0.00001) 49 | 50 | else: 51 | last = sources_all[-1] 52 | S.append((S[0] - last[0])/last[0]) 53 | S.append((S[1] - last[1])/last[1]) 54 | S.append((S[2] - last[2])/last[2]) 55 | 56 | 57 | sources_all.append(S) 58 | 59 | T = S[3] 60 | targets_all.append(T) 61 | 62 | 63 | print(len(sources_all)) 64 | print(len(targets_all)) 65 | 66 | 67 | sources = sources_all[:5000] 68 | targets = targets_all[:5000] 69 | sources_test = sources_all[5000:] 70 | targets_test = targets_all[5000:] 71 | 72 | 73 | ''' 74 | path = get_file('nietzsche.txt', origin="https://s3.amazonaws.com/text-datasets/nietzsche.txt") 75 | text = open(path).read().lower() 76 | print('corpus length:', len(text)) 77 | 78 | chars = sorted(list(set(text))) 79 | print('total chars:', len(chars)) 80 | char_indices = dict((c, i) for i, c in enumerate(chars)) 81 | indices_char = dict((i, c) for i, c in enumerate(chars)) 82 | ''' 83 | 84 | # cut the text in semi-redundant sequences of maxlen characters 85 | maxlen = 40 86 | step = 1 87 | sentences = [] 88 | next_chars = [] 89 | for i in range(0, len(sources) - maxlen, step): 90 | sentences.append(sources[i: i + maxlen]) 91 | next_chars.append(targets[i + maxlen]) 92 | print('nb sequences:', len(sentences)) 93 | 94 | sentences_test = [] 95 | next_chars_test = [] 96 | for i in range(0, len(sources_test) - maxlen, step): 97 | sentences_test.append(sources_test[i: i + maxlen]) 98 | next_chars_test.append(targets_test[i + maxlen]) 99 | print('nb test sequences:', len(sentences_test)) 100 | 101 | print('Vectorization...') 102 | X = np.zeros((len(sentences), maxlen, 6), dtype=np.float32) 103 | y = np.zeros((len(sentences), 1), dtype=np.float32) 104 | for i, sentence in enumerate(sentences): 105 | for t, char in enumerate(sentence): 106 | for g in xrange(0,6): 107 | X[i, t, g] = char[g] 108 | y[i, 0] = next_chars[i] 109 | 110 | 111 | # build the model: 2 stacked LSTM 112 | print('Build model...') 113 | model = Sequential() 114 | model.add(LSTM(1024, return_sequences=True, input_shape=(maxlen, 6))) 115 | model.add(LSTM(1024, return_sequences=False)) 116 | model.add(Dropout(0.2)) 117 | model.add(Dense(1)) 118 | #model.add(Dense(1)) 119 | #model.add(Activation('softmax')) 120 | model.add(Activation('linear')) 121 | 122 | model.compile(loss='mse', optimizer='rmsprop') 123 | 124 | ''' 125 | def sample(a, temperature=1.0): 126 | # helper function to sample an index from a probability array 127 | a = np.log(a) / temperature 128 | a = np.exp(a) / np.sum(np.exp(a)) 129 | return np.argmax(np.random.multinomial(1, a, 1)) 130 | ''' 131 | 132 | # train the model, output generated text after each iteration 133 | for iteration in range(1, 60): 134 | print() 135 | print('-' * 50) 136 | print('Iteration', iteration) 137 | model.fit(X, y, batch_size=128, nb_epoch=1) 138 | 139 | 140 | predret = [] 141 | for sent,targ in zip(sentences_test[:20], targets_test[:20]): 142 | x = np.zeros((1, maxlen, 6)) 143 | for t, char in enumerate(sent): 144 | for g in xrange(0,6): 145 | x[0, t, g] = char[g] 146 | 147 | preds = model.predict(x, verbose=0)[0] 148 | print(preds[0], targ) 149 | 150 | 151 | 152 | print("cbf done!") 153 | -------------------------------------------------------------------------------- /material/stub: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aceimnorstuvwxz/chatbot-zh-torch7/0c0751aed0c163c1a2f4b3226780d8b6b5684712/material/stub -------------------------------------------------------------------------------- /movie_script_parser.lua: -------------------------------------------------------------------------------- 1 | local Parser = torch.class("neuralconvo.MovieScriptParser") 2 | 3 | function Parser:parse(file) 4 | local f = assert(io.open(file, 'r')) 5 | self.input = f:read("*all") 6 | f:close() 7 | 8 | self.pos = 0 9 | self.match = nil 10 | 11 | -- Find start of script 12 | repeat self:acceptLine() until self:accept("
")
 13 | 
 14 |   local dialogs = {}
 15 | 
 16 |   -- Apply rules until end of script
 17 |   while not self:accept("
") and self:acceptLine() do 18 | local dialog = self:parseDialog() 19 | if dialog then 20 | table.insert(dialogs, dialog) 21 | end 22 | end 23 | 24 | return dialogs 25 | end 26 | 27 | -- Returns true if regexp matches and advance position 28 | function Parser:accept(regexp) 29 | local match = string.match(self.input, "^" .. regexp, self.pos) 30 | if match then 31 | self.pos = self.pos + #match 32 | self.match = match 33 | return true 34 | end 35 | end 36 | 37 | -- Accept anything up to the end of line 38 | function Parser:acceptLine() 39 | return self:accept(".-\n") 40 | end 41 | 42 | function Parser:acceptSep() 43 | while self:accept("") or self:accept(" +") do end 44 | return self:accept("\n") 45 | end 46 | 47 | function Parser:parseDialog() 48 | local dialogs = {} 49 | 50 | repeat 51 | local dialog = self:parseSpeech() 52 | if dialog then 53 | table.insert(dialogs, dialog) 54 | end 55 | until not self:acceptSep() 56 | 57 | if #dialogs > 0 then 58 | return dialogs 59 | end 60 | end 61 | 62 | -- Matches: 63 | -- 64 | -- NAME 65 | -- some nice text 66 | -- more text. 67 | -- 68 | -- or 69 | -- 70 | -- NAME; text 71 | function Parser:parseSpeech() 72 | local name 73 | 74 | self:accept("") 75 | self:accept("") 76 | 77 | -- Get the character name (all caps) 78 | -- TODO remove parenthesis from name 79 | if self:accept(" +") and self:accept("[A-Z][A-Z%- %.%(%)]+") then 80 | name = self.match 81 | else 82 | return 83 | end 84 | 85 | -- Handle inline dialog: `NAME; text` 86 | if self:accept(";") and self:accept("[^\n]+") then 87 | return { 88 | character = name, 89 | text = self.match 90 | } 91 | end 92 | 93 | self:accept("\n") 94 | 95 | if not self:accept("") then 96 | return 97 | end 98 | 99 | -- Get the dialog lines 100 | -- TODO remove parenthesis from text 101 | local lines = {} 102 | while self:accept(" +") do 103 | -- The actual line of dialog 104 | if self:accept("[^\n]+") then 105 | table.insert(lines, self.match) 106 | end 107 | self:accept("\n") 108 | end 109 | 110 | if #lines > 0 then 111 | return { 112 | character = name, 113 | text = table.concat(lines) 114 | } 115 | end 116 | end 117 | -------------------------------------------------------------------------------- /neuralconvo.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'rnn' 4 | 5 | neuralconvo = {} 6 | 7 | torch.include('neuralconvo', 'cornell_movie_dialogs.lua') 8 | torch.include('neuralconvo', 'dataset.lua') 9 | torch.include('neuralconvo', 'movie_script_parser.lua') 10 | torch.include('neuralconvo', 'seq2seq.lua') 11 | 12 | return neuralconvo 13 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Neural Conversational Model in Torch 2 | 3 | Forked from https://github.com/chenb67/neuralconvo 4 | 5 | 如果训练中遇到问题请先翻原始fork的issue,这里只是改成中文而已! 6 | 7 | 8 | 9 | ## How 10 | Use https://github.com/dgkae/dgk_lost_conv as training corpus. The chinese sentenses should be splited by semantic words, using '/'. We modify cornell_movie_dialog.lua to support it. Lua save all string(e.g. chinese) all in multibyte, so in chinese the formal pl.lexer is not working. We use outsider word-splitting tool and using '/' as the tag. 11 | 12 | ## Result 13 | 14 | ![result](a.png) 15 | ![result2](b.png) 16 | 17 | 18 | 19 | 20 | ## Rwt 21 | 本repo已不在维护,有几个聊天群: 22 | -------------------------------------------------------------------------------- /run_server.sh: -------------------------------------------------------------------------------- 1 | th -i eval-server.lua --cuda 2 | -------------------------------------------------------------------------------- /seq2seq.lua: -------------------------------------------------------------------------------- 1 | -- Based on https://github.com/Element-Research/rnn/blob/master/examples/encoder-decoder-coupling.lua 2 | local Seq2Seq = torch.class("neuralconvo.Seq2Seq") 3 | 4 | function Seq2Seq:__init(vocabSize, hiddenSize) 5 | self.vocabSize = assert(vocabSize, "vocabSize required at arg #1") 6 | self.hiddenSize = assert(hiddenSize, "hiddenSize required at arg #2") 7 | 8 | self:buildModel() 9 | end 10 | 11 | function Seq2Seq:buildModel() 12 | self.encoder = nn.Sequential() 13 | self.encoder:add(nn.LookupTableMaskZero(self.vocabSize, self.hiddenSize)) 14 | self.encoderLSTM = nn.LSTM(self.hiddenSize, self.hiddenSize):maskZero(1) 15 | self.encoder:add(nn.Sequencer(self.encoderLSTM)) 16 | self.encoder:add(nn.Select(1,-1)) 17 | 18 | self.decoder = nn.Sequential() 19 | self.decoder:add(nn.LookupTableMaskZero(self.vocabSize, self.hiddenSize)) 20 | self.decoderLSTM = nn.LSTM(self.hiddenSize, self.hiddenSize):maskZero(1) 21 | self.decoder:add(nn.Sequencer(self.decoderLSTM)) 22 | self.decoder:add(nn.Sequencer(nn.MaskZero(nn.Linear(self.hiddenSize, self.vocabSize),1))) 23 | self.decoder:add(nn.Sequencer(nn.MaskZero(nn.LogSoftMax(),1))) 24 | 25 | self.encoder:zeroGradParameters() 26 | self.decoder:zeroGradParameters() 27 | end 28 | 29 | function Seq2Seq:cuda() 30 | self.encoder:cuda() 31 | self.decoder:cuda() 32 | 33 | if self.criterion then 34 | self.criterion:cuda() 35 | end 36 | end 37 | 38 | function Seq2Seq:cl() 39 | self.encoder:cl() 40 | self.decoder:cl() 41 | 42 | if self.criterion then 43 | self.criterion:cl() 44 | end 45 | end 46 | 47 | --[[ Forward coupling: Copy encoder cell and output to decoder LSTM ]]-- 48 | function Seq2Seq:forwardConnect(inputSeqLen) 49 | self.decoderLSTM.userPrevOutput = 50 | nn.rnn.recursiveCopy(self.decoderLSTM.userPrevOutput, self.encoderLSTM.outputs[inputSeqLen]) 51 | self.decoderLSTM.userPrevCell = 52 | nn.rnn.recursiveCopy(self.decoderLSTM.userPrevCell, self.encoderLSTM.cells[inputSeqLen]) 53 | end 54 | 55 | --[[ Backward coupling: Copy decoder gradients to encoder LSTM ]]-- 56 | function Seq2Seq:backwardConnect() 57 | self.encoderLSTM.userNextGradCell = 58 | nn.rnn.recursiveCopy(self.encoderLSTM.userNextGradCell, self.decoderLSTM.userGradPrevCell) 59 | self.encoderLSTM.gradPrevOutput = 60 | nn.rnn.recursiveCopy(self.encoderLSTM.gradPrevOutput, self.decoderLSTM.userGradPrevOutput) 61 | end 62 | 63 | function Seq2Seq:train(encoderInputs, decoderInputs, decoderTargets) 64 | 65 | -- Forward pass 66 | local encoderOutput = self.encoder:forward(encoderInputs) 67 | self:forwardConnect(encoderInputs:size(1)) 68 | local decoderOutput = self.decoder:forward(decoderInputs) 69 | local Edecoder = self.criterion:forward(decoderOutput, decoderTargets) 70 | 71 | if Edecoder ~= Edecoder then -- Exist early on bad error 72 | return Edecoder 73 | end 74 | 75 | -- Backward pass 76 | local gEdec = self.criterion:backward(decoderOutput, decoderTargets) 77 | self.decoder:backward(decoderInputs, gEdec) 78 | self:backwardConnect() 79 | self.encoder:backward(encoderInputs, encoderOutput:zero()) 80 | 81 | self.encoder:updateGradParameters(self.momentum) 82 | self.decoder:updateGradParameters(self.momentum) 83 | self.decoder:updateParameters(self.learningRate) 84 | self.encoder:updateParameters(self.learningRate) 85 | self.encoder:zeroGradParameters() 86 | self.decoder:zeroGradParameters() 87 | 88 | self.decoder:forget() 89 | self.encoder:forget() 90 | 91 | return Edecoder 92 | end 93 | 94 | local MAX_OUTPUT_SIZE = 20 95 | 96 | function Seq2Seq:eval(input) 97 | assert(self.goToken, "No goToken specified") 98 | assert(self.eosToken, "No eosToken specified") 99 | 100 | self.encoder:forward(input) 101 | self:forwardConnect(input:size(1)) 102 | 103 | local predictions = {} 104 | local probabilities = {} 105 | 106 | -- Forward and all of it's output recursively back to the decoder 107 | local output = {self.goToken} 108 | for i = 1, MAX_OUTPUT_SIZE do 109 | local prediction = self.decoder:forward(torch.Tensor(output))[#output] 110 | -- prediction contains the probabilities for each word IDs. 111 | -- The index of the probability is the word ID. 112 | local prob, wordIds = prediction:topk(5, 1, true, true) 113 | 114 | -- First one is the most likely. 115 | next_output = wordIds[1] 116 | table.insert(output, next_output) 117 | 118 | -- Terminate on EOS token 119 | if next_output == self.eosToken then 120 | break 121 | end 122 | 123 | table.insert(predictions, wordIds) 124 | table.insert(probabilities, prob) 125 | end 126 | 127 | self.decoder:forget() 128 | self.encoder:forget() 129 | 130 | return predictions, probabilities 131 | end 132 | -------------------------------------------------------------------------------- /tokenizer.lua: -------------------------------------------------------------------------------- 1 | local lexer = require "pl.lexer" 2 | local yield = coroutine.yield 3 | local M = {} 4 | 5 | local function word(token) 6 | return yield("word", token) 7 | end 8 | 9 | local function quote(token) 10 | return yield("quote", token) 11 | end 12 | 13 | local function space(token) 14 | return yield("space", token) 15 | end 16 | 17 | local function tag(token) 18 | return yield("tag", token) 19 | end 20 | 21 | local function punct(token) 22 | return yield("punct", token) 23 | end 24 | 25 | local function endpunct(token) 26 | return yield("endpunct", token) 27 | end 28 | 29 | local function unknown(token) 30 | print("unknown") 31 | return yield("unknown", token) 32 | end 33 | 34 | function M.tokenize(text) 35 | 36 | print(text) 37 | 38 | --{ "^[\128-\193]+", word }, 39 | return lexer.scan(text, { 40 | { "^%s+", space }, 41 | { "^['\"]", quote }, 42 | { "^%w+", word }, 43 | { "^%-+", space }, 44 | { "^[,:;%-]", punct }, 45 | { "^%.+", endpunct }, 46 | { "^[%.%?!]", endpunct }, 47 | { "^", tag }, 48 | { "^.", unknown }, 49 | }, { [space]=true, [tag]=true }) 50 | end 51 | 52 | function M.join(words) 53 | local s = table.concat(words, " ") 54 | s = s:gsub("^%l", string.upper) 55 | s = s:gsub(" (') ", "%1") 56 | s = s:gsub(" ([,:;%-%.%?!])", "%1") 57 | return s 58 | end 59 | 60 | return M -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | require 'neuralconvo' 2 | require 'xlua' 3 | 4 | cmd = torch.CmdLine() 5 | cmd:text('Options:') 6 | cmd:option('--dataset', 0, 'approximate size of dataset to use (0 = all)') 7 | cmd:option('--minWordFreq', 1, 'minimum frequency of words kept in vocab') 8 | cmd:option('--cuda', false, 'use CUDA') 9 | cmd:option('--opencl', false, 'use opencl') 10 | cmd:option('--hiddenSize', 300, 'number of hidden units in LSTM') 11 | cmd:option('--learningRate', 0.05, 'learning rate at t=0') 12 | cmd:option('--momentum', 0.9, 'momentum') 13 | cmd:option('--minLR', 0.00001, 'minimum learning rate') 14 | cmd:option('--saturateEpoch', 20, 'epoch at which linear decayed LR will reach minLR') 15 | cmd:option('--maxEpoch', 30, 'maximum number of epochs to run') 16 | cmd:option('--batchSize', 10, 'number of examples to load at once') 17 | 18 | cmd:text() 19 | options = cmd:parse(arg) 20 | 21 | if options.dataset == 0 then 22 | options.dataset = nil 23 | end 24 | 25 | -- Data 26 | print("-- Loading dataset") 27 | dataset = neuralconvo.DataSet(neuralconvo.CornellMovieDialogs("data/cornell_movie_dialogs"), 28 | { 29 | loadFirst = options.dataset, 30 | minWordFreq = options.minWordFreq 31 | }) 32 | 33 | print("\nDataset stats:") 34 | print(" Vocabulary size: " .. dataset.wordsCount) 35 | print(" Examples: " .. dataset.examplesCount) 36 | 37 | -- Model 38 | model = neuralconvo.Seq2Seq(dataset.wordsCount, options.hiddenSize) 39 | model.goToken = dataset.goToken 40 | model.eosToken = dataset.eosToken 41 | 42 | -- Training parameters 43 | if options.batchSize > 1 then 44 | model.criterion = nn.SequencerCriterion(nn.MaskZeroCriterion(nn.ClassNLLCriterion(),1)) 45 | else 46 | model.criterion = nn.SequencerCriterion(nn.ClassNLLCriterion()) 47 | end 48 | model.learningRate = options.learningRate 49 | model.momentum = options.momentum 50 | local decayFactor = (options.minLR - options.learningRate) / options.saturateEpoch 51 | local minMeanError = nil 52 | 53 | -- Enabled CUDA 54 | if options.cuda then 55 | require 'cutorch' 56 | require 'cunn' 57 | model:cuda() 58 | elseif options.opencl then 59 | require 'cltorch' 60 | require 'clnn' 61 | model:cl() 62 | end 63 | 64 | 65 | -- Run the experiment 66 | print("dgk ending") 67 | --exit() 68 | 69 | for epoch = 1, options.maxEpoch do 70 | print("\n-- Epoch " .. epoch .. " / " .. options.maxEpoch) 71 | print("") 72 | 73 | local errors = torch.Tensor(dataset.examplesCount):fill(0) 74 | local timer = torch.Timer() 75 | 76 | local i = 1 77 | for encInputs, decInputs, decTargets in dataset:batches(options.batchSize) do 78 | collectgarbage() 79 | 80 | if options.cuda then 81 | encInputs = encInputs:cuda() 82 | decInputs = decInputs:cuda() 83 | decTargets = decTargets:cuda() 84 | elseif options.opencl then 85 | encInputs = encInputs:cl() 86 | decInputs = decInputs:cl() 87 | decTargets = decTargets:cl() 88 | end 89 | 90 | local err = model:train(encInputs, decInputs, decTargets) 91 | 92 | -- Check if error is NaN. If so, it's probably a bug. 93 | if err ~= err then 94 | error("Invalid error! Exiting.") 95 | end 96 | 97 | errors[i] = err 98 | xlua.progress(i * options.batchSize, dataset.examplesCount) 99 | i = i + 1 100 | end 101 | 102 | timer:stop() 103 | 104 | print("\nFinished in " .. xlua.formatTime(timer:time().real) .. " " .. (dataset.examplesCount / timer:time().real) .. ' examples/sec.') 105 | print("\nEpoch stats:") 106 | print(" LR= " .. model.learningRate) 107 | print(" Errors: min= " .. errors:min()) 108 | print(" max= " .. errors:max()) 109 | print(" median= " .. errors:median()[1]) 110 | print(" mean= " .. errors:mean()) 111 | print(" std= " .. errors:std()) 112 | 113 | -- Save the model if it improved. 114 | if minMeanError == nil or errors:mean() < minMeanError then 115 | print("\n(Saving model ...)") 116 | torch.save("data/model.t7", model) 117 | minMeanError = errors:mean() 118 | end 119 | 120 | model.learningRate = model.learningRate + decayFactor 121 | model.learningRate = math.max(options.minLR, model.learningRate) 122 | end 123 | 124 | -- Load testing script 125 | require "eval" 126 | --------------------------------------------------------------------------------