├── LICENSE ├── README.md ├── data_utils.py ├── eval ├── all_wordsim.py ├── data │ ├── EN-MC-30.txt │ └── EN-MTurk-287.txt ├── ranking.py ├── read_write.py └── wordsim.py ├── main.py ├── model.py ├── vector_handle.py └── word2vec.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Using pytorch to implement word2vec algorithm Skip-gram Negative Sampling (SGNS), and refer paper [Distributed Representations of Words and Phrases and their Compositionality](https://arxiv.org/abs/1310.4546v1). 2 | 3 | ## Dependency 4 | - python 3.6 5 | - pytorch 0.4+ 6 | 7 | ## Usage 8 | Run `main.py`. 9 | 10 | Initialize the dataset and model. 11 | 12 | ```python 13 | # init dataset and model 14 | word2vec = Word2Vec(data_path='text8', 15 | vocabulary_size=50000, 16 | embedding_size=300) 17 | 18 | # the index of the whole corpus 19 | print(word2vec.data[:10]) 20 | 21 | # word_count like this [['word', word_count], ...] 22 | # the index of list correspond index of word 23 | print(word2vec.word_count[:10]) 24 | 25 | # index to word 26 | print(word2vec.index2word[34]) 27 | 28 | # word to index 29 | print(word2vec.word2index['hello']) 30 | ``` 31 | 32 | 33 | Train and get the vector. 34 | 35 | ```python 36 | # train model 37 | word2vec.train(train_steps=200000, 38 | skip_window=1, 39 | num_skips=2, 40 | num_neg=20, 41 | output_dir='out/run-1') 42 | 43 | # save vector txt file 44 | word2vec.save_vector_txt(path_dir='out/run-1') 45 | 46 | # get vector list 47 | vector = word2vec.get_list_vector() 48 | print(vector[123]) 49 | print(vector[word2vec.word2index['hello']]) 50 | 51 | # get top k similar word 52 | sim_list = word2vec.most_similar('one', top_k=8) 53 | print(sim_list) 54 | 55 | # load pre-train model 56 | word2vec.load_model('out/run-1/model_step200000.pt') 57 | ``` 58 | 59 | 60 | ## Evaluate 61 | Refer repository [eval-word-vectors](https://github.com/mfaruqui/eval-word-vectors). 62 | Like this: 63 | ``` 64 | eval/wordsim.py vector.txt eval/data/EN-MTurk-287.txt 65 | ``` 66 | ``` 67 | eval/wordsim.py vector.txt eval/data/EN-MC-30.txt 68 | ``` 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import pickle 4 | import random 5 | import urllib 6 | from io import open 7 | import numpy as np 8 | 9 | def maybe_download(filename, expected_bytes): 10 | """ 11 | download text8.zip 12 | :param filename: 13 | :param expected_bytes: 14 | :return: 15 | """ 16 | url = 'http://mattmahoney.net/dc/' 17 | if not os.path.exists(filename): 18 | print('start downloading...') 19 | filename, _ = urllib.request.urlretrieve(url + filename, filename) 20 | statinfo = os.stat(filename) 21 | if statinfo.st_size == expected_bytes: 22 | print('Found and verified', filename) 23 | else: 24 | print(statinfo.st_size) 25 | raise Exception( 26 | 'Failed to verify ' + filename + '. Can you get to it with a browser?') 27 | return filename 28 | 29 | 30 | def read_own_data(filename): 31 | """ 32 | read your own data. 33 | :param filename: 34 | :return: 35 | """ 36 | print('reading data...') 37 | with open(filename, 'r', encoding='utf-8') as f: 38 | data = f.read().split() 39 | print('corpus size', len(data)) 40 | return data 41 | 42 | 43 | def build_dataset(words, n_words): 44 | """ 45 | build dataset 46 | :param words: corpus 47 | :param n_words: learn most common n_words 48 | :return: 49 | - data: [word_index] 50 | - count: [ [word_index, word_count], ] 51 | - dictionary: {word_str: word_index} 52 | - reversed_dictionary: {word_index: word_str} 53 | """ 54 | count = [['UNK', -1]] 55 | count.extend(collections.Counter(words).most_common(n_words - 1)) 56 | dictionary = dict() 57 | for word, _ in count: 58 | dictionary[word] = len(dictionary) 59 | data = list() 60 | unk_count = 0 61 | for word in words: 62 | if word in dictionary: 63 | index = dictionary[word] 64 | else: 65 | index = 0 # UNK index is 0 66 | unk_count += 1 67 | data.append(index) 68 | count[0][1] = unk_count 69 | reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys())) 70 | return data, count, dictionary, reversed_dictionary 71 | 72 | def dataset_tofile(data, count, dictionary, reversed_dictionary): 73 | pickle.dump(data, open("data/data.list", "wb")) 74 | pickle.dump(count, open("data/count.list", "wb")) 75 | pickle.dump(dictionary, open("data/word2index.dict", "wb")) 76 | pickle.dump(reversed_dictionary, open("data/index2word.dict", "wb")) 77 | 78 | def read_fromfile(): 79 | data = pickle.load(open("data/data.list", "rb")) 80 | count = pickle.load(open("data/count.list", "rb")) 81 | dictionary = pickle.load(open("data/word2index.dict", "rb")) 82 | reversed_dictionary = pickle.load(open("data/index2word.dict", "rb")) 83 | return data, count, dictionary, reversed_dictionary 84 | 85 | def noise(vocabs, word_count): 86 | """ 87 | generate noise distribution 88 | :param vocabs: 89 | :param word_count: 90 | :return: 91 | """ 92 | Z = 0.001 93 | unigram_table = [] 94 | num_total_words = sum([c for w, c in word_count]) 95 | for vo in vocabs: 96 | unigram_table.extend([vo] * int(((word_count[vo][1]/num_total_words)**0.75)/Z)) 97 | 98 | print("vocabulary size", len(vocabs)) 99 | print("unigram_table size:", len(unigram_table)) 100 | return unigram_table 101 | 102 | 103 | class DataPipeline: 104 | def __init__(self, data, vocabs, word_count, data_index=0, use_noise_neg=True): 105 | self.data = data 106 | self.data_index = data_index 107 | if use_noise_neg: 108 | self.unigram_table = noise(vocabs, word_count) 109 | else: 110 | self.unigram_table = vocabs 111 | 112 | def get_neg_data(self, batch_size, num, target_inputs): 113 | """ 114 | sample the negative data. Don't use np.random.choice(), it is very slow. 115 | :param batch_size: int 116 | :param num: int 117 | :param target_inputs: [] 118 | :return: 119 | """ 120 | neg = np.zeros((num)) 121 | for i in range(batch_size): 122 | delta = random.sample(self.unigram_table, num) 123 | while target_inputs[i] in delta: 124 | delta = random.sample(self.unigram_table, num) 125 | neg = np.vstack([neg, delta]) 126 | return neg[1: batch_size + 1] 127 | 128 | def generate_batch(self, batch_size, num_skips, skip_window): 129 | """ 130 | get the data batch 131 | :param batch_size: 132 | :param num_skips: 133 | :param skip_window: 134 | :return: target batch and context batch 135 | """ 136 | assert batch_size % num_skips == 0 137 | assert num_skips <= 2 * skip_window 138 | batch = np.ndarray(shape=(batch_size), dtype=np.int32) 139 | labels = np.ndarray(shape=(batch_size), dtype=np.int32) 140 | span = 2 * skip_window + 1 # [ skip_window, target, skip_window ] 141 | buffer = collections.deque(maxlen=span) 142 | for _ in range(span): 143 | buffer.append(self.data[self.data_index]) 144 | self.data_index = (self.data_index + 1) % len(self.data) 145 | for i in range(batch_size // num_skips): 146 | target = skip_window 147 | targets_to_avoid = [skip_window] 148 | for j in range(num_skips): 149 | while target in targets_to_avoid: 150 | target = random.randint(0, span - 1) 151 | targets_to_avoid.append(target) 152 | batch[i * num_skips + j] = buffer[skip_window] 153 | labels[i * num_skips + j] = buffer[target] 154 | buffer.append(self.data[self.data_index]) 155 | self.data_index = (self.data_index + 1) % len(self.data) 156 | self.data_index = (self.data_index + len(self.data) - span) % len(self.data) 157 | return batch, labels -------------------------------------------------------------------------------- /eval/all_wordsim.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/mfaruqui/eval-word-vectors 3 | """ 4 | 5 | import sys 6 | import os 7 | 8 | from read_write import read_word_vectors 9 | from ranking import * 10 | 11 | if __name__=='__main__': 12 | word_vec_file = sys.argv[1] 13 | word_sim_dir = sys.argv[2] 14 | 15 | word_vecs = read_word_vectors(word_vec_file) 16 | print('=================================================================================') 17 | print("%6s" %"Serial", "%20s" % "Dataset", "%15s" % "Num Pairs", "%15s" % "Not found", "%15s" % "Rho") 18 | print('=================================================================================') 19 | 20 | for i, filename in enumerate(os.listdir(word_sim_dir)): 21 | manual_dict, auto_dict = ({}, {}) 22 | not_found, total_size = (0, 0) 23 | for line in open(os.path.join(word_sim_dir, filename),'r'): 24 | line = line.strip().lower() 25 | word1, word2, val = line.split() 26 | if word1 in word_vecs and word2 in word_vecs: 27 | manual_dict[(word1, word2)] = float(val) 28 | auto_dict[(word1, word2)] = cosine_sim(word_vecs[word1], word_vecs[word2]) 29 | else: 30 | not_found += 1 31 | total_size += 1 32 | print("%6s" % str(i+1), "%20s" % filename, "%15s" % str(total_size), end=' ') 33 | print("%15s" % str(not_found), end=' ') 34 | print("%15.4f" % spearmans_rho(assign_ranks(manual_dict), assign_ranks(auto_dict))) 35 | -------------------------------------------------------------------------------- /eval/data/EN-MC-30.txt: -------------------------------------------------------------------------------- 1 | car automobile 3.92 2 | gem jewel 3.84 3 | journey voyage 3.84 4 | boy lad 3.76 5 | coast shore 3.70 6 | asylum madhouse 3.61 7 | magician wizard 3.50 8 | midday noon 3.42 9 | furnace stove 3.11 10 | food fruit 3.08 11 | bird cock 3.05 12 | bird crane 2.97 13 | tool implement 2.95 14 | brother monk 2.82 15 | lad brother 1.66 16 | crane implement 1.68 17 | journey car 1.16 18 | monk oracle 1.10 19 | cemetery woodland 0.95 20 | food rooster 0.89 21 | coast hill 0.87 22 | forest graveyard 0.84 23 | shore woodland 0.63 24 | monk slave 0.55 25 | coast forest 0.42 26 | lad wizard 0.42 27 | chord smile 0.13 28 | glass magician 0.11 29 | rooster voyage 0.08 30 | noon string 0.08 31 | -------------------------------------------------------------------------------- /eval/data/EN-MTurk-287.txt: -------------------------------------------------------------------------------- 1 | episcopal russia 2.75 2 | water shortage 2.714285714 3 | horse wedding 2.266666667 4 | plays losses 3.2 5 | classics advertiser 2.25 6 | latin credit 2.0625 7 | ship ballots 2.3125 8 | mistake error 4.352941176 9 | disease plague 4.117647059 10 | sake shade 2.529411765 11 | saints observatory 1.9375 12 | treaty wheat 1.8125 13 | texas death 1.533333333 14 | republicans challenge 2.3125 15 | body peaceful 2.058823529 16 | admiralty intensity 2.647058824 17 | body improving 2.117647059 18 | heroin marijuana 3.375 19 | scottish commuters 2.6875 20 | apollo myth 2.6 21 | film cautious 2.125 22 | exhibition art 4.117647059 23 | chocolate candy 3.764705882 24 | republic candidate 2.8125 25 | gospel church 4.0625 26 | momentum desirable 2.4 27 | singapore sanctions 2.117647059 28 | english french 3.823529412 29 | exile church 2.941176471 30 | navy coordinator 2.235294118 31 | adventure flood 2.4375 32 | radar plane 3.235294118 33 | pacific ocean 4.266666667 34 | scotch liquor 4.571428571 35 | kennedy gun 3 36 | garfield cat 2.866666667 37 | scale budget 3.5 38 | rhythm blues 3.071428571 39 | rich privileges 3.2 40 | navy withdrawn 1.571428571 41 | marble marching 2.615384615 42 | polo charged 2.125 43 | mark missing 2.333333333 44 | battleship army 4.235294118 45 | medium organization 2.5625 46 | pennsylvania writer 1.466666667 47 | hamlet poet 3.882352941 48 | battle prisoners 3.705882353 49 | guild smith 2.75 50 | mud soil 4.235294118 51 | crime assaulted 3.941176471 52 | mussolini stability 2.133333333 53 | lincoln division 2.4375 54 | slaves insured 2.2 55 | summer winter 4.375 56 | integration dignity 3.058823529 57 | money quota 2.5 58 | honolulu vacation 3.6875 59 | libya forged 2.461538462 60 | cheers musician 2.823529412 61 | session surprises 1.8125 62 | billion campaigning 2.571428571 63 | perjury soybean 2.0625 64 | forswearing perjury 3.3125 65 | costume halloween 3.4375 66 | bulgarian nurses 1.941176471 67 | costume ultimate 2.5 68 | faith judging 2.235294118 69 | france bridges 2.235294118 70 | citizenship casey 2.2 71 | recreation dish 1.4 72 | intelligence troubles 1.625 73 | germany worst 1.4375 74 | chaos death 2.75 75 | sydney hancock 2.857142857 76 | sabbath stevenson 2.214285714 77 | espionage passport 2.3125 78 | political today 1.6875 79 | pipe convertible 2 80 | scouting demonstrate 2.5625 81 | salute patterns 2.235294118 82 | reichstag germany 2.285714286 83 | radiation costumes 1.5625 84 | horace grief 1.764705882 85 | sale rental 3.470588235 86 | open close 4.058823529 87 | photography proving 2.375 88 | propaganda germany 1.705882353 89 | assassination forbes 2.071428571 90 | mirror duel 1.928571429 91 | probability hanging 2.058823529 92 | africa theater 1.5 93 | hell heaven 4.117647059 94 | mussolini italy 3 95 | composer beethoven 3.647058824 96 | minister forthcoming 1.764705882 97 | brussels sweden 3.176470588 98 | neutral parish 1.6 99 | emotion taxation 1.733333333 100 | louisiana simple 2 101 | quarantine disease 3 102 | cannon imprisoned 2.625 103 | bronze suspicion 2 104 | pearl interim 2.352941176 105 | artist paint 4.117647059 106 | relay family 2.0625 107 | art mortality 2.294117647 108 | food investment 2.25 109 | alt tenor 2.692307692 110 | catholics protestant 3.5625 111 | militia landlord 3.0625 112 | battle warships 4.176470588 113 | alcohol fleeing 2.5625 114 | coil ashes 3.117647059 115 | poland russia 4 116 | explosive builders 2.4375 117 | aeronautics plane 4.277777778 118 | charge sentence 3.133333333 119 | pet retiring 2 120 | drink alcohol 4.352941176 121 | stability species 2.375 122 | colonies depression 2 123 | easter preference 2.0625 124 | genius intellect 4.090909091 125 | diamond killed 1.555555556 126 | slavery african 2.8 127 | jurisdiction law 4.454545455 128 | saints repeal 1.555555556 129 | conspiracy campaign 2.166666667 130 | operator extracts 2.214285714 131 | physician action 2.153846154 132 | electronics guess 1.916666667 133 | slavery diamond 2.285714286 134 | quarterback sport 3.142857143 135 | assassination killed 4.285714286 136 | slavery klan 2.230769231 137 | heroin shoot 2.692307692 138 | birds disturbances 1.692307692 139 | palestinians turks 2.5 140 | citizenship court 2.5 141 | immunity violation 2.076923077 142 | alternative contend 2.461538462 143 | chile plates 2.692307692 144 | abraham stranger 1.846153846 145 | kansas city 3.769230769 146 | month year 3.857142857 147 | month day 3.857142857 148 | amateur actor 2.333333333 149 | afghanistan war 3.384615385 150 | transmission maxwell 2.25 151 | manchester ambitious 1.923076923 152 | program battered 1.928571429 153 | drawing music 2.583333333 154 | exile pledges 2.307692308 155 | adventure sixteen 1.538461538 156 | exile threats 2.166666667 157 | concrete wings 1.428571429 158 | seizure bishops 2 159 | submarine sea 3.857142857 160 | villa mayor 2.25 161 | trade farley 2.375 162 | nature forest 3.636363636 163 | chronicle young 1.9 164 | radical bishops 1.818181818 165 | pakistan radical 2.875 166 | fire water 4.266666667 167 | gossip nuisance 3.0625 168 | con examiner 2.266666667 169 | satellite space 3.75 170 | essay boston 2 171 | miniature statue 3.6 172 | spill pollution 3.5 173 | minister council 3.5625 174 | landscape mountain 3.5625 175 | religion remedy 2.5625 176 | ship storm 3.5 177 | college scientist 2.8125 178 | crystal oldest 2.5625 179 | afghanistan wise 2.066666667 180 | trinity religion 3.133333333 181 | homer odyssey 2.857142857 182 | parish clue 2.4375 183 | actress actor 4.0625 184 | patent professionals 2.375 185 | chaos horrible 3.066666667 186 | acre earthquake 2.125 187 | goverment immunity 2 188 | football justice 1.8 189 | gambling money 3.75 190 | corruption nervous 1.875 191 | cardinals villages 2.375 192 | life death 4.103448276 193 | artillery sanctions 2.428571429 194 | jerusalem murdered 2.357142857 195 | cell brick 3.285714286 196 | knowledge promoter 2.642857143 197 | adventure rails 2.571428571 198 | houston crash 2.357142857 199 | oxford subcommittee 2.642857143 200 | militia weapon 3.785714286 201 | manufacturer meat 1.857142857 202 | damages reaction 3.071428571 203 | sea fishing 4.357142857 204 | atomic clash 2.785714286 205 | broadcasting athletics 3 206 | mystery expedition 2.538461538 207 | kremlin soviets 3.166666667 208 | pig blaze 1.75 209 | riverside vietnamese 2.25 210 | bitter protective 1.923076923 211 | disaster announced 2.384615385 212 | pork blaze 2.230769231 213 | feet international 1.916666667 214 | radical uniform 2.5 215 | gossip condemned 2.692307692 216 | mozart wagner 3.166666667 217 | soccer boxing 3.4 218 | radical roles 2.75 219 | rescued slaying 3 220 | researchers tested 3.538461538 221 | sales season 2.307692308 222 | homeless refugees 3.615384615 223 | pakistan repair 1.75 224 | athens painting 2.294117647 225 | tiger woods 3.375 226 | aircraft plane 4.473684211 227 | solar carbon 2.842105263 228 | enterprise bankruptcy 2.5 229 | homer springfield 2.833333333 230 | coin awards 2.166666667 231 | rhodes native 2.25 232 | soccer curator 2.125 233 | gasoline stock 2.888888889 234 | guilt extended 2.105263158 235 | rapid singapore 1.764705882 236 | coin banker 3.631578947 237 | london correspondence 1.944444444 238 | pop sex 2.6 239 | medicine bread 2.176470588 240 | asia animal 1.555555556 241 | pop clubhouse 3.210526316 242 | nazi defensive 2.055555556 243 | earth poles 3.421052632 244 | thailand crowded 2.166666667 245 | day independence 3.473684211 246 | controversy pitch 2.375 247 | stock gasoline 3.166666667 248 | composers mozart 3.833333333 249 | tone piano 3.722222222 250 | paris chef 2.111111111 251 | profession responsible 2.722222222 252 | bankruptcy chronicle 2 253 | lebanon war 2.722222222 254 | israel terror 3.055555556 255 | angola military 2.941176471 256 | chemistry patients 2.357142857 257 | munich constitution 3.071428571 258 | piano theater 3.266666667 259 | poetry artist 3.8 260 | acre burned 1.769230769 261 | religion abortion 2.076923077 262 | jazz music 4.533333333 263 | government transportation 3 264 | color wine 2.533333333 265 | jackson quota 1.692307692 266 | shariff deputy 3.642857143 267 | boat negroes 2 268 | shooting sentenced 2.933333333 269 | republicans friedman 2.416666667 270 | politics brokerage 2.5 271 | russian stalin 3.357142857 272 | love philip 2.5 273 | nuclear plant 3.733333333 274 | jamaica queens 3.076923077 275 | dollar asylum 1.846153846 276 | bridge rowing 2.785714286 277 | berlin germany 4 278 | funeral death 4.714285714 279 | albert einstein 4.266666667 280 | gulf shore 3.857142857 281 | ecuador argentina 3.266666667 282 | britain france 3.714285714 283 | sports score 3.866666667 284 | socialism capitalism 3.785714286 285 | treaty peace 4.166666667 286 | exchange market 4.266666667 287 | marriage anniversary 4.333333333 -------------------------------------------------------------------------------- /eval/ranking.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/mfaruqui/eval-word-vectors 3 | """ 4 | 5 | import math 6 | import numpy 7 | from operator import itemgetter 8 | from numpy.linalg import norm 9 | 10 | EPSILON = 1e-6 11 | 12 | def euclidean(vec1, vec2): 13 | diff = vec1 - vec2 14 | return math.sqrt(diff.dot(diff)) 15 | 16 | def cosine_sim(vec1, vec2): 17 | vec1 += EPSILON * numpy.ones(len(vec1)) 18 | vec2 += EPSILON * numpy.ones(len(vec1)) 19 | return vec1.dot(vec2)/(norm(vec1)*norm(vec2)) 20 | 21 | def assign_ranks(item_dict): 22 | ranked_dict = {} 23 | sorted_list = [(key, val) for (key, val) in sorted(item_dict.items(), 24 | key=itemgetter(1), 25 | reverse=True)] 26 | for i, (key, val) in enumerate(sorted_list): 27 | same_val_indices = [] 28 | for j, (key2, val2) in enumerate(sorted_list): 29 | if val2 == val: 30 | same_val_indices.append(j+1) 31 | if len(same_val_indices) == 1: 32 | ranked_dict[key] = i+1 33 | else: 34 | ranked_dict[key] = 1.*sum(same_val_indices)/len(same_val_indices) 35 | return ranked_dict 36 | 37 | def correlation(dict1, dict2): 38 | avg1 = 1.*sum([val for key, val in dict1.iteritems()])/len(dict1) 39 | avg2 = 1.*sum([val for key, val in dict2.iteritems()])/len(dict2) 40 | numr, den1, den2 = (0., 0., 0.) 41 | for val1, val2 in zip(dict1.itervalues(), dict2.itervalues()): 42 | numr += (val1 - avg1) * (val2 - avg2) 43 | den1 += (val1 - avg1) ** 2 44 | den2 += (val2 - avg2) ** 2 45 | return numr / math.sqrt(den1 * den2) 46 | 47 | def spearmans_rho(ranked_dict1, ranked_dict2): 48 | assert len(ranked_dict1) == len(ranked_dict2) 49 | if len(ranked_dict1) == 0 or len(ranked_dict2) == 0: 50 | return 0. 51 | x_avg = 1.*sum([val for val in ranked_dict1.values()])/len(ranked_dict1) 52 | y_avg = 1.*sum([val for val in ranked_dict2.values()])/len(ranked_dict2) 53 | num, d_x, d_y = (0., 0., 0.) 54 | for key in ranked_dict1.keys(): 55 | xi = ranked_dict1[key] 56 | yi = ranked_dict2[key] 57 | num += (xi-x_avg)*(yi-y_avg) 58 | d_x += (xi-x_avg)**2 59 | d_y += (yi-y_avg)**2 60 | return num/(math.sqrt(d_x*d_y)) 61 | -------------------------------------------------------------------------------- /eval/read_write.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/mfaruqui/eval-word-vectors 3 | """ 4 | 5 | import sys 6 | import gzip 7 | import numpy 8 | import math 9 | from collections import Counter 10 | from operator import itemgetter 11 | 12 | ''' Read all the word vectors and normalize them ''' 13 | def read_word_vectors(filename): 14 | word_vecs = {} 15 | if filename.endswith('.gz'): file_object = gzip.open(filename, 'r') 16 | else: file_object = open(filename, 'r') 17 | 18 | for line_num, line in enumerate(file_object): 19 | line = line.strip().lower() 20 | word = line.split()[0] 21 | word_vecs[word] = numpy.zeros(len(line.split())-1, dtype=float) 22 | for index, vec_val in enumerate(line.split()[1:]): 23 | word_vecs[word][index] = float(vec_val) 24 | ''' normalize weight vector ''' 25 | word_vecs[word] /= math.sqrt((word_vecs[word]**2).sum() + 1e-6) 26 | 27 | sys.stderr.write("Vectors read from: "+filename+" \n") 28 | return word_vecs 29 | -------------------------------------------------------------------------------- /eval/wordsim.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/mfaruqui/eval-word-vectors 3 | """ 4 | 5 | import sys 6 | from read_write import read_word_vectors 7 | from ranking import * 8 | 9 | if __name__=='__main__': 10 | word_vec_file = sys.argv[1] 11 | word_sim_file = sys.argv[2] 12 | 13 | word_vecs = read_word_vectors(word_vec_file) 14 | print('=================================================================================') 15 | print("%15s" % "Num Pairs", "%15s" % "Not found", "%15s" % "Rho") 16 | print('=================================================================================') 17 | 18 | manual_dict, auto_dict = ({}, {}) 19 | not_found, total_size = (0, 0) 20 | for line in open(word_sim_file,'r'): 21 | line = line.strip().lower() 22 | word1, word2, val = line.split() 23 | if word1 in word_vecs and word2 in word_vecs: 24 | manual_dict[(word1, word2)] = float(val) 25 | auto_dict[(word1, word2)] = cosine_sim(word_vecs[word1], word_vecs[word2]) 26 | else: 27 | not_found += 1 28 | total_size += 1 29 | print("%15s" % str(total_size), "%15s" % str(not_found), end=' ') 30 | print("%15.4f" % spearmans_rho(assign_ranks(manual_dict), assign_ranks(auto_dict))) 31 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from word2vec import Word2Vec 2 | 3 | # init dataset and model 4 | word2vec = Word2Vec(data_path='text8', 5 | vocabulary_size=50000, 6 | embedding_size=300) 7 | 8 | # the index of the whole corpus 9 | print(word2vec.data[:10]) 10 | 11 | # word_count like this [['word', word_count], ...] 12 | # the index of list correspond index of word 13 | print(word2vec.word_count[:10]) 14 | 15 | # index to word 16 | print(word2vec.index2word[34]) 17 | 18 | # word to index 19 | print(word2vec.word2index['hello']) 20 | 21 | # train model 22 | word2vec.train(train_steps=200000, 23 | skip_window=1, 24 | num_skips=2, 25 | num_neg=20, 26 | output_dir='out/run-1') 27 | 28 | 29 | # save vector txt file 30 | word2vec.save_vector_txt(path_dir='out/run-1') 31 | 32 | # get vector list 33 | vector = word2vec.get_list_vector() 34 | print(vector[123]) 35 | print(vector[word2vec.word2index['hello']]) 36 | 37 | # get top k similar word 38 | sim_list = word2vec.most_similar('one', top_k=8) 39 | print(sim_list) 40 | 41 | # load pre-train model 42 | # word2vec.load_model('out/run-1/model_step200000.pt') 43 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class SkipGramNeg(nn.Module): 6 | def __init__(self, vocab_size, emb_dim): 7 | super(SkipGramNeg, self).__init__() 8 | self.input_emb = nn.Embedding(vocab_size, emb_dim) 9 | self.output_emb = nn.Embedding(vocab_size, emb_dim) 10 | self.log_sigmoid = nn.LogSigmoid() 11 | 12 | initrange = (2.0 / (vocab_size + emb_dim)) ** 0.5 # Xavier init 13 | self.input_emb.weight.data.uniform_(-initrange, initrange) 14 | self.output_emb.weight.data.uniform_(-0, 0) 15 | 16 | 17 | def forward(self, target_input, context, neg): 18 | """ 19 | :param target_input: [batch_size] 20 | :param context: [batch_size] 21 | :param neg: [batch_size, neg_size] 22 | :return: 23 | """ 24 | # u,v: [batch_size, emb_dim] 25 | v = self.input_emb(target_input) 26 | u = self.output_emb(context) 27 | # positive_val: [batch_size] 28 | positive_val = self.log_sigmoid(torch.sum(u * v, dim=1)).squeeze() 29 | 30 | # u_hat: [batch_size, neg_size, emb_dim] 31 | u_hat = self.output_emb(neg) 32 | # [batch_size, neg_size, emb_dim] x [batch_size, emb_dim, 1] = [batch_size, neg_size, 1] 33 | # neg_vals: [batch_size, neg_size] 34 | neg_vals = torch.bmm(u_hat, v.unsqueeze(2)).squeeze(2) 35 | # neg_val: [batch_size] 36 | neg_val = self.log_sigmoid(-torch.sum(neg_vals, dim=1)).squeeze() 37 | 38 | loss = positive_val + neg_val 39 | return -loss.mean() 40 | 41 | def predict(self, inputs): 42 | return self.input_emb(inputs) 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /vector_handle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def model_to_vector(model, emb_layer_name='input_emb'): 5 | """ 6 | get the wordvec weight 7 | :param model: 8 | :param emb_layer_name: 9 | :return: 10 | """ 11 | sd = model.state_dict() 12 | return sd[emb_layer_name + '.weight'].cpu().numpy().tolist() 13 | 14 | def save_embedding(file_name, embeddings, id2word): 15 | """ 16 | wordvec save to text file 17 | :param file_name: 18 | :param embeddings: 19 | :param id2word: 20 | :return: 21 | """ 22 | fo = open(file_name, 'w') 23 | for idx in range(len(embeddings)): 24 | word = id2word[idx] 25 | embed = embeddings[idx] 26 | embed_list = [str(i) for i in embed] 27 | line_str = ' '.join(embed_list) 28 | fo.write(word + ' ' + line_str + '\n') 29 | 30 | fo.close() 31 | 32 | def nearest(model, vali_examples, vali_size, id2word_dict, top_k=8): 33 | """ 34 | find the nearest word of vali_examples 35 | :param model: model 36 | :param vali_examples: [] 37 | :param vali_size: int 38 | :param id2word_dict: {} 39 | :param top_k: int 40 | :return: 41 | """ 42 | vali_examples = torch.tensor(vali_examples, dtype=torch.long).cuda() 43 | vali_emb = model.predict(vali_examples) 44 | # sim: [batch_size, vocab_size] 45 | sim = torch.mm(vali_emb, model.input_emb.weight.transpose(0, 1)) 46 | for i in range(vali_size): 47 | vali_word = id2word_dict[vali_examples[i].item()] 48 | nearest = (-sim[i, :]).sort()[1][1: top_k + 1] 49 | log_str = 'Nearest to %s:' % vali_word 50 | for k in range(top_k): 51 | close_word = id2word_dict[nearest[k].item()] 52 | log_str = '%s %s,' % (log_str, close_word) 53 | print(log_str) -------------------------------------------------------------------------------- /word2vec.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | from torch.optim import SGD 5 | from data_utils import read_own_data, build_dataset, DataPipeline 6 | from model import SkipGramNeg 7 | from vector_handle import nearest 8 | 9 | class Word2Vec: 10 | def __init__(self, data_path, vocabulary_size, embedding_size, learning_rate=1.0): 11 | 12 | self.corpus = read_own_data(data_path) 13 | 14 | self.data, self.word_count, self.word2index, self.index2word = build_dataset(self.corpus, 15 | vocabulary_size) 16 | self.vocabs = list(set(self.data)) 17 | 18 | self.model: SkipGramNeg = SkipGramNeg(vocabulary_size, embedding_size).cuda() 19 | self.model_optim = SGD(self.model.parameters(), lr=learning_rate) 20 | 21 | 22 | def train(self, train_steps, skip_window=1, num_skips=2, num_neg=20, batch_size=128, data_offest=0, vali_size=3, output_dir='out'): 23 | self.outputdir = os.mkdir(output_dir) 24 | 25 | avg_loss = 0 26 | pipeline = DataPipeline(self.data, self.vocabs ,self.word_count, data_offest) 27 | vali_examples = random.sample(self.vocabs, vali_size) 28 | 29 | for step in range(train_steps): 30 | batch_inputs, batch_labels = pipeline.generate_batch(batch_size, num_skips, skip_window) 31 | batch_neg = pipeline.get_neg_data(batch_size, num_neg, batch_inputs) 32 | 33 | batch_inputs = torch.tensor(batch_inputs, dtype=torch.long).cuda() 34 | batch_labels = torch.tensor(batch_labels, dtype=torch.long).cuda() 35 | batch_neg = torch.tensor(batch_neg, dtype=torch.long).cuda() 36 | 37 | loss = self.model(batch_inputs, batch_labels, batch_neg) 38 | self.model_optim.zero_grad() 39 | loss.backward() 40 | self.model_optim.step() 41 | 42 | avg_loss += loss.item() 43 | 44 | if step % 2000 == 0 and step > 0: 45 | avg_loss /= 2000 46 | print('Average loss at step ', step, ': ', avg_loss) 47 | avg_loss = 0 48 | 49 | if step % 10000 == 0 and vali_size > 0: 50 | nearest(self.model, vali_examples, vali_size, self.index2word, top_k=8) 51 | 52 | # checkpoint 53 | if step % 100000 == 0 and step > 0: 54 | torch.save(self.model.state_dict(), self.outputdir + '/model_step%d.pt' % step) 55 | 56 | # save model at last 57 | torch.save(self.model.state_dict(), self.outputdir + '/model_step%d.pt' % train_steps) 58 | 59 | def save_model(self, out_path): 60 | torch.save(self.model.state_dict(), out_path + '/model.pt') 61 | 62 | def get_list_vector(self): 63 | sd = self.model.state_dict() 64 | return sd['input_emb.weight'].tolist() 65 | 66 | def save_vector_txt(self, path_dir): 67 | embeddings = self.get_list_vector() 68 | fo = open(path_dir + '/vector.txt', 'w') 69 | for idx in range(len(embeddings)): 70 | word = self.index2word[idx] 71 | embed = embeddings[idx] 72 | embed_list = [str(i) for i in embed] 73 | line_str = ' '.join(embed_list) 74 | fo.write(word + ' ' + line_str + '\n') 75 | fo.close() 76 | 77 | def load_model(self, model_path): 78 | self.model.load_state_dict(torch.load(model_path)) 79 | 80 | def vector(self, index): 81 | self.model.predict(index) 82 | 83 | def most_similar(self, word, top_k=8): 84 | index = self.word2index[word] 85 | index = torch.tensor(index, dtype=torch.long).cuda().unsqueeze(0) 86 | emb = self.model.predict(index) 87 | sim = torch.mm(emb, self.model.input_emb.weight.transpose(0, 1)) 88 | nearest = (-sim[0]).sort()[1][1: top_k + 1] 89 | top_list = [] 90 | for k in range(top_k): 91 | close_word = self.index2word[nearest[k].item()] 92 | top_list.append(close_word) 93 | return top_list 94 | 95 | 96 | --------------------------------------------------------------------------------