├── test_slider.lua ├── convert_space_sep_data_to_lua_tensor.py ├── test_normalizer.lua ├── test_codec.lua ├── test_gru.lua ├── test_thread1.lua ├── random-test.py ├── utils ├── decoder.lua ├── levenshtein.lua └── logs.lua ├── sample_setting.json ├── test_cuda_ctc.lua ├── ocr.lua ├── codec.lua ├── .gitignore ├── CMakeLists.txt ├── gaussian_test.py ├── slider.lua ├── test_cuda.lua ├── test_sharedclone.lua ├── test_ctc.lua ├── pretrain.lua ├── README.md ├── solve.lua ├── contrast.txt ├── ctc.lua ├── ctc_lua.lua ├── test_thread.lua ├── LICENSE ├── loader.lua ├── normalizer.cc ├── GRU.lua ├── ctc_log.lua ├── libctc.cc ├── main.lua ├── rbm.lua └── test_ctc_large.lua /test_slider.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | require 'slider' 3 | 4 | 5 | im = image.load('scaled.png')[1] 6 | 7 | slider = Slider() 8 | slider:load(im) 9 | 10 | -- print(im) 11 | 12 | local s = slider:slide() 13 | 14 | while s do 15 | image.display(s) 16 | s = slider:slide() 17 | end 18 | 19 | -- image.display(im[{{1, im:size(1)}, {50, 150}}]) 20 | 21 | -------------------------------------------------------------------------------- /convert_space_sep_data_to_lua_tensor.py: -------------------------------------------------------------------------------- 1 | f = open("data1.txt", "r") 2 | lines = f.readlines() 3 | f.close() 4 | 5 | output = "data = torch.Tensor{\n" 6 | 7 | for line in lines: 8 | nums = line.split() 9 | output += "\t{" 10 | output += ",".join(nums) 11 | output += "},\n" 12 | 13 | output += "}" 14 | 15 | f = open("data.lua", "w") 16 | f.write(output) 17 | f.close() -------------------------------------------------------------------------------- /test_normalizer.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | require 'normalizer' 3 | 4 | im = image.load("26.png", 1) 5 | 6 | if im:dim() == 3 then 7 | im = im[1] 8 | end 9 | 10 | output = torch.Tensor() 11 | 12 | w = im:size()[2] 13 | h = im:size()[1] 14 | 15 | ones = torch.ones(h, w) 16 | 17 | im = ones - im 18 | 19 | normalizer.normalize(im, output) 20 | 21 | image.save("out.png", output) 22 | 23 | 24 | -------------------------------------------------------------------------------- /test_codec.lua: -------------------------------------------------------------------------------- 1 | require 'loader' 2 | 3 | l = Loader() 4 | l:load("1.txt") 5 | 6 | INIT_LAMBDA = 3.0 7 | 8 | function getLambda(i, total) 9 | return (1 - i / total) * INIT_LAMBDA 10 | end 11 | 12 | lambda = 3.0 13 | 14 | for i = 1, 100 do 15 | print(l:pickWithWeight()) 16 | 17 | if i % 10 == 0 then 18 | lambda = getLambda(i, 100) 19 | print(">>> updated lambda = " .. lambda) 20 | l:updateWeight(lambda) 21 | end 22 | end 23 | 24 | -------------------------------------------------------------------------------- /test_gru.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | require 'rnn' 4 | require 'GRU' 5 | GRU = require 'GRU_char-rnn' 6 | 7 | local myGRU = nn.GRU(10, 20) 8 | 9 | local module = myGRU.recurrentModule 10 | 11 | local input = torch.rand(10) 12 | local output_ = torch.rand(20) 13 | 14 | print(module) 15 | print(module:forward({input, output_})) 16 | 17 | gru = GRU.gru(10, 20, 1) 18 | 19 | graph.dot(gru.fg, "GRU") 20 | 21 | print(gru:forward({input, output_})[1]) -------------------------------------------------------------------------------- /test_thread1.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | net = nn.Linear(10, 10) 4 | 5 | params, grad_params = net:getParameters() 6 | 7 | n = net:clone() 8 | 9 | p, g = n:getParameters() 10 | 11 | p = params 12 | 13 | net:share(n, 'weight', 'bias') 14 | 15 | 16 | inp = torch.randn(10) 17 | 18 | n:forward(inp) 19 | 20 | grad = torch.randn(10) 21 | 22 | n:backward(inp, grad) 23 | 24 | n:updateParameters(1e-3) 25 | 26 | print(p:sum()) 27 | print(params:sum()) 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /random-test.py: -------------------------------------------------------------------------------- 1 | import random 2 | N = 1000000 3 | wt = [10, 20, 40] 4 | wtp = [1.*x/sum(wt) for x in wt] 5 | result = [] 6 | p = [random.normalvariate(1./x, 1./x/3.) for x in wtp] 7 | for i in xrange(N): 8 | minp = 1.e9 9 | minj = -1 10 | for j, pp in enumerate(p): 11 | if pp < minp: 12 | minp = pp 13 | minj = j 14 | result.append(minj) 15 | for j, pp in enumerate(p): 16 | p[j] -= minp 17 | p[minj] = random.normalvariate(1./wtp[minj], 1./wtp[minj]/3.) 18 | 19 | -------------------------------------------------------------------------------- /utils/decoder.lua: -------------------------------------------------------------------------------- 1 | decoder = {} 2 | 3 | function decoder.best_path_decode(outputTable, codec) 4 | local result = {} 5 | 6 | 7 | local class_num = outputTable[1]:size()[1] 8 | local last_max_class = nil; 9 | local last_max = -1; 10 | 11 | for i = 1, #outputTable do 12 | local max_val, max = torch.max(outputTable[i], 1) 13 | max = max[1] 14 | 15 | max_val = max_val[1] 16 | 17 | if max ~= last_max_class then 18 | if max ~= class_num then 19 | table.insert(result, max) 20 | end 21 | last_max_class = max 22 | end 23 | 24 | 25 | end 26 | 27 | return codec:decode(result) 28 | end -------------------------------------------------------------------------------- /sample_setting.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "sample_exp", 3 | "raw_input": false, 4 | "hidden_size": 200, 5 | "nthread": 3, 6 | "ctc_lua": false, 7 | "test_every": 1000, 8 | "stride": 5, 9 | "save_every": 10000, 10 | "gpu": false, 11 | "show_every": 10, 12 | "input_size": 48, 13 | "clamp_size": 1, 14 | "max_param_norm": false, 15 | "testing_ratio": 1, 16 | "momentum": 0.9, 17 | "dropout_rate": 0, 18 | "max_iter": 10000000000, 19 | "feature_size": 240, 20 | "learning_rate": 0.0001, 21 | "training_list_file": "1.txt", 22 | "omp_threads": 1, 23 | "recurrent_unit": "gru", 24 | "windows_size": 10 25 | } -------------------------------------------------------------------------------- /test_cuda_ctc.lua: -------------------------------------------------------------------------------- 1 | require 'cutorch' 2 | require 'ctc' 3 | 4 | t = torch.randn(100000):cuda() 5 | t2 = torch.randn(100000) 6 | 7 | timer = torch.Timer() 8 | 9 | last = 0 10 | 11 | for i = 1, (#t)[1] do 12 | t[i] = t[i] + 1 13 | end 14 | 15 | now = timer:time().real 16 | print(now - last) 17 | last = now 18 | 19 | 20 | for i = 1, (#t2)[1] do 21 | t2[i] = t2[i] + 1 22 | end 23 | 24 | 25 | now = timer:time().real 26 | print(now - last) 27 | last = now 28 | 29 | t3 = t:float() 30 | 31 | for i = 1, (#t3)[1] do 32 | t3[i] = t3[i] + 1 33 | end 34 | 35 | 36 | now = timer:time().real 37 | print(now - last) 38 | last = now 39 | 40 | print(t3) -------------------------------------------------------------------------------- /ocr.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | require 'cutorch' 4 | require 'cunn' 5 | require 'rnn' 6 | 7 | require 'image' 8 | 9 | local im = image.load("1.png", 1)[1]:t() 10 | 11 | local size = im:size() 12 | local im_w = size[1] 13 | local im_h = size[2] 14 | 15 | local input_size = im_h 16 | local hidden_size = 400 17 | local class_num = 10 18 | 19 | 20 | local net = nn.Sequential() 21 | 22 | im = im:float():cuda() 23 | 24 | net:add(nn.SplitTable(1)) 25 | net:add(nn.BiSequencer(nn.LSTM(input_size, hidden_size))) 26 | output = nn.Sequential() 27 | output:add(nn.Linear(hidden_size * 2, class_num + 1)) 28 | output:add(nn.SoftMax()) 29 | net:add(nn.Sequencer(output)) 30 | net:cuda() 31 | 32 | print(net:forward(im)) 33 | -------------------------------------------------------------------------------- /codec.lua: -------------------------------------------------------------------------------- 1 | utf8 = utf8 or require 'utf8' 2 | 3 | Codec = { 4 | codec = {}, 5 | codec_inv = {}, 6 | codec_size = 0 7 | } 8 | 9 | setmetatable(Codec, { 10 | __call = 11 | function (cls, ...) 12 | return cls:new(...) 13 | end 14 | }) 15 | 16 | function Codec:new(o) 17 | o = o or {} 18 | setmetatable(o, self) 19 | self.__index = self 20 | return o 21 | end 22 | 23 | function Codec:encode(src) 24 | local result = {} 25 | for _, v, _ in utf8.iter(src) do 26 | table.insert(result, self.codec[v]) 27 | end 28 | 29 | return result 30 | end 31 | 32 | function Codec:decode(src) 33 | local result = "" 34 | for _, v in ipairs(src) do 35 | result = result .. self.codec_inv[v] 36 | end 37 | 38 | return result 39 | end -------------------------------------------------------------------------------- /utils/levenshtein.lua: -------------------------------------------------------------------------------- 1 | utf8 = require 'utf8' 2 | 3 | function utf8.levenshtein(str1, str2) 4 | local len1 = utf8.len(str1) 5 | local len2 = utf8.len(str2) 6 | 7 | local matrix = {} 8 | local cost 9 | 10 | if (len1 == 0) then 11 | return len2 12 | elseif (len2 == 0) then 13 | return len1 14 | elseif (str1 == str2) then 15 | return 0 16 | end 17 | 18 | for i = 0, len1 do 19 | matrix[i] = {} 20 | matrix[i][0] = i 21 | end 22 | for j = 1, len2 do 23 | matrix[0][j] = j 24 | end 25 | 26 | for i = 1, len1 do 27 | for j = 1, len2 do 28 | if (utf8.at(str1, i) == utf8.at(str2, j)) then 29 | cost = 0 30 | else 31 | cost = 1 32 | end 33 | 34 | matrix[i][j] = math.min(matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + cost) 35 | end 36 | end 37 | 38 | return matrix[len1][len2] 39 | end 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 1/* 2 | samples/* 3 | build/* 4 | ctc/* 5 | normalizer/* 6 | 7 | models/* 8 | experiments/* 9 | 10 | rbm/* 11 | *.rbm 12 | *.json 13 | *.txt~ 14 | *~ 15 | 16 | *.txt 17 | *.codec 18 | !CMakeLists.txt 19 | !sample_setting.json 20 | 21 | .* 22 | 23 | *.uma 24 | *.png 25 | *.tif 26 | 27 | .DS_Store 28 | *.dSYM 29 | 30 | # Compiled Lua sources 31 | luac.out 32 | 33 | # luarocks build files 34 | *.src.rock 35 | *.zip 36 | *.tar.gz 37 | 38 | # Object files 39 | *.o 40 | *.os 41 | *.ko 42 | *.obj 43 | *.elf 44 | 45 | # Precompiled Headers 46 | *.gch 47 | *.pch 48 | 49 | # Libraries 50 | *.lib 51 | *.a 52 | *.la 53 | *.lo 54 | *.def 55 | *.exp 56 | 57 | # Shared objects (inc. Windows DLLs) 58 | *.dll 59 | *.so 60 | *.so.* 61 | *.dylib 62 | 63 | # Executables 64 | *.exe 65 | *.out 66 | *.app 67 | *.i*86 68 | *.x86_64 69 | *.hex 70 | 71 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) 2 | CMAKE_POLICY(VERSION 2.6) 3 | FIND_PACKAGE(Torch REQUIRED) 4 | FIND_PACKAGE(OpenMP) 5 | 6 | # set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 7 | 8 | 9 | if(OPENMP_FOUND) 10 | message("OPENMP FOUND") 11 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 12 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 13 | set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") 14 | endif() 15 | 16 | foreach(target libctc normalizer) 17 | message(${target}) 18 | add_library(${target} SHARED "${target}.cc") 19 | set_target_properties(${target} 20 | PROPERTIES PREFIX "" 21 | SUFFIX ".so" 22 | LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/../" 23 | ) 24 | target_link_libraries(${target} TH luaT luajit) 25 | endforeach(target) 26 | 27 | 28 | -------------------------------------------------------------------------------- /gaussian_test.py: -------------------------------------------------------------------------------- 1 | from scipy import misc, ndimage 2 | from scipy.ndimage import filters 3 | import matplotlib.pyplot as plt 4 | from pylab import * 5 | from numpy import * 6 | import PIL 7 | 8 | 9 | 10 | def pil2array(im,alpha=0): 11 | if im.mode=="L": 12 | a = fromstring(im.tostring(),'B') 13 | a.shape = im.size[1],im.size[0] 14 | return a 15 | if im.mode=="RGB": 16 | a = fromstring(im.tostring(),'B') 17 | a.shape = im.size[1],im.size[0],3 18 | return a 19 | if im.mode=="RGBA": 20 | a = fromstring(im.tostring(),'B') 21 | a.shape = im.size[1],im.size[0],4 22 | if not alpha: a = a[:,:,:3] 23 | return a 24 | return pil2array(im.convert("L")) 25 | 26 | im = PIL.Image.open('bq01_006-1.png') 27 | im = pil2array(im) 28 | 29 | im = im / 255.0 30 | 31 | print(im) 32 | 33 | 34 | h = im.shape[0] 35 | w = im.shape[1] 36 | 37 | smooth = filters.gaussian_filter(im, (h * 0.5, h * 1.0), mode='constant') 38 | 39 | smooth += 0.001*filters.uniform_filter(smooth, (h*0.5, w), mode='constant') 40 | 41 | print(smooth.shape) 42 | 43 | a = argmax(smooth, axis=0) 44 | a = filters.gaussian_filter(a, h * 0.3) 45 | 46 | center = array(a,'i') 47 | # print(center) 48 | deltas = abs(arange(h)[:, newaxis] - center[newaxis, :]) 49 | mad = mean(deltas[im != 0]) 50 | r = int(1 + 4 * mad) 51 | 52 | plt.imshow(smooth, cmap=cm.gray) 53 | plot(center) 54 | plt.show() -------------------------------------------------------------------------------- /slider.lua: -------------------------------------------------------------------------------- 1 | Slider = { 2 | pos = 0, 3 | total = 0, 4 | stride = 0, 5 | win_width = 0, 6 | height = 0, 7 | width = 0, 8 | im = nil 9 | } 10 | 11 | setmetatable(Slider, { 12 | __call = 13 | function (cls, ...) 14 | return cls:new(...) 15 | end 16 | }) 17 | 18 | function Slider:new(o, win_width, stride) 19 | local o = o or {} 20 | self.win_width = win_width or 10 21 | self.stride = stride or self.win_width / 2 22 | setmetatable(o, self) 23 | self.__index = self 24 | return o 25 | end 26 | 27 | function Slider:load(im) 28 | assert(im:dim() == 2, "[Slider] a 2-dimensional tensor expected.") 29 | 30 | self.height = im:size(1) 31 | self.width = im:size(2) 32 | 33 | self.total = math.ceil(self.width / self.stride) -- total # of windows 34 | 35 | self.pos = 0 36 | self.im = im 37 | end 38 | 39 | function Slider:slide() 40 | assert(self.im, "[Slider] need to load a image before slide") 41 | 42 | if self.pos >= self.total then 43 | return nil 44 | end 45 | 46 | local _start = self.pos * self.stride + 1 47 | local _end = _start + self.win_width 48 | local ret 49 | 50 | if _end > self.width then 51 | ret = torch.zeros(self.height, self.win_width) 52 | ret[{{1, self.height}, {1, self.width - _start + 1}}] = self.im[{{1, self.height}, {_start, self.width}}] 53 | else 54 | ret = self.im[{{1, self.height}, {_start, _start + self.win_width - 1}}] 55 | 56 | end 57 | 58 | self.pos = self.pos + 1 59 | 60 | return ret 61 | end 62 | 63 | function Slider:genSequence() 64 | local seq = {} 65 | 66 | local s = self:slide() 67 | 68 | while s do 69 | s = s:reshape(s:nElement()) 70 | 71 | table.insert(seq, s) 72 | s = self:slide() 73 | end 74 | 75 | return seq 76 | end -------------------------------------------------------------------------------- /test_cuda.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'rnn' 3 | require 'image' 4 | require 'optim' 5 | 6 | require 'loader' 7 | 8 | opts = { 9 | gpu = false, 10 | dropout_rate = 0.4, 11 | input_size = 64, 12 | hidden_size = 400, 13 | learning_rate = 1e-4, 14 | momentum = 0.9 15 | } 16 | 17 | local class_num = 100 18 | 19 | local net = nn.Sequential() 20 | 21 | net:add(nn.Dropout(DROPOUT_RATE)) 22 | net:add(nn.SplitTable(1)) 23 | net:add(nn.BiSequencer(nn.LSTM(opts.input_size, opts.hidden_size))) 24 | net:add(nn.BiSequencer(nn.LSTM(opts.hidden_size * 2, opts.hidden_size))) 25 | output = nn.Sequential() 26 | output:add(nn.Linear(opts.hidden_size * 2, class_num + 1)) 27 | output:add(nn.SoftMax()) 28 | net:add(nn.Sequencer(output)) 29 | 30 | if opts.gpu then 31 | require 'cutorch' 32 | require 'cunn' 33 | net:cuda() 34 | cutorch.setDevice(1) 35 | cutorch.manualSeed(450) 36 | else 37 | torch.manualSeed(450) 38 | end 39 | 40 | 41 | timer = torch.Timer() 42 | 43 | 44 | 45 | 46 | im = torch.randn(64, 64) 47 | 48 | 49 | if opts.gpu then 50 | im = im:cuda() 51 | end 52 | 53 | base = timer:time().real 54 | 55 | outputTable = net:forward(im) 56 | 57 | print(timer:time().real - base) 58 | 59 | 60 | 61 | ims = {} 62 | 63 | for i = 1, 10 do 64 | table.insert(ims, torch.randn(64 * i, 64)) 65 | end 66 | 67 | 68 | 69 | if opts.gpu then 70 | for i = 1, 10 do 71 | ims[i] = ims[i]:cuda() 72 | end 73 | end 74 | 75 | base = timer:time().real 76 | 77 | last = base 78 | 79 | for i = 1, 10 do 80 | 81 | outputTable = net:forward(ims[i]) 82 | last = timer:time().real 83 | end 84 | 85 | base = last 86 | 87 | for i = 1, 10 do 88 | 89 | outputTable = net:forward(ims[i]) 90 | print((timer:time().real - last)) 91 | last = timer:time().real 92 | end 93 | 94 | print(last - base) 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /utils/logs.lua: -------------------------------------------------------------------------------- 1 | logs = { 2 | EXP_MAX = 1e10, 3 | EXP_MIN = 1e-10, 4 | LOG_ZERO = -1e10, 5 | LOG_INF = 1e10 6 | } 7 | 8 | logs.EXP_LIMIT = math.log(logs.EXP_MAX) 9 | 10 | function logs.safe_log(x) 11 | if x == 0 then 12 | return logs.LOG_ZERO 13 | elseif x > 0 then 14 | return math.log(x) 15 | else 16 | error("passing a negtive number to the log function.") 17 | end 18 | end 19 | 20 | function logs.safe_exp(x) 21 | if x == logs.LOG_ZERO then 22 | return 0 23 | end 24 | if x >= logs.EXP_LIMIT then 25 | return logs.EXP_MAX 26 | end 27 | return math.exp(x) 28 | end 29 | 30 | function logs.log_add(x, y) 31 | 32 | if math.abs(x - y) > 10 then 33 | return math.max(x, y) 34 | end 35 | 36 | if x < y then 37 | return y + math.log(1.0 + logs.safe_exp(x - y)) 38 | else 39 | return x + math.log(1.0 + logs.safe_exp(y - x)) 40 | end 41 | end 42 | 43 | function logs.log_sub(x, y) 44 | if y == logs.LOG_ZERO then 45 | return x 46 | end 47 | if y >= x then 48 | return logs.LOG_ZERO 49 | end 50 | return x + math.log(1.0 - logs.safe_exp(y - x)) 51 | end 52 | 53 | function logs.log_mul(x, y) 54 | if y == logs.LOG_ZERO or x == logs.LOG_ZERO then 55 | return logs.LOG_ZERO 56 | end 57 | 58 | return x + y 59 | end 60 | 61 | function logs.log_div(x, y) 62 | if x == logs.LOG_ZERO then 63 | return logs.LOG_ZERO 64 | end 65 | 66 | if y == logs.LOG_ZERO then 67 | return logs.LOG_INF 68 | end 69 | 70 | return x - y 71 | end 72 | 73 | function logs.log_sum(...) 74 | local arg = table.pack(...) 75 | if arg["n"] == 1 then 76 | return arg[1] 77 | end 78 | 79 | local max = math.max(unpack(arg)) 80 | 81 | local result = 0.0 82 | 83 | for i, v in ipairs(arg) do 84 | result = result + logs.safe_exp(v - max) 85 | end 86 | 87 | result = max + logs.safe_log(result) 88 | 89 | return result 90 | end 91 | 92 | -------------------------------------------------------------------------------- /test_sharedclone.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'rnn' 3 | require 'image' 4 | require 'optim' 5 | 6 | require 'loader' 7 | require 'ctc_log' 8 | require 'utils/decoder' 9 | 10 | local threads = require 'threads' 11 | 12 | 13 | -- initialize 14 | torch.setdefaulttensortype('torch.FloatTensor') 15 | torch.manualSeed(450) 16 | 17 | -- debug switch 18 | DEBUG = false 19 | 20 | -- timer initialize 21 | base = 0 22 | timer = torch.Timer() 23 | 24 | function show_log(log) 25 | local now = timer:time().real 26 | local cost = now - base 27 | base = now 28 | print(string.format("[%.4f][%.4f]%s", now, cost, log)) 29 | end 30 | 31 | -- settings 32 | 33 | DROPOUT_RATE = 0.4 34 | GPU_ENABLED = false 35 | local input_size = 32 36 | local hidden_size = 200 37 | clamp_size = 5 38 | 39 | -- configuration 40 | training_list_file = "1.txt" 41 | 42 | -- GPU 43 | 44 | if GPU_ENABLED then 45 | require 'cutorch' 46 | require 'cunn' 47 | end 48 | 49 | -- load samples 50 | 51 | show_log("Loading samples...") 52 | 53 | local loader = Loader() 54 | loader:load(training_list_file) 55 | local codec = loader:codec() 56 | 57 | show_log(string.format("Loading finished. Got %d samples, %d classes of characters.", #loader.samples, codec.codec_size)) 58 | 59 | local class_num = codec.codec_size 60 | 61 | -- build network 62 | 63 | show_log("Building networks...") 64 | 65 | local net 66 | 67 | net = nn.Sequential() 68 | 69 | net:add(nn.Dropout(DROPOUT_RATE)) 70 | net:add(nn.SplitTable(1)) 71 | net:add(nn.BiSequencer(nn.FastLSTM(input_size, hidden_size))) 72 | 73 | output = nn.Sequential() 74 | output:add(nn.Linear(hidden_size * 2, class_num + 1)) 75 | output:add(nn.SoftMax()) 76 | net:add(nn.Sequencer(output)) 77 | net:float() 78 | 79 | -- prepare prarmeters and training method 80 | 81 | local params, grad_params 82 | 83 | params, grad_params = net:getParameters() 84 | 85 | n = net:sharedClone(true, false) 86 | 87 | p, gd = n:getParameters() 88 | 89 | local sample = loader:pick() 90 | local im = sample.img 91 | local target = codec:encode(sample.gt) 92 | 93 | outputTable = n:forward(im) 94 | loss, grad = ctc.getCTCCostAndGrad(outputTable, target) 95 | 96 | n:backward(im, grad) 97 | 98 | 99 | -------------------------------------------------------------------------------- /test_ctc.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'ctc_log' 3 | 4 | 5 | 6 | 7 | torch.setdefaulttensortype('torch.DoubleTensor') 8 | 9 | outputTable = torch.Tensor{ 10 | {0.0684907, 0.0683173, 0.0682402, 0.0682124, 0.0682041, 0.068242, 0.0682717}, 11 | {0.067452, 0.0671584, 0.066958, 0.0667768, 0.0665642, 0.0662235, 0.0656839}, 12 | {0.0722717, 0.0725465, 0.0726959, 0.0727905, 0.072852, 0.0729304, 0.0730594}, 13 | {0.0719666, 0.0716504, 0.0714893, 0.0713959, 0.0713595, 0.0713568, 0.0714108}, 14 | {0.0744622, 0.0744583, 0.0744827, 0.0745185, 0.0745582, 0.0746348, 0.0747697}, 15 | {0.0660699, 0.0658052, 0.0656008, 0.0654654, 0.0653909, 0.0653646, 0.0653229}, 16 | {0.0736831, 0.0741657, 0.0743352, 0.0743431, 0.0742691, 0.0740992, 0.0738383}, 17 | {0.0771784, 0.0769233, 0.0768525, 0.0768478, 0.0768212, 0.0767496, 0.076588}, 18 | {0.0770828, 0.0770411, 0.0770657, 0.0770728, 0.0770434, 0.0769496, 0.0767167}, 19 | {0.0658117, 0.0656928, 0.0656175, 0.0656082, 0.0656804, 0.0658425, 0.0661498}, 20 | {0.0691583, 0.0690236, 0.0690129, 0.069102, 0.0692818, 0.0695909, 0.0701721}, 21 | {0.0700088, 0.0704332, 0.0706438, 0.0707642, 0.0708752, 0.0710007, 0.0711254}, 22 | {0.0727488, 0.0726681, 0.0726766, 0.0727124, 0.0727471, 0.0728149, 0.0730135}, 23 | {0.0736152, 0.0741163, 0.074329, 0.07439, 0.0743531, 0.0742006, 0.0738778}, 24 | } 25 | 26 | target = {4, 3, 13, 1, 10, 7} 27 | 28 | function toMatrix(outputTable) 29 | local net = nn.Sequential() 30 | net:add(nn.JoinTable(1)) 31 | net:add(nn.Reshape(#outputTable, outputTable[1]:size(1))) 32 | return net:forward(outputTable) 33 | end 34 | 35 | -- outputTable = nn.Log():forward(outputTable:t()) 36 | 37 | nrow = outputTable:size(2) 38 | 39 | splitedOutputTable = nn.SplitTable(1):forward(outputTable:t()) 40 | 41 | c_pzx, c_grad = ctc.getCTCCostAndGrad(splitedOutputTable, target) 42 | 43 | c_m = toMatrix(c_grad):float() 44 | 45 | 46 | eps = 1e-6 47 | 48 | ctc_lua = false 49 | 50 | est_grad = torch.Tensor(nrow) 51 | 52 | for i = 1, nrow do 53 | outputTable[1][i] = outputTable[1][i] + eps 54 | 55 | splitedOutputTable = nn.SplitTable(1):forward(outputTable:t()) 56 | loss1, _ = ctc.getCTCCostAndGrad(splitedOutputTable, target) 57 | 58 | outputTable[1][i] = outputTable[1][i] - 2 * eps 59 | splitedOutputTable = nn.SplitTable(1):forward(outputTable:t()) 60 | loss2, _ = ctc.getCTCCostAndGrad(splitedOutputTable, target) 61 | 62 | outputTable[1][i] = outputTable[1][i] + eps 63 | 64 | est_grad[i] = (loss1 - loss2) / (2 * eps) 65 | end 66 | 67 | print(est_grad) 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /pretrain.lua: -------------------------------------------------------------------------------- 1 | require 'slider' 2 | require 'loader' 3 | require 'image' 4 | RBM = require 'rbm' 5 | require 'cutorch' 6 | require 'cunn' 7 | 8 | opt = { 9 | input_size = 48, 10 | epoch = 30, 11 | 12 | -- sliding window 13 | window_size = 10, 14 | stride = 5, 15 | 16 | -- RBM hyperparameters 17 | hidden_size = 48 * 5, 18 | 19 | -- miscellaneous 20 | output_file = "wwr.rbm" 21 | } 22 | 23 | 24 | 25 | -- load samples 26 | 27 | loader = Loader() 28 | loader:load('wwr.txt') 29 | loader:targetHeight(opt.input_size) 30 | 31 | torch.setdefaulttensortype('torch.CudaTensor') 32 | 33 | -- setup RBM 34 | 35 | local n_visible = opt.input_size * opt.window_size 36 | 37 | local rbm = RBM.new{n_visible=n_visible, n_hidden=opt.hidden_size, CDsteps=1, momentum={0.5, 0.9}, 38 | momentumAfter={5}, v_activation='binary', h_activation='relu', 39 | learningRate=0.01} 40 | 41 | -- train 42 | 43 | for i = 1, opt.epoch do 44 | -- for each sample 45 | local im, p, total = loader:pickInSequential() 46 | local input 47 | while im do 48 | xlua.progress(p, total) 49 | im = im.img 50 | slider = Slider() 51 | slider:load(im:t()) 52 | 53 | -- for each window 54 | input = slider:genSequence() 55 | inputMatrix = nn.JoinTable(1):forward(input):reshape(slider.total, input[1]:size(1)):cuda() 56 | 57 | rbm:updateParameters(inputMatrix) 58 | 59 | im, p, total = loader:pickInSequential() 60 | end 61 | 62 | loader:reset() 63 | 64 | print(string.format("total progress %d / %d eps.", i, opt.epoch)) 65 | end 66 | 67 | -- save 68 | 69 | rbm_data = { 70 | n_visible = n_visible, 71 | n_hidden = opt.hidden_size, 72 | encoder = rbm.encoder:double(), 73 | decoder = rbm.decoder:double() 74 | } 75 | 76 | paths.mkdir("rbm") 77 | 78 | local output_path = "rbm/" .. opt.output_file 79 | torch.save(output_path, rbm_data) 80 | print("RBM network saved at " .. output_path) 81 | 82 | -- test 83 | 84 | mlp = nn.Sequential() 85 | mlp:add(rbm.encoder) 86 | mlp:add(rbm.decoder) 87 | mlp:cuda() 88 | 89 | loader:reset() 90 | local im = loader:pickInSequential().img:cuda() 91 | slider = Slider() 92 | slider:load(im:t()) 93 | 94 | local input = slider:slide() 95 | local output = mlp:forward(input:reshape(input:nElement())) 96 | 97 | input = input:double() 98 | output = output:double() 99 | 100 | torch.setdefaulttensortype('torch.DoubleTensor') 101 | 102 | image.save("1.png", input:reshape(opt.input_size, opt.window_size)) 103 | image.save("2.png", output:reshape(opt.input_size, opt.window_size)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # umaru 2 | An OCR-system based on torch using the technique of LSTM/GRU-RNN, CTC and referred to the works of rnnlib and clstm. 3 | 4 | ## Notice 5 | 6 | This work is now completely UNSTABLE, EXPERIMENTAL and UNDER DEVELOPMENT. 7 | 8 | ## Dependencies 9 | 10 | - [torch](https://github.com/torch/torch7) (and following packages) 11 | - image 12 | - nn/cunn 13 | - optim 14 | - [rnn](https://github.com/Element-Research/rnn) 15 | - [json](https://github.com/clementfarabet/lua---json) 16 | - [utf8](https://github.com/clementfarabet/lua-utf8) 17 | - [torchRBM](https://github.com/nhammerla/torchRBM) 18 | 19 | ## Build 20 | 21 | ```sh 22 | $ ./build.sh 23 | ``` 24 | 25 | ## Usage 26 | 27 | ### General 28 | 29 | - You could modify the settings in the `main.lua` directly and execute `th main.lua`, the input format is clstm-like (`.png` and `.gt.txt` pair) and you should put all input file path in a text file. 30 | - or if you prefer to use a JSON-format configuration file, you could follow the example below, and run: 31 | 32 | ```sh 33 | $ th main.lua -setting [setting file] 34 | ``` 35 | 36 | ### Run Folder 37 | 38 | There would be a folder created in the `experments` folder for every experiment. You could check out the log, settings and saved models there. 39 | 40 | ## Example Configuration File 41 | 42 | descriptions for each option could be found in `main.lua`. 43 | 44 | ```js 45 | { 46 | "project_name": "uy_rbm_noised", 47 | "raw_input": false, 48 | "hidden_size": 200, 49 | "nthread": 3, 50 | "clamp_size": 1, 51 | "ctc_lua": false, 52 | "recurrent_unit": "gru", 53 | "test_every": 2000, 54 | "omp_threads": 1, 55 | "show_every": 10, 56 | "testing_list_file": "wwr.txt", 57 | "input_size": 48, 58 | "testing_ratio": 1, 59 | "max_param_norm": false, 60 | "training_list_file": "full-train.txt", 61 | "feature_size": 240, 62 | "momentum": 0.9, 63 | "dropout_rate": 0.5, 64 | "max_iter": 10000000000, 65 | "save_every": 10000, 66 | "learning_rate": 0.0001, 67 | "stride": 5, 68 | "gpu": false, 69 | "rbm_network_file": "rbm/wwr.rbm", 70 | "windows_size": 10 71 | } 72 | ``` 73 | 74 | 75 | ## LICENSE 76 | 77 | BSD 3-Clause License 78 | 79 | ## References 80 | 81 | * [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](http://arxiv.org/abs/1412.3555) 82 | * [Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks](ftp://ftp.idsia.ch/pub/juergen/icml2006.pdf) 83 | * [RNNLIB: Connectionist Temporal Classification and Transcription Layer](http://wantee.github.io/blog/2015/02/08/rnnlib-connectionist-temporal-classification-and-transcription-layer/) 84 | * [rnnlib](http://sourceforge.net/p/rnnl/wiki/Home/) 85 | * [clstm](https://github.com/tmbdev/clstm) 86 | -------------------------------------------------------------------------------- /solve.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'rnn' 3 | require 'GRU' 4 | require 'image' 5 | require 'optim' 6 | 7 | require 'loader' 8 | require 'ctc_log' 9 | require 'utils.decoder' 10 | require 'utils.levenshtein' 11 | 12 | -- initialize 13 | torch.setdefaulttensortype('torch.FloatTensor') 14 | 15 | -- debug switch 16 | DEBUG = false 17 | 18 | -- timer initialize 19 | base = 0 20 | timer = torch.Timer() 21 | 22 | function show_log(log) 23 | local now = timer:time().real 24 | local cost = now - base 25 | base = now 26 | -- print(string.format("[%.4f][%.4f]%s", now, cost, log)) 27 | print(string.format("%s", log)) 28 | end 29 | 30 | -- settings 31 | 32 | GPU_ENABLED = false 33 | local input_size = 48 34 | 35 | -- configuration 36 | list_file = "wwr.txt" 37 | using_model_file = "umaru_model_15-09-10_21:51:30_30000.uma" 38 | using_codec = "full-train.codec" 39 | 40 | -- GPU 41 | 42 | if GPU_ENABLED then 43 | require 'cutorch' 44 | require 'cunn' 45 | end 46 | 47 | -- load samples 48 | 49 | show_log("Loading samples...") 50 | 51 | loader = Loader() 52 | loader:load(list_file) 53 | loader:targetHeight(input_size) 54 | codec = loader:codec() 55 | 56 | if using_codec then 57 | codec = loader:loadCodec(using_codec) 58 | end 59 | 60 | show_log(string.format("Loading finished. Got %d samples, %d classes of characters.", #loader.samples, codec.codec_size)) 61 | 62 | local class_num = codec.codec_size 63 | 64 | -- build network 65 | 66 | show_log("Building networks...") 67 | 68 | local net 69 | 70 | if using_model_file then 71 | net = torch.load(using_model_file) 72 | net:evaluate() 73 | else 74 | error("There must be a model file.") 75 | end 76 | 77 | if GPU_ENABLED then 78 | net:cuda() 79 | end 80 | 81 | 82 | show_log(string.format("Start solving with model file: %s", using_model_file)) 83 | 84 | -- solving 85 | 86 | local sample = loader:pickInSequential() 87 | 88 | begin_time = timer:time().real 89 | local dist, tmp_dist, out = 0, 0, 0 90 | local len, tmp_len = 0, 0 91 | 92 | while sample do 93 | local im = sample.img 94 | local target = codec:encode(sample.gt) 95 | 96 | net:forget() 97 | 98 | outputTable = net:forward(im) 99 | 100 | out = decoder.best_path_decode(outputTable, codec) 101 | 102 | tmp_dist = utf8.levenshtein(out, sample.gt) 103 | tmp_len = utf8.len(sample.gt) 104 | dist = dist + tmp_dist 105 | len = len + tmp_len 106 | 107 | print("") 108 | show_log("FILE " .. sample.src) 109 | show_log("TARGET " .. sample.gt) 110 | show_log("OUTPUT " .. out) 111 | show_log("DISTANCE " .. tmp_dist) 112 | show_log("ERROR " .. string.format("%.2f%%", dist / len * 100)) 113 | 114 | sample = loader:pickInSequential() 115 | end 116 | 117 | 118 | -------------------------------------------------------------------------------- /contrast.txt: -------------------------------------------------------------------------------- 1 | output 2 | DIMENSIONS: 7 3 | 0.0684907 0.0683173 0.0682402 0.0682124 0.0682041 0.068242 0.0682717 4 | 0.067452 0.0671584 0.066958 0.0667768 0.0665642 0.0662235 0.0656839 5 | 0.0722717 0.0725465 0.0726959 0.0727905 0.072852 0.0729304 0.0730594 6 | 0.0719666 0.0716504 0.0714893 0.0713959 0.0713595 0.0713568 0.0714108 7 | 0.0744622 0.0744583 0.0744827 0.0745185 0.0745582 0.0746348 0.0747697 8 | 0.0660699 0.0658052 0.0656008 0.0654654 0.0653909 0.0653646 0.0653229 9 | 0.0736831 0.0741657 0.0743352 0.0743431 0.0742691 0.0740992 0.0738383 10 | 0.0771784 0.0769233 0.0768525 0.0768478 0.0768212 0.0767496 0.076588 11 | 0.0770828 0.0770411 0.0770657 0.0770728 0.0770434 0.0769496 0.0767167 12 | 0.0658117 0.0656928 0.0656175 0.0656082 0.0656804 0.0658425 0.0661498 13 | 0.0691583 0.0690236 0.0690129 0.069102 0.0692818 0.0695909 0.0701721 14 | 0.0700088 0.0704332 0.0706438 0.0707642 0.0708752 0.0710007 0.0711254 15 | 0.0727488 0.0726681 0.0726766 0.0727124 0.0727471 0.0728149 0.0730135 16 | 0.0736152 0.0741163 0.074329 0.07439 0.0743531 0.0742006 0.0738778 17 | target 18 | 3 2 12 0 9 6 19 | DIMENSIONS: 7 20 | -2.6089 -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 21 | -2.63155 -4.56297 -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 22 | -1e+100 -5.23367 -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 23 | -1e+100 -5.25508 -6.48538 -1e+100 -1e+100 -1e+100 -1e+100 24 | -1e+100 -1e+100 -7.85434 -1e+100 -1e+100 -1e+100 -1e+100 25 | -1e+100 -1e+100 -7.87682 -8.6991 -1e+100 -1e+100 -1e+100 26 | -1e+100 -1e+100 -1e+100 -10.4753 -1e+100 -1e+100 -1e+100 27 | -1e+100 -1e+100 -1e+100 -10.5619 -11.1033 -1e+100 -1e+100 28 | -1e+100 -1e+100 -1e+100 -1e+100 -13.1609 -1e+100 -1e+100 29 | -1e+100 -1e+100 -1e+100 -1e+100 -13.2849 -13.6082 -1e+100 30 | -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 -15.8859 -1e+100 31 | -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 -15.8873 -16.0277 32 | -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 -18.4926 33 | DIMENSIONS: 7 34 | -15.8903 -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 35 | -13.3955 -13.2543 -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 36 | -1e+100 -13.2543 -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 37 | -1e+100 -10.9558 -10.6329 -1e+100 -1e+100 -1e+100 -1e+100 38 | -1e+100 -1e+100 -10.6329 -1e+100 -1e+100 -1e+100 -1e+100 39 | -1e+100 -1e+100 -8.56104 -8.01162 -1e+100 -1e+100 -1e+100 40 | -1e+100 -1e+100 -1e+100 -8.01162 -1e+100 -1e+100 -1e+100 41 | -1e+100 -1e+100 -1e+100 -6.17004 -5.32637 -1e+100 -1e+100 42 | -1e+100 -1e+100 -1e+100 -1e+100 -5.32637 -1e+100 -1e+100 43 | -1e+100 -1e+100 -1e+100 -1e+100 -3.8497 -2.60588 -1e+100 44 | -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 -2.60588 -1e+100 45 | -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 -1.91246 0 46 | -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 -1e+100 0 47 | DIMENSIONS: 7 48 | -0 -0 -0 -6.68062 -9.0399 -0 -0 49 | -0 -0 -0 -0 -0 -0 -0 50 | -0 -10.5774 -4.26011 -0 -0 -0 -0 51 | -12.8138 -2.14838 -0 -0 -0 -0 -0 52 | -0 -0 -0 -0 -0 -0 -0 53 | -0 -0 -0 -0 -0 -0 -0 54 | -0 -0 -0 -0 -0 -2.11426 -12.482 55 | -0 -0 -0 -0 -0 -0 -0 56 | -0 -0 -0 -0 -0 -0 -0 57 | -0 -0 -0 -0 -4.63865 -11.6174 -0 58 | -0 -0 -0 -0 -0 -0 -0 59 | -0 -0 -0 -0 -0 -0 -0 60 | -0 -0 -8.4144 -6.40188 -0 -0 -0 61 | -1.05736 -1.06203 -1.05985 -1.05932 -1.05945 -1.05685 -1.06058 62 | 63 | -------------------------------------------------------------------------------- /ctc.lua: -------------------------------------------------------------------------------- 1 | require 'utils/logs' 2 | 3 | ctc = {} 4 | 5 | --[[ 6 | getOnehotMatrix 7 | 8 | target - a vector of number of class 9 | 10 | return a L * C onehot Matrix, C is the number of kinds of classes. 11 | ]] 12 | 13 | function ctc.getOnehotMatrix(target, class_num) 14 | onehot = torch.zeros((#target)[1], class_num) 15 | for i = 1, (#target)[1] do 16 | c = target[i] 17 | if c > 0 then 18 | onehot[i][c] = 1 19 | else 20 | onehot[i][class_num] = 1 21 | end 22 | end 23 | return onehot 24 | end 25 | 26 | --[[ 27 | getFilledTarget 28 | 29 | target - a unicode string of ground truth 30 | 31 | return a 2L + 1 vector of number of class. 32 | ]] 33 | 34 | function ctc.getFilledTarget(target) 35 | local filled = torch.zeros(#target * 2 + 1) 36 | for i = 1, (#filled)[1] do 37 | if i % 2 == 0 then 38 | filled[i] = string.sub(target, i / 2, i / 2) 39 | end 40 | end 41 | return filled 42 | end 43 | 44 | function ctc.toMatrix(outputTable) 45 | local net = nn.Sequential() 46 | net:add(nn.JoinTable(1)) 47 | net:add(nn.Reshape(#outputTable, 11)) 48 | return net:forward(outputTable) 49 | end 50 | 51 | --[[ 52 | getForwardVariable 53 | 54 | calculate ForwardVariable for any (t, u) 55 | 56 | - outputTable: a T * (2C + 1) matrix 57 | - alignedTable: a T * L matrix 58 | - target: a (2L + 1) * (2C + 1) matrix 59 | ]]-- 60 | function ctc.getForwardVariable(outputTable, alignedTable, target) 61 | local T = (#outputTable)[1] 62 | -- create a T * (2L + 1) Matrix 63 | 64 | local L = (#target)[1] 65 | local fvs = torch.zeros(T, L) 66 | 67 | -- calculate using dynamic programming 68 | 69 | -- initialize 70 | 71 | fvs[1][1] = alignedTable[1][1] 72 | fvs[1][2] = alignedTable[1][2] 73 | 74 | local upper_bound = 2 75 | 76 | -- calculate 77 | for i = 2, T do 78 | upper_bound = upper_bound + 2 79 | if upper_bound > L then 80 | upper_bound = L 81 | end 82 | for u = 1, upper_bound do 83 | 84 | -- if l'[u] is not blank 85 | 86 | if u % 2 == 1 then 87 | fvs[i][u] = fvs[i][u] + fvs[i - 1][u] 88 | if u > 1 then fvs[i][u] = fvs[i][u] + fvs[i - 1][u - 1] end 89 | fvs[i][u] = fvs[i][u] * alignedTable[i][u] 90 | else 91 | if u > 2 and target[u - 2] ~= target[u] then fvs[i][u] = fvs[i][u] + fvs[i - 1][u - 2] end 92 | if u > 1 then fvs[i][u] = fvs[i][u] + fvs[i - 1][u - 1] end 93 | fvs[i][u] = fvs[i][u] + fvs[i - 1][u] 94 | fvs[i][u] = fvs[i][u] * alignedTable[i][u] 95 | end 96 | end 97 | 98 | 99 | end 100 | 101 | return fvs 102 | end 103 | 104 | function ctc.getBackwardVariable(outputTable, alignedTable, target) 105 | local T = (#outputTable)[1] 106 | -- create a T * (2L + 1) Matrix 107 | 108 | local L = (#target)[1] 109 | local bvs = torch.zeros(T, L) 110 | 111 | -- initialize 112 | 113 | bvs[T][L] = 1 114 | bvs[T][L - 1] = 1 115 | 116 | -- calculate using dynamic programming 117 | 118 | for i = T - 1, 1, -1 do 119 | for u = L, 1, -1 do 120 | if i % 2 == 1 then 121 | bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u] * bvs[i + 1][u] 122 | if u < L then bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u + 1] * bvs[i + 1][u + 1] end 123 | else 124 | bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u] * bvs[i + 1][u] 125 | if u < L then bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u + 1] * bvs[i + 1][u + 1] end 126 | if u < L - 1 and target[u + 2] ~= target[u] then 127 | bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u + 2] * bvs[i + 1][u + 2] 128 | end 129 | end 130 | end 131 | end 132 | 133 | return bvs 134 | end 135 | 136 | function ctc.getCTCCost(outputTable, target) 137 | -- convert target to one-hot Matrix (class + 1 * len(target)) 138 | local class_num = (#(outputTable[1]))[1] 139 | 140 | target = ctc.getFilledTarget(target) 141 | target = ctc.getOnehotMatrix(target, class_num) 142 | 143 | outputTable = ctc.toMatrix(outputTable) 144 | 145 | -- get aligned_table 146 | -- outputTable: Tx(cls+1) 147 | -- target: L'x(cls+1) --> targetT : (cls+1)xL' 148 | -- alienged_table = TxL' 149 | local alignedTable = outputTable * target:t() 150 | 151 | fvs = ctc.getForwardVariable(outputTable, alignedTable, target) 152 | 153 | -- calculate backwardVariable 154 | 155 | bvs = ctc.getBackwardVariable(outputTable, alignedTable, target) 156 | 157 | print(bvs) 158 | end 159 | -------------------------------------------------------------------------------- /ctc_lua.lua: -------------------------------------------------------------------------------- 1 | require 'utils/logs' 2 | 3 | ctc = {} 4 | 5 | --[[ 6 | getOnehotMatrix 7 | 8 | target - a vector of number of class 9 | 10 | return a L * C onehot Matrix, C is the number of kinds of classes. 11 | ]] 12 | 13 | function ctc.getOnehotMatrix(target, class_num) 14 | onehot = torch.zeros((#target)[1], class_num) 15 | for i = 1, (#target)[1] do 16 | c = target[i] 17 | if c > 0 then 18 | onehot[i][c] = 1 19 | else 20 | onehot[i][class_num] = 1 21 | end 22 | end 23 | return onehot 24 | end 25 | 26 | --[[ 27 | getFilledTarget 28 | 29 | target - a unicode string of ground truth 30 | 31 | return a 2L + 1 vector of number of class. 32 | ]] 33 | 34 | function ctc.getFilledTarget(target) 35 | local filled = torch.zeros(#target * 2 + 1) 36 | for i = 1, (#filled)[1] do 37 | if i % 2 == 0 then 38 | filled[i] = string.sub(target, i / 2, i / 2) 39 | end 40 | end 41 | return filled 42 | end 43 | 44 | function ctc.toMatrix(outputTable) 45 | local net = nn.Sequential() 46 | net:add(nn.JoinTable(1)) 47 | net:add(nn.Reshape(#outputTable, 11)) 48 | return net:forward(outputTable) 49 | end 50 | 51 | --[[ 52 | getForwardVariable 53 | 54 | calculate ForwardVariable for any (t, u) 55 | 56 | - outputTable: a T * (2C + 1) matrix 57 | - alignedTable: a T * L matrix 58 | - target: a (2L + 1) * (2C + 1) matrix 59 | ]]-- 60 | function ctc.getForwardVariable(outputTable, alignedTable, target) 61 | local T = (#outputTable)[1] 62 | -- create a T * (2L + 1) Matrix 63 | 64 | local L = (#target)[1] 65 | local fvs = torch.zeros(T, L) 66 | 67 | -- calculate using dynamic programming 68 | 69 | -- initialize 70 | 71 | fvs[1][1] = alignedTable[1][1] 72 | fvs[1][2] = alignedTable[1][2] 73 | 74 | local upper_bound = 2 75 | 76 | -- calculate 77 | for i = 2, T do 78 | upper_bound = upper_bound + 2 79 | if upper_bound > L then 80 | upper_bound = L 81 | end 82 | for u = 1, upper_bound do 83 | 84 | -- if l'[u] is not blank 85 | 86 | if u % 2 == 1 then 87 | fvs[i][u] = fvs[i][u] + fvs[i - 1][u] 88 | if u > 1 then fvs[i][u] = fvs[i][u] + fvs[i - 1][u - 1] end 89 | fvs[i][u] = fvs[i][u] * alignedTable[i][u] 90 | else 91 | if u > 2 and target[u - 2] ~= target[u] then fvs[i][u] = fvs[i][u] + fvs[i - 1][u - 2] end 92 | if u > 1 then fvs[i][u] = fvs[i][u] + fvs[i - 1][u - 1] end 93 | fvs[i][u] = fvs[i][u] + fvs[i - 1][u] 94 | fvs[i][u] = fvs[i][u] * alignedTable[i][u] 95 | end 96 | end 97 | 98 | 99 | end 100 | 101 | return fvs 102 | end 103 | 104 | function ctc.getBackwardVariable(outputTable, alignedTable, target) 105 | local T = (#outputTable)[1] 106 | -- create a T * (2L + 1) Matrix 107 | 108 | local L = (#target)[1] 109 | local bvs = torch.zeros(T, L) 110 | 111 | -- initialize 112 | 113 | bvs[T][L] = 1 114 | bvs[T][L - 1] = 1 115 | 116 | -- calculate using dynamic programming 117 | 118 | for i = T - 1, 1, -1 do 119 | for u = L, 1, -1 do 120 | if i % 2 == 1 then 121 | bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u] * bvs[i + 1][u] 122 | if u < L then bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u + 1] * bvs[i + 1][u + 1] end 123 | else 124 | bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u] * bvs[i + 1][u] 125 | if u < L then bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u + 1] * bvs[i + 1][u + 1] end 126 | if u < L - 1 and target[u + 2] ~= target[u] then 127 | bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u + 2] * bvs[i + 1][u + 2] 128 | end 129 | end 130 | end 131 | end 132 | 133 | return bvs 134 | end 135 | 136 | function ctc.getCTCCost(outputTable, target) 137 | -- convert target to one-hot Matrix (class + 1 * len(target)) 138 | local class_num = (#(outputTable[1]))[1] 139 | 140 | target = ctc.getFilledTarget(target) 141 | target = ctc.getOnehotMatrix(target, class_num) 142 | 143 | outputTable = ctc.toMatrix(outputTable) 144 | 145 | -- get aligned_table 146 | -- outputTable: Tx(cls+1) 147 | -- target: L'x(cls+1) --> targetT : (cls+1)xL' 148 | -- alienged_table = TxL' 149 | local alignedTable = outputTable * target:t() 150 | 151 | fvs = ctc.getForwardVariable(outputTable, alignedTable, target) 152 | 153 | -- calculate backwardVariable 154 | 155 | bvs = ctc.getBackwardVariable(outputTable, alignedTable, target) 156 | 157 | print(bvs) 158 | end 159 | -------------------------------------------------------------------------------- /test_thread.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'rnn' 3 | require 'image' 4 | require 'optim' 5 | require 'lfs' 6 | require 'json' 7 | 8 | require 'loader' 9 | require 'ctc_log' 10 | require 'utils/decoder' 11 | 12 | local threads = require 'threads' 13 | 14 | -- timer initialize 15 | base = 0 16 | timer = torch.Timer() 17 | 18 | -- initialize 19 | torch.setdefaulttensortype('torch.FloatTensor') 20 | torch.manualSeed(os.time()) 21 | 22 | -- debug switch 23 | DEBUG = false 24 | 25 | 26 | 27 | function show_log(log) 28 | local now = timer:time().real 29 | local cost = now - base 30 | base = now 31 | print(string.format("[%.4f][%.4f]%s", now, cost, log)) 32 | end 33 | 34 | -- settings 35 | 36 | opt = { 37 | -- project 38 | project_name = os.date("%y-%m-%d_") .. torch.random() % 10000, 39 | 40 | -- hyperparameters 41 | input_size = 48, 42 | hidden_size = 200, 43 | clamp_size = 1, 44 | learning_rate = 1e-4, 45 | momentum = 0.9, 46 | dropout_rate = 0.5, 47 | 48 | -- configurations 49 | gpu = false, 50 | 51 | -- threading 52 | nthread = 3, 53 | 54 | -- samples 55 | training_list_file = "wwr.txt", 56 | codec_file = "", 57 | 58 | -- miscellaneous 59 | max_iter = 1e10 60 | } 61 | 62 | cmd = torch.CmdLine() 63 | 64 | 65 | 66 | show_log("======== UMARU ========") 67 | show_log("project: " .. opt.project_name) 68 | show_log("") 69 | 70 | -- preparation for model saving and logging 71 | 72 | lfs.mkdir("models") 73 | 74 | local project_dir = "models/" .. opt.project_name .. "/" 75 | 76 | lfs.mkdir("models/" .. opt.project_name) 77 | 78 | json.save(project_dir .. "settings.json", opt) 79 | 80 | -- GPU 81 | 82 | if opt.gpu then 83 | require 'cutorch' 84 | require 'cunn' 85 | end 86 | 87 | -- load samples 88 | 89 | show_log("Loading samples...") 90 | 91 | local loader = Loader() 92 | loader:targetHeight(opt.input_size) 93 | loader:load(opt.training_list_file) 94 | local codec = loader:codec() 95 | 96 | show_log(string.format("Loading finished. Got %d samples, %d classes of characters.", #loader.samples, codec.codec_size)) 97 | show_log(string.format("lr = %f, opt.momentum = %.4f clamp = %.2f", opt.learning_rate, opt.momentum, opt.clamp_size)) 98 | show_log(string.format("using %d threads.", opt.nthread)) 99 | 100 | local class_num = codec.codec_size 101 | 102 | -- build network 103 | 104 | show_log("Building networks...") 105 | 106 | local net 107 | 108 | net = nn.Sequential() 109 | 110 | net:add(nn.Dropout(opt.dropout_rate)) 111 | net:add(nn.SplitTable(1)) 112 | net:add(nn.BiSequencer(nn.FastLSTM(opt.input_size, opt.hidden_size))) 113 | 114 | output = nn.Sequential() 115 | output:add(nn.Dropout(opt.dropout_rate)) 116 | output:add(nn.Linear(opt.hidden_size * 2, class_num + 1)) 117 | output:add(nn.SoftMax()) 118 | net:add(nn.Sequencer(output)) 119 | net:float() 120 | 121 | -- prepare prarmeters and training method 122 | 123 | local params, grad_params 124 | 125 | params, grad_params = net:getParameters() 126 | 127 | 128 | state = { 129 | learningRate = opt.learning_rate, 130 | momentum = opt.momentum 131 | } 132 | 133 | threads.serialization('threads.sharedserialize') 134 | 135 | local pool = threads(opt.nthread, 136 | function(id) 137 | require 'nn' 138 | require 'rnn' 139 | require 'ctc_log' 140 | end, 141 | 142 | function() 143 | torch.setdefaulttensortype('torch.FloatTensor') 144 | local n = net:clone() 145 | local p, gp = n:getParameters() 146 | 147 | n:zeroGradParameters() 148 | torch.manualSeed(450) 149 | 150 | local loss, grad 151 | 152 | function eval(id, ps, im, target) 153 | -- n:zeroGradParameters() 154 | p:copy(ps) 155 | n:forget() 156 | -- print(p:sum()) 157 | outputTable = n:forward(im) 158 | loss, grad = ctc.getCTCCostAndGrad(outputTable, target) 159 | n:backward(im, grad) 160 | 161 | -- print("loss " .. loss) 162 | return outputTable, loss, gp 163 | end 164 | end 165 | ) 166 | 167 | -- training 168 | 169 | begin_time = 0 170 | 171 | state = { 172 | learningRate = opt.learning_rate, 173 | momentum = opt.momentum 174 | } 175 | 176 | 177 | 178 | for i = 1, opt.max_iter do 179 | local totalerr = 0 180 | local totalgrad = nil 181 | 182 | local feval = function(params) 183 | grad_params:zero() 184 | 185 | for j = 1, opt.nthread do 186 | local sample = loader:pick() 187 | local im = sample.img 188 | local target = codec:encode(sample.gt) 189 | pool:addjob( 190 | function(idx) 191 | local im = im 192 | local target = target 193 | local ps = params 194 | 195 | return eval(idx, ps, im, target) 196 | end, 197 | 198 | function(out, loss, gp) 199 | totalerr = totalerr + loss 200 | 201 | if i % 10 == 0 and j == 1 then 202 | show_log("LOSS " .. loss) 203 | print(sample.gt) 204 | print(decoder.best_path_decode(out, codec)) 205 | print(string.format("%.2f sec/ep.", timer:time().real / (i * opt.nthread))) 206 | end 207 | 208 | gp:cmul(gp:eq(gp):float()) 209 | gp:clamp(-opt.clamp_size, opt.clamp_size) 210 | grad_params = grad_params + gp 211 | 212 | end, 213 | idx 214 | ) 215 | end 216 | 217 | pool:synchronize() 218 | 219 | return totalerr, grad_params 220 | end 221 | 222 | optim.sgd(feval, params, state) 223 | 224 | 225 | end 226 | 227 | pool:terminate() 228 | 229 | 230 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | For umaru 2 | 3 | Copyright (c) 2015, Zhu Jiadong 4 | 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 12 | 13 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 16 | 17 | For Torch7 18 | 19 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 20 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 21 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 22 | Copyright (c) 2011-2013 NYU (Clement Farabet) 23 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 24 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 25 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 26 | 27 | All rights reserved. 28 | 29 | Redistribution and use in source and binary forms, with or without 30 | modification, are permitted provided that the following conditions are met: 31 | 32 | 1. Redistributions of source code must retain the above copyright 33 | notice, this list of conditions and the following disclaimer. 34 | 35 | 2. Redistributions in binary form must reproduce the above copyright 36 | notice, this list of conditions and the following disclaimer in the 37 | documentation and/or other materials provided with the distribution. 38 | 39 | 3. Neither the names of Deepmind Technologies, NYU, NEC Laboratories America 40 | and IDIAP Research Institute nor the names of its contributors may be 41 | used to endorse or promote products derived from this software without 42 | specific prior written permission. 43 | 44 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 45 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 46 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 47 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 48 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 49 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 50 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 51 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 52 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 53 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 54 | POSSIBILITY OF SUCH DAMAGE. 55 | 56 | For json4Lua 57 | 58 | The MIT License 59 | 60 | Copyright (c) 2009 Craig Mason-Jones 61 | 62 | Permission is hereby granted, free of charge, to any person obtaining a copy 63 | of this software and associated documentation files (the "Software"), to deal 64 | in the Software without restriction, including without limitation the rights 65 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 66 | copies of the Software, and to permit persons to whom the Software is 67 | furnished to do so, subject to the following conditions: 68 | 69 | The above copyright notice and this permission notice shall be included in 70 | all copies or substantial portions of the Software. 71 | 72 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 73 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 74 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 75 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 76 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 77 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 78 | THE SOFTWARE. 79 | 80 | For lua-utf8 81 | 82 | Copyright (c) Clement Farabet 83 | 84 | All rights reserved. 85 | 86 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 87 | 88 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 89 | 90 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 91 | 92 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 93 | 94 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /loader.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | require 'codec' 3 | require 'normalizer' 4 | local lfs = require 'lfs' 5 | local utf8 = require 'utf8' 6 | 7 | Loader = { 8 | samples = {}, 9 | training = {}, 10 | testing = {}, 11 | weights = nil, 12 | p = nil, 13 | codec_table = {}, 14 | codec_inv = {}, 15 | codec_size = 0, 16 | codec_obj = nil, 17 | threshold = 3, 18 | lambda = 3.0, 19 | pos = 1, 20 | target_height = 32 21 | } 22 | 23 | setmetatable(Loader, { 24 | __call = 25 | function (cls, ...) 26 | return cls:new(...) 27 | end 28 | }) 29 | 30 | function Loader:new(o) 31 | o = o or {} 32 | setmetatable(o, self) 33 | self.__index = self 34 | return o 35 | end 36 | 37 | function Loader:shuffle() 38 | for i = 1, #self.samples do 39 | local j = torch.random(#self.samples) 40 | self.samples[i], self.samples[j]= self.samples[j], self.samples[i] 41 | end 42 | end 43 | 44 | function Loader:__split(rate) 45 | assert(rate <= 1 and rate > 0, "", "invalid rate") 46 | ntrain = math.floor(#self.samples * rate) 47 | ntest = #self.samples - ntrain 48 | 49 | for i = 1, #self.samples do 50 | if i <= ntrain then 51 | table.insert(self.training, self.samples[i]) 52 | else 53 | table.insert(self.testing, self.samples[i]) 54 | end 55 | end 56 | end 57 | 58 | function Loader:targetHeight(target_height) 59 | self.target_height = target_height or self.target_height 60 | return targetHeight 61 | end 62 | 63 | function Loader:__getNormalizedImage(src) 64 | local defaultTensorType = torch.getdefaulttensortype() 65 | torch.setdefaulttensortype('torch.DoubleTensor') 66 | local im = image.load(src, 1) 67 | 68 | if im:dim() == 3 then 69 | im = im[1] 70 | end 71 | 72 | local output = torch.DoubleTensor() 73 | 74 | local w = im:size()[2] 75 | local h = im:size()[1] 76 | 77 | local ones = torch.ones(h, w) 78 | 79 | im = ones - im 80 | normalizer.normalize(im:double(), output, self.target_height) 81 | -- image.save("normalized.png", output:float()) 82 | 83 | --local target_width = self.target_height / h * w 84 | 85 | --output = image.scale(im, target_width, self.target_height) 86 | 87 | -- image.save("scaled.png", output) 88 | torch.setdefaulttensortype(defaultTensorType) 89 | return output 90 | end 91 | 92 | function Loader:load(file, rate) 93 | self.samples = {} 94 | local f = assert(io.open(file, "r")) 95 | for line in f:lines() do 96 | local src = line 97 | 98 | if lfs.attributes(src, "size") < 200 then 99 | print("found invalid sample " .. src) 100 | goto continue 101 | end 102 | 103 | 104 | local gt = src:gsub("[.].*", ".gt.txt") 105 | local cf = io.open(gt, "r") 106 | 107 | if cf == nil then 108 | print("ground truth not found " .. gt) 109 | goto continue 110 | end 111 | 112 | local gt = cf:read("*line") 113 | cf:close() 114 | 115 | for _, c, _ in utf8.iter(gt) do 116 | if self.codec_table[c] == nil then 117 | self.codec_size = self.codec_size + 1 118 | self.codec_table[c] = self.codec_size 119 | end 120 | 121 | end 122 | 123 | table.insert(self.samples, {src = src, gt = gt, img = nil}) 124 | 125 | ::continue:: 126 | end 127 | f:close() 128 | 129 | for k, v in pairs(self.codec_table) do 130 | self.codec_inv[v] = k 131 | end 132 | 133 | self.codec_obj = nil 134 | self.weights = nil 135 | 136 | rate = rate or 1 137 | self:__split(rate) 138 | 139 | -- return self.samples 140 | end 141 | 142 | function Loader:loadTesting(file) 143 | local f = assert(io.open(file, "r")) 144 | for line in f:lines() do 145 | local src = line 146 | 147 | if lfs.attributes(src, "size") < 200 then 148 | print("found invalid sample " .. src) 149 | goto continue 150 | end 151 | 152 | local gt = src:gsub("[.].*", ".gt.txt") 153 | local cf = io.open(gt, "r") 154 | 155 | if cf == nil then 156 | print("found invalid sample " .. src) 157 | goto continue 158 | end 159 | 160 | local gt = cf:read("*line") 161 | cf:close() 162 | 163 | for _, c, _ in utf8.iter(gt) do 164 | if self.codec_table[c] == nil then 165 | print("there is a character that shows in testing set but not in training set.") 166 | end 167 | end 168 | 169 | local sample = {src = src, gt = gt, img = nil} 170 | 171 | table.insert(self.samples, sample) 172 | table.insert(self.testing, sample) 173 | 174 | ::continue:: 175 | end 176 | f:close() 177 | end 178 | 179 | function Loader:__pick(index, from) 180 | from = from or "training" 181 | 182 | if self[from][index].img == nil then 183 | 184 | t = self[from][index].src:sub(-3, -1) 185 | 186 | if (t == "png") then 187 | self[from][index].img = self:__getNormalizedImage(self[from][index].src):t() 188 | elseif (t == ".ft") then 189 | self[from][index].img = torch.load(self[from][index].src):t() 190 | end 191 | 192 | if false then 193 | self[from][index].img = self[from][index].img:cuda() 194 | end 195 | end 196 | 197 | return self[from][index] 198 | end 199 | 200 | function Loader:pick() 201 | from = from or "training" 202 | assert(self[from], "invalid set name.") 203 | 204 | local index = torch.random(#self[from]) 205 | 206 | return self:__pick(index) 207 | end 208 | 209 | function Loader:pickWithWeight() 210 | 211 | if self.weights == nil then 212 | self.weights = torch.zeros(#self.training) 213 | for i, v in ipairs(self.samples) do 214 | self.weights[i] = math.pow(1.0 / math.max(utf8.len(v.gt), self.threshold), self.lambda) 215 | end 216 | self.weights = torch.div(self.weights, torch.sum(self.weights)) 217 | 218 | self.p = torch.zeros(#self.training) 219 | local i = 0 220 | self.p:apply(function() 221 | i = i + 1 222 | return torch.normal(1.0 / self.weights[i], 1.0 / self.weights[i] / 3.0) 223 | end) 224 | end 225 | local _, index = torch.min(self.p, 1) 226 | index = index[1] 227 | self.p[index] = torch.normal(1.0 / self.weights[index], 1.0 / self.weights[index] / 3.0) + 1 228 | 229 | return self:__pick(index) 230 | end 231 | 232 | function Loader:reset() 233 | self.pos = 1 234 | end 235 | 236 | function Loader:pickInSequential(from) 237 | from = from or "samples" 238 | if self.pos <= #self[from] then 239 | self.pos = self.pos + 1 240 | return self:__pick(self.pos - 1, from), self.pos - 1, #self[from] 241 | else 242 | return nil 243 | end 244 | end 245 | 246 | function Loader:updateWeight(lambda) 247 | self.lambda = lambda 248 | self.weights = nil 249 | end 250 | 251 | function Loader:codec() 252 | self.codec_obj = self.codec_obj or Codec:new{ 253 | codec = self.codec_table, 254 | codec_inv = self.codec_inv, 255 | codec_size = self.codec_size 256 | } 257 | 258 | return self.codec_obj 259 | end 260 | 261 | function Loader:loadCodec(codec_file) 262 | self.codec_obj = Codec(torch.load(codec_file)) 263 | 264 | return self.codec_obj 265 | end -------------------------------------------------------------------------------- /normalizer.cc: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | #include 3 | #include 4 | } 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | const double SIGMA_RATE_VERT = 0.5; 11 | const double SIGMA_RATE_HORZ = 1.0; 12 | const double SIGMA_RATE_CENTER = 0.3; 13 | const double RANGE_RATE = 4.0; 14 | const long TARGET_HEIGHT = 32; 15 | 16 | template 17 | static void gauss1dWithMask(double * out, T * in, long size, double * mask, long mask_size) { 18 | // apply it 19 | long range = (mask_size - 1) / 2; 20 | int n = size; 21 | for (int i = 0; i < n; i++) { 22 | double total = 0.0; 23 | for (int j = 0; j < mask_size; j++) { 24 | int index = i+j-range; 25 | if (index < 0) 26 | index = 0; 27 | if (index >= n) 28 | index = n-1; 29 | total += in[index] * mask[j]; 30 | } 31 | out[i] = double(total); 32 | } 33 | } 34 | 35 | static void create1DMask(double * & mask, double sigma, long & mask_size) { 36 | int range = 1 + int(3.0*sigma); 37 | mask_size = 2 * range + 1; 38 | mask = new double[mask_size]; 39 | for (int i = 0; i <= range; i++) { 40 | double sd = sigma * sigma; 41 | double y = exp(-i*i/2.0/sd); 42 | mask[range+i] = mask[range-i] = y; 43 | } 44 | double total = 0.0; 45 | for (int i = 0; i < mask_size; i++) 46 | total += mask[i]; 47 | for (int i = 0; i < mask_size; i++) { 48 | mask[i] /= total; 49 | } 50 | 51 | 52 | } 53 | 54 | template 55 | static void gauss1d(double * out, T * in, long size, double sigma) { 56 | double * mask = NULL; 57 | long ms; 58 | create1DMask(mask, sigma, ms); 59 | gauss1dWithMask(out, in, size, mask, ms); 60 | delete [] mask; 61 | } 62 | 63 | static void getDim1(double * in, double * out, long w, long h, long index) { 64 | for (int i = 0; i < h; i++) { 65 | out[i] = in[i * w + index]; 66 | } 67 | } 68 | 69 | static void setDim1(double * in, double * out, long w, long h, long index) { 70 | for (int i = 0; i < h; i++) { 71 | out[i * w + index] = in[i]; 72 | } 73 | } 74 | 75 | static void gauss2d(double * src, long w, long h, double sigmaX, double sigmaY) { 76 | double tmp[h]; 77 | double in_tmp[h]; 78 | double * maskX = NULL, * maskY = NULL; 79 | long msX, msY; 80 | 81 | create1DMask(maskY, sigmaY, msY); 82 | create1DMask(maskX, sigmaX, msX); 83 | 84 | 85 | for (int i = 0; i < w; i++) { 86 | getDim1(src, in_tmp, w, h, i); 87 | gauss1dWithMask(tmp, in_tmp, h, maskY, msY); 88 | setDim1(tmp, src, w, h, i); 89 | } 90 | 91 | double tmp2[w]; 92 | for (int i = 0; i < h; i++) { 93 | memcpy(tmp2, src + w * i, w * sizeof(double)); 94 | gauss1dWithMask(src + w * i, tmp2, w, maskX, msX); 95 | } 96 | 97 | delete [] maskX; 98 | delete [] maskY; 99 | } 100 | 101 | 102 | 103 | static double bilinear(double * in, int w, int h, double x, double y) { 104 | int xi = int(x), yi = int(y), xt = xi + 1, yt = yi + 1; 105 | double xf = x - xi, yf = y - yi; 106 | 107 | // printf("(%d, %d)\n", xi, yi); 108 | 109 | if (xi > w - 1 || yi > h - 1 || x < 0 || y < 0) { 110 | return 0; 111 | } 112 | 113 | xi = xi < 0 ? 0 : xi; 114 | yi = yi < 0 ? 0 : yi; 115 | 116 | 117 | xi = xi > w - 1 ? w - 1 : xi; 118 | yi = yi > h - 1 ? h - 1 : yi; 119 | 120 | 121 | xt = xt > w - 1 ? w - 1 : xt; 122 | yt = yt > h - 1 ? h - 1 : yt; 123 | 124 | 125 | 126 | double p00 = in[yi * w + xi]; 127 | double p01 = in[yt * w + xi]; 128 | double p10 = in[yi * w + xt]; 129 | double p11 = in[yt * w + xt]; 130 | 131 | double result = p00 * (1.0 - xf) * (1.0 - yf) + p10 * xf * (1.0 - yf) + p01 * (1.0 - xf) * yf + p11 * xf * yf; 132 | if (result < 0) { 133 | printf("warning result < 0. %.4lf, %.4lf, %.4f, %.4f\n" \ 134 | "%.4f, %.4f, %.4f, %.4f\n", x, y, xf, yf, p00, p01, p10, p11); 135 | } 136 | 137 | return result; 138 | } 139 | 140 | static void measure(THDoubleTensor * src, double * & center, double & mean, int & r) { 141 | long h = src->size[0]; 142 | long w = src->size[1]; 143 | double sigmaX = h * SIGMA_RATE_HORZ; 144 | double sigmaY = h * SIGMA_RATE_VERT; 145 | double * dataSrc = THDoubleTensor_data(src); 146 | THDoubleTensor * smooth = THDoubleTensor_newClone(src); 147 | double * dataSmooth = THDoubleTensor_data(smooth); 148 | gauss2d(dataSmooth, w, h, sigmaX, sigmaY); 149 | 150 | THDoubleTensor * minVT = THDoubleTensor_new(); 151 | THLongTensor * minT = THLongTensor_new(); 152 | 153 | THDoubleTensor_max(minVT, minT, smooth, 0); 154 | 155 | long * min = THLongTensor_data(minT); 156 | 157 | center = new double[w]; 158 | 159 | gauss1d(center, min, w, h * SIGMA_RATE_CENTER); 160 | 161 | double s1 = 0.0, sy = 0.0; 162 | 163 | for (int i = 0; i < h; i++) { 164 | for (int j = 0; j < w; j++) { 165 | s1 += dataSrc[i * w + j]; 166 | sy += dataSrc[i * w + j] * fabs(i - center[j]); 167 | } 168 | } 169 | 170 | 171 | mean = sy / s1; 172 | r = int(mean * RANGE_RATE + 1); 173 | 174 | THDoubleTensor_free(minVT); 175 | THLongTensor_free(minT); 176 | THDoubleTensor_free(smooth); 177 | 178 | /* printf("mean = %lf r = %d\n", mean, r); */ 179 | } 180 | 181 | static void normalize 182 | (THDoubleTensor * src, THDoubleTensor * out, double * center, double mean, int r, int target_height) { 183 | long h = src->size[0]; 184 | long w = src->size[1]; 185 | float scale = (2.0 * r) / target_height; 186 | int target_width = fmax(int(w / scale), 1); 187 | 188 | double * inData = THDoubleTensor_data(src); 189 | 190 | THDoubleTensor_resize2d(out, target_height, target_width); 191 | 192 | // printf("scale = %.4f\n", scale); 193 | 194 | double * outData = THDoubleTensor_data(out); 195 | 196 | 197 | for (int i = 0; i < target_height; i++) { 198 | for (int j = 0; j < target_width; j++) { 199 | float x = scale * j; 200 | float y = scale * (i - target_height / 2) + center[int(x)]; 201 | // printf(" = %d\n", (i - target_height / 2)); 202 | outData[i * target_width + j] = bilinear(inData, w, h, x, y); 203 | } 204 | } 205 | 206 | } 207 | 208 | static int normalizer_gauss1d(lua_State * L) 209 | { 210 | THDoubleTensor * input = (THDoubleTensor *)luaT_checkudata(L, 1, "torch.DoubleTensor"); 211 | double sigma = luaL_checknumber(L, 2); 212 | int size = input->size[0]; 213 | 214 | printf("sigma = %.4lf\n", sigma); 215 | 216 | double * data = THDoubleTensor_data(input); 217 | 218 | THDoubleTensor * outputT = THDoubleTensor_newClone(input); 219 | 220 | double * output = THDoubleTensor_data(outputT); 221 | 222 | gauss1d(output, data, size, sigma); 223 | 224 | return 0; 225 | } 226 | 227 | static int normalizer_gauss2d(lua_State * L) 228 | { 229 | THDoubleTensor * input = (THDoubleTensor *)luaT_checkudata(L, 1, "torch.DoubleTensor"); 230 | double sigmaX = luaL_checknumber(L, 2); 231 | double sigmaY = luaL_checknumber(L, 3); 232 | long h = input->size[0]; 233 | long w = input->size[1]; 234 | 235 | double * data = THDoubleTensor_data(input); 236 | 237 | printf("w = %ld, h = %ld\n", w, h); 238 | printf("sigmaX = %.4lf sigmaY = %.4lf\n", sigmaX, sigmaY); 239 | 240 | gauss2d(data, w, h, sigmaX, sigmaY); 241 | 242 | return 0; 243 | } 244 | 245 | static int normalizer_normalize(lua_State * L) 246 | { 247 | THDoubleTensor * input = 248 | (THDoubleTensor *)luaT_checkudata(L, 1, "torch.DoubleTensor"); 249 | THDoubleTensor * output = 250 | (THDoubleTensor *)luaT_checkudata(L, 2, "torch.DoubleTensor"); 251 | 252 | int target_height = TARGET_HEIGHT; 253 | 254 | if(lua_isnumber(L, 3)) { 255 | target_height = luaL_checkint(L, 3); 256 | } 257 | 258 | double * center = NULL; 259 | double mean = 0.0; 260 | int r = 0; 261 | 262 | measure(input, center, mean, r); 263 | normalize(input, output, center, mean, r, target_height); 264 | 265 | 266 | delete [] center; 267 | return 0; 268 | } 269 | 270 | static const struct luaL_reg normalizer[] = { 271 | {"gauss1d", normalizer_gauss1d}, 272 | {"gauss2d", normalizer_gauss2d}, 273 | {"normalize", normalizer_normalize}, 274 | {NULL, NULL} 275 | }; 276 | 277 | LUA_EXTERNC int luaopen_normalizer(lua_State *L) { 278 | luaL_openlib(L, "normalizer", normalizer, 0); 279 | return 1; 280 | } -------------------------------------------------------------------------------- /GRU.lua: -------------------------------------------------------------------------------- 1 | local GRU, parent 2 | 3 | GRU, parent = torch.class('nn.GRU', 'nn.AbstractRecurrent') 4 | 5 | function GRU:__init(inputSize, outputSize, rho) 6 | parent.__init(self, rho or 9999) 7 | self.inputSize = inputSize 8 | self.outputSize = outputSize 9 | 10 | self.recurrentModule = self:buildModel() 11 | 12 | self.modules[1] = self.recurrentModule 13 | self.sharedClones[1] = self.recurrentModule 14 | 15 | self.zeroTensor = torch.Tensor() 16 | 17 | self.cells = {} 18 | self.gradCells = {} 19 | end 20 | 21 | function GRU:buildGate() 22 | local gate = nn.Sequential() 23 | local i2g = nn.Linear(self.inputSize, self.outputSize) 24 | local o2g = nn.Linear(self.outputSize, self.outputSize) 25 | local para = nn.ParallelTable() 26 | para:add(i2g):add(o2g) 27 | gate:add(para) 28 | gate:add(nn.CAddTable()) 29 | gate:add(nn.Sigmoid()) 30 | 31 | return gate 32 | end 33 | 34 | function GRU:buildResetGate() 35 | self.resetGate = (self.resetGate == nil and self:buildGate() or self.resetGate) 36 | return self.resetGate 37 | end 38 | 39 | function GRU:buildUpdateGate() 40 | self.updateGate = (self.updateGate == nil and self:buildGate() or self.updateGate) 41 | return self.updateGate 42 | end 43 | 44 | 45 | -- outputCandidate = tanh(W * x + U(r . h[t - 1]))) 46 | function GRU:buildOutputCandidate() 47 | local hiddenCandidate = nn.Sequential() 48 | local left = nn.Sequential() 49 | -- select x 50 | left:add(nn.SelectTable(1)) 51 | left:add(nn.Linear(self.inputSize, self.outputSize)) 52 | local right = nn.Sequential() 53 | -- select (r, y[t - 1]) 54 | right:add(nn.NarrowTable(2, 2)) 55 | right:add(nn.CMulTable()) 56 | right:add(nn.Linear(self.outputSize, self.outputSize)) 57 | local para = nn.ConcatTable() 58 | para:add(left):add(right) 59 | 60 | hiddenCandidate:add(para) 61 | hiddenCandidate:add(nn.CAddTable()) 62 | hiddenCandidate:add(nn.Tanh()) 63 | 64 | return hiddenCandidate 65 | end 66 | 67 | -- input {input, output[t - 1]} 68 | 69 | function GRU:buildModel() 70 | self.resetGate = self:buildResetGate() 71 | self.updateGate = self:buildUpdateGate() 72 | self.outputCandidate = self:buildOutputCandidate() 73 | 74 | local cell = nn.Sequential() 75 | 76 | local concat = nn.ConcatTable() 77 | concat:add(nn.Identity()):add(self.resetGate):add(self.updateGate) 78 | 79 | 80 | cell:add(concat) 81 | cell:add(nn.FlattenTable()) 82 | 83 | 84 | local seq1 = nn.Sequential() 85 | 86 | 87 | -- select output[t - 1] 88 | seq1:add(nn.SelectTable(2)) 89 | 90 | local seq2 = nn.Sequential() 91 | seq2:add(nn.SelectTable(4)) 92 | seq2:add(nn.MulConstant(-1, false)) 93 | seq2:add(nn.AddConstant(1, false)) 94 | 95 | local seq3 = nn.Sequential() 96 | seq3:add(nn.NarrowTable(1, 3)) 97 | seq3:add(self.outputCandidate) 98 | 99 | local concat2 = nn.ConcatTable() 100 | -- input: {x, h[t - 1], r, z} 101 | -- output: h[t - 1] (1 - z) z ~h 102 | concat2:add(seq1) 103 | concat2:add(seq2) 104 | concat2:add(nn.SelectTable(4)) 105 | concat2:add(seq3) 106 | 107 | cell:add(concat2) 108 | 109 | 110 | -- cell:add(nn.FlattenTable()) 111 | 112 | 113 | local seq4 = nn.Sequential() 114 | seq4:add(nn.NarrowTable(1, 2)) 115 | seq4:add(nn.CMulTable()) 116 | 117 | local seq5 = nn.Sequential() 118 | seq5:add(nn.NarrowTable(3, 2)) 119 | seq5:add(nn.CMulTable()) 120 | 121 | -- input: {(1 - z) h[t - 1] z ~h} 122 | -- output: {(1 - z) * h[t - i], z * ~h} 123 | 124 | 125 | local concat3 = nn.ConcatTable() 126 | concat3:add(seq4):add(seq5) 127 | 128 | cell:add(concat3) 129 | 130 | 131 | cell:add(nn.CAddTable()) 132 | 133 | 134 | return cell 135 | end 136 | 137 | function GRU:updateOutput(input) 138 | local prevOutput, prevCell 139 | if self.step == 1 then 140 | prevOutput = self.zeroTensor 141 | 142 | assert(input:dim() == 1, "only support input with dimension == 1") 143 | 144 | self.zeroTensor:resize(self.outputSize):zero() 145 | else 146 | prevOutput = self.output 147 | end 148 | 149 | local output 150 | if self.train ~= false then 151 | self:recycle() 152 | local recurrentModule = self:getStepModule(self.step) 153 | 154 | -- print{input, prevOutput} 155 | 156 | output = recurrentModule:updateOutput{input, prevOutput} 157 | else 158 | output = self.recurrentModule:updateOutput{input, prevOutput} 159 | end 160 | 161 | if self.train ~= false then 162 | local input_ = self.inputs[self.step] 163 | self.inputs[self.step] = self.copyInputs 164 | and nn.rnn.recursiveCopy(input_, input) 165 | or nn.rnn.recursiveSet(input_, input) 166 | end 167 | 168 | self.outputs[self.step] = output 169 | 170 | self.output = output 171 | 172 | self.step = self.step + 1 173 | self.gradParametersAccumulated = false 174 | 175 | return self.output 176 | end 177 | 178 | function GRU:backwardThroughTime() 179 | assert(self.step > 1, "expecting at least one updateOutput") 180 | self.gradInputs = {} 181 | local rho = math.min(self.rho, self.step - 1) 182 | local stop = self.step - rho 183 | if self.fastBackward then 184 | local gradPrevOutput 185 | for step = self.step - 1, math.max(stop, 1), -1 do 186 | local recurrentModule = self:getStepModule(step) 187 | 188 | local gradOutput= self.gradOutputs[step] 189 | if gradPrevOutput then 190 | self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], gradPrevOutput) 191 | nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput) 192 | gradOutput = self._gradOutputs[step] 193 | end 194 | 195 | local scale = self.scales[step] 196 | local output = (step == 1) and self.zeroTensor or self.outputs[step - 1] 197 | 198 | local inputTable = {self.inputs[step], output} 199 | local gradInputTable = recurrentModule:backward(inputTable, gradOutput, scale) 200 | 201 | local gradInput, gradPrevOutput = unpack(gradInputTable) 202 | 203 | table.insert(self.gradInputs, 1, gradInput) 204 | end 205 | return gradInput 206 | else 207 | local gradInput = self:updateGradInputThroughTime() 208 | self:accGradParametersThroughTime() 209 | return gradInput 210 | end 211 | end 212 | 213 | function GRU:updateGradInputThroughTime() 214 | assert(self.step > 1, "expecting at least one updateOutput") 215 | self.gradInputs = {} 216 | local gradInput, gradPrevOutput 217 | local rho = math.min(self.rho, self.step - 1) 218 | local stop = self.step - rho 219 | 220 | for step = self.step - 1, math.max(stop, 1), -1 do 221 | local recurrentModule = self:getStepModule(step) 222 | 223 | local gradOutput = self.gradOutputs[step] 224 | if gradPrevOutput then 225 | self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], gradPrevOutput) 226 | nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput) 227 | gradOutput = self._gradOutputs[step] 228 | end 229 | 230 | local output = (step == 1) and self.zeroTensor or self.outputs[step - 1] 231 | local inputTable = {self.inputs[step], output} 232 | local gradInputTable = recurrentModule:updateGradInput(inputTable, gradOutput) 233 | 234 | gradInput, gradPrevOutput = unpack(gradInputTable) 235 | 236 | table.insert(self.gradInputs, 1, gradInput) 237 | end 238 | 239 | return gradInput 240 | end 241 | 242 | function GRU:accGradParametersThroughTime() 243 | local rho = math.min(self.rho, self.step - 1) 244 | local stop = self.step - rho 245 | for step = self.step - 1, math.max(stop, 1), -1 do 246 | local recurrentModule = self:getStepModule(step) 247 | 248 | local scale = self.scales[step] 249 | local output = (step == 1) and self.zeroTensor or self.outputs[step - 1] 250 | local inputTable = {self.inputs[step], output} 251 | local gradOutput = (step == self.step - 1) and self.gradOutputs[step] or self._gradOutputs[step] 252 | 253 | 254 | 255 | recurrentModule:accGradParameters(inputTable, gradOutput, scale) 256 | end 257 | 258 | self.gradParametersAccumulated = true 259 | return gradInput 260 | end 261 | 262 | function GRU:accUpdateGradParametersThroughTime(lr) 263 | local rho = math.min(self.rho, self.step - 1) 264 | local stop = self.step - rho 265 | 266 | for step = self.step - 1, math.max(stop, 1), -1 do 267 | local recurrentModule = self:getStepModule(step) 268 | 269 | local scale = self.scales[step] 270 | local output = (step == 1) and self.zeroTensor or self.outputs[step - 1] 271 | local inputTable = {self.inputs[step], output} 272 | local gradOutput = (step == self.step - 1) and self.gradOutputs[step] or self._gradOutputs[step] 273 | local gradOutputTable = {self.gradOutputs[step]} 274 | 275 | recurrentModule:accUpdateGradParameters(inputTable, gradOutput, lr * scale) 276 | end 277 | 278 | return gradInput 279 | end -------------------------------------------------------------------------------- /ctc_log.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'utils/logs' 3 | require 'libctc' 4 | 5 | ctc = {} 6 | 7 | local base = 0 8 | local timer = nil; 9 | 10 | 11 | 12 | function print_timestamp(msg) 13 | if not DEBUG then 14 | return 15 | end 16 | 17 | print(msg .. " " .. timer:time().real - base) 18 | base = timer:time().real 19 | end 20 | 21 | --[[ 22 | getOnehotMatrix 23 | 24 | target - a vector of number of class 25 | 26 | return a (2L + 1) * C onehot Matrix, C is the number of kinds of classes. 27 | ]] 28 | 29 | function ctc.__getOnehotMatrix(target, class_num) 30 | onehot = torch.zeros((#target)[1], class_num) 31 | for i = 1, (#target)[1] do 32 | c = target[i] 33 | if c > 0 then 34 | onehot[i][c] = 1 35 | else 36 | onehot[i][class_num] = 1 37 | end 38 | end 39 | return onehot 40 | end 41 | 42 | --[[ 43 | getFilledTarget 44 | 45 | target - a unicode string of ground truth 46 | 47 | return a 2L + 1 vector of number of class. 48 | ]] 49 | 50 | function ctc.__getFilledTargetFromString(target) 51 | local filled = torch.zeros(#target * 2 + 1) 52 | for i = 1, (#filled)[1] do 53 | if i % 2 == 0 then 54 | filled[i] = string.sub(target, i / 2, i / 2) + 1 55 | end 56 | end 57 | return filled 58 | end 59 | 60 | function ctc.__getFilledTarget(target) 61 | local filled = torch.zeros(#target * 2 + 1) 62 | for i = 1, (#filled)[1] do 63 | if i % 2 == 0 then 64 | filled[i] = target[i / 2] 65 | end 66 | end 67 | return filled 68 | end 69 | 70 | function ctc.__toMatrix(outputTable, class_num) 71 | local net = nn.Sequential() 72 | net:add(nn.JoinTable(1)) 73 | net:add(nn.Reshape(#outputTable, class_num)) 74 | return net:forward(outputTable) 75 | end 76 | 77 | --[[ 78 | getForwardVariable 79 | 80 | calculate ForwardVariable for any (t, u) 81 | 82 | - outputTable: a T * (C + 1) matrix 83 | - alignedTable: a T * L matrix 84 | - target: a (2L + 1) * (C + 1) matrix 85 | ]]-- 86 | function ctc.__getForwardVariable(outputTable, alignedTable, target) 87 | local T = (#outputTable)[1] 88 | -- create a T * (2L + 1) Matrix 89 | 90 | local L = (#target)[1] 91 | local fvs = torch.ones(T, L) * logs.LOG_ZERO 92 | 93 | -- calculate using dynamic programming 94 | 95 | -- initialize 96 | 97 | fvs[1][1] = alignedTable[1][1] 98 | fvs[1][2] = alignedTable[1][2] 99 | 100 | local lower_bound = 0 101 | local upper_bound = 2 102 | 103 | -- calculate 104 | for i = 2, T do 105 | upper_bound = upper_bound + 2 106 | if upper_bound > L then 107 | upper_bound = L 108 | end 109 | lower_bound = L - 2 * (T - i) - 1 110 | if lower_bound < 1 then 111 | lower_bound = 1 112 | end 113 | for u = lower_bound, upper_bound do 114 | -- if l'[u] is blank 115 | 116 | if u % 2 == 1 then 117 | fvs[i][u] = logs.log_add(fvs[i][u], fvs[i - 1][u]) 118 | if u > 1 then fvs[i][u] = logs.log_add(fvs[i][u], fvs[i - 1][u - 1]) end 119 | fvs[i][u] = logs.log_mul(fvs[i][u], alignedTable[i][u]) 120 | else 121 | if u > 2 and target[u - 2] ~= target[u] then fvs[i][u] = logs.log_add(fvs[i][u], fvs[i - 1][u - 2]) end 122 | if u > 1 then fvs[i][u] = logs.log_add(fvs[i][u], fvs[i - 1][u - 1]) end 123 | fvs[i][u] = logs.log_add(fvs[i][u], fvs[i - 1][u]) 124 | fvs[i][u] = logs.log_mul(fvs[i][u], alignedTable[i][u]) 125 | end 126 | end 127 | 128 | end 129 | 130 | return fvs 131 | end 132 | 133 | function ctc.__getBackwardVariable(outputTable, alignedTable, target) 134 | local T = (#outputTable)[1] 135 | -- create a T * (2L + 1) Matrix 136 | 137 | local L = (#target)[1] 138 | local bvs = torch.ones(T, L) * logs.LOG_ZERO 139 | 140 | -- initialize 141 | 142 | bvs[T][L] = 0 143 | bvs[T][L - 1] = 0 144 | 145 | -- calculate using dynamic programming 146 | 147 | local upper_bound = L - 2 148 | local lower_bound 149 | 150 | for i = T - 1, 1, -1 do 151 | upper_bound = upper_bound - 2 152 | if upper_bound < 1 then 153 | upper_bound = 1 154 | end 155 | 156 | lower_bound = 2 * i 157 | if lower_bound > L - 1 then 158 | lower_bound = L - 1 159 | end 160 | 161 | -- print(lower_bound, upper_bound) 162 | 163 | for u = lower_bound, upper_bound, -1 do 164 | 165 | if u % 2 == 1 then 166 | bvs[i][u] = logs.log_mul(alignedTable[i + 1][u], bvs[i + 1][u]) 167 | bvs[i][u] = logs.log_add(bvs[i][u], logs.log_mul(alignedTable[i + 1][u + 1], bvs[i + 1][u + 1])) 168 | else 169 | bvs[i][u] = logs.log_mul(alignedTable[i + 1][u], bvs[i + 1][u]) 170 | bvs[i][u] = logs.log_add(bvs[i][u], logs.log_mul(alignedTable[i + 1][u + 1], bvs[i + 1][u + 1])) 171 | if u < L - 1 and target[u + 2] ~= target[u] then 172 | bvs[i][u] = logs.log_add(bvs[i][u], logs.log_mul(alignedTable[i + 1][u + 2], bvs[i + 1][u + 2])) 173 | end 174 | end 175 | end 176 | end 177 | 178 | return bvs 179 | end 180 | 181 | -- calculate cost matrix (Tx(cls+1)) 182 | 183 | function ctc.__getGrad(fb, pzx, class_num, outputTable, target) 184 | local T = (#fb)[1] 185 | local grad = torch.zeros(T, class_num) 186 | local temp_sum = 0 187 | local u = 0 188 | 189 | 190 | for t = 1, T do 191 | for k = 1, class_num do 192 | temp_sum = logs.LOG_ZERO 193 | grad[t][k] = logs.log_mul(-pzx, -outputTable[t][k]) 194 | u = k 195 | 196 | 197 | 198 | -- if current label is blank 199 | if u == class_num then u = 0 end 200 | for i = 1, (#target)[1] do 201 | if target[i] == u then 202 | -- print(fb[t][i]) 203 | temp_sum = logs.log_add(temp_sum, fb[t][i]) 204 | end 205 | end 206 | 207 | grad[t][k] = logs.log_mul(grad[t][k], temp_sum) 208 | 209 | grad[t][k] = -logs.safe_exp(grad[t][k]) 210 | end 211 | end 212 | return grad 213 | end 214 | 215 | function ctc.__getCost(fb, target) 216 | local cost = 0.0 217 | 218 | for i = 1, (#target)[1] do 219 | cost = logs.log_add(fb) 220 | end 221 | end 222 | 223 | function ctc.getCTCCostAndGrad(outputTable, target, gpu) 224 | if DEBUG then 225 | print("debug") 226 | timer = torch.Timer() 227 | base = 0; 228 | end 229 | 230 | -- convert target to one-hot Matrix (class + 1 * len(target)) 231 | local class_num = (#(outputTable[1]))[1] 232 | local T = #outputTable 233 | 234 | print_timestamp(" CTC begin") 235 | 236 | 237 | 238 | targetClasses = ctc.__getFilledTarget(target) 239 | 240 | 241 | 242 | targetMatrix = ctc.__getOnehotMatrix(targetClasses, class_num) 243 | 244 | outputTable = ctc.__toMatrix(outputTable, class_num) 245 | 246 | -- print(outputTable) 247 | 248 | 249 | if torch.type(outputTable) ~= "torch.DoubleTensor" then 250 | outputTable = outputTable:double() 251 | end 252 | 253 | orig = outputTable:clone() 254 | 255 | -- outputTable = outputTable:cmax(1e-4) 256 | -- local total = outputTable:sum(2):expand(outputTable:size()[1], outputTable:size()[2]) 257 | 258 | -- print(total) 259 | 260 | 261 | -- outputTable = torch.cdiv(outputTable, total) 262 | 263 | -- print(torch.dist(orig, outputTable)) 264 | 265 | 266 | print_timestamp(" perpare") 267 | 268 | 269 | 270 | outputTable:apply(function (x) 271 | x = logs.safe_log(x) 272 | return x 273 | end) 274 | 275 | print_timestamp(" log") 276 | 277 | 278 | 279 | -- get aligned_table 280 | -- outputTable: Tx(cls+1) 281 | -- target: L'x(cls+1) --> targetT : (cls+1)xL' 282 | -- alienged_table = TxL' 283 | 284 | 285 | local alignedTable = outputTable * targetMatrix:double():t() 286 | 287 | 288 | 289 | 290 | -- calculate forwardVariable (in log space) 291 | 292 | 293 | 294 | local fvs, bvs, fb, grad 295 | 296 | if ctc_lua then 297 | fvs = ctc.__getForwardVariable(outputTable, alignedTable, targetClasses) 298 | else 299 | fvs = libctc.get_forward_variable(outputTable, alignedTable, targetClasses) 300 | end 301 | 302 | 303 | 304 | 305 | local L_1 = (#targetClasses)[1] 306 | 307 | -- calculate log(p(z|x)) 308 | local pzx = logs.log_add(fvs[T][L_1], fvs[T][L_1-1]) 309 | 310 | 311 | -- calculate backwardVariable (in log space) 312 | if ctc_lua then 313 | bvs = ctc.__getBackwardVariable(outputTable, alignedTable, targetClasses) 314 | else 315 | bvs = libctc.get_backward_variable(outputTable, alignedTable, targetClasses) 316 | end 317 | 318 | -- print(torch.dist(bvs, bvs1)) 319 | 320 | print_timestamp(" fw bw") 321 | 322 | fb = fvs + bvs 323 | 324 | -- calculate gradient matrix (Tx(cls+1)) 325 | if ctc_lua then 326 | grad = ctc.__getGrad(fb, pzx, class_num, outputTable, targetClasses) 327 | else 328 | grad = libctc.get_grad(fb, outputTable, targetClasses, pzx) 329 | end 330 | 331 | print_timestamp(" get grad") 332 | 333 | --[[ 334 | print("=========FVS=========") 335 | print(fvs:t()) 336 | print("=========BVS=========") 337 | print(bvs:t()) 338 | print("=========GRAD=========") 339 | print(grad) 340 | ]] 341 | 342 | if gpu then 343 | grad = grad:cuda() 344 | end 345 | 346 | 347 | grad = nn.SplitTable(1):forward(grad) 348 | 349 | 350 | return -pzx, grad 351 | 352 | end 353 | 354 | function ctc.test() 355 | return gpu 356 | end 357 | -------------------------------------------------------------------------------- /libctc.cc: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | #include 3 | #include 4 | #include 5 | } 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #define ENABLE_OPENMP 12 | 13 | static const double EXP_MAX = 1e100; 14 | static const double EXP_MIN = 1e-100; 15 | static const double LOG_ZERO = -1e100; 16 | static const double LOG_INF = 1e100; 17 | static const double EXP_LIMIT = log(EXP_MAX); 18 | 19 | 20 | static double safe_log(double x) { 21 | if (x == 0) { 22 | return LOG_ZERO; 23 | } 24 | else if (x > 0) { 25 | return log(x); 26 | } 27 | else { 28 | perror("Error: passing a negative number to the log function."); 29 | return LOG_ZERO; 30 | } 31 | } 32 | 33 | static double safe_exp(double x) { 34 | if (x == LOG_ZERO) { 35 | return 0; 36 | } 37 | if (x >= EXP_LIMIT) { 38 | return EXP_MAX; 39 | } 40 | return exp(x); 41 | } 42 | 43 | static double log_add(double x, double y) { 44 | if (fabs(x - y) > 10) { 45 | fmax(x, y); 46 | } 47 | 48 | if (x < y) { 49 | return y + log(1.0 + safe_exp(x - y)); 50 | } 51 | return x + log(1.0 + safe_exp(y - x)); 52 | } 53 | 54 | static double log_sub(double x, double y) { 55 | if (y == LOG_ZERO) { 56 | return x; 57 | } 58 | if (y >= x) { 59 | return LOG_ZERO; 60 | } 61 | return x + log(1.0 - safe_exp(y - x)); 62 | } 63 | 64 | static double log_mul(double x, double y) { 65 | if (y == LOG_ZERO or x == LOG_ZERO) { 66 | return LOG_ZERO; 67 | } 68 | 69 | return x + y; 70 | } 71 | 72 | static THDoubleTensor * __get_forward_variable(THDoubleTensor * outputTable, THDoubleTensor * alignedTable, THDoubleTensor * targetT) { 73 | int T = outputTable->size[0]; 74 | int L = targetT->size[0]; 75 | 76 | double * aligned = THDoubleTensor_data(alignedTable); 77 | double * target = THDoubleTensor_data(targetT); 78 | 79 | 80 | THDoubleTensor * fvsT = THDoubleTensor_newWithSize2d(T, L); 81 | THDoubleStorage_fill(fvsT->storage, LOG_ZERO); 82 | double * fvs = THDoubleTensor_data(fvsT); 83 | 84 | fvs[0] = aligned[0]; 85 | fvs[1] = aligned[1]; 86 | 87 | int lower_bound = -1, upper_bound = 2; 88 | 89 | double fvs_tmp, fvs_i1u, fvs_i1u1, fvs_i1u2; 90 | 91 | for(int i = 1; i < T; i++) { 92 | // adjust bounds, some positions would never been visited 93 | 94 | upper_bound += 2; 95 | if (upper_bound > L) { 96 | upper_bound = L; 97 | } 98 | 99 | lower_bound = L - 2 * (T - i); 100 | if (lower_bound < 0) { 101 | lower_bound = 0; 102 | } 103 | 104 | assert(lower_bound >= 0 && lower_bound < T * L); 105 | assert(upper_bound >= 0 && upper_bound < T * L); 106 | 107 | for (int u = lower_bound; u < upper_bound; u++) { 108 | double tmp = LOG_ZERO; 109 | 110 | fvs_i1u = fvs[(i - 1) * L + u]; 111 | fvs_i1u1 = (u > 0) ? fvs[(i - 1) * L + u - 1] : LOG_ZERO; 112 | fvs_i1u2 = (u > 1 && target[u - 2] != target[u]) ? fvs[(i - 1) * L + u - 2] : LOG_ZERO; 113 | 114 | if (u % 2) { 115 | tmp = log_add(tmp, fvs_i1u); 116 | tmp = log_add(tmp, fvs_i1u1); 117 | tmp = log_add(tmp, fvs_i1u2); 118 | } 119 | else { 120 | tmp = log_add(tmp, fvs_i1u); 121 | tmp = log_add(tmp, fvs_i1u1); 122 | } 123 | fvs[i * L + u] = log_mul(tmp, aligned[i * L + u]); 124 | } 125 | 126 | } 127 | return fvsT; 128 | } 129 | 130 | static THDoubleTensor * __get_backward_variable(THDoubleTensor * outputTable, THDoubleTensor * alignedTable, THDoubleTensor * targetT) { 131 | int T = outputTable->size[0]; 132 | int L = targetT->size[0]; 133 | 134 | double * aligned = THDoubleTensor_data(alignedTable); 135 | double * target = THDoubleTensor_data(targetT); 136 | 137 | THDoubleTensor * bvsT = THDoubleTensor_newWithSize2d(T, L); 138 | THDoubleStorage_fill(bvsT->storage, LOG_ZERO); 139 | double * bvs = THDoubleTensor_data(bvsT); 140 | 141 | assert(T * L >= 2); 142 | 143 | bvs[T * L - 1] = 0; 144 | bvs[T * L - 2] = 0; 145 | 146 | int lower_bound = -1, upper_bound = L - 3; 147 | 148 | double bvs_tmp, bvs_i1u, bvs_i1u1, bvs_i1u2; 149 | 150 | 151 | for(int i = T - 2; i >= 0; i--) { 152 | // adjust bounds, some positions would never been visited 153 | 154 | upper_bound -= 2; 155 | if (upper_bound < 0) { 156 | upper_bound = 0; 157 | } 158 | 159 | lower_bound = 2 * i + 1; 160 | if (lower_bound > L - 1) { 161 | lower_bound = L - 1; 162 | } 163 | 164 | if (lower_bound < 0) { 165 | lower_bound = 0; 166 | } 167 | 168 | if (upper_bound > L - 2) { 169 | upper_bound = L - 2; 170 | } 171 | 172 | assert(upper_bound >= 0 && upper_bound < L); 173 | assert(lower_bound >= 0 && lower_bound < L); 174 | 175 | // printf("%d %d\n", upper_bound, lower_bound); 176 | 177 | for (int u = lower_bound; u >= upper_bound; u--) { 178 | 179 | double tmp = LOG_ZERO; 180 | 181 | assert((i * L + u < T * L) && (i * L + u) >= 0); 182 | assert(((i + 1) * L + u) >= 0 && ((i + 1) * L + u) < T * L); 183 | assert((u >= L - 1) || ((i + 1) * L + u + 1 >= 0 && ((i + 1) * L + u + 1 < T * L))); 184 | assert(!(u < L - 2 && target[u + 2] != target[u]) || ((i + 1) * L + u + 2) >= 0 && ((i + 1) * L + u + 2) < T * L); 185 | 186 | bvs_i1u = bvs[(i + 1) * L + u]; 187 | bvs_i1u1 = (u < L - 1) ? bvs[(i + 1) * L + u + 1] : LOG_ZERO; 188 | bvs_i1u2 = (u < L - 2 && target[u + 2] != target[u]) ? bvs[(i + 1) * L + u + 2] : LOG_ZERO; 189 | 190 | tmp = log_mul(aligned[(i + 1) * L + u], bvs_i1u); 191 | 192 | 193 | if (u < L - 1) { 194 | assert(((i + 1) * L + u + 1 >= 0) && ((i + 1) * L + u + 1 < T * L)); 195 | tmp = log_add(tmp, log_mul(aligned[(i + 1) * L + u + 1], bvs_i1u1)); 196 | } 197 | 198 | if (u % 2 && u < L - 2) { 199 | assert(((i + 1) * L + u + 2) >= 0 && ((i + 1) * L + u + 2) < T * L); 200 | tmp = log_add(tmp, log_mul(aligned[(i + 1) * L + u + 2], bvs_i1u2)); 201 | } 202 | 203 | bvs[i * L + u] = tmp; 204 | 205 | if ((u < L - 1) && (i + 1) * L + u + 1 >= T * L) { 206 | perror("out of range\n"); 207 | } 208 | } 209 | 210 | } 211 | return bvsT; 212 | } 213 | 214 | 215 | static THDoubleTensor * __get_grad(THDoubleTensor * fbT, THDoubleTensor * outputTable, THDoubleTensor * targetT, double pzx) { 216 | 217 | int T = fbT->size[0]; 218 | int L = targetT->size[0]; 219 | int class_num = outputTable->size[1]; 220 | 221 | int pos; 222 | 223 | THDoubleTensor * gradT = THDoubleTensor_newWithSize2d(T, class_num); 224 | double * fb = THDoubleTensor_data(fbT); 225 | double * output = THDoubleTensor_data(outputTable); 226 | double * grad = THDoubleTensor_data(gradT); 227 | double * target = THDoubleTensor_data(targetT); 228 | 229 | double tmp_sum = 0, u = 0, tmp = 0; 230 | 231 | int t; 232 | 233 | #ifdef ENABLE_OPENMP 234 | #pragma omp parallel for private(tmp_sum, u, tmp, pos) lastprivate(t) 235 | #endif 236 | for (t = 0; t < T; t++) { 237 | // printf("%d\n", t); 238 | 239 | for (int k = 0; k < class_num; k++) { 240 | pos = t * class_num + k; 241 | 242 | assert(pos >=0 && pos < class_num * T); 243 | 244 | tmp_sum = LOG_ZERO; 245 | tmp = log_mul(-pzx, -output[pos]); 246 | u = k + 1; 247 | 248 | if (u == class_num) { 249 | u = 0; 250 | } 251 | 252 | for (int i = 0; i < L; i++) { 253 | if (target[i] == u) { 254 | // printf("%.4f\n", fb[t * L + i]); 255 | 256 | assert((t * L + i) >=0 && (t * L + i) < T * L); 257 | 258 | tmp_sum = log_add(fb[t * L + i], tmp_sum); 259 | } 260 | } 261 | 262 | 263 | tmp = log_mul(tmp, tmp_sum); 264 | 265 | grad[pos] = -safe_exp(tmp); 266 | } 267 | } 268 | #pragma omp barrier 269 | 270 | return gradT; 271 | } 272 | 273 | static int ctc_get_forward_variable(lua_State * L) { 274 | THDoubleTensor * output = (THDoubleTensor *)luaT_checkudata(L, 1, "torch.DoubleTensor"); 275 | THDoubleTensor * alignedTable = (THDoubleTensor *)luaT_checkudata(L, 2, "torch.DoubleTensor"); 276 | THDoubleTensor * target = (THDoubleTensor *)luaT_checkudata(L, 3, "torch.DoubleTensor"); 277 | 278 | 279 | 280 | THDoubleTensor * fvs = __get_forward_variable(output, \ 281 | alignedTable, target); 282 | 283 | 284 | luaT_pushudata(L, fvs, "torch.DoubleTensor"); 285 | 286 | return 1; 287 | } 288 | 289 | static int ctc_get_backward_variable(lua_State * L) { 290 | THDoubleTensor * output = (THDoubleTensor *)luaT_checkudata(L, 1, "torch.DoubleTensor"); 291 | THDoubleTensor * alignedTable = (THDoubleTensor *)luaT_checkudata(L, 2, "torch.DoubleTensor"); 292 | THDoubleTensor * target = (THDoubleTensor *)luaT_checkudata(L, 3, "torch.DoubleTensor"); 293 | 294 | 295 | 296 | THDoubleTensor * bvs = __get_backward_variable(output, \ 297 | alignedTable, target); 298 | 299 | /* 300 | double * data = THDoubleTensor_data(bvs); 301 | 302 | for (int i = 0; i < bvs->size[0]; i++) { 303 | for (int j = 0; j < bvs->size[1]; j++) { 304 | printf("%.4f\t", data[i * bvs->size[1] + j] == -1e10 ? 0 : data[i * bvs->size[1] + j]); 305 | } 306 | printf("\n"); 307 | } 308 | */ 309 | 310 | luaT_pushudata(L, bvs, "torch.DoubleTensor"); 311 | 312 | return 1; 313 | } 314 | 315 | 316 | 317 | static int ctc_get_grad(lua_State * L) { 318 | THDoubleTensor * fb = (THDoubleTensor *)luaT_checkudata(L, 1, "torch.DoubleTensor"); 319 | THDoubleTensor * outputTable = (THDoubleTensor *)luaT_checkudata(L, 2, "torch.DoubleTensor"); 320 | THDoubleTensor * target = (THDoubleTensor *)luaT_checkudata(L, 3, "torch.DoubleTensor"); 321 | double pzx = luaL_checknumber(L, 4); 322 | 323 | THDoubleTensor * grad = __get_grad(fb, outputTable, target, pzx); 324 | 325 | luaT_pushudata(L, grad, "torch.DoubleTensor"); 326 | 327 | return 1; 328 | } 329 | 330 | 331 | static const struct luaL_reg libctc[] = { 332 | {"get_forward_variable", ctc_get_forward_variable}, 333 | {"get_backward_variable", ctc_get_backward_variable}, 334 | {"get_grad", ctc_get_grad}, 335 | {NULL, NULL} 336 | }; 337 | 338 | LUA_EXTERNC int luaopen_libctc(lua_State *L) { 339 | luaL_openlib(L, "libctc", libctc, 0); 340 | return 1; 341 | } -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'rnn' 3 | require 'GRU' 4 | require 'slider' 5 | require 'image' 6 | require 'optim' 7 | require 'json' 8 | 9 | require 'loader' 10 | require 'ctc_log' 11 | require 'utils.decoder' 12 | require 'utils.levenshtein' 13 | 14 | -- initialize 15 | torch.setdefaulttensortype('torch.DoubleTensor') 16 | torch.manualSeed(os.time()) 17 | 18 | -- debug switch 19 | DEBUG = false 20 | 21 | -- timer initialize 22 | base = 0 23 | timer = torch.Timer() 24 | 25 | function show_log(log) 26 | local now = timer:time().real 27 | local cost = now - base 28 | base = now 29 | print(string.format("[%.4f][%.4f]%s", now, cost, log)) 30 | end 31 | 32 | -- settings 33 | 34 | opt = { 35 | -- project 36 | project_name = os.date("%y-%m-%d_%H%M%S"), -- important !! the name of run folder, besure not to override a existing one. 37 | using_model_file = nil, 38 | 39 | recurrent_unit = "gru", 40 | 41 | -- hyperparameters 42 | input_size = 48, 43 | hidden_size = 200, 44 | clamp_size = 1, 45 | learning_rate = 1e-4, 46 | momentum = 0.9, 47 | dropout_rate = 0, 48 | max_param_norm = false, 49 | 50 | -- configurations 51 | gpu = false, -- might not help 52 | 53 | -- threading 54 | nthread = 3, 55 | omp_threads = 1, 56 | 57 | -- samples 58 | training_list_file = "wwr.txt", 59 | testing_list_file = nil, 60 | codec_file = nil, 61 | testing_ratio = 1, -- is valid unless testing_list_file == nil 62 | 63 | -- feature extracting layer 64 | raw_input = true, -- disable this layer? 65 | windows_size = 10, 66 | stride = 5, 67 | feature_size = 48 * 5, -- feature size 68 | rbm_network_file = "rbm/wwr.rbm", -- set if you want to use a pretrained rbm encoder 69 | 70 | -- miscellaneous 71 | max_iter = 1e10, 72 | show_every = 10, 73 | save_every = 10000, 74 | test_every = 1000, 75 | ctc_lua = false 76 | 77 | } 78 | 79 | -- load settings 80 | 81 | cmd = torch.CmdLine() 82 | cmd:option("-setting", "", "setting file") 83 | params = cmd:parse(arg) 84 | 85 | if params.setting ~= "" then 86 | show_log("loading setting file " .. params.setting) 87 | opt = json.load(params.setting) 88 | show_log("setting file loaded successfully.") 89 | end 90 | 91 | run_dir = "experiments/" .. opt.project_name .. "/" 92 | paths.mkdir(run_dir) 93 | 94 | -- logging 95 | 96 | cmd:log(run_dir .. "log.txt", opt) 97 | json.save(run_dir .. "setting.json", opt) 98 | 99 | 100 | 101 | -- apply settings 102 | if opt.omp_threads then 103 | torch.setnumthreads(opt.omp_threads) 104 | end 105 | 106 | ctc_lua = opt.ctc_lua 107 | 108 | -- curriculum training settings 109 | 110 | curriculum_training = false 111 | weight_change_iter_span = 10000 112 | lambda = 3 113 | lambda_change_every = 1000 114 | lambda_grad = lambda / (weight_change_iter_span / lambda_change_every) 115 | 116 | -- GPU 117 | 118 | if opt.gpu then 119 | require 'cutorch' 120 | require 'cunn' 121 | end 122 | 123 | -- load samples 124 | 125 | show_log("Loading samples...") 126 | 127 | loader = Loader() 128 | loader:targetHeight(opt.input_size) 129 | 130 | if opt.testing_list_file ~= nil then 131 | loader:load(opt.training_list_file, 1) 132 | loader:loadTesting(opt.testing_list_file) 133 | else 134 | loader:load(opt.training_list_file, opt.testing_ratio) 135 | end 136 | 137 | 138 | -- load codec 139 | 140 | if opt.codec_file then 141 | codec = loader:loadCodec(opt.codec_file) 142 | else 143 | codec = loader:codec() 144 | torch.save(opt.training_list_file:gsub("[.].*", ".codec"), codec) 145 | end 146 | 147 | 148 | show_log(string.format("Loading finished. Got %d samples, %d classes of characters.", #loader.samples, codec.codec_size)) 149 | show_log(string.format("Splited into %d for training, %d for testing", #loader.training, #loader.testing)) 150 | 151 | 152 | 153 | local class_num = codec.codec_size 154 | 155 | -- build network 156 | 157 | show_log("Building networks...") 158 | 159 | local net, recurrent 160 | 161 | if opt.using_model_file then 162 | -- load current network 163 | net = torch.load(opt.using_model_file) 164 | else 165 | local rnn_input_size = opt.input_size 166 | 167 | net = nn.Sequential() 168 | 169 | -- build feature extracting layer 170 | 171 | if (opt.raw_input) then 172 | net:add(nn.SplitTable(1)) 173 | else 174 | local raw_input_size = opt.windows_size * opt.input_size 175 | rnn_input_size = opt.feature_size 176 | 177 | if (opt.rbm_network_file) then 178 | show_log("loading RBM nerwork...") 179 | local rbm = torch.load(opt.rbm_network_file) 180 | show_log(string.format("loaded RBM Layer with n_visual=%d, n_hidden=%d.", rbm.n_visible, rbm.n_hidden)) 181 | net:add(nn.Sequencer(rbm.encoder)) 182 | else 183 | net:add(nn.Sequencer(nn.Linear(raw_input_size, opt.feature_size))) 184 | end 185 | end 186 | 187 | 188 | -- build RNN layer 189 | 190 | if opt.recurrent_unit == "gru" then 191 | recurrent = nn.GRU(rnn_input_size, opt.hidden_size) 192 | elseif opt.recurrent_unit == "lstm" then 193 | recurrent = nn.LSTM(rnn_input_size, opt.hidden_size) 194 | elseif opt.recurrent_unit == "lstm_nopeephole" then 195 | recurrent = nn.LSTM(rnn_input_size, opt.hidden_size, 9999, false) 196 | elseif opt.recurrent_unit == "lstm_fast" then 197 | recurrent = nn.FastLSTM(rnn_input_size, opt.hidden_size) 198 | end 199 | 200 | net:add(nn.BiSequencer(recurrent)) 201 | output = nn.Sequential() 202 | output:add(nn.Dropout(opt.dropout_rate)) 203 | output:add(nn.Linear(opt.hidden_size * 2, class_num + 1)) 204 | output:add(nn.SoftMax()) 205 | net:add(nn.Sequencer(output)) 206 | net:double() 207 | end 208 | 209 | if opt.gpu then 210 | net:cuda() 211 | end 212 | 213 | -- prepare prarmeters and training method 214 | 215 | params, grad_params = net:getParameters() 216 | 217 | state = { 218 | learningRate = opt.learning_rate, 219 | momentum = opt.momentum 220 | } 221 | 222 | show_log(string.format("Start training. umaru~~")) 223 | 224 | -- training 225 | 226 | begin_time = 0 227 | 228 | -- sliding window 229 | 230 | get_input = function(im) 231 | local input 232 | if opt.raw_input then 233 | input = im 234 | else 235 | if opt.gpu then 236 | torch.setdefaulttensortype('torch.CudaTensor') 237 | end 238 | local slider = Slider() 239 | slider:load(im:t()) 240 | input = slider:genSequence() 241 | torch.setdefaulttensortype('torch.DoubleTensor') 242 | end 243 | 244 | return input 245 | end 246 | 247 | 248 | for i = 1, 1000000 do 249 | -- epoch begin 250 | 251 | local b1 = timer:time().real 252 | 253 | local sample 254 | 255 | -- pick a sample randomly 256 | 257 | if not curriculum_training then 258 | sample = loader:pick() 259 | else 260 | sample = loader:pickWithWeight() 261 | if i % lambda_change_every == 0 then 262 | lambda = lambda - lambda_grad 263 | loader:updateWeight(lambda) 264 | show_log("lambda was changed to " .. lambda) 265 | end 266 | end 267 | 268 | local im 269 | 270 | if opt.gpu then 271 | im = sample.img:cuda() 272 | else 273 | im = sample.img 274 | end 275 | 276 | -- encode target 277 | local target = codec:encode(sample.gt) 278 | 279 | -- forward backward 280 | local feval = function(params) 281 | net:forget() 282 | 283 | local input = get_input(im) 284 | 285 | -- forward 286 | local outputTable = net:forward(input) 287 | 288 | -- get CTC loss and Gradient 289 | local loss, grad = ctc.getCTCCostAndGrad(outputTable, target, opt.gpu) 290 | 291 | if opt.show_every > 0 and i % opt.show_every == 0 then 292 | print("") 293 | show_log("EPOCH " .. i) 294 | show_log("TARGET " .. sample.gt) 295 | show_log("OUTPUT " .. decoder.best_path_decode(outputTable, codec)) 296 | show_log("LOSS " .. loss) 297 | show_log("sec/ep " .. (timer:time().real - begin_time) / i) 298 | end 299 | 300 | -- backward 301 | net:backward(input, grad) 302 | 303 | -- process the gradients (avoiding gradient explosion) 304 | if opt.gpu then 305 | grad_params:cmul(grad_params:eq(grad_params)) 306 | else 307 | grad_params:cmul(grad_params:eq(grad_params):double()) 308 | end 309 | grad_params:clamp(-opt.clamp_size, opt.clamp_size) 310 | 311 | 312 | input = nil 313 | sample.img = nil 314 | 315 | return loss, grad_params 316 | end 317 | 318 | 319 | -- sgd optimize 320 | optim.sgd(feval, params, state) 321 | 322 | if opt.max_param_norm then 323 | net:maxParamNorm(2) 324 | end 325 | 326 | if i % 100 == 0 then 327 | show_log("Collecting garbage... before gc " .. collectgarbage("count")) 328 | collectgarbage() 329 | show_log("GC Finished. after gc " .. collectgarbage("count")) 330 | end 331 | 332 | -- model saving 333 | 334 | if opt.save_every > 0 and i % opt.save_every == 0 then 335 | print("") 336 | show_log("Saving model...") 337 | local filename = string.format("umaru_model_%s_%d.uma", os.date("%y-%m-%d_%X"), i) 338 | torch.save(run_dir .. filename, net) 339 | show_log(string.format("Saving finished, saved model file is at %s.", filename)) 340 | end 341 | 342 | 343 | -- testing 344 | 345 | if opt.test_every > 0 and i % opt.test_every == 0 and #loader.testing > 0 then 346 | show_log("testing...") 347 | net:evaluate() 348 | local dist, len = 0, 0 349 | loader:reset() 350 | local s = loader:pickInSequential("testing") 351 | 352 | while s do 353 | local im 354 | if opt.gpu then 355 | im = s.img:cuda() 356 | else 357 | im = s.img 358 | end 359 | 360 | local input = get_input(im) 361 | 362 | local out = decoder.best_path_decode(net:forward(input), codec) 363 | dist = dist + utf8.levenshtein(s.gt, out) 364 | len = len + utf8.len(s.gt) 365 | 366 | -- print("") 367 | -- show_log("FILE " .. s.src) 368 | -- show_log("TARGET " .. s.gt) 369 | -- show_log("OUTPUT " .. out) 370 | -- show_log("ERROR " .. string.format("%.2f%%", dist / len * 100)) 371 | 372 | s = loader:pickInSequential("testing") 373 | end 374 | 375 | show_log("testing finished, error rate: " .. string.format("%.2f%% at epoch %d.", dist / len * 100, i)) 376 | 377 | net:training() 378 | end 379 | 380 | end 381 | 382 | 383 | -------------------------------------------------------------------------------- /rbm.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015, Nils Hammerla 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 15 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 16 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 18 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 19 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 20 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 21 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 22 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 23 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | --]] 25 | 26 | local Rbm = {} 27 | Rbm.__index = Rbm 28 | 29 | function Rbm.new(arg) 30 | local self = setmetatable({}, Rbm) 31 | 32 | -- parse parameters 33 | -- network 34 | self.n_visible = arg.n_visible 35 | self.n_hidden = arg.n_hidden 36 | 37 | -- unit type, default: sigmoids 38 | self.v_activation = arg.v_activation or 'binary' 39 | self.h_activation = arg.v_activation or 'binary' 40 | self.useStates = arg.useStates or true 41 | 42 | -- learning 43 | self.learningRate = arg.learningRate or 0.1 44 | self.minibatchSize = arg.minibatchSize or 100 45 | self.momentum = arg.momemtum 46 | self.momentumAfter = arg.momentumAfter or 5 47 | self.CDsteps = arg.CDsteps or 1 48 | 49 | -- regularisation 50 | self.weightCost = arg.weightCost or -0.0000001 51 | 52 | -- some variables to save progress 53 | self.epoch = 0 54 | self.epochError = 0 55 | 56 | -- initialise weights 57 | self:initWeights() 58 | 59 | -- set up sampling functions for visible and hidden units 60 | self.binarySampler = function(input) 61 | local a = nn.Sigmoid()(input) 62 | local s = torch.gt(a, torch.Tensor(a:size()):rand(a:size())):type(torch.getdefaulttensortype()) 63 | return a,s 64 | end 65 | self.reluSampler = function(input) 66 | local n = torch.Tensor(input:size()):randn(input:size()) 67 | local a = nn.ReLU()(input+n) 68 | return a,a 69 | end 70 | self.gaussSampler = function(input) 71 | return input, input+torch.Tensor(input:size()):randn(input:size()) 72 | end 73 | 74 | if self.h_activation == 'binary' then 75 | self.h_sampler = self.binarySampler 76 | elseif self.h_activation == 'relu' then 77 | self.h_sampler = self.reluSampler 78 | elseif self.h_activation == 'gaussian' then 79 | self.h_sampler = self.gaussSampler 80 | end 81 | 82 | if self.v_activation == 'binary' then 83 | self.v_sampler = self.binarySampler 84 | elseif self.v_activation == 'relu' then 85 | self.v_sampler = self.reluSampler 86 | elseif self.v_activation == 'gaussian' then 87 | self.v_sampler = self.gaussSampler 88 | end 89 | 90 | -- Set up an encoder and decoder (nn) 91 | -- Weights in these networks point to the weight tensor in the rbm 92 | self.encoder, self.decoder = self:getNN() 93 | 94 | return self 95 | end 96 | 97 | function Rbm.initWeights(self) 98 | -- Initialise weights 99 | -- Basically reset the whole thing 100 | local nV, nH = self.n_visible, self.n_hidden 101 | 102 | self.W = torch.Tensor(nV,nH):randn(nV, nH):mul(0.1) 103 | self.hbias = torch.Tensor(1,nH):zeros(1,nH) 104 | self.vbias = torch.Tensor(1,nV):zeros(1,nV) 105 | 106 | self.Winc = torch.Tensor(nV,nH):zeros(nV, nH) 107 | self.hbias = torch.Tensor(1,nH):zeros(1, nH) 108 | self.vbias = torch.Tensor(1,nV):zeros(1, nV) 109 | end 110 | 111 | function Rbm.HgivenV(self, v_sample) 112 | -- sample hidden layer based on visible layer 113 | local pre, post, states 114 | -- h_in = v*W + h_bias 115 | pre = torch.mm(v_sample, self.W):add(self.hbias:repeatTensor(v_sample:size(1),1)) 116 | post, states = self.h_sampler(pre) 117 | 118 | if self.useStates == false then 119 | states = post 120 | end 121 | 122 | return pre, post, states 123 | end 124 | 125 | function Rbm.VgivenH(self, h_sample) 126 | -- sample visible layer given hidden layer 127 | local pre, post, states 128 | -- v_in = h*W' + v_bias 129 | pre = torch.mm(h_sample, self.W:t()):add(self.vbias:repeatTensor(h_sample:size(1),1)) 130 | post, states = self.v_sampler(pre) 131 | 132 | if self.useStates == false then 133 | states = post 134 | end 135 | 136 | return pre, post, states 137 | end 138 | 139 | function Rbm.sampleChain(self, h, CDsteps) 140 | -- sample markov chain for contrastive divergence training (starting from hidden state h) 141 | start = h 142 | local v_mean, v_sample, h_mean, h_sample 143 | for i=1, CDsteps do 144 | _, v_mean, v_sample = self:VgivenH(start) 145 | _, h_mean, h_sample = self:HgivenV(v_sample) 146 | start = h_sample -- reset 147 | end 148 | return v_mean, v_sample, h_mean, h_sample 149 | end 150 | 151 | function Rbm.freeEnergy(self,sample) 152 | -- calculate free energy (for convergence check if required) 153 | -- This is just for binary-binary rbms! (so far) 154 | local wx_b = torch.mm(sample, self.W):add(self.hbias:repeatTensor(sample:size(1),1)) 155 | local vbias_term = torch.mm(sample, self.vbias:t()) 156 | local hidden_term = torch.log(torch.add(wx_b:exp(),1)):sum(2) 157 | local e = -hidden_term - vbias_term 158 | return e 159 | end 160 | 161 | function Rbm.updateParameters(self, v0) 162 | -- calculate gradients for W, vbias, hbias and update weight matrices 163 | 164 | local momentum 165 | 166 | -- if we use momentum, then check which we want to use 167 | if self.momemtum then 168 | momentum = self.momentum[1] 169 | end 170 | if self.momentum and self.epoch > self.momentumAfter then 171 | momentum = self.momentum[2] 172 | end 173 | 174 | -- sample first hidden layer 175 | local _, h0_mean, h0_sample = self:HgivenV(v0) 176 | 177 | -- get sample from markov chain 178 | local v_model_mean, v_model_sample, h_model_mean, h_model_sample = self:sampleChain(h0_sample, self.CDsteps) 179 | 180 | if momentum then 181 | -- if momentum is set then memorise weights (soo much memory for this) 182 | local ww = self.Winc:clone() 183 | local vb = self.vbias:clone() 184 | local hb = self.hbias:clone() 185 | end 186 | 187 | -- calculate derivatives and update matrices 188 | 189 | -- calculate weight derivatives. 190 | -- This looks a bit weird but splitting up the calculations is apparently 191 | -- more memory efficient (still a mystery to me). 192 | 193 | -- formulas (from Geoff Hinton and Ruslan Salakhutdinov's matlab code): 194 | -- vishidinc = momentum*vishidinc + epsilonw*( (posprods-negprods)/numcases - weightcost*vishid); 195 | -- visbiasinc = momentum*visbiasinc + (epsilonvb/numcases)*(posvisact-negvisact); 196 | -- hidbiasinc = momentum*hidbiasinc + (epsilonhb/numcases)*(poshidact-neghidact); 197 | 198 | self.Winc = torch.mm(v0:t(),h0_mean) -- posprods 199 | self.Winc:add(torch.mm(v_model_mean:t(), h_model_mean):mul(-1)) -- -negprods 200 | self.Winc:div(v0:size(1)) -- / numsamples 201 | self.Winc:add(torch.mul(self.W, self.weightCost)) -- regularisation 202 | self.Winc:mul(self.learningRate) -- * learning rate 203 | 204 | -- visible bias 205 | self.vbiasinc = v0:sum(1) 206 | self.vbiasinc:add(-v_model_mean:sum(1)) 207 | self.vbiasinc:mul(self.learningRate) 208 | self.vbiasinc:div(v0:size(1)) 209 | 210 | -- hidden bias 211 | self.hbiasinc = h0_mean:sum(1) 212 | self.hbiasinc:add(-h_model_mean:sum(1)) 213 | self.hbiasinc:mul(self.learningRate) 214 | self.hbiasinc:div(v0:size(1)) 215 | 216 | if self.momentum and self.epoch > 1 then 217 | -- momentum? if so add derivatives*momentum 218 | self.Winc:add(torch.mul(ww, momentum)) 219 | self.vbiasinc:add(torch.mul(vb, momentum)) 220 | self.hbiasinc:add(torch.mul(hb, momentum)) 221 | end 222 | 223 | -- update params 224 | self.W:add(self.Winc) 225 | self.vbias:add(self.vbiasinc) 226 | self.hbias:add(self.hbiasinc) 227 | end 228 | 229 | function Rbm.train(self, data, epochs) 230 | -- train an rbm on data for a number of epochs. If data is large this may run 231 | -- into memory issues. If so use updateParameters() 232 | local e, a 233 | for e=1,epochs do 234 | xlua.progress(e, epochs) 235 | 236 | self.epoch = self.epoch + 1 237 | for i=1,data:size(1),self.minibatchSize do 238 | self:updateParameters(data[{{i,i+self.minibatchSize-1}, {}}]) 239 | end 240 | 241 | collectgarbage() -- this is needed, not sure where the leak is 242 | end 243 | end 244 | 245 | function Rbm.getNN(self) 246 | -- construct an encoder and decoder network 247 | -- these share the memory with the rbm so no just little overhead. This should 248 | -- make stacking rbms and fine-tuning with backprop much easier. 249 | 250 | local encoder, decoder 251 | 252 | -- encoder: visible -> hidden 253 | encoder = nn.Sequential() 254 | encoder:add(nn.Linear(self.n_visible, self.n_hidden)) 255 | if self.h_activation == 'binary' then 256 | encoder:add(nn.Sigmoid()) 257 | elseif self.h_activation == 'relu' then 258 | encoder:add(nn.ReLU()) 259 | elseif self.h_activation == 'gaussian' then 260 | -- linear is fine 261 | end 262 | encoder:get(1).weight = self.W:t() -- weight matrix is flipped in nn 263 | encoder:get(1).bias = self.hbias[1] 264 | 265 | -- decoder: hidden -> visible 266 | decoder = nn.Sequential() 267 | decoder:add(nn.Linear(self.n_hidden, self.n_visible)) 268 | if self.v_activation == 'binary' then 269 | decoder:add(nn.Sigmoid()) 270 | elseif self.v_activation == 'relu' then 271 | decoder:add(nn.ReLU()) 272 | elseif self.v_activation == 'gaussian' then 273 | -- linear is fine 274 | end 275 | decoder:get(1).weight = self.W 276 | decoder:get(1).bias = self.vbias[1] 277 | 278 | return encoder, decoder 279 | end 280 | 281 | function Rbm.fromNN(self, encoder, decoder) 282 | -- Parsing encoder (linear layer + non-linearity) for weights and hidden bias, 283 | -- and hidden unit actiation. Parsing decoder for visible bias (and activation). 284 | -- 285 | -- Assumes: 286 | -- encoder = nn.Sequential() 287 | -- encoder:add(nn.Linear()) 288 | -- encoder:add(nn.Sigmoid()) 289 | -- 290 | self.W = encoder.get(1).weight:t() 291 | self.hbias[1] = encoder.get(1).bias 292 | self.h_activation = encoder.get(2) 293 | self.vbias[1] = decoder.get(1).bias 294 | self.v_activation = decoder.get(2) 295 | end 296 | 297 | return Rbm 298 | -------------------------------------------------------------------------------- /test_ctc_large.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'ctc_log' 3 | 4 | torch.setdefaulttensortype('torch.DoubleTensor') 5 | 6 | outputTable = torch.Tensor{ 7 | {0.00851561,0.00838812,0.00831823,0.00827574,0.00824642,0.00822375,0.00820343,0.00819058,0.00818666,0.00818973,0.0081931,0.00819045,0.0081746,0.00815639,0.00813955,0.00813717,0.00814291,0.00815282,0.00816507,0.00818201,0.00819585,0.00820183,0.0082032,0.00820701,0.00821897,0.00823849,0.00827241,0.00831417,0.00834877,0.00838797}, 8 | {0.00832554,0.0083547,0.00836662,0.0083672,0.00835925,0.00834443,0.00832733,0.00830804,0.00829558,0.00828797,0.00826878,0.00823972,0.00820705,0.00817149,0.00815277,0.00815138,0.00816869,0.00819128,0.00820619,0.00821728,0.00822713,0.00822409,0.00820866,0.00819896,0.00819222,0.00818621,0.00817647,0.00817657,0.00819549,0.0082471}, 9 | {0.00804687,0.00828849,0.00844932,0.00854689,0.00860203,0.00863111,0.00864463,0.00864913,0.00864941,0.00864568,0.00864145,0.00864273,0.00864651,0.00864103,0.00862982,0.00861546,0.00860792,0.0086006,0.00859696,0.00859797,0.00860569,0.00861662,0.00862803,0.00864035,0.00864985,0.00866094,0.00866769,0.00867244,0.00866648,0.00860276}, 10 | {0.0071507,0.00694943,0.00682586,0.00674856,0.0067018,0.00667533,0.00666434,0.00666094,0.00665562,0.00664758,0.00663679,0.00662265,0.00661063,0.00661422,0.00662202,0.00662611,0.00662495,0.00662458,0.00662262,0.00661391,0.00660734,0.00661144,0.00662187,0.0066317,0.00663687,0.00663818,0.00663882,0.00664628,0.00667254,0.00675394}, 11 | {0.00754264,0.00749866,0.00748816,0.00748909,0.00748963,0.00748719,0.00748255,0.00747492,0.00747167,0.00747423,0.0074759,0.00747402,0.0074708,0.00746449,0.00746391,0.00747165,0.00748495,0.00749894,0.00751074,0.00751627,0.007513,0.00749486,0.00747238,0.00745048,0.00743328,0.00741408,0.0073878,0.00735739,0.00732742,0.00732945}, 12 | {0.0073096,0.00723635,0.00719494,0.00716805,0.00715027,0.00713952,0.00713509,0.00712752,0.00711783,0.00711011,0.00710192,0.00708911,0.00708077,0.00708461,0.00708462,0.00707723,0.00707139,0.00706921,0.00706241,0.00705613,0.00706148,0.00707487,0.00708709,0.00709222,0.00708823,0.00708249,0.00708064,0.00708988,0.00710951,0.00714969}, 13 | {0.00830771,0.00823418,0.00816136,0.00809586,0.00804167,0.00799979,0.00796758,0.00794364,0.00793313,0.00792998,0.00792908,0.00792848,0.00792855,0.00792866,0.00792361,0.00791404,0.00790926,0.00790524,0.00789946,0.00789935,0.00789862,0.0078946,0.00788624,0.00787816,0.00786684,0.00784255,0.00779934,0.00772262,0.00759376,0.00739087}, 14 | {0.00727931,0.00734384,0.00737194,0.00738059,0.00737914,0.00737259,0.00736598,0.00737079,0.00738365,0.00739321,0.00740285,0.0074161,0.00742545,0.00742927,0.00743403,0.00744413,0.00745141,0.00745337,0.00746366,0.00746699,0.00745354,0.00743827,0.00743547,0.00744751,0.00748458,0.00755231,0.00764996,0.00777761,0.00796437,0.00824862}, 15 | {0.00759865,0.00760253,0.007619,0.00763628,0.00764887,0.00765355,0.00765037,0.00764366,0.0076389,0.00763273,0.00762202,0.00761542,0.00761321,0.00761493,0.00761815,0.00761988,0.00762385,0.00761992,0.00761883,0.0076235,0.00763021,0.00763094,0.00762206,0.00760918,0.00758451,0.00755949,0.00753664,0.00751976,0.00751458,0.00752432}, 16 | {0.00807587,0.00800675,0.00794393,0.00790045,0.00787595,0.0078656,0.00786785,0.00787848,0.00788506,0.00788517,0.00788813,0.00789203,0.00789427,0.00790812,0.00792374,0.00793619,0.00793606,0.00793511,0.00793498,0.00792428,0.00790891,0.00790654,0.00791503,0.0079268,0.00794357,0.00796792,0.00800614,0.00807022,0.00817937,0.00835883}, 17 | {0.00902721,0.00923254,0.00935981,0.00942818,0.00945831,0.00946226,0.00945235,0.00943923,0.00943316,0.00942081,0.00939779,0.00937562,0.00936163,0.00935968,0.0093776,0.0094018,0.00943297,0.00944957,0.00945726,0.00946422,0.00946807,0.00945142,0.00941994,0.00939279,0.00935696,0.00931577,0.00925278,0.0091793,0.00908306,0.00896482}, 18 | {0.00817932,0.00824187,0.00828341,0.00831163,0.00833313,0.00835258,0.00837072,0.00838791,0.00839768,0.00840187,0.00840477,0.00840498,0.00839802,0.00839423,0.0083998,0.0084094,0.00840665,0.00839996,0.00839527,0.00838367,0.00837028,0.0083724,0.00839625,0.00843072,0.00846929,0.00850286,0.0085422,0.00858495,0.00866,0.00877247}, 19 | {0.00765023,0.00748373,0.00738717,0.00733217,0.00730354,0.00729194,0.00729252,0.00729467,0.00729358,0.00729548,0.00730373,0.00730991,0.00731084,0.0073158,0.00731402,0.00730786,0.00729353,0.00728685,0.00728254,0.00727386,0.00727117,0.00728254,0.00729897,0.00730843,0.00731725,0.0073286,0.0073535,0.00738615,0.00742871,0.00748284}, 20 | {0.00675057,0.0066293,0.00654755,0.0064934,0.00645808,0.0064347,0.00641783,0.00640552,0.0063952,0.00638518,0.00637609,0.00637163,0.00637494,0.00638893,0.00641063,0.00643255,0.00644374,0.00644555,0.00644544,0.00644605,0.00644555,0.00644547,0.0064497,0.00646288,0.00648797,0.00653274,0.00661251,0.00673724,0.00692173,0.00718667}, 21 | {0.00918335,0.00914577,0.00911366,0.00909089,0.0090778,0.00907167,0.00907246,0.00908205,0.00908736,0.00908863,0.00909847,0.00911673,0.00913672,0.00915208,0.0091591,0.00915776,0.00914384,0.00912792,0.00912464,0.00911718,0.00910646,0.00910203,0.00910907,0.00911052,0.00910743,0.00909927,0.00907481,0.00902575,0.00894358,0.00878026}, 22 | {0.00804851,0.00801377,0.00798571,0.00796503,0.00794885,0.00793622,0.00792406,0.00790549,0.00788894,0.00788346,0.00788096,0.0078791,0.00788259,0.00788236,0.00787847,0.00787234,0.00786953,0.00787329,0.00787384,0.00788077,0.00789635,0.00790469,0.00790562,0.00790307,0.00790476,0.00791308,0.00793394,0.00797176,0.0080062,0.00802328}, 23 | {0.00804263,0.00804189,0.00801113,0.00797653,0.00794703,0.00792481,0.00790932,0.00789841,0.00788874,0.00788058,0.0078668,0.00784778,0.00783389,0.00783016,0.00784207,0.00785935,0.00787323,0.00787845,0.00787904,0.00788003,0.0078825,0.00788204,0.00788413,0.00788901,0.00789046,0.00788951,0.00788775,0.00788351,0.00788097,0.00789731}, 24 | {0.00909428,0.00902128,0.00893768,0.00886991,0.00881941,0.00878495,0.00876315,0.00875592,0.00875657,0.00875402,0.0087504,0.00875375,0.00876459,0.00877769,0.00879957,0.00881432,0.00882092,0.0088098,0.00880371,0.00879922,0.00878471,0.00876368,0.0087541,0.00875095,0.00873373,0.00869797,0.00862629,0.00851564,0.00835177,0.00811844}, 25 | {0.00832823,0.00827849,0.00822974,0.00819151,0.00816408,0.00814953,0.0081415,0.00813423,0.00812959,0.00813012,0.00813613,0.00814694,0.00815946,0.00816672,0.00816999,0.00817311,0.00817281,0.00817262,0.00816917,0.00816872,0.008169,0.00817016,0.00817534,0.00818311,0.00819565,0.00820304,0.00820972,0.00819912,0.0081488,0.00802376}, 26 | {0.00893189,0.00876061,0.00863223,0.00854065,0.0084778,0.00843569,0.00840853,0.00839344,0.00838725,0.0083868,0.00838918,0.00839096,0.00839137,0.00838554,0.00837721,0.00836837,0.00836635,0.00836774,0.0083689,0.00837453,0.0083774,0.00837381,0.0083644,0.00835373,0.00834258,0.00833227,0.00831803,0.00829835,0.00827618,0.00824168}, 27 | {0.00938955,0.0097018,0.00986054,0.00993083,0.00995411,0.00995608,0.00994645,0.00993669,0.00993473,0.00993255,0.00992774,0.00992689,0.00992614,0.00991692,0.00991169,0.0099126,0.00992585,0.00993359,0.0099365,0.00994683,0.00994926,0.0099362,0.0099234,0.00992503,0.00993668,0.00995512,0.00998125,0.010009,0.0100375,0.0100574}, 28 | {0.00822628,0.00813483,0.00809142,0.00807144,0.00806462,0.00806396,0.00806866,0.00807756,0.00808577,0.00808872,0.00809288,0.00809777,0.0080966,0.0080989,0.00810496,0.00811333,0.0081189,0.00812571,0.00813616,0.00813678,0.00812748,0.00812333,0.00812213,0.00812465,0.00813816,0.00817017,0.00821983,0.00828945,0.00839643,0.00855793}, 29 | {0.0073821,0.00744486,0.00749707,0.00753813,0.00756922,0.00758955,0.00760586,0.007623,0.00763743,0.00764293,0.00763994,0.00763569,0.00762763,0.00762322,0.00763916,0.00765757,0.00767172,0.00767852,0.00768977,0.00768628,0.00766897,0.00765094,0.00763666,0.00763626,0.00764463,0.00767381,0.00770845,0.00776687,0.00787445,0.00806479}, 30 | {0.00816511,0.00835246,0.00846613,0.00853132,0.00856845,0.00859176,0.00860707,0.00862639,0.00865073,0.00866682,0.00867696,0.00868401,0.00868711,0.0086959,0.00871359,0.00873163,0.00874378,0.00873709,0.00873007,0.00872333,0.00870615,0.00868983,0.00868414,0.00868932,0.0086862,0.00866888,0.00863037,0.00856381,0.00846892,0.00834699}, 31 | {0.00672763,0.00663195,0.00658458,0.006564,0.00655871,0.00656313,0.00657208,0.00658167,0.00658574,0.00658562,0.00658811,0.00659119,0.00659055,0.00659394,0.00659564,0.0065974,0.00659346,0.00658937,0.00658611,0.0065807,0.00657191,0.0065736,0.00658544,0.00660277,0.00662215,0.00664676,0.0066961,0.0067745,0.00691051,0.00714066}, 32 | {0.00795527,0.00800272,0.00804013,0.00806538,0.00807822,0.0080822,0.0080804,0.00807383,0.00807003,0.00806981,0.00806731,0.00806454,0.00806575,0.00806101,0.00805787,0.00805649,0.00806262,0.00806512,0.00805976,0.00806065,0.00806109,0.00805094,0.0080299,0.00800483,0.00797333,0.00794074,0.00790277,0.00786305,0.00781187,0.00775677}, 33 | {0.00945979,0.00962338,0.00974115,0.0098231,0.00987218,0.00989487,0.00989837,0.00989753,0.00990285,0.00991213,0.00991727,0.00992351,0.00993047,0.009926,0.00991739,0.00991031,0.0099114,0.00991239,0.00991817,0.00992977,0.00993348,0.00991423,0.00988041,0.00983151,0.0097712,0.00969198,0.0095657,0.00937378,0.00909269,0.00873733}, 34 | {0.00826436,0.00824927,0.0082387,0.008228,0.00821899,0.00821281,0.00821159,0.0082172,0.00823028,0.0082419,0.00825111,0.00825754,0.00825667,0.00825381,0.0082589,0.00827044,0.00828027,0.00828464,0.00829193,0.00829053,0.00827096,0.0082466,0.00822509,0.00821277,0.00820766,0.00821798,0.00823905,0.00826873,0.00833336,0.00845363}, 35 | {0.0075062,0.00743378,0.00738513,0.0073574,0.00734223,0.00733433,0.00732783,0.00731749,0.00730607,0.00730071,0.00729411,0.00728141,0.00727171,0.00726773,0.00726458,0.00726396,0.00726647,0.00727288,0.00727558,0.00728078,0.00729355,0.00730374,0.00730832,0.0073007,0.0072831,0.00726053,0.00724689,0.00724662,0.0072625,0.00731086}, 36 | {0.00968033,0.00993844,0.0100757,0.010149,0.0101864,0.0102035,0.0102068,0.0102,0.0101931,0.0101945,0.0101948,0.0101885,0.0101922,0.0101943,0.0101904,0.0101824,0.0101789,0.0101818,0.0101821,0.0101906,0.0102047,0.0102023,0.0101858,0.0101446,0.0100914,0.0100147,0.00991667,0.00978903,0.00962167,0.00941938}, 37 | {0.00802416,0.00793638,0.0078765,0.00784234,0.00783091,0.00783533,0.00785072,0.00786933,0.00788445,0.00788942,0.00789295,0.00789344,0.00789678,0.00792134,0.00795406,0.00798069,0.00799008,0.0079884,0.00797783,0.0079578,0.00793712,0.00793903,0.0079651,0.00800462,0.00804883,0.00809369,0.00815544,0.00823091,0.00832762,0.00841157}, 38 | {0.00853555,0.00845061,0.00841409,0.00839778,0.00838938,0.00838442,0.00838369,0.00839311,0.00840735,0.00842228,0.00843841,0.00845327,0.00845323,0.00843668,0.00841933,0.00841044,0.00840988,0.00841204,0.00842002,0.00841658,0.00839951,0.00837829,0.008355,0.00833758,0.008326,0.00830916,0.0082714,0.00821081,0.00814739,0.00809059}, 39 | {0.00823466,0.00838423,0.00848976,0.0085607,0.00860948,0.00864258,0.00866689,0.00867614,0.00867126,0.00866224,0.00865255,0.00863827,0.00863,0.00862398,0.00861004,0.00859451,0.0085834,0.00858825,0.00858424,0.0085859,0.00860449,0.00863134,0.00864507,0.00864348,0.00864114,0.00863929,0.00864046,0.00863788,0.00862007,0.00858039}, 40 | {0.00711341,0.00697165,0.00685488,0.00676534,0.0067009,0.00665729,0.00663088,0.00662267,0.00661683,0.00660613,0.00659872,0.00659772,0.00660237,0.0066163,0.0066341,0.0066511,0.0066572,0.00665493,0.00665477,0.00665014,0.00663928,0.006635,0.00664726,0.0066699,0.00669894,0.00673031,0.00677044,0.00682263,0.00691122,0.00707113}, 41 | {0.00742586,0.00745445,0.00746212,0.00745735,0.00744753,0.00743919,0.00743399,0.00743093,0.00743029,0.00742986,0.00742801,0.00742402,0.00741539,0.0074071,0.00739963,0.0073974,0.00740106,0.0074051,0.00740493,0.00740502,0.0074026,0.00740373,0.00741326,0.00743393,0.00746545,0.00750544,0.00756233,0.00763864,0.00775392,0.00792699}, 42 | {0.00810503,0.0081006,0.00808862,0.00807403,0.00805891,0.00804319,0.00802594,0.00800353,0.00798498,0.00797421,0.00796382,0.00795446,0.00794679,0.00793222,0.00791804,0.00790546,0.00790797,0.00791482,0.00791741,0.0079289,0.00794628,0.00795505,0.00795088,0.00793918,0.00792036,0.00789843,0.00787731,0.00786474,0.00785879,0.00787226}, 43 | {0.00822779,0.00828209,0.00832029,0.00834997,0.00837197,0.00838898,0.00839849,0.00840119,0.00840609,0.00841978,0.00843184,0.00843696,0.00844149,0.00843852,0.00842428,0.00840258,0.00839032,0.00838749,0.00838231,0.00838937,0.0084062,0.00841595,0.00840907,0.0083867,0.0083598,0.00833249,0.00830573,0.00827547,0.00821749,0.00813466}, 44 | {0.00789461,0.00771518,0.00760839,0.00754152,0.00749961,0.00747244,0.00745673,0.00745248,0.00745463,0.00745627,0.00745707,0.00745238,0.00743568,0.00741627,0.00739588,0.00738404,0.00738154,0.00738538,0.00739056,0.00739596,0.0073974,0.00739869,0.00739632,0.00739773,0.00740556,0.00741644,0.00743037,0.00743574,0.00743097,0.00741283}, 45 | {0.00876863,0.00904295,0.00924974,0.00940216,0.00951036,0.00958663,0.00964111,0.00968051,0.00970447,0.0097231,0.00974793,0.00978039,0.00981876,0.00984351,0.00986112,0.00986468,0.00985487,0.0098403,0.00983136,0.0098175,0.00980372,0.00979265,0.00978307,0.00976711,0.00974415,0.00971031,0.00963568,0.00950979,0.00931236,0.00901985}, 46 | {0.00871832,0.00854662,0.00844365,0.00838554,0.00835444,0.00833796,0.00832952,0.00833101,0.00833804,0.00834504,0.00835484,0.00837061,0.00838562,0.0083905,0.00839182,0.00839232,0.00838792,0.00837944,0.00837924,0.00838121,0.00837853,0.00837633,0.00838156,0.00839518,0.00841876,0.00845119,0.00848353,0.00851105,0.00853026,0.00851655}, 47 | {0.00801276,0.00793451,0.00789482,0.00787856,0.00787733,0.00788392,0.00789587,0.0079079,0.00791389,0.00791485,0.00791713,0.00791908,0.0079217,0.00792697,0.00793,0.00793011,0.00792594,0.00792582,0.00792413,0.00791546,0.00790844,0.0079133,0.00792706,0.00794777,0.00797882,0.00802127,0.00808142,0.00817053,0.00830989,0.00850804}, 48 | {0.00804214,0.00784403,0.00770717,0.00761508,0.00755763,0.00752647,0.00751291,0.00750588,0.00749424,0.007482,0.00747053,0.00745554,0.0074342,0.00741969,0.00741156,0.00741046,0.00740804,0.00740765,0.00740777,0.0074052,0.00740895,0.00742947,0.00745608,0.00748796,0.00752065,0.00755832,0.0076262,0.00773727,0.0079192,0.00817534}, 49 | {0.00816911,0.00786699,0.00767915,0.00756933,0.00751176,0.00748827,0.00748391,0.00748982,0.00749754,0.00750365,0.00751046,0.00751735,0.00752354,0.00752534,0.00751566,0.00750559,0.00749742,0.00749461,0.00749046,0.00748722,0.00748122,0.00748013,0.00748176,0.00747817,0.00746909,0.00744615,0.007414,0.00736801,0.00731362,0.00724872}, 50 | {0.0100089,0.00996254,0.00991846,0.00988815,0.00987049,0.00986237,0.0098603,0.00986085,0.00985786,0.00985566,0.00986122,0.00987314,0.0098893,0.00990235,0.00991245,0.00991363,0.00990202,0.00989077,0.00988973,0.00988558,0.00988182,0.00988367,0.00989757,0.00990127,0.00990056,0.00989005,0.0098611,0.00979996,0.00967591,0.00941751}, 51 | {0.00668494,0.0066093,0.00656151,0.00653254,0.00651565,0.0065059,0.00650192,0.0065047,0.00651054,0.00651358,0.00652037,0.0065315,0.00654723,0.00655908,0.00656391,0.00656513,0.00656057,0.00655316,0.00654738,0.0065382,0.00652566,0.00651756,0.00651897,0.00653105,0.00655684,0.00659797,0.00665132,0.00672262,0.00684132,0.00702638}, 52 | {0.00832419,0.0083764,0.00841408,0.00843397,0.00844396,0.00844914,0.0084524,0.00845202,0.00844734,0.00843953,0.00842743,0.0084119,0.00839519,0.00839092,0.00838496,0.00838521,0.00838912,0.00839667,0.0083953,0.00839199,0.00839835,0.00841221,0.0084228,0.00843765,0.00845033,0.00844481,0.00843264,0.00841378,0.00839056,0.00835806}, 53 | {0.00980411,0.00997704,0.0100629,0.0101078,0.0101316,0.0101468,0.010155,0.0101574,0.0101518,0.0101527,0.0101595,0.0101657,0.0101689,0.0101625,0.010146,0.0101285,0.0101151,0.0101133,0.0101135,0.0101166,0.0101299,0.010142,0.0101523,0.0101412,0.0101123,0.0100586,0.00997469,0.00985054,0.00964292,0.00930616}, 54 | {0.00723069,0.00720363,0.00719636,0.00719824,0.00720376,0.00720901,0.00721101,0.00721184,0.00721337,0.00721119,0.00720243,0.00719308,0.00718602,0.00718407,0.00719144,0.00720408,0.00721733,0.00722411,0.00723238,0.00723989,0.00724529,0.00724981,0.00726272,0.00728227,0.00730474,0.00733328,0.0073847,0.00746027,0.00757091,0.00770956}, 55 | {0.00798611,0.00825946,0.00844139,0.00855443,0.00861921,0.008653,0.0086685,0.00867775,0.00868508,0.00869232,0.00869699,0.00869802,0.0086837,0.00865284,0.00863441,0.00863386,0.00864542,0.0086546,0.00866475,0.00867665,0.0086829,0.00867913,0.00867174,0.00867392,0.00868616,0.00871249,0.00875348,0.00881463,0.00891975,0.009083}, 56 | {0.00777662,0.00765043,0.00757949,0.00754096,0.00752041,0.00751139,0.00750619,0.00749737,0.00749119,0.0074864,0.00748194,0.00748002,0.0074845,0.00749105,0.00749899,0.00750288,0.00750555,0.007502,0.00749639,0.00749492,0.00750104,0.0075094,0.00752157,0.00754264,0.00756712,0.00759204,0.00762696,0.00768497,0.00778112,0.00791304}, 57 | {0.00781043,0.00788317,0.00798568,0.0080856,0.00816726,0.00822749,0.00826596,0.00828458,0.00829546,0.00830359,0.00830899,0.0083146,0.00831931,0.00831959,0.00832082,0.00832957,0.00834414,0.00835335,0.00835403,0.00836239,0.00837417,0.00837491,0.00836628,0.00835261,0.00833397,0.00830971,0.00828284,0.00823596,0.00815306,0.00801108}, 58 | {0.00800828,0.00788008,0.00777793,0.00771033,0.00767049,0.00765292,0.00764878,0.00764537,0.00764042,0.00763853,0.00764078,0.00764725,0.0076636,0.00767995,0.00769293,0.0076951,0.00768856,0.0076784,0.0076656,0.0076589,0.00766304,0.00766984,0.00767775,0.00767947,0.00767608,0.00766961,0.00765395,0.00762336,0.00754239,0.00736801}, 59 | {0.00819978,0.00817231,0.00815959,0.00816021,0.00816536,0.00817142,0.0081786,0.00818332,0.00818296,0.00818617,0.0081882,0.00818516,0.00818235,0.00817378,0.00816302,0.00814845,0.00813402,0.0081271,0.00812022,0.00811405,0.00811896,0.00812354,0.00812484,0.00810843,0.00807754,0.00802944,0.00795545,0.00786497,0.00775183,0.00763206}, 60 | {0.00731681,0.00728241,0.00722482,0.00717478,0.00714097,0.00712321,0.00711791,0.00712109,0.00712366,0.00712145,0.00711801,0.00711421,0.00710818,0.00710559,0.00711039,0.00712155,0.00712756,0.00712637,0.00712291,0.00711887,0.00710981,0.00710406,0.00710635,0.00711403,0.00712172,0.00713027,0.0071534,0.00719576,0.0072824,0.00744648}, 61 | {0.00781636,0.00782035,0.00784844,0.00788021,0.00790524,0.00792403,0.00793539,0.0079439,0.00795475,0.00797036,0.0079895,0.00801132,0.00802206,0.00801556,0.00800629,0.00799979,0.00799703,0.00799249,0.00799382,0.00799184,0.00798092,0.00796637,0.00795547,0.00795645,0.00796617,0.00798141,0.00799347,0.00800544,0.0080078,0.00797629}, 62 | {0.00826222,0.00811796,0.00804458,0.00800687,0.00798525,0.00796965,0.00795772,0.00795007,0.00793686,0.00792452,0.00791468,0.00790654,0.00789228,0.00787334,0.00785552,0.00784726,0.00784383,0.00785177,0.00786439,0.00787097,0.00787448,0.00787683,0.00788132,0.00787532,0.00786757,0.00785658,0.00784798,0.00784918,0.00786728,0.00790658}, 63 | {0.0070611,0.00697053,0.00692387,0.00690461,0.00689908,0.00689932,0.00689862,0.00689794,0.00690062,0.00690767,0.00691553,0.00692341,0.0069301,0.00692479,0.00691579,0.00691143,0.00691054,0.00691059,0.00691546,0.00692198,0.0069258,0.00692491,0.00692832,0.0069356,0.00695237,0.00698156,0.0070247,0.00709089,0.0071965,0.0073567}, 64 | {0.00771557,0.00747784,0.00736301,0.00731144,0.00729046,0.00728489,0.00728564,0.00728478,0.00728708,0.00729434,0.00729969,0.00730264,0.0073017,0.00729525,0.00729404,0.00729749,0.00730495,0.00731378,0.00731908,0.00732298,0.00732454,0.00732041,0.00730753,0.00729773,0.00728898,0.00727703,0.00725852,0.00723693,0.00720964,0.00718237}, 65 | {0.00853832,0.00840316,0.00830878,0.00824143,0.00819326,0.00816139,0.00814287,0.00813064,0.00812545,0.0081277,0.00813395,0.00813628,0.00813258,0.00813765,0.00814508,0.00815177,0.0081507,0.00814741,0.00814895,0.00814496,0.00813766,0.00813526,0.00813862,0.00814672,0.00816299,0.00818413,0.00820859,0.00822938,0.00823225,0.0082119}, 66 | {0.00822219,0.00829847,0.00833646,0.0083506,0.00835031,0.00834328,0.00833308,0.00832247,0.00831809,0.00831728,0.00831278,0.00830149,0.00828784,0.00827067,0.00825357,0.00824412,0.00825011,0.0082577,0.00825535,0.00825599,0.00826458,0.00827028,0.00827276,0.00827666,0.00828154,0.00828136,0.00828794,0.00831612,0.00838332,0.00853083}, 67 | {0.00869298,0.00906458,0.00925673,0.0093469,0.00938573,0.00940278,0.00941127,0.00941683,0.00942406,0.00942847,0.00943199,0.00943486,0.00944159,0.0094525,0.00946436,0.00947304,0.00947124,0.00945904,0.00944548,0.00943759,0.00943152,0.0094281,0.00942646,0.00942364,0.00941387,0.00939904,0.00937635,0.00932683,0.00921739,0.00898317}, 68 | {0.0086211,0.00869362,0.00872946,0.00874285,0.0087454,0.00874475,0.0087429,0.0087373,0.00873005,0.00872464,0.008719,0.00871485,0.00871191,0.0087067,0.00870161,0.00868901,0.00867611,0.00866208,0.00865066,0.00864396,0.0086459,0.0086541,0.00866582,0.00867875,0.00868162,0.00867414,0.00864994,0.00860523,0.00852967,0.00838322}, 69 | {0.00825319,0.00815454,0.0080952,0.00805894,0.00803427,0.00801861,0.00801025,0.00801161,0.00801628,0.00801877,0.00802114,0.00802666,0.00802408,0.00802013,0.00802771,0.0080464,0.00805697,0.00805162,0.00804745,0.00804627,0.00803141,0.00801095,0.0079978,0.00799876,0.00800138,0.00800113,0.00800173,0.00798558,0.00795278,0.00790058}, 70 | {0.0081735,0.00824767,0.00826074,0.00825164,0.00823852,0.00822809,0.00822136,0.00821686,0.00821063,0.00820379,0.00819931,0.00819392,0.00818884,0.00819263,0.0081967,0.00819663,0.0081883,0.00818162,0.00818152,0.00818318,0.00818771,0.00820696,0.00823524,0.00826814,0.00830234,0.00833964,0.00838687,0.00843108,0.00846586,0.0084886}, 71 | {0.00848142,0.00821787,0.00806688,0.00798364,0.00794071,0.00792177,0.00791469,0.00790952,0.00790406,0.00789865,0.00789461,0.00789499,0.00790412,0.00792205,0.0079312,0.00793717,0.00793625,0.00793359,0.00793159,0.0079307,0.00793126,0.00793799,0.00794849,0.00796077,0.00797745,0.0079994,0.00804048,0.0081007,0.00818465,0.00829025}, 72 | {0.00818051,0.00827486,0.008329,0.00836539,0.00839246,0.00841224,0.00842696,0.00843307,0.00842669,0.00841449,0.00840465,0.00839899,0.00840217,0.0084092,0.0084147,0.0084095,0.00840073,0.00839612,0.00838989,0.00838344,0.00839014,0.00840846,0.00842708,0.00843614,0.00843109,0.00841405,0.00838981,0.00837057,0.00835886,0.00835859}, 73 | {0.00880887,0.00887836,0.00898243,0.00907586,0.00914661,0.00919544,0.00922653,0.0092393,0.00923864,0.00923315,0.00922492,0.00921712,0.00921045,0.00921246,0.00921386,0.00921088,0.00920387,0.00919777,0.00919036,0.00918236,0.00918127,0.00918882,0.00919371,0.00920007,0.00920147,0.00920703,0.00922195,0.0092603,0.00932008,0.00938829}, 74 | {0.00754102,0.00736217,0.00725871,0.00719609,0.00715775,0.0071342,0.00712166,0.00711263,0.00710602,0.00709894,0.00708892,0.00707564,0.00705958,0.00704515,0.00703047,0.0070243,0.00702758,0.00703595,0.0070362,0.00703025,0.00702638,0.0070273,0.00703038,0.00704532,0.00706899,0.0070943,0.00712536,0.00716604,0.00722668,0.00728564}, 75 | {0.00893579,0.0090543,0.0091018,0.00911643,0.00912062,0.00912281,0.0091245,0.00912488,0.0091221,0.00911222,0.00909924,0.00908861,0.00908499,0.00909758,0.00912337,0.0091494,0.00916554,0.00916707,0.00916373,0.00916603,0.00916498,0.0091588,0.009156,0.00915255,0.00914059,0.00911409,0.00908396,0.00902925,0.00894576,0.00884455}, 76 | {0.00795009,0.00787562,0.00781586,0.00777209,0.00774413,0.00772622,0.00771701,0.00771738,0.00772131,0.00772258,0.00772617,0.00773135,0.00773437,0.00773341,0.00772742,0.00772256,0.00771818,0.0077152,0.00771655,0.00771802,0.0077123,0.00770612,0.00769773,0.00769435,0.00769407,0.00769766,0.00769694,0.00767691,0.00763532,0.00755163}, 77 | {0.00753475,0.00741378,0.00734505,0.00729876,0.00726232,0.00722954,0.00720305,0.00719288,0.00719587,0.00720134,0.00720898,0.00721803,0.007217,0.00720657,0.00719186,0.00718281,0.00718261,0.00718481,0.00719212,0.00719435,0.00718379,0.00716256,0.00713843,0.00711914,0.00710701,0.00709932,0.00708516,0.00707004,0.00705583,0.00705311}, 78 | {0.00972694,0.00966189,0.0096285,0.00960658,0.0095867,0.00956473,0.00954407,0.00953093,0.00952876,0.00953485,0.00954458,0.00955453,0.00955792,0.00954439,0.00951714,0.00949344,0.0094893,0.00949921,0.00950893,0.00951924,0.00952495,0.00951451,0.00947773,0.0094333,0.00939233,0.00935604,0.00930266,0.00922387,0.00909472,0.0089331}, 79 | {0.010209,0.0101152,0.0100646,0.0100462,0.0100437,0.0100505,0.0100579,0.0100555,0.0100477,0.0100524,0.0100579,0.0100577,0.0100622,0.0100619,0.0100457,0.010021,0.00999639,0.0099875,0.00997799,0.00997622,0.0099958,0.0100201,0.0100422,0.0100488,0.0100523,0.0100343,0.0099799,0.00985902,0.00962058,0.00922388}, 80 | {0.00751606,0.00764043,0.00770796,0.00774805,0.00777473,0.00779508,0.00780979,0.00781295,0.00780927,0.00780724,0.00781065,0.00781696,0.00783762,0.00786714,0.00788357,0.0078873,0.00788077,0.00787578,0.00786914,0.00786627,0.00787864,0.00790113,0.0079278,0.00794517,0.00796105,0.00797751,0.00800664,0.00805187,0.00809805,0.00813973}, 81 | {0.00858215,0.008676,0.00873108,0.00876239,0.00877919,0.0087875,0.00879069,0.00878931,0.00879203,0.00879587,0.00879699,0.00879436,0.00879199,0.00879429,0.00879478,0.00879655,0.00880643,0.00881525,0.00881456,0.00881576,0.00881875,0.00881574,0.0088051,0.00878748,0.00877078,0.00876033,0.00876851,0.00880772,0.00888213,0.00902303}, 82 | {0.00887,0.0088996,0.00895433,0.00900537,0.00904399,0.00906639,0.00907807,0.00908352,0.00909081,0.00909294,0.009091,0.00909393,0.00909977,0.00910475,0.00911661,0.00912832,0.00914199,0.00914648,0.00914708,0.00915016,0.00915423,0.00915338,0.00914647,0.00914758,0.00915036,0.00916169,0.00917281,0.00917764,0.00916145,0.00911837}, 83 | {0.0097771,0.0100795,0.010281,0.0104084,0.0104871,0.0105355,0.0105644,0.0105798,0.0105838,0.0105791,0.0105719,0.0105725,0.0105865,0.0106074,0.0106312,0.0106351,0.0106273,0.0106158,0.0106096,0.0105943,0.0105818,0.0105814,0.0105898,0.0105999,0.010595,0.0105737,0.0105237,0.0104488,0.0103547,0.0102176}, 84 | {0.00817792,0.00841369,0.0085629,0.00865848,0.00872168,0.00876396,0.00879216,0.00880552,0.00880708,0.00880392,0.00879695,0.0087869,0.00878167,0.0087771,0.00877138,0.00876147,0.00875286,0.00874815,0.00873892,0.00873676,0.00874923,0.0087715,0.00879201,0.00880605,0.0088178,0.00884375,0.00889946,0.00899429,0.00913764,0.00934017}, 85 | {0.00883274,0.00881958,0.00877983,0.00873783,0.0086995,0.00866809,0.00864146,0.00861868,0.00860058,0.00859111,0.00858875,0.00858879,0.00858471,0.00857597,0.00857376,0.00857434,0.00857695,0.00858153,0.00859018,0.00860069,0.00860243,0.00859379,0.00858601,0.00857947,0.00857596,0.00857748,0.00858308,0.00858707,0.00856886,0.0085154}, 86 | {0.0088984,0.0087711,0.00868456,0.00862534,0.00858961,0.00857446,0.00857526,0.00858537,0.00859031,0.0085918,0.0086018,0.00862037,0.0086402,0.00865727,0.00865856,0.00864848,0.00862688,0.00860968,0.00859777,0.00858204,0.00856926,0.0085781,0.00859572,0.00861775,0.00863784,0.00864626,0.0086348,0.0085915,0.00851014,0.00834124}, 87 | {0.0078012,0.00787306,0.0079136,0.00793376,0.00794202,0.00794268,0.00794225,0.00794615,0.00795168,0.0079533,0.00795634,0.00796418,0.00797299,0.00797962,0.00798292,0.00797732,0.00796628,0.00795188,0.00794479,0.00793433,0.00792366,0.00792059,0.00792011,0.00792502,0.00792577,0.00793248,0.00793628,0.00794842,0.00796284,0.00795301}, 88 | {0.00733179,0.00730508,0.00731856,0.00734958,0.00738298,0.00741198,0.00743514,0.00745041,0.00745285,0.00744778,0.00744374,0.00744238,0.00744759,0.00745825,0.00747035,0.00747833,0.00747753,0.00747203,0.00747,0.00746371,0.00745989,0.00746194,0.00747358,0.00748243,0.00748346,0.00748186,0.00748981,0.00752795,0.00762646,0.00782171}, 89 | {0.00900397,0.00909496,0.00912191,0.00912412,0.00911974,0.0091145,0.00911079,0.00911498,0.0091239,0.00912766,0.0091282,0.0091279,0.00912641,0.00912707,0.00913175,0.00913952,0.00914763,0.00914989,0.00915478,0.00915809,0.00915064,0.00913654,0.00912456,0.00911694,0.00911187,0.00911231,0.00911255,0.00912005,0.00915999,0.00927653}, 90 | {0.0102168,0.0104509,0.0106174,0.0107306,0.0107976,0.010832,0.0108406,0.0108231,0.0108001,0.010792,0.0107856,0.010777,0.0107701,0.0107493,0.0107287,0.0107144,0.0107148,0.0107294,0.0107322,0.0107525,0.010787,0.0107969,0.0107769,0.0107432,0.010711,0.0106741,0.0106244,0.0105448,0.0103752,0.0100836}, 91 | {0.00880721,0.00874214,0.00868002,0.00862884,0.0085899,0.00856332,0.00854509,0.00852939,0.00851365,0.00850797,0.00850584,0.00849693,0.00848222,0.00846427,0.00843901,0.00841872,0.00840821,0.00841105,0.00840822,0.00841566,0.00843501,0.00845498,0.00846784,0.00847318,0.00847615,0.00846992,0.00846906,0.00846446,0.00845019,0.00843142}, 92 | {0.00823238,0.00805571,0.00794075,0.00786734,0.00782074,0.00779092,0.00777002,0.0077606,0.00775626,0.00775359,0.00775263,0.00775312,0.00775441,0.00775543,0.00775428,0.00775643,0.00775416,0.00774997,0.0077459,0.00774244,0.00773851,0.00773724,0.00773923,0.00774027,0.00773609,0.00771592,0.00768241,0.0076344,0.0075742,0.00750311}, 93 | {0.00785017,0.00803968,0.00815903,0.00823172,0.00827303,0.00829173,0.0082952,0.00829042,0.00829372,0.00830078,0.00830541,0.00830995,0.00831017,0.00830297,0.00829319,0.00828597,0.00829114,0.00829876,0.0083061,0.00831449,0.00831972,0.00831277,0.00829488,0.00828322,0.0082821,0.00829245,0.00830789,0.00833511,0.00837079,0.00843173}, 94 | {0.00864061,0.00848069,0.00840366,0.0083701,0.0083579,0.00835421,0.00835471,0.00835788,0.00836544,0.0083769,0.00838657,0.00839435,0.00839611,0.00839455,0.00838911,0.00838436,0.00838498,0.00838779,0.00839308,0.00839969,0.00840699,0.00840776,0.00839765,0.00838804,0.00838355,0.00839317,0.00840148,0.00839753,0.00835346,0.00825402}, 95 | {0.00781415,0.00775764,0.00772198,0.00769899,0.00768532,0.00767788,0.00767567,0.00768323,0.00769188,0.00769896,0.00770706,0.00771184,0.00770526,0.0076939,0.00768276,0.00767926,0.0076789,0.00768325,0.00769152,0.00769179,0.0076809,0.00767061,0.00766299,0.00765555,0.00764943,0.00763572,0.007612,0.00757646,0.00754679,0.00752799}, 96 | {0.0077854,0.00783047,0.00785889,0.00787832,0.00789328,0.00790354,0.00791202,0.00792476,0.00793735,0.0079437,0.0079503,0.00795953,0.00796434,0.00796115,0.00795628,0.00795416,0.00795182,0.00794659,0.00794842,0.00794966,0.00794195,0.00793549,0.00793299,0.00793096,0.00792738,0.00793714,0.0079567,0.00800459,0.00809351,0.00823585}, 97 | {0.0074581,0.00737995,0.00737281,0.00739473,0.00742425,0.00745012,0.00747022,0.00747925,0.00748072,0.0074833,0.00748858,0.00749438,0.00750255,0.00750437,0.0074968,0.00748255,0.0074715,0.00746848,0.00746625,0.00746649,0.00747753,0.00749108,0.00749887,0.00749128,0.00747804,0.00746457,0.0074497,0.00744229,0.00744667,0.00746611}, 98 | {0.00923514,0.00908802,0.0090132,0.00898222,0.00897337,0.00897288,0.00897302,0.00897107,0.0089662,0.00896756,0.00896578,0.00895653,0.00893981,0.00891763,0.00890213,0.00889128,0.00889143,0.00890099,0.00890844,0.00891932,0.00893521,0.00894309,0.00893473,0.00891334,0.00888201,0.00884528,0.00879972,0.00875699,0.00870287,0.00864399}, 99 | {0.00804684,0.0080773,0.00810221,0.00811784,0.00812692,0.00813155,0.00813143,0.00812348,0.0081152,0.00810749,0.00810013,0.00809452,0.00809941,0.00811147,0.00812021,0.00812612,0.00813454,0.00813844,0.00813396,0.00814089,0.00815367,0.00815774,0.00815089,0.00813822,0.00812323,0.0081026,0.00808279,0.00804732,0.00798728,0.00789101}, 100 | {0.00944003,0.00960968,0.00969074,0.00973056,0.00975063,0.00976167,0.0097683,0.00976432,0.00975898,0.00975953,0.00975705,0.00974781,0.00974481,0.00974906,0.00975723,0.0097661,0.00977497,0.0097831,0.00978262,0.00978695,0.00980254,0.00980834,0.00980613,0.00978684,0.00976087,0.00972998,0.0096986,0.00966787,0.00961381,0.00952383}, 101 | {0.00836106,0.00862158,0.00875706,0.00882463,0.00885539,0.0088653,0.00886546,0.00886048,0.00884953,0.00883769,0.00882661,0.00881605,0.00880527,0.00879206,0.00878552,0.0087858,0.0087931,0.00880448,0.00881135,0.00881633,0.00882673,0.00883136,0.00883366,0.00882728,0.00881239,0.00879142,0.00876652,0.00874226,0.00869199,0.00858671}, 102 | {0.0100216,0.0103206,0.0104832,0.0105725,0.010624,0.0106578,0.0106793,0.0106857,0.0106825,0.0106808,0.0106809,0.0106777,0.010678,0.0106851,0.0106971,0.0106967,0.0106908,0.0106904,0.0106872,0.0106828,0.0106843,0.0106921,0.0106944,0.0106794,0.01064,0.0105753,0.0104788,0.0103522,0.0101757,0.0099435}, 103 | {0.00835019,0.00827548,0.00823937,0.00822678,0.00822673,0.00823369,0.00824387,0.0082522,0.00826008,0.00827237,0.00828891,0.00830022,0.00831149,0.00832118,0.0083242,0.00832241,0.00831848,0.00832296,0.00832748,0.00832794,0.00832959,0.00833447,0.00834021,0.0083434,0.00835946,0.00838259,0.00841168,0.0084363,0.00844884,0.00844881}, 104 | {0.00993882,0.0100932,0.0102173,0.0103054,0.0103617,0.0103976,0.0104177,0.0104252,0.0104271,0.0104279,0.0104277,0.0104265,0.0104258,0.0104301,0.0104249,0.0104137,0.0104024,0.010391,0.0103788,0.010371,0.0103669,0.0103637,0.0103559,0.0103309,0.0102823,0.010206,0.0101085,0.00998256,0.0098051,0.00956307}, 105 | {0.00877139,0.0087868,0.00878205,0.00877229,0.00876085,0.00874895,0.00873603,0.00872822,0.0087244,0.00872102,0.00871614,0.00871067,0.00870559,0.00870196,0.00869318,0.0086852,0.00867948,0.00867817,0.00868239,0.00869062,0.00869437,0.00869529,0.008696,0.00869319,0.00868935,0.00868718,0.00869571,0.00872546,0.00879578,0.0089522}, 106 | {0.00703551,0.00673188,0.00654358,0.00643147,0.00636422,0.00632443,0.00629901,0.00628193,0.00627375,0.00627655,0.0062798,0.00628016,0.00627811,0.00626445,0.00625057,0.00624491,0.00624871,0.00625425,0.00625643,0.00626337,0.006275,0.00627877,0.0062807,0.00628846,0.00630473,0.00632141,0.00633927,0.00637107,0.00642192,0.0065082}, 107 | {0.00879944,0.00872835,0.00868996,0.00867414,0.00867194,0.00867551,0.00867917,0.0086821,0.00868301,0.00868229,0.00868271,0.00868098,0.00868066,0.00868655,0.00868741,0.00868389,0.00867633,0.00867393,0.00867152,0.00867194,0.00867656,0.00868648,0.00869372,0.00869019,0.00868178,0.00866746,0.00865711,0.00865071,0.00866005,0.00871071}, 108 | {0.00742356,0.00749895,0.0075688,0.00763229,0.00768873,0.00773534,0.00777222,0.00780255,0.00782254,0.00783265,0.00784509,0.00786217,0.00787961,0.00789349,0.00790096,0.0079019,0.00789078,0.007877,0.0078693,0.0078603,0.00785235,0.00785871,0.00787782,0.00790632,0.00794494,0.00800399,0.00809359,0.00823337,0.00845837,0.00879183}, 109 | {0.00761389,0.00755138,0.0075057,0.00747341,0.00745167,0.00743847,0.00743249,0.00743337,0.00743738,0.00744086,0.00744306,0.00744167,0.00743566,0.0074332,0.00742549,0.00742107,0.00741894,0.00741914,0.00741922,0.00742005,0.00741969,0.00742152,0.00742122,0.00742055,0.00742615,0.00743788,0.00745817,0.00747779,0.00749896,0.00752313}, 110 | {0.00743572,0.00734547,0.00730547,0.0072921,0.00729068,0.00729336,0.00729622,0.00730156,0.00730859,0.0073133,0.00732015,0.00733484,0.00735538,0.00737185,0.0073799,0.00737887,0.0073741,0.00736769,0.00736736,0.00736237,0.00735058,0.00733808,0.00732973,0.00732835,0.0073297,0.00733532,0.00733584,0.00734183,0.00736087,0.00741251}, 111 | {0.00835918,0.00846782,0.00853272,0.00857191,0.00859747,0.00861366,0.00862468,0.00863292,0.00864041,0.0086437,0.00863792,0.00862351,0.00860143,0.00858589,0.00858524,0.00859089,0.0086001,0.008612,0.00862368,0.00862952,0.00863385,0.00864236,0.00865053,0.00866213,0.00867328,0.00869326,0.00873152,0.00880599,0.00893268,0.00914855}, 112 | {0.00814841,0.00808546,0.00800262,0.0079289,0.00787442,0.00783782,0.00781495,0.00780143,0.00779745,0.00779459,0.00778959,0.00778163,0.00777806,0.00778524,0.00779371,0.00780162,0.00780752,0.00780949,0.00780894,0.00780909,0.00780451,0.00780061,0.00780178,0.00780912,0.00782344,0.00784281,0.00786837,0.00789507,0.00792574,0.00796197}, 113 | {0.00817568,0.00832946,0.00841146,0.00845243,0.0084697,0.00847266,0.00846655,0.00845747,0.00845182,0.00844865,0.00844454,0.00844183,0.00843913,0.00843314,0.00843114,0.00843379,0.00844209,0.00844801,0.00845291,0.00846016,0.00846561,0.00845773,0.00844548,0.00843012,0.00840946,0.00838088,0.00834248,0.00828807,0.00821316,0.00811799}, 114 | {0.00824494,0.00821142,0.00825207,0.00830945,0.00836021,0.00839785,0.00842338,0.00843953,0.00844649,0.00844936,0.0084553,0.00846981,0.00848284,0.00848613,0.00848317,0.00847804,0.00846716,0.0084572,0.00845545,0.00845618,0.00845436,0.00845758,0.00846298,0.00847704,0.00850172,0.00854505,0.00860898,0.0086951,0.00881089,0.00894146}, 115 | {0.00752691,0.00738913,0.00730165,0.00724597,0.00721142,0.00719038,0.00717794,0.00718074,0.00719066,0.00719435,0.00719381,0.00719728,0.00719752,0.0071982,0.00720888,0.00722439,0.00723383,0.00722709,0.00722816,0.00722703,0.00721049,0.0071957,0.00719667,0.00721808,0.00724622,0.00728421,0.00732643,0.00737688,0.00745412,0.00756747}, 116 | {0.0095169,0.00952029,0.00948913,0.00945296,0.00942086,0.0093919,0.00936909,0.00936011,0.00935637,0.0093565,0.00935851,0.00935587,0.00934702,0.0093351,0.00932156,0.00931039,0.00930629,0.00931248,0.00932241,0.0093261,0.00932061,0.0093044,0.00927575,0.00923023,0.00917527,0.00910397,0.00900187,0.00886315,0.00869294,0.00851027}, 117 | {0.00905123,0.00903012,0.00902729,0.00902263,0.00901464,0.00900402,0.0089935,0.00898712,0.0089923,0.00899423,0.00898448,0.00896623,0.00893522,0.00891099,0.00890882,0.00892279,0.00894655,0.00896013,0.0089703,0.00897631,0.00896814,0.00894943,0.00893216,0.00892357,0.00891088,0.00888994,0.00886056,0.00882086,0.00878368,0.00876218}, 118 | {0.00870153,0.00865085,0.00864089,0.00864885,0.00866594,0.00868299,0.00869888,0.00870767,0.00871127,0.0087076,0.00870328,0.00870431,0.00871699,0.00874104,0.00876003,0.00876935,0.00877159,0.00877779,0.00877934,0.0087735,0.0087713,0.00877809,0.00877987,0.00878235,0.00878801,0.00879801,0.00880754,0.00881752,0.00884322,0.00889981}, 119 | {0.00744222,0.00743069,0.00744062,0.00745596,0.00746953,0.00747793,0.00748196,0.00748687,0.0074911,0.00749481,0.00750229,0.00751296,0.00751924,0.00751584,0.00750494,0.00749584,0.0074874,0.00748398,0.00748852,0.00748931,0.00748708,0.00748972,0.00750036,0.00752276,0.0075606,0.00761827,0.00769491,0.00779602,0.00793062,0.00806755}, 120 | {0.00809,0.00790053,0.00779089,0.00773157,0.0077002,0.00768372,0.00767606,0.00767583,0.00767405,0.00767025,0.00767024,0.00767454,0.00767712,0.00766906,0.00765887,0.00765255,0.00764722,0.00764259,0.00764154,0.00764236,0.00763985,0.00763388,0.00763014,0.0076242,0.00761697,0.00760156,0.00757892,0.00755775,0.00755124,0.00757799}, 121 | {0.00759478,0.00749539,0.00741224,0.00735161,0.00731049,0.00728552,0.00726922,0.00725733,0.00724808,0.00723846,0.00722923,0.00722364,0.007227,0.00723766,0.00725242,0.00726682,0.00727457,0.00727162,0.00726738,0.00726693,0.00726571,0.00726327,0.00726512,0.00726968,0.00727322,0.00727591,0.00728489,0.00729892,0.00731279,0.00730546}, 122 | {0.00744853,0.00742638,0.00742703,0.00743369,0.00743963,0.00744091,0.00743648,0.0074248,0.00741506,0.00741136,0.00740749,0.00740318,0.00739638,0.00739572,0.0073997,0.00740396,0.00740838,0.0074111,0.00741363,0.00742007,0.00742986,0.00743267,0.00742633,0.00741926,0.00741771,0.00743229,0.007463,0.0075053,0.00754021,0.00755862}, 123 | {0.00906735,0.00913575,0.00918241,0.00921007,0.00922639,0.00923505,0.00923994,0.00923932,0.00923589,0.00923101,0.00922777,0.00923069,0.00924465,0.00926818,0.00928288,0.00927806,0.00926404,0.00925131,0.00924003,0.00922675,0.00922418,0.00923658,0.00924557,0.00925607,0.00927108,0.00929973,0.00932547,0.00933896,0.00932132,0.00924601}, 124 | {0.00795064,0.00785978,0.00781961,0.00780583,0.00780626,0.00781233,0.00782176,0.00782596,0.00781982,0.00781003,0.00780691,0.0078067,0.00781409,0.00783141,0.00784565,0.00784977,0.00784457,0.00784462,0.00783997,0.0078327,0.00783083,0.0078358,0.00783595,0.00782099,0.00780112,0.00777269,0.0077334,0.00767268,0.00758734,0.00747448}, 125 | {0.00773852,0.00791484,0.00800954,0.00805249,0.00806788,0.00807033,0.00806968,0.00806869,0.00806535,0.00806054,0.00805808,0.00805545,0.00805031,0.00805259,0.0080587,0.00806986,0.00807823,0.00808334,0.00808693,0.00808449,0.00807867,0.00807372,0.00807475,0.00807931,0.00808575,0.00809284,0.00810902,0.00814298,0.00820787,0.00832869}, 126 | {0.00868237,0.00867878,0.00865667,0.00863023,0.00860903,0.00859407,0.00858633,0.00858444,0.00858327,0.008583,0.00858782,0.00859463,0.00860012,0.00860409,0.00859754,0.00858535,0.00856836,0.00856043,0.00855924,0.00855448,0.00855101,0.00855749,0.00856867,0.00858724,0.00862182,0.00867893,0.00875373,0.00884808,0.0089642,0.0090781}, 127 | {0.00889825,0.0092788,0.00951488,0.00965528,0.00973856,0.00978541,0.0098126,0.00983706,0.00985894,0.00986843,0.00986822,0.00986635,0.00986833,0.00987679,0.00988693,0.00989444,0.00990429,0.00991033,0.0099134,0.00990951,0.00989951,0.00989256,0.009886,0.00988598,0.00988427,0.00988131,0.00986968,0.00984868,0.00984009,0.00985286}, 128 | } 129 | 130 | target = {74, 119, 118, 81, 117, 95, 7} 131 | 132 | fvs = torch.Tensor{ 133 | {-4.7219,-9.40192,-14.0568,-18.6971,-23.3287,-27.9556,-32.5797,-37.2013,-41.8207,-46.4391,-51.0575,-55.6761,-60.2946,-64.9121,-69.5287,-74.1445,-78.7592,-83.3734,-87.9873,-92.6015,-97.2168,-101.833,-106.449,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100}, 134 | {-4.89071,-8.9839,-13.3436,-17.8051,-22.3185,-26.8622,-31.4256,-36.0035,-40.5921,-45.1881,-49.7886,-54.3923,-58.9967,-63.6003,-68.2048,-72.8113,-77.4206,-82.0317,-86.6443,-91.2575,-95.8694,-100.479,-105.087,-109.694,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100}, 135 | {-1e+100,-9.57074,-13.1966,-17.2145,-21.4053,-25.6948,-30.048,-34.4447,-38.8731,-43.3267,-47.8006,-52.2909,-56.7939,-61.3067,-65.8271,-70.3542,-74.8868,-79.4246,-83.9673,-88.5151,-93.0679,-97.6249,-102.186,-106.749,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100}, 136 | {-1e+100,-9.72973,-13.1028,-16.9326,-20.9767,-25.1477,-29.4028,-33.7176,-38.0772,-42.4714,-46.8929,-51.3367,-55.7995,-60.2771,-64.7665,-69.2652,-73.7724,-78.2871,-82.8088,-87.3373,-91.8723,-96.4131,-100.958,-105.507,-110.059,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100}, 137 | {-1e+100,-1e+100,-14.3846,-17.4981,-21.1144,-24.9769,-28.9896,-33.1035,-37.2904,-41.5336,-45.8217,-50.1457,-54.4987,-58.8755,-63.2719,-67.6853,-72.1128,-76.5528,-81.0041,-85.4661,-89.9382,-94.4191,-98.9081,-103.404,-107.905,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100}, 138 | {-1e+100,-1e+100,-14.5808,-17.5465,-21.0391,-24.7953,-28.7152,-32.7489,-36.8671,-41.0501,-45.2833,-49.5574,-53.8646,-58.1985,-62.5557,-66.9341,-71.3316,-75.7445,-80.1715,-84.6112,-89.0616,-93.521,-97.9897,-102.469,-106.957,-111.455,-1e+100,-1e+100,-1e+100,-1e+100}, 139 | {-1e+100,-1e+100,-1e+100,-19.2211,-22.0064,-25.3438,-28.9633,-32.76,-36.6807,-40.6948,-44.782,-48.9271,-53.1186,-57.348,-61.6088,-65.8968,-70.2082,-74.5408,-78.8923,-83.2611,-87.6459,-92.0445,-96.4554,-100.877,-105.308,-109.75,-1e+100,-1e+100,-1e+100,-1e+100}, 140 | {-1e+100,-1e+100,-1e+100,-19.4175,-22.0884,-25.3263,-28.858,-32.5768,-36.4274,-40.3776,-44.4055,-48.4955,-52.6364,-56.82,-61.0398,-65.2911,-69.5703,-73.8739,-78.1979,-82.5408,-86.9008,-91.2754,-95.6631,-100.062,-104.473,-108.893,-113.323,-1e+100,-1e+100,-1e+100}, 141 | {-1e+100,-1e+100,-1e+100,-1e+100,-24.0491,-26.5836,-29.7001,-33.1214,-36.7387,-40.4962,-44.3605,-48.3082,-52.3227,-56.3917,-60.5065,-64.6607,-68.8486,-73.0668,-77.3117,-81.5808,-85.8719,-90.1823,-94.5099,-98.8522,-103.208,-107.576,-111.957,-1e+100,-1e+100,-1e+100}, 142 | {-1e+100,-1e+100,-1e+100,-1e+100,-24.1032,-26.5309,-29.5504,-32.8842,-36.4218,-40.1053,-43.8997,-47.7812,-51.7326,-55.7412,-59.7993,-63.9015,-68.0424,-72.2165,-76.4195,-80.6486,-84.8999,-89.17,-93.4581,-97.7624,-102.081,-106.411,-110.753,-115.107,-1e+100,-1e+100}, 143 | {-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-28.73,-31.0498,-33.9705,-37.2128,-40.6663,-44.2723,-47.9942,-51.8073,-55.6937,-59.6405,-63.6394,-67.6835,-71.768,-75.888,-80.04,-84.2208,-88.4266,-92.6542,-96.901,-101.165,-105.446,-109.741,-114.052,-1e+100,-1e+100}, 144 | {-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-28.8288,-31.0647,-33.9097,-37.0832,-40.4728,-44.0186,-47.6843,-51.4455,-55.2846,-59.1873,-63.1436,-67.1461,-71.1893,-75.2694,-79.3825,-83.5246,-87.6937,-91.887,-96.1028,-100.339,-104.595,-108.868,-113.157,-117.464,-1e+100}, 145 | {-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-33.4529,-35.5985,-38.3595,-41.4555,-44.7733,-48.2518,-51.8536,-55.5533,-59.3333,-63.1803,-67.0834,-71.0353,-75.0301,-79.0637,-83.1326,-87.2323,-91.3601,-95.5128,-99.6886,-103.886,-108.104,-112.342,-116.597,-1e+100}, 146 | {-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-33.6612,-35.7462,-38.4507,-41.4926,-44.759,-48.1895,-51.7467,-55.4053,-59.1477,-62.9601,-66.8311,-70.7523,-74.718,-78.7226,-82.7626,-86.8352,-90.9378,-95.0673,-99.2215,-103.4,-107.603,-111.831,-116.088,-120.378}, 147 | {-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-38.2828,-40.2895,-42.9215,-45.8963,-49.0995,-52.4697,-55.9687,-59.5712,-63.2598,-67.0204,-70.8423,-74.717,-78.6386,-82.6018,-86.6018,-90.6352,-94.6986,-98.7897,-102.906,-107.048,-111.215,-115.404,-119.616}, 148 | } 149 | 150 | bvs = torch.Tensor{ 151 | {-114.948,-110.711,-106.51,-102.339,-98.1947,-94.0767,-89.9857,-85.923,-81.8907,-77.8915,-73.9287,-70.0069,-66.133,-62.3152,-58.5614,-54.8813,-51.2876,-47.7986,-44.4392,-41.2466,-38.2829,-35.6631,-33.6727,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100}, 152 | {-115.38,-111.101,-106.854,-102.633,-98.4347,-94.2595,-90.1083,-85.982,-81.8822,-77.811,-73.7714,-69.7668,-65.8023,-61.8847,-58.0204,-54.217,-50.4845,-46.8361,-43.2896,-39.8702,-36.6159,-33.5881,-30.9008,-28.8375,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100}, 153 | {-1e+100,-111.875,-107.605,-103.362,-99.1421,-94.9448,-90.7699,-86.619,-82.4942,-78.3975,-74.3314,-70.2994,-66.3058,-62.3561,-58.457,-54.6165,-50.8448,-47.1549,-43.5645,-40.0984,-36.7941,-33.712,-30.9654,-28.8375,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100}, 154 | {-1e+100,-112.504,-108.19,-103.899,-99.627,-95.3746,-91.1418,-86.9298,-82.7399,-78.5736,-74.4334,-70.3223,-66.2439,-62.203,-58.2045,-54.2546,-50.3606,-46.5324,-42.7826,-39.1282,-35.594,-32.2179,-29.0603,-26.2336,-24.0199,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100}, 155 | {-1e+100,-1e+100,-109.078,-104.761,-100.464,-96.1867,-91.9282,-87.69,-83.4733,-79.2792,-75.1096,-70.9674,-66.8565,-62.7816,-58.7477,-54.7599,-50.8253,-46.9537,-43.1572,-39.4523,-35.8635,-32.4284,-29.2071,-26.3104,-24.0199,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100}, 156 | {-1e+100,-1e+100,-109.874,-105.512,-101.165,-96.8339,-92.5192,-88.2222,-83.9432,-79.6829,-75.4435,-71.2271,-67.0369,-62.8771,-58.7513,-54.6633,-50.6182,-46.6234,-42.6879,-38.8233,-35.0462,-31.3812,-27.8641,-24.5534,-21.5597,-19.1627,-1e+100,-1e+100,-1e+100,-1e+100}, 157 | {-1e+100,-1e+100,-1e+100,-106.435,-102.067,-97.7134,-93.3751,-89.0531,-84.7487,-80.4628,-76.1968,-71.9528,-67.7335,-63.5422,-59.3824,-55.2577,-51.1735,-47.1364,-43.1555,-39.242,-35.4116,-31.6877,-28.1053,-24.7227,-21.649,-19.1627,-1e+100,-1e+100,-1e+100,-1e+100}, 158 | {-1e+100,-1e+100,-1e+100,-107.478,-103.064,-98.661,-94.2699,-89.8915,-85.5271,-81.1772,-76.8431,-72.5274,-68.2324,-63.9608,-59.7152,-55.4978,-51.3125,-47.1637,-43.0584,-39.0045,-35.0132,-31.1006,-27.2887,-23.6116,-20.1255,-16.9386,-14.3264,-1e+100,-1e+100,-1e+100}, 159 | {-1e+100,-1e+100,-1e+100,-1e+100,-104.11,-99.686,-95.2732,-90.8724,-86.4847,-82.1104,-77.7502,-73.406,-69.0802,-64.7756,-60.4945,-56.2391,-52.0133,-47.8214,-43.6693,-39.5646,-35.5178,-31.5442,-27.6645,-23.9112,-20.3387,-17.0531,-14.3264,-1e+100,-1e+100,-1e+100}, 160 | {-1e+100,-1e+100,-1e+100,-1e+100,-105.591,-101.121,-96.6601,-92.2075,-87.7639,-83.3294,-78.9049,-74.4922,-70.0939,-65.713,-61.3513,-57.0096,-52.6901,-48.3961,-44.1317,-39.9019,-35.7144,-31.5799,-27.5118,-23.5297,-19.6641,-15.9676,-12.5416,-9.65286,-1e+100,-1e+100}, 161 | {-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-102.732,-98.236,-93.7487,-89.2701,-84.7999,-80.3381,-75.8854,-71.4434,-67.0139,-62.5991,-58.2011,-53.8226,-49.4665,-45.1366,-40.8372,-36.5744,-32.3572,-28.1977,-24.114,-20.1339,-16.3055,-12.7253,-9.65286,-1e+100,-1e+100}, 162 | {-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-104.664,-100.121,-95.584,-91.0514,-86.5237,-82.0013,-77.4848,-72.9751,-68.4727,-63.9791,-59.496,-55.0258,-50.5706,-46.1325,-41.7144,-37.3209,-32.9571,-28.6308,-24.3518,-20.1354,-16.0058,-12.0059,-8.22485,-4.90751,-1e+100}, 163 | {-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-101.882,-97.3157,-92.7548,-88.1981,-83.6451,-79.096,-74.5518,-70.0136,-65.4823,-60.9585,-56.4438,-51.9392,-47.4464,-42.9672,-38.504,-34.0611,-29.6441,-25.2617,-20.926,-16.6569,-12.4895,-8.499,-4.90751,-1e+100}, 164 | {-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-104.578,-99.9631,-95.3479,-90.7335,-86.1199,-81.5075,-76.8967,-72.2878,-67.6805,-63.0745,-58.4704,-53.8686,-49.2695,-44.6744,-40.0847,-35.5019,-30.928,-26.3667,-21.8226,-17.3023,-12.8168,-8.38661,-4.06031,0}, 165 | {-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-1e+100,-101.579,-96.9593,-92.3409,-87.7225,-83.1038,-78.4854,-73.8678,-69.2513,-64.6355,-60.0207,-55.4066,-50.7927,-46.1784,-41.5632,-36.9472,-32.3305,-27.7139,-23.0971,-18.48,-13.8617,-9.24128,-4.61999,0}, 166 | } 167 | 168 | grad_contrast = torch.Tensor{ 169 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 170 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 171 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 172 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 173 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 174 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 175 | {-0,-0,-0,-0,-0,-0,-6.98458e-07,-8.79663e-06,-5.9524e-05,-0.000286949,-0.00110384,-0.00359937,-0.0103233,-0.0267012,-0.0634482,-0.140475,-0.292605,-0.578153,-1.09007,-1.9675,-3.40946,-5.68163,-9.11205,-14.0484,-20.7762,-29.3335,-39.1416,-48.3712,-52.716,-43.034}, 176 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 177 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 178 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 179 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 180 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 181 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 182 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 183 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 184 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 185 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 186 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 187 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 188 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 189 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 190 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 191 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 192 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 193 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 194 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 195 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 196 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 197 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 198 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 199 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 200 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 201 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 202 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 203 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 204 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 205 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 206 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 207 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 208 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 209 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 210 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 211 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 212 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 213 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 214 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 215 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 216 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 217 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 218 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 219 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 220 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 221 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 222 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 223 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 224 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 225 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 226 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 227 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 228 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 229 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 230 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 231 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 232 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 233 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 234 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 235 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 236 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 237 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 238 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 239 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 240 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 241 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 242 | {-47.1401,-55.8055,-49.4259,-38.6753,-28.1197,-19.3998,-12.8235,-8.16023,-5.00712,-2.96345,-1.69043,-0.927868,-0.488023,-0.244833,-0.116555,-0.0521782,-0.02173,-0.00830344,-0.00286166,-0.000867593,-0.000222867,-4.56863e-05,-6.67346e-06,-5.23058e-07,-0,-0,-0,-0,-0,-0}, 243 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 244 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 245 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 246 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 247 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 248 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 249 | {-0,-0,-0,-0.0592168,-0.338283,-1.08436,-2.56112,-4.95071,-8.26973,-12.3312,-16.7415,-20.9628,-24.4263,-26.6491,-27.3386,-26.4425,-24.1082,-20.6879,-16.6384,-12.4789,-8.64289,-5.44738,-3.06292,-1.48671,-0.589928,-0.171724,-0.027879,-0,-0,-0}, 250 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 251 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 252 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 253 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 254 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 255 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 256 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 257 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 258 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 259 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 260 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 261 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 262 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 263 | {-0,-0,-0,-0,-0,-7.23177e-05,-0.000726106,-0.003947,-0.0153826,-0.0480768,-0.127829,-0.29965,-0.634195,-1.23294,-2.22779,-3.77262,-6.01763,-9.07371,-12.9693,-17.5849,-22.5843,-27.4219,-31.3181,-33.3815,-32.768,-28.9531,-22.092,-13.3334,-4.98783,-0}, 264 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 265 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 266 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 267 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 268 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 269 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 270 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 271 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 272 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 273 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 274 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 275 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 276 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 277 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 278 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 279 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 280 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 281 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 282 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 283 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 284 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 285 | {-0,-0,-0,-0,-0.00310071,-0.0238786,-0.100908,-0.308925,-0.764711,-1.62138,-3.04571,-5.17878,-8.08482,-11.7018,-15.8286,-20.1208,-24.0921,-27.2005,-28.9551,-29.0178,-27.2338,-23.7465,-19.0396,-13.7788,-8.74504,-4.62572,-1.84613,-0.425872,-0,-0}, 286 | {-0,-0,-0.690029,-2.79347,-6.56235,-11.6537,-17.274,-22.4665,-26.4056,-28.5637,-28.7528,-27.141,-24.1215,-20.2229,-16.0158,-11.9723,-8.4223,-5.5444,-3.3936,-1.91122,-0.974964,-0.440236,-0.169993,-0.0529486,-0.0119061,-0.00146216,-0,-0,-0,-0}, 287 | {-0,-6.28158,-15.9098,-25.1123,-31.4671,-34.1262,-33.3769,-30.1171,-25.4269,-20.2577,-15.296,-10.9693,-7.47322,-4.82769,-2.95264,-1.70331,-0.921611,-0.463538,-0.214133,-0.0893763,-0.0328761,-0.0102645,-0.00256195,-0.000457469,-4.41053e-05,-0,-0,-0,-0,-0}, 288 | {-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0,-0}, 289 | {-72.5639,-56.4625,-51.0991,-49.2839,-48.6275,-48.3875,-48.3172,-48.306,-48.3193,-48.3384,-48.2992,-48.1512,-47.8671,-47.4516,-46.9692,-46.4891,-46.0249,-45.6177,-45.2939,-45.1054,-45.0658,-45.1812,-45.503,-46.0166,-46.7356,-47.6678,-49,-51.3679,-56.5374,-69.2125}, 290 | } 291 | function toMatrix(outputTable) 292 | local net = nn.Sequential() 293 | net:add(nn.JoinTable(1)) 294 | net:add(nn.Reshape(#outputTable, outputTable[1]:size(1))) 295 | return net:forward(outputTable) 296 | end 297 | 298 | -- outputTable = nn.Log():forward(outputTable:t()) 299 | 300 | T = outputTable:size(1) 301 | nrow = outputTable:size(2) 302 | 303 | splitedOutputTable = nn.SplitTable(1):forward(outputTable:t()) 304 | 305 | c_pzx, c_grad = ctc.getCTCCostAndGrad(splitedOutputTable, target) 306 | 307 | c_m = toMatrix(c_grad):float() 308 | 309 | -- print(c_m:t()) 310 | 311 | -- print(torch.dist(c_m:t(), grad_contrast:float())) 312 | 313 | 314 | 315 | eps = 1e-5 316 | ctc_lua = false 317 | 318 | est_grad = torch.Tensor(nrow) 319 | 320 | for t = 1, T do 321 | for i = 1, nrow do 322 | outputTable[t][i] = outputTable[t][i] + eps 323 | 324 | splitedOutputTable = nn.SplitTable(1):forward(outputTable:t()) 325 | loss1, _ = ctc.getCTCCostAndGrad(splitedOutputTable, target) 326 | 327 | outputTable[t][i] = outputTable[t][i] - 2 * eps 328 | splitedOutputTable = nn.SplitTable(1):forward(outputTable:t()) 329 | loss2, _ = ctc.getCTCCostAndGrad(splitedOutputTable, target) 330 | 331 | 332 | outputTable[t][i] = outputTable[t][i] + eps 333 | 334 | est_grad[i] = (loss1 - loss2) / (2 * eps) 335 | 336 | print(est_grad[i], grad_contrast[t][i]) 337 | end 338 | end 339 | 340 | 341 | --------------------------------------------------------------------------------