├── README ├── corpus.txt ├── main.lua └── word2vec.lua /README: -------------------------------------------------------------------------------- 1 | Word2Vec in Torch 2 | Yoon Kim 3 | yhk255@nyu.edu 4 | 5 | Only has the skip-gram architecture with negative sampling. See https://code.google.com/p/word2vec/ for more details. 6 | 7 | Note: This is considerably slower than the word2vec toolkit and gensim implementations. 8 | 9 | Input file is a text file where each line represents one sentence (see corpus.txt for an example) 10 | 11 | Arguments are mostly self-explanatory (see main.lua for default arguments) 12 | 13 | -corpus : text file with the corpus 14 | -window : max window size 15 | -dim : dimensionality of word embeddings 16 | -alpha : exponent to smooth out unigram distribution 17 | -table_size : unigram table size. if you have plenty of RAM, bring this up to 10^8 18 | -neg_samples : number of negative samples for each valid word-context pair 19 | -minfreq : minimum frequency to be included in the vocab 20 | -lr : starting learning rate 21 | -min_lr : minimum learning rate--lr will linearly decay to this value 22 | -epochs : number of epochs to run 23 | -stream : whether to stream text data from HD or store in memory (1 = stream, 0 = not) 24 | -gpu : whether to use gpu (1 = use gpu, 0 = not) 25 | 26 | For example: 27 | 28 | CPU: 29 | th main.lua -corpus corpus.txt -window 3 -dim 100 -minfreq 10 -stream 1 -gpu 0 30 | 31 | GPU: 32 | th main.lua -corpus corpus.txt -window 3 -dim 100 -minfreq 10 -stream 0 -gpu 1 33 | 34 | 35 | -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Config file for skipgram with negative sampling 3 | --]] 4 | 5 | require("io") 6 | require("os") 7 | require("paths") 8 | require("torch") 9 | dofile("word2vec.lua") 10 | 11 | -- Default configuration 12 | config = {} 13 | config.corpus = "corpus.txt" -- input data 14 | config.window = 5 -- (maximum) window size 15 | config.dim = 100 -- dimensionality of word embeddings 16 | config.alpha = 0.75 -- smooth out unigram frequencies 17 | config.table_size = 1e8 -- table size from which to sample neg samples 18 | config.neg_samples = 5 -- number of negative samples for each positive sample 19 | config.minfreq = 10 --threshold for vocab frequency 20 | config.lr = 0.025 -- initial learning rate 21 | config.min_lr = 0.001 -- min learning rate 22 | config.epochs = 3 -- number of epochs to train 23 | config.gpu = 0 -- 1 = use gpu, 0 = use cpu 24 | config.stream = 1 -- 1 = stream from hard drive 0 = copy to memory first 25 | 26 | -- Parse input arguments 27 | cmd = torch.CmdLine() 28 | cmd:option("-corpus", config.corpus) 29 | cmd:option("-window", config.window) 30 | cmd:option("-minfreq", config.minfreq) 31 | cmd:option("-dim", config.dim) 32 | cmd:option("-lr", config.lr) 33 | cmd:option("-min_lr", config.min_lr) 34 | cmd:option("-neg_samples", config.neg_samples) 35 | cmd:option("-table_size", config.table_size) 36 | cmd:option("-epochs", config.epochs) 37 | cmd:option("-gpu", config.gpu) 38 | cmd:option("-stream", config.stream) 39 | params = cmd:parse(arg) 40 | 41 | for param, value in pairs(params) do 42 | config[param] = value 43 | end 44 | 45 | for i,j in pairs(config) do 46 | print(i..": "..j) 47 | end 48 | -- Train model 49 | m = Word2Vec(config) 50 | m:build_vocab(config.corpus) 51 | m:build_table() 52 | 53 | for k = 1, config.epochs do 54 | m.lr = config.lr -- reset learning rate at each epoch 55 | m:train_model(config.corpus) 56 | end 57 | m:print_sim_words({"the","he","can"},5) 58 | -------------------------------------------------------------------------------- /word2vec.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Class for word2vec with skipgram and negative sampling 3 | --]] 4 | 5 | require("sys") 6 | require("nn") 7 | 8 | local Word2Vec = torch.class("Word2Vec") 9 | 10 | function Word2Vec:__init(config) 11 | self.tensortype = torch.getdefaulttensortype() 12 | self.gpu = config.gpu -- 1 if train on gpu, otherwise cpu 13 | self.stream = config.stream -- 1 if stream from hard drive, 0 otherwise 14 | self.neg_samples = config.neg_samples 15 | self.minfreq = config.minfreq 16 | self.dim = config.dim 17 | self.criterion = nn.BCECriterion() -- logistic loss 18 | self.word = torch.IntTensor(1) 19 | self.contexts = torch.IntTensor(1+self.neg_samples) 20 | self.labels = torch.zeros(1+self.neg_samples); self.labels[1] = 1 -- first label is always pos sample 21 | self.window = config.window 22 | self.lr = config.lr 23 | self.min_lr = config.min_lr 24 | self.alpha = config.alpha 25 | self.table_size = config.table_size 26 | self.vocab = {} 27 | self.index2word = {} 28 | self.word2index = {} 29 | self.total_count = 0 30 | end 31 | 32 | -- move to cuda 33 | function Word2Vec:cuda() 34 | require("cunn") 35 | require("cutorch") 36 | cutorch.setDevice(1) 37 | self.word = self.word:cuda() 38 | self.contexts = self.contexts:cuda() 39 | self.labels = self.labels:cuda() 40 | self.criterion:cuda() 41 | self.w2v:cuda() 42 | end 43 | 44 | -- Build vocab frequency, word2index, and index2word from input file 45 | function Word2Vec:build_vocab(corpus) 46 | print("Building vocabulary...") 47 | local start = sys.clock() 48 | local f = io.open(corpus, "r") 49 | local n = 1 50 | for line in f:lines() do 51 | for _, word in ipairs(self:split(line)) do 52 | self.total_count = self.total_count + 1 53 | if self.vocab[word] == nil then 54 | self.vocab[word] = 1 55 | else 56 | self.vocab[word] = self.vocab[word] + 1 57 | end 58 | end 59 | n = n + 1 60 | end 61 | f:close() 62 | -- Delete words that do not meet the minfreq threshold and create word indices 63 | for word, count in pairs(self.vocab) do 64 | if count >= self.minfreq then 65 | self.index2word[#self.index2word+1] = word 66 | self.word2index[word] = #self.index2word 67 | else 68 | self.vocab[word] = nil 69 | end 70 | end 71 | self.vocab_size = #self.index2word 72 | print(string.format("%d words and %d sentences processed in %.2f seconds.", self.total_count, n, sys.clock() - start)) 73 | print(string.format("Vocab size after eliminating words occuring less than %d times: %d", self.minfreq, self.vocab_size)) 74 | -- initialize word/context embeddings now that vocab size is known 75 | self.word_vecs = nn.LookupTable(self.vocab_size, self.dim) -- word embeddings 76 | self.context_vecs = nn.LookupTable(self.vocab_size, self.dim) -- context embeddings 77 | self.word_vecs:reset(0.25); self.context_vecs:reset(0.25) -- rescale N(0,1) 78 | self.w2v = nn.Sequential() 79 | self.w2v:add(nn.ParallelTable()) 80 | self.w2v.modules[1]:add(self.context_vecs) 81 | self.w2v.modules[1]:add(self.word_vecs) 82 | self.w2v:add(nn.MM(false, true)) -- dot prod and sigmoid to get probabilities 83 | self.w2v:add(nn.Sigmoid()) 84 | self.decay = (self.min_lr-self.lr)/(self.total_count*self.window) 85 | end 86 | 87 | -- Build a table of unigram frequencies from which to obtain negative samples 88 | function Word2Vec:build_table() 89 | local start = sys.clock() 90 | local total_count_pow = 0 91 | print("Building a table of unigram frequencies... ") 92 | for _, count in pairs(self.vocab) do 93 | total_count_pow = total_count_pow + count^self.alpha 94 | end 95 | self.table = torch.IntTensor(self.table_size) 96 | local word_index = 1 97 | local word_prob = self.vocab[self.index2word[word_index]]^self.alpha / total_count_pow 98 | for idx = 1, self.table_size do 99 | self.table[idx] = word_index 100 | if idx / self.table_size > word_prob then 101 | word_index = word_index + 1 102 | word_prob = word_prob + self.vocab[self.index2word[word_index]]^self.alpha / total_count_pow 103 | end 104 | if word_index > self.vocab_size then 105 | word_index = word_index - 1 106 | end 107 | end 108 | print(string.format("Done in %.2f seconds.", sys.clock() - start)) 109 | end 110 | 111 | -- Train on word context pairs 112 | function Word2Vec:train_pair(word, contexts) 113 | local p = self.w2v:forward({contexts, word}) 114 | local loss = self.criterion:forward(p, self.labels) 115 | local dl_dp = self.criterion:backward(p, self.labels) 116 | self.w2v:zeroGradParameters() 117 | self.w2v:backward({contexts, word}, dl_dp) 118 | self.w2v:updateParameters(self.lr) 119 | end 120 | 121 | -- Sample negative contexts 122 | function Word2Vec:sample_contexts(context) 123 | self.contexts[1] = context 124 | local i = 0 125 | while i < self.neg_samples do 126 | neg_context = self.table[torch.random(self.table_size)] 127 | if context ~= neg_context then 128 | self.contexts[i+2] = neg_context 129 | i = i + 1 130 | end 131 | end 132 | end 133 | 134 | -- Train on sentences that are streamed from the hard drive 135 | -- Check train_mem function to train from memory (after pre-loading data into tensor) 136 | function Word2Vec:train_stream(corpus) 137 | print("Training...") 138 | local start = sys.clock() 139 | local c = 0 140 | f = io.open(corpus, "r") 141 | for line in f:lines() do 142 | sentence = self:split(line) 143 | for i, word in ipairs(sentence) do 144 | word_idx = self.word2index[word] 145 | if word_idx ~= nil then -- word exists in vocab 146 | local reduced_window = torch.random(self.window) -- pick random window size 147 | self.word[1] = word_idx -- update current word 148 | for j = i - reduced_window, i + reduced_window do -- loop through contexts 149 | local context = sentence[j] 150 | if context ~= nil and j ~= i then -- possible context 151 | context_idx = self.word2index[context] 152 | if context_idx ~= nil then -- valid context 153 | self:sample_contexts(context_idx) -- update pos/neg contexts 154 | self:train_pair(self.word, self.contexts) -- train word context pair 155 | c = c + 1 156 | self.lr = math.max(self.min_lr, self.lr + self.decay) 157 | if c % 100000 ==0 then 158 | print(string.format("%d words trained in %.2f seconds. Learning rate: %.4f", c, sys.clock() - start, self.lr)) 159 | end 160 | end 161 | end 162 | end 163 | end 164 | end 165 | end 166 | end 167 | 168 | -- Row-normalize a matrix 169 | function Word2Vec:normalize(m) 170 | m_norm = torch.zeros(m:size()) 171 | for i = 1, m:size(1) do 172 | m_norm[i] = m[i] / torch.norm(m[i]) 173 | end 174 | return m_norm 175 | end 176 | 177 | -- Return the k-nearest words to a word or a vector based on cosine similarity 178 | -- w can be a string such as "king" or a vector for ("king" - "queen" + "man") 179 | function Word2Vec:get_sim_words(w, k) 180 | if self.word_vecs_norm == nil then 181 | self.word_vecs_norm = self:normalize(self.word_vecs.weight:double()) 182 | end 183 | if type(w) == "string" then 184 | if self.word2index[w] == nil then 185 | print("'"..w.."' does not exist in vocabulary.") 186 | return nil 187 | else 188 | w = self.word_vecs_norm[self.word2index[w]] 189 | end 190 | end 191 | local sim = torch.mv(self.word_vecs_norm, w) 192 | sim, idx = torch.sort(-sim) 193 | local r = {} 194 | for i = 1, k do 195 | r[i] = {self.index2word[idx[i]], -sim[i]} 196 | end 197 | return r 198 | end 199 | 200 | -- print similar words 201 | function Word2Vec:print_sim_words(words, k) 202 | for i = 1, #words do 203 | r = self:get_sim_words(words[i], k) 204 | if r ~= nil then 205 | print("-------"..words[i].."-------") 206 | for j = 1, k do 207 | print(string.format("%s, %.4f", r[j][1], r[j][2])) 208 | end 209 | end 210 | end 211 | end 212 | 213 | -- split on separator 214 | function Word2Vec:split(input, sep) 215 | if sep == nil then 216 | sep = "%s" 217 | end 218 | local t = {}; local i = 1 219 | for str in string.gmatch(input, "([^"..sep.."]+)") do 220 | t[i] = str; i = i + 1 221 | end 222 | return t 223 | end 224 | 225 | -- pre-load data as a torch tensor instead of streaming it. this requires a lot of memory, 226 | -- so if the corpus is huge you should partition into smaller sets 227 | function Word2Vec:preload_data(corpus) 228 | print("Preloading training corpus into tensors (Warning: this takes a lot of memory)") 229 | local start = sys.clock() 230 | local c = 0 231 | f = io.open(corpus, "r") 232 | self.train_words = {}; self.train_contexts = {} 233 | for line in f:lines() do 234 | sentence = self:split(line) 235 | for i, word in ipairs(sentence) do 236 | word_idx = self.word2index[word] 237 | if word_idx ~= nil then -- word exists in vocab 238 | local reduced_window = torch.random(self.window) -- pick random window size 239 | self.word[1] = word_idx -- update current word 240 | for j = i - reduced_window, i + reduced_window do -- loop through contexts 241 | local context = sentence[j] 242 | if context ~= nil and j ~= i then -- possible context 243 | context_idx = self.word2index[context] 244 | if context_idx ~= nil then -- valid context 245 | c = c + 1 246 | self:sample_contexts(context_idx) -- update pos/neg contexts 247 | if self.gpu==1 then 248 | self.train_words[c] = self.word:clone():cuda() 249 | self.train_contexts[c] = self.contexts:clone():cuda() 250 | else 251 | self.train_words[c] = self.word:clone() 252 | self.train_contexts[c] = self.contexts:clone() 253 | end 254 | end 255 | end 256 | end 257 | end 258 | end 259 | end 260 | print(string.format("%d word-contexts processed in %.2f seconds", c, sys.clock() - start)) 261 | end 262 | 263 | -- train from memory. this is needed to speed up GPU training 264 | function Word2Vec:train_mem() 265 | local start = sys.clock() 266 | for i = 1, #self.train_words do 267 | self:train_pair(self.train_words[i], self.train_contexts[i]) 268 | self.lr = math.max(self.min_lr, self.lr + self.decay) 269 | if i%100000==0 then 270 | print(string.format("%d words trained in %.2f seconds. Learning rate: %.4f", i, sys.clock() - start, self.lr)) 271 | end 272 | end 273 | end 274 | 275 | -- train the model using config parameters 276 | function Word2Vec:train_model(corpus) 277 | if self.gpu==1 then 278 | self:cuda() 279 | end 280 | if self.stream==1 then 281 | self:train_stream(corpus) 282 | else 283 | self:preload_data(corpus) 284 | self:train_mem() 285 | end 286 | end 287 | --------------------------------------------------------------------------------