38 |
39 |
40 | {% endblock %}
41 |
42 | {% block bottom %}
43 |
44 |
105 |
106 |
114 | {% endblock %}
115 |
--------------------------------------------------------------------------------
/web.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import logging
4 | import json
5 | import time
6 |
7 | import tensorflow as tf
8 | from flask import Flask, request, render_template, jsonify, send_from_directory
9 | from flask import Response
10 | from word2vec_optimized import Word2Vec
11 | from instagram import Instagram
12 | from flags import Options
13 |
14 | NEARBY_COUNT = 12
15 |
16 | def get_model():
17 | opts = Options.web()
18 | session = tf.Session()
19 | return Word2Vec(opts, session)
20 |
21 | app = Flask(__name__)
22 | start_time = time.time()
23 | model = get_model()
24 | print("--- model load time: %.1f seconds ---" % (time.time() - start_time))
25 | instagram = Instagram()
26 |
27 | if os.environ.get('MEMCACHEDCLOUD_SERVERS'):
28 | from cache import MemcachedCache
29 | cache = MemcachedCache()
30 | else:
31 | from werkzeug.contrib.cache import SimpleCache
32 | cache = SimpleCache()
33 |
34 | @app.route("/", methods=['GET'])
35 | def main():
36 | q = request.args.get('q') or ''
37 | q = q.strip()
38 |
39 | if not q:
40 | data = {'vocab_size': model.get_vocab_size(), 'emb_dim': model.get_emb_dim() }
41 | return render_template('index.html', query='', data=data)
42 | _add_recent_queries(q)
43 | return query(q)
44 |
45 | def query(q):
46 | data = {}
47 | if q.startswith('!'):
48 | words = q[1:].strip().split()
49 | data['doesnt_match'] = model.get_doesnt_match(*words)
50 | else:
51 | words = q.split()
52 | count = len(words)
53 | m = re.search('([^\-]+)\-([^\+]+)\+(.+)', q)
54 | if m:
55 | words = map(lambda x: x.strip(), m.groups())
56 | data['analogy'] = model.get_analogy(*words)
57 | elif count == 1 and not q.startswith('-'):
58 | data['no_words'] = model.get_no_words(words)
59 | if not data['no_words']:
60 | data['nearby'] = model.get_nearby([q], [], num=NEARBY_COUNT + count)
61 | data['tag'] = q
62 | else:
63 | negative_words = [word[1:] for word in words if word.startswith('-')]
64 | positive_words = [word for word in words if not word.startswith('-')]
65 | data['no_words'] = model.get_no_words(negative_words + positive_words)
66 | if not data['no_words']:
67 | data['nearby'] = model.get_nearby(positive_words, negative_words, num=NEARBY_COUNT + count)
68 | data['tag'] = data['nearby'][0][0]
69 | data['words'] = words
70 | return render_template('query.html', query=q, data=data)
71 |
72 | @app.route("/tags//media.js", methods=['GET'])
73 | def tag_media(tag_name):
74 | key = '/tags/%s/media.js' % tag_name
75 | data = cache.get(key)
76 | if not data:
77 | media = instagram.media(tag_name)
78 | media = {'media': media[:12]}
79 | data = json.dumps(media)
80 | cache.set(key, data, timeout=60*60)
81 | return Response(response=data, status=200, mimetype='application/json')
82 |
83 | @app.route("/tsne.js", methods=['GET'])
84 | def tsne_js():
85 | return send_from_directory(model.get_save_path(), 'tsne.js')
86 |
87 | @app.route("/recent_queries", methods=['GET'])
88 | def recent_queries():
89 | queries = _get_recent_queries()
90 | return render_template('recent_queries.html', queries=queries)
91 |
92 | MAX_RECENT_QUERIES_LENGTH = 500
93 | KEY_RECENT_QUERIES = 'recent_queries'
94 |
95 | def _add_recent_queries(q):
96 | recent_queries = cache.get(KEY_RECENT_QUERIES) or ''
97 | recent_queries += q + '\n'
98 | length = len(recent_queries)
99 | if length > MAX_RECENT_QUERIES_LENGTH:
100 | index = recent_queries.find('\n', length - MAX_RECENT_QUERIES_LENGTH)
101 | recent_queries = recent_queries[index+1]
102 | cache.set(KEY_RECENT_QUERIES, recent_queries)
103 |
104 | def _get_recent_queries():
105 | return (cache.get(KEY_RECENT_QUERIES) or '').strip().split('\n')
106 |
107 |
108 | if __name__ == "__main__":
109 | app.debug = True
110 | app.run(host=os.getenv('IP', '0.0.0.0'),port=int(os.getenv('PORT', 8080)))
111 |
--------------------------------------------------------------------------------
/flags.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import tensorflow as tf
4 |
5 | flags = tf.app.flags
6 |
7 | flags.DEFINE_string("save_path", None, "Directory to write the model.")
8 | flags.DEFINE_string(
9 | "train_data", None,
10 | "Training data. E.g., unzipped file http://mattmahoney.net/dc/text8.zip.")
11 | flags.DEFINE_string(
12 | "eval_data", None, "Analogy questions. "
13 | "https://word2vec.googlecode.com/svn/trunk/questions-words.txt.")
14 | flags.DEFINE_integer("embedding_size", 200, "The embedding dimension size.")
15 | flags.DEFINE_integer(
16 | "epochs_to_train", 15,
17 | "Number of epochs to train. Each epoch processes the training data once "
18 | "completely.")
19 | flags.DEFINE_float("learning_rate", 0.025, "Initial learning rate.")
20 | flags.DEFINE_integer("num_neg_samples", 25,
21 | "Negative samples per training example.")
22 | flags.DEFINE_integer("batch_size", 500,
23 | "Numbers of training examples each step processes "
24 | "(no minibatching).")
25 | flags.DEFINE_integer("concurrent_steps", 12,
26 | "The number of concurrent training steps.")
27 | flags.DEFINE_integer("window_size", 5,
28 | "The number of words to predict to the left and right "
29 | "of the target word.")
30 | flags.DEFINE_integer("min_count", 5,
31 | "The minimum number of word occurrences for it to be "
32 | "included in the vocabulary.")
33 | flags.DEFINE_float("subsample", 1e-3,
34 | "Subsample threshold for word occurrence. Words that appear "
35 | "with higher frequency will be randomly down-sampled. Set "
36 | "to 0 to disable.")
37 | flags.DEFINE_boolean(
38 | "interactive", False,
39 | "If true, enters an IPython interactive session to play with the trained "
40 | "model. E.g., try model.analogy(b'france', b'paris', b'russia') and "
41 | "model.nearby([b'proton', b'elephant', b'maxwell'])")
42 | flags.DEFINE_string("emb_data", None, "Intial vector data.")
43 |
44 | FLAGS = flags.FLAGS
45 |
46 | class Options(object):
47 | """Options used by our word2vec model."""
48 |
49 | def __init__(self):
50 | # Model options.
51 |
52 | # Embedding dimension.
53 | self.emb_dim = FLAGS.embedding_size
54 |
55 | # Training options.
56 |
57 | # The training text file.
58 | self.train_data = FLAGS.train_data
59 |
60 | # Number of negative samples per example.
61 | self.num_samples = FLAGS.num_neg_samples
62 |
63 | # The initial learning rate.
64 | self.learning_rate = FLAGS.learning_rate
65 |
66 | # Number of epochs to train. After these many epochs, the learning
67 | # rate decays linearly to zero and the training stops.
68 | self.epochs_to_train = FLAGS.epochs_to_train
69 |
70 | # Concurrent training steps.
71 | self.concurrent_steps = FLAGS.concurrent_steps
72 |
73 | # Number of examples for one training step.
74 | self.batch_size = FLAGS.batch_size
75 |
76 | # The number of words to predict to the left and right of the target word.
77 | self.window_size = FLAGS.window_size
78 |
79 | # The minimum number of word occurrences for it to be included in the
80 | # vocabulary.
81 | self.min_count = FLAGS.min_count
82 |
83 | # Subsampling threshold for word occurrence.
84 | self.subsample = FLAGS.subsample
85 |
86 | # Where to write out summaries.
87 | self.save_path = FLAGS.save_path
88 |
89 | # initial word embed data
90 | self.emb_data = FLAGS.emb_data
91 |
92 | # Eval options.
93 |
94 | # The text file for eval.
95 | self.eval_data = FLAGS.eval_data
96 |
97 | self.interactive = FLAGS.interactive
98 |
99 | @classmethod
100 | def web(cls):
101 | opts = Options()
102 | opts.save_path = 'train'
103 | opts.emb_dim = 100
104 | opts.interactive = True
105 |
106 | emb_data = 'train/model.vec'
107 | if os.path.isfile(emb_data):
108 | opts.emb_data = emb_data
109 | else:
110 | opts.train_data = 'data/tags.txt'
111 |
112 | with open(os.devnull, 'w') as FNULL:
113 | if subprocess.call(['ls', opts.save_path], stdout=FNULL) != 0:
114 | if subprocess.call(['ls', opts.train_data], stdout=FNULL) == 0:
115 | subprocess.call(['mkdir', opts.save_path])
116 | else:
117 | subprocess.call(['wget', 'https://muik-projects.firebaseapp.com/tf/tag2vec-train.tgz'],
118 | stdout=FNULL)
119 | subprocess.call(['tar', 'xvfz', 'tag2vec-train.tgz'])
120 | subprocess.call(['rm', 'tag2vec-train.tgz'])
121 | return opts
122 |
123 | @classmethod
124 | def train(cls):
125 | opts = Options()
126 | opts.train_data = 'data/tags.txt'
127 | opts.save_path = 'train'
128 | opts.eval_data = 'data/questions-tags.txt'
129 | opts.window_size = 5
130 | opts.min_count = 7
131 | opts.emb_dim = 100
132 | return opts
133 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/word2vec_optimized.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Multi-threaded word2vec unbatched skip-gram model.
17 |
18 | Trains the model described in:
19 | (Mikolov, et. al.) Efficient Estimation of Word Representations in Vector Space
20 | ICLR 2013.
21 | http://arxiv.org/abs/1301.3781
22 | This model does true SGD (i.e. no minibatching). To do this efficiently, custom
23 | ops are used to sequentially process data within a 'batch'.
24 |
25 | The key ops used are:
26 | * skipgram custom op that does input processing.
27 | * neg_train custom op that efficiently calculates and applies the gradient using
28 | true SGD.
29 | """
30 | from __future__ import absolute_import
31 | from __future__ import division
32 | from __future__ import print_function
33 |
34 | import os
35 | import sys
36 | import threading
37 | import time
38 | import random
39 |
40 | from six.moves import xrange # pylint: disable=redefined-builtin
41 |
42 | import numpy as np
43 | import tensorflow as tf
44 | import pandas as pd
45 |
46 | from tensorflow.models.embedding import gen_word2vec as word2vec
47 | from flags import FLAGS, Options
48 |
49 | class Word2Vec(object):
50 | """Word2Vec model (Skipgram)."""
51 | def __init__(self, options, session):
52 | self._options = options
53 | self._session = session
54 | self._word2id = {}
55 | self._id2word = []
56 | if options.emb_data or options.interactive:
57 | self.load_emb()
58 | else:
59 | self.build_graph()
60 | self.build_eval_graph()
61 | if options.eval_data:
62 | self._read_analogies()
63 | if not options.emb_data and not options.interactive:
64 | self.save_vocab()
65 | if not options.emb_data and options.train_data and not options.interactive:
66 | self._load_corpus()
67 |
68 | def _read_analogies(self):
69 | """Reads through the analogy question file.
70 |
71 | Returns:
72 | questions: a [n, 4] numpy array containing the analogy question's
73 | word ids.
74 | questions_skipped: questions skipped due to unknown words.
75 | """
76 | questions = []
77 | questions_skipped = 0
78 | with open(self._options.eval_data, "rb") as analogy_f:
79 | for line in analogy_f:
80 | if line.startswith(b":"): # Skip comments.
81 | continue
82 | words = line.decode('utf-8').strip().lower().split(b" ")
83 | ids = [self._word2id.get(w.strip()) for w in words]
84 | if None in ids or len(ids) != 4:
85 | questions_skipped += 1
86 | else:
87 | questions.append(np.array(ids))
88 | print("Eval analogy file: ", self._options.eval_data)
89 | print("Questions: ", len(questions))
90 | print("Skipped: ", questions_skipped)
91 | self._analogy_questions = np.array(questions, dtype=np.int32)
92 |
93 | def get_no_words(self, words):
94 | return [word for word in words if word not in self._word2id]
95 |
96 | def get_vocab_size(self):
97 | return self._options.vocab_size
98 |
99 | def get_emb_dim(self):
100 | return self._options.emb_dim
101 |
102 | def load_emb(self):
103 | start_time = time.time()
104 | opts = self._options
105 |
106 | if opts.emb_data:
107 | with open(opts.emb_data) as f:
108 | opts.emb_dim = int(f.readline().split()[1])
109 | self._id2word = pd.read_csv(opts.emb_data, delimiter=' ',
110 | skiprows=1, header=0, usecols=[0]).values
111 | self._id2word = np.transpose(self._id2word)[0]
112 | if self._id2word[0] == '':
113 | self._id2word[0] = 'UNK'
114 | else:
115 | self._id2word = np.loadtxt(os.path.join(opts.save_path, "vocab.txt"),
116 | 'str', unpack=True)[0]
117 |
118 | self._id2word = [str(x).decode('utf-8') for x in self._id2word]
119 | for i, w in enumerate(self._id2word):
120 | self._word2id[w] = i
121 | opts.vocab_size = len(self._id2word)
122 |
123 | if opts.emb_data:
124 | def initializer(shape, dtype):
125 | initial_value = pd.read_csv(opts.emb_data, delimiter=' ',
126 | skiprows=1, header=0, usecols=range(1, opts.emb_dim+1)).values
127 | if opts.save_path:
128 | path = os.path.join(opts.save_path, 'tsne.js')
129 | if not os.path.isfile(path):
130 | self._export_tsne(initial_value)
131 | return initial_value
132 | self._w_in = tf.get_variable('w_in', [opts.vocab_size, opts.emb_dim],
133 | initializer=initializer)
134 | else:
135 | self._w_in = tf.get_variable('w_in', [opts.vocab_size, opts.emb_dim])
136 | print("--- embed data load time: %.1f seconds ---" % (time.time() - start_time))
137 |
138 | def build_graph(self):
139 | """Build the model graph."""
140 | opts = self._options
141 |
142 | # The training data. A text file.
143 | (words, counts, words_per_epoch, current_epoch, total_words_processed,
144 | examples, labels) = word2vec.skipgram(filename=opts.train_data,
145 | batch_size=opts.batch_size,
146 | window_size=opts.window_size,
147 | min_count=opts.min_count,
148 | subsample=opts.subsample)
149 | (opts.vocab_words, opts.vocab_counts,
150 | opts.words_per_epoch) = self._session.run([words, counts, words_per_epoch])
151 | opts.vocab_size = len(opts.vocab_words)
152 | print("Data file: ", opts.train_data)
153 | print("Vocab size: ", opts.vocab_size - 1, " + UNK")
154 | print("Words per epoch: ", opts.words_per_epoch)
155 |
156 | opts.vocab_words = map(lambda x: x.decode('utf-8'), opts.vocab_words)
157 | self._id2word = opts.vocab_words
158 | for i, w in enumerate(self._id2word):
159 | self._word2id[w] = i
160 |
161 | # Declare all variables we need.
162 | # Input words embedding: [vocab_size, emb_dim]
163 | w_in = tf.Variable(
164 | tf.random_uniform(
165 | [opts.vocab_size,
166 | opts.emb_dim], -0.5 / opts.emb_dim, 0.5 / opts.emb_dim),
167 | name="w_in")
168 |
169 | # Global step: scalar, i.e., shape [].
170 | w_out = tf.Variable(tf.zeros([opts.vocab_size, opts.emb_dim]), name="w_out")
171 |
172 | # Global step: []
173 | global_step = tf.Variable(0, name="global_step")
174 |
175 | # Linear learning rate decay.
176 | words_to_train = float(opts.words_per_epoch * opts.epochs_to_train)
177 | lr = opts.learning_rate * tf.maximum(
178 | 0.0001,
179 | 1.0 - tf.cast(total_words_processed, tf.float32) / words_to_train)
180 |
181 | examples = tf.placeholder(dtype=tf.int32) # [N]
182 | labels = tf.placeholder(dtype=tf.int32) # [N]
183 |
184 | # Training nodes.
185 | inc = global_step.assign_add(1)
186 | with tf.control_dependencies([inc]):
187 | train = word2vec.neg_train(w_in,
188 | w_out,
189 | examples,
190 | labels,
191 | lr,
192 | vocab_count=opts.vocab_counts.tolist(),
193 | num_negative_samples=opts.num_samples)
194 |
195 | self._w_in = w_in
196 | self._examples = examples
197 | self._labels = labels
198 | self._lr = lr
199 | self._train = train
200 | self.step = global_step
201 | self._epoch = current_epoch
202 | self._words = total_words_processed
203 |
204 | def save_vocab(self):
205 | """Save the vocabulary to a file so the model can be reloaded."""
206 | opts = self._options
207 | with open(os.path.join(opts.save_path, "vocab.txt"), "w") as f:
208 | for i in xrange(opts.vocab_size):
209 | f.write("%s %d\n" % (tf.compat.as_text(opts.vocab_words[i]).encode('utf-8'),
210 | opts.vocab_counts[i]))
211 |
212 | def build_eval_graph(self):
213 | """Build the evaluation graph."""
214 | # Eval graph
215 | opts = self._options
216 |
217 | # Each analogy task is to predict the 4th word (d) given three
218 | # words: a, b, c. E.g., a=italy, b=rome, c=france, we should
219 | # predict d=paris.
220 |
221 | # The eval feeds three vectors of word ids for a, b, c, each of
222 | # which is of size N, where N is the number of analogies we want to
223 | # evaluate in one batch.
224 | analogy_a = tf.placeholder(dtype=tf.int32) # [N]
225 | analogy_b = tf.placeholder(dtype=tf.int32) # [N]
226 | analogy_c = tf.placeholder(dtype=tf.int32) # [N]
227 |
228 | word_ids = tf.placeholder(dtype=tf.int32) # [N]
229 | negative_word_ids = tf.placeholder(dtype=tf.int32) # [N]
230 |
231 | # Normalized word embeddings of shape [vocab_size, emb_dim].
232 | nemb = tf.nn.l2_normalize(self._w_in, 1)
233 |
234 | # Each row of a_emb, b_emb, c_emb is a word's embedding vector.
235 | # They all have the shape [N, emb_dim]
236 | a_emb = tf.gather(nemb, analogy_a) # a's embs
237 | b_emb = tf.gather(nemb, analogy_b) # b's embs
238 | c_emb = tf.gather(nemb, analogy_c) # c's embs
239 |
240 | words_emb = tf.nn.embedding_lookup(nemb, word_ids)
241 | negative_words_emb = tf.nn.embedding_lookup(nemb, negative_word_ids)
242 |
243 | # We expect that d's embedding vectors on the unit hyper-sphere is
244 | # near: c_emb + (b_emb - a_emb), which has the shape [N, emb_dim].
245 | target = c_emb + (b_emb - a_emb)
246 |
247 | # Compute cosine distance between each pair of target and vocab.
248 | # dist has shape [N, vocab_size].
249 | dist = tf.matmul(target, nemb, transpose_b=True)
250 | self._target = target
251 | self._dist = dist
252 |
253 | # For each question (row in dist), find the top 4 words.
254 | _, pred_idx = tf.nn.top_k(dist, 4)
255 |
256 | mean = tf.reduce_mean(words_emb, 0)
257 | mean = tf.reshape(mean, [-1, opts.emb_dim])
258 | mean_dist = 1.0 - tf.matmul(mean, words_emb, transpose_b=True)
259 | _, self._mean_pred_idx = tf.nn.top_k(mean_dist, 1)
260 |
261 | joint_dist = tf.matmul(words_emb, nemb, transpose_b=True)
262 | n_joint_dist = tf.matmul(negative_words_emb, nemb, transpose_b=True)
263 | joint_dist = tf.reduce_sum(joint_dist, 0) - tf.reduce_sum(n_joint_dist, 0)
264 | self._joint_idx = tf.nn.top_k(joint_dist, min(1000, opts.vocab_size))
265 |
266 | # Nodes in the construct graph which are used by training and
267 | # evaluation to run/feed/fetch.
268 | self._analogy_a = analogy_a
269 | self._analogy_b = analogy_b
270 | self._analogy_c = analogy_c
271 | self._word_ids = word_ids
272 | self._negative_word_ids = negative_word_ids
273 | self._analogy_pred_idx = pred_idx
274 |
275 | ckpt = None
276 | self.saver = tf.train.Saver()
277 | if not opts.emb_data:
278 | ckpt = tf.train.latest_checkpoint(os.path.join(opts.save_path))
279 | if ckpt:
280 | self.saver.restore(self._session, ckpt)
281 | print('loaded %s' % ckpt)
282 | else:
283 | # Properly initialize all variables.
284 | self._session.run(tf.initialize_all_variables())
285 |
286 | def _load_corpus(self):
287 | corpus = []
288 | with open(self._options.train_data, 'r') as f:
289 | unk_id = self._word2id['UNK']
290 | def word2id(w):
291 | return w in self._word2id and self._word2id[w] or unk_id
292 | while True:
293 | line = f.readline().decode('utf-8')
294 | if not line:
295 | break
296 | corpus.append([word2id(w) for w in line.split()])
297 | self._corpus = corpus
298 | self._corpus_lines_count = len(corpus)
299 |
300 | def _batch_data(self):
301 | examples = []
302 | labels = []
303 | batch_size = self._options.batch_size
304 | window_size = self._options.window_size
305 | unk_id = self._word2id['UNK']
306 | count = 0
307 | while True:
308 | line = self._corpus[random.randrange(0,self._corpus_lines_count)]
309 | words_count = len(line)
310 | for i, center_id in enumerate(line):
311 | if center_id == unk_id:
312 | continue
313 | start_index = max(0, i-window_size)
314 | end_index = min(words_count, i + 1 + window_size)
315 | outputs = line[start_index:end_index]
316 | outputs = filter(lambda x: x != unk_id and x != center_id, outputs)
317 | outputs_count = len(outputs)
318 | examples += [center_id] * outputs_count
319 | labels += outputs
320 | count += outputs_count
321 | if count >= batch_size:
322 | return examples[:batch_size], labels[:batch_size]
323 |
324 | def _train_thread_body(self):
325 | initial_epoch, = self._session.run([self._epoch])
326 | while True:
327 | examples, labels = self._batch_data()
328 | _, epoch = self._session.run([self._train, self._epoch], {
329 | self._examples: examples,
330 | self._labels: labels
331 | })
332 | if epoch != initial_epoch:
333 | break
334 | # time.sleep(0.02) # for preventing notebook noise
335 |
336 | def train(self):
337 | """Train the model."""
338 | opts = self._options
339 |
340 | initial_epoch, initial_words = self._session.run([self._epoch, self._words])
341 |
342 | workers = []
343 | for _ in xrange(opts.concurrent_steps):
344 | t = threading.Thread(target=self._train_thread_body)
345 | t.start()
346 | workers.append(t)
347 |
348 | last_words, last_time = initial_words, time.time()
349 | while True:
350 | time.sleep(2) # Reports our progress once a while.
351 | (epoch, step, words,
352 | lr) = self._session.run([self._epoch, self.step, self._words, self._lr])
353 | now = time.time()
354 | last_words, last_time, rate = words, now, (words - last_words) / (
355 | now - last_time)
356 | print("Epoch %4d Step %8d: lr = %5.3f words/sec = %8.0f\r" % (epoch, step,
357 | lr, rate),
358 | end="")
359 | sys.stdout.flush()
360 | if epoch != initial_epoch:
361 | break
362 |
363 | for t in workers:
364 | t.join()
365 |
366 | def _predict(self, analogy):
367 | """Predict the top 4 answers for analogy questions."""
368 | idx, = self._session.run([self._analogy_pred_idx], {
369 | self._analogy_a: analogy[:, 0],
370 | self._analogy_b: analogy[:, 1],
371 | self._analogy_c: analogy[:, 2]
372 | })
373 | return idx
374 |
375 | def eval(self):
376 | """Evaluate analogy questions and reports accuracy."""
377 |
378 | # How many questions we get right at precision@1.
379 | correct = 0
380 |
381 | total = self._analogy_questions.shape[0]
382 | start = 0
383 | while start < total:
384 | limit = start + 2500
385 | sub = self._analogy_questions[start:limit, :]
386 | idx = self._predict(sub)
387 | start = limit
388 | for question in xrange(sub.shape[0]):
389 | for j in xrange(4):
390 | if idx[question, j] == sub[question, 3]:
391 | # Bingo! We predicted correctly. E.g., [italy, rome, france, paris].
392 | correct += 1
393 | break
394 | elif idx[question, j] in sub[question, :3]:
395 | # We need to skip words already in the question.
396 | continue
397 | else:
398 | # The correct label is not the precision@1
399 | break
400 | accuracy = correct * 100.0 / total
401 | print()
402 | print("Eval %4d/%d accuracy = %4.1f%%" % (correct, total, accuracy))
403 | return accuracy
404 |
405 | def get_nearby(self, words, negative_words, num=20):
406 | wids = [self._word2id.get(w, 0) for w in words]
407 | n_wids = [self._word2id.get(w, 0) for w in negative_words]
408 | idx = self._session.run(self._joint_idx, {
409 | self._word_ids: wids,
410 | self._negative_word_ids: n_wids
411 | })
412 | results = []
413 | for distance, i in zip(idx[0][:num], idx[1][:num]):
414 | if i in wids:
415 | continue
416 | results.append((self._id2word[i], distance))
417 | return results
418 |
419 | def doesnt_match(self, *words):
420 | wids = [self._word2id.get(w, 0) for w in words]
421 | idx, = self._session.run(self._mean_pred_idx, {
422 | self._word_ids: wids
423 | })
424 | print(words[idx[0]])
425 | return
426 |
427 | def get_doesnt_match(self, *words):
428 | wids = [self._word2id.get(w, 0) for w in words]
429 | idx, = self._session.run(self._mean_pred_idx, {
430 | self._word_ids: wids
431 | })
432 | return words[idx[0]]
433 |
434 | def get_analogy(self, w0, w1, w2):
435 | """Predict word w3 as in w0:w1 vs w2:w3."""
436 | wid = np.array([[self._word2id.get(w, 0) for w in [w0, w1, w2]]])
437 | idx = self._predict(wid)
438 | for c in [self._id2word[i] for i in idx[0, :]]:
439 | if c not in [w0, w1, w2, 'UNK']:
440 | return c
441 | return
442 |
443 | def save(self):
444 | opts = self._options
445 | self.saver.save(self._session, os.path.join(opts.save_path, "model.ckpt"))
446 | all_embs = self._session.run(self._w_in)
447 | self._export_tsne(all_embs)
448 | print('Saved')
449 |
450 | def _export_tsne(self, all_embs):
451 | from sklearn.manifold import TSNE
452 | import json
453 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
454 | plot_only = min(500, all_embs.shape[0])
455 | low_dim_embs = tsne.fit_transform(all_embs[:plot_only,:])
456 | labels = [self._id2word[i] for i in xrange(plot_only)]
457 | embs = [list(e) for e in low_dim_embs]
458 | json_data = json.dumps({'embs': embs, 'labels': labels})
459 | path = os.path.join(self._options.save_path, 'tsne.js')
460 | with open(path, 'w') as f:
461 | f.write(json_data)
462 | print('%s exported' % path)
463 |
464 | def get_save_path(self):
465 | return self._options.save_path
466 |
467 |
468 | def main(_):
469 | """Train a word2vec model."""
470 | opts = Options()
471 | if not opts.train_data and opts.eval_data:
472 | with tf.Graph().as_default(), tf.Session() as session:
473 | model = Word2Vec(opts, session)
474 | model.eval() # Eval analogies.
475 | return
476 |
477 | if not opts.train_data or not opts.save_path or not opts.eval_data:
478 | print("--train_data --eval_data and --save_path must be specified.")
479 | sys.exit(1)
480 |
481 | with tf.Graph().as_default(), tf.Session() as session:
482 | model = Word2Vec(opts, session)
483 | for i in xrange(opts.epochs_to_train):
484 | model.train() # Process one epoch
485 | accuracy = model.eval() # Eval analogies.
486 | if (i+1) % 5 == 0:
487 | model.save()
488 | if opts.epochs_to_train % 5 != 0:
489 | model.save()
490 |
491 |
492 | if __name__ == "__main__":
493 | tf.app.run()
494 |
--------------------------------------------------------------------------------