├── Procfile ├── .gitignore ├── static └── tsne.png ├── requirements.txt ├── templates ├── recent_queries.html ├── query.html ├── layout.html └── index.html ├── cache.py ├── instagram.py ├── README.md ├── web.py ├── flags.py ├── LICENSE.txt └── word2vec_optimized.py /Procfile: -------------------------------------------------------------------------------- 1 | web: gunicorn web:app --log-file - 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | *.sh 4 | .DS_Store 5 | train 6 | data 7 | -------------------------------------------------------------------------------- /static/tsne.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muik/tag2vec/HEAD/static/tsne.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==0.11.1 2 | gunicorn==19.4.5 3 | python-binary-memcached==0.24.6 4 | pandas==0.18.1 5 | https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl 6 | -------------------------------------------------------------------------------- /templates/recent_queries.html: -------------------------------------------------------------------------------- 1 | {% extends "layout.html" %} 2 | {% block content %} 3 | 8 | {% endblock %} 9 | -------------------------------------------------------------------------------- /cache.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib 3 | 4 | import bmemcached 5 | 6 | """ 7 | for Heroku memcached 8 | """ 9 | class MemcachedCache: 10 | def __init__(self): 11 | self._cache = bmemcached.Client(os.environ.get('MEMCACHEDCLOUD_SERVERS').split(','), os.environ.get('MEMCACHEDCLOUD_USERNAME'), os.environ.get('MEMCACHEDCLOUD_PASSWORD')) 12 | 13 | def set(self, key, value, timeout=0): 14 | key = self._key(key) 15 | if timeout > 0: 16 | self._cache.set(key, value, time=timeout) 17 | else: 18 | self._cache.set(key, value) 19 | 20 | def get(self, key): 21 | key = self._key(key) 22 | return self._cache.get(key) 23 | 24 | def _key(self, key): 25 | if type(key) == unicode: 26 | key = key.encode('utf-8') 27 | return urllib.quote(key) 28 | -------------------------------------------------------------------------------- /instagram.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | import re 3 | import json 4 | import logging 5 | import urllib2 6 | 7 | class Instagram: 8 | def __init__(self): 9 | pass 10 | 11 | def media(self, tag): 12 | url = 'https://www.instagram.com/explore/tags/%s/' % tag 13 | response = urllib2.urlopen(url.encode('utf-8')) 14 | html = response.read() 15 | return self.parse(html) 16 | 17 | def parse(self, content): 18 | s = content.index('{"country_code":') 19 | e = content.index(';', s) 20 | dumps = content[s:e] 21 | obj = json.loads(dumps) 22 | nodes = obj['entry_data']['TagPage'][0]['tag']['top_posts']['nodes'] \ 23 | + obj['entry_data']['TagPage'][0]['tag']['media']['nodes'] 24 | """ 25 | print(obj['entry_data']['TagPage'][0]['tag']['top_posts'].keys()) # [u'media', u'content_advisory', u'top_posts', u'name'] 26 | print(obj['entry_data']['TagPage'][0]['tag']['media'].keys()) # [u'count', u'page_info', u'nodes'] 27 | """ 28 | return nodes 29 | 30 | if __name__ == "__main__": 31 | media = Instagram().media(u'맛집') 32 | print(media[0]['date']) 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tag2vec 2 | 인스타그램 태그를 Word2vec으로 학습시킨 태그 벡터 공간입니다. https://tag2vec.herokuapp.com/ 3 | 4 | Tensorflow의 [word2vec_optimized.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/embedding/word2vec_optimized.py) 코드를 사용하여 구현하였습니다. 5 | 6 | ![tsne](./static/tsne.png) 7 | 8 | ## Requirements 9 | - Linux or Mac OS 10 | - [Python](https://www.python.org/) 2.7 11 | - [Tensorflow](https://www.tensorflow.org/) 0.9+ 12 | - [Flask](http://flask.pocoo.org/) 13 | - [sklearn](http://scikit-learn.org/) 14 | 15 | ## 학습 데이터 16 | 인스타그램에서 올려진 글에서 태그를 추출하여 이어지는 태그를 문장으로 취급하였습니다. 17 | 하지만 글들의 이어지는 태그는 연관성이 없어서 skip-gram model에서 window 선택시 해당 문장에서만 선택하도록 학습하였습니다. 18 | 19 | 인스타그램 태그 수집기: https://github.com/muik/instagram-tag-crawler 20 | 21 | ## 기능 22 | 1. 유사 태그: word2vec의 nearby 23 | 2. 다중 유사 태그: 여러 태그의 공통으로 유사한 태그 조회 24 | 3. 부정 태그: 태그와 거리가 먼 태그 조회 25 | 4. 추론: word2vec의 analogy 26 | 5. 별개 태그: gensim word2vec의 doesnt_match 구현 27 | 28 | ## 설치 29 | ``` 30 | $ pip install -t lib -r requirements.txt 31 | $ pip install sklearn 32 | ``` 33 | ## 데모 웹 실행 34 | ``` 35 | $ python web.py 36 | ``` 37 | 38 | ## 학습 데이터 형식 39 | 학습 시키기 위해서는 data/tags.txt, data/questions-tags.txt 파일의 학습데이터가 필요합니다. 40 | #### data/tags.txt 41 | ``` 42 | 여름 먹방 시원한 냉면 냉면맛집 물냉면 43 | 물놀이 계곡 힐링 휴가 강원도 44 | ... 45 | ``` 46 | #### data/questions-tags.txt 47 | ``` 48 | 비냉 비빔냉면 물냉 물냉면 49 | 이태원 경리단길 신사동 가로수길 50 | ... 51 | ``` 52 | 53 | ## 기존 학습데이터 삭제 54 | ``` 55 | $ rm train/model.ckpt* 56 | ``` 57 | ## 학습 실행 58 | ``` 59 | $ python word2vec_optimized.py 60 | ``` 61 | -------------------------------------------------------------------------------- /templates/query.html: -------------------------------------------------------------------------------- 1 | {% extends "layout.html" %} 2 | {% block content %} 3 | {% if data['no_words'] %} 4 |

등록되지 않은 태그: {{ data['no_words']|join(', ') }}

5 |

다른 태그로 검색해보세요.

6 | {% endif %} 7 | {% if 'analogy' in data %} 8 |

{{ data['words'][0] }} - {{ data['words'][1] }} ≈ {{ data['words'][2] }} - {{ data['analogy'] }}

9 | {% endif %} 10 | {% if 'nearby' in data %} 11 |

{{ query }}

12 |
13 |

비슷한 태그

14 | 15 | {% for word, distance in data['nearby'] %} 16 | 17 | 18 | 19 | 23 | 24 | {% endfor %} 25 |
{{ word }}{{ distance|round(2) }} 20 | + 21 | - 22 |
26 |
27 | {% if 'tag' in data %} 28 |
29 |

{{ data['tag'] }} 사진

30 |
31 | 인스타그램에서 더보기 32 |
33 | {% endif %} 34 | {% endif %} 35 | {% if 'doesnt_match' in data %} 36 |

{{ query }}

37 |

가장 거리 먼 태그: {{ data['doesnt_match'] }}

38 | {% endif %} 39 | {% endblock %} 40 | 41 | {% block bottom %} 42 | 59 | {% endblock %} 60 | -------------------------------------------------------------------------------- /templates/layout.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | Tag2vec Alpha 13 | 14 | 28 | {% block head %}{% endblock %} 29 | 30 | 31 |
32 | 41 |
42 | {% block content %}{% endblock %} 43 |
44 | 45 | 48 |
49 | 50 | {% block bottom %}{% endblock %} 51 | 52 | 53 | -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | {% extends "layout.html" %} 2 | 3 | {% block head %} 4 | 8 | {% endblock %} 9 | 10 | {% block content %} 11 |
12 |

인스타그램 태그를 Word2vec으로 학습시킨 태그 벡터 공간입니다. 데이터베이스가 아닌 {{ data['emb_dim'] }}차원의 벡터공간에 있는 태그를 검색해보세요!

13 |

학습 데이터가 아직 많지 않아서 없는 결과나 어설픈 결과가 나올 수 있습니다. 사전 태그 수: {{ "{:,.0f}".format(data['vocab_size']) }}개

14 |

검색 예

15 |

유사 태그 (태그1 태그2 ...)

16 | 21 |

유추 하기 (태그1 - 태그2 + 태그3)

22 | 28 |

유사 태그에 부정 추가 (태그1 -태그2 ...)

29 | 32 |

가장 거리가 먼 태그 (! 태그1 태그2 태그3 ...)

33 | 38 |
39 | 40 | {% endblock %} 41 | 42 | {% block bottom %} 43 | 44 | 105 | 106 | 114 | {% endblock %} 115 | -------------------------------------------------------------------------------- /web.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import logging 4 | import json 5 | import time 6 | 7 | import tensorflow as tf 8 | from flask import Flask, request, render_template, jsonify, send_from_directory 9 | from flask import Response 10 | from word2vec_optimized import Word2Vec 11 | from instagram import Instagram 12 | from flags import Options 13 | 14 | NEARBY_COUNT = 12 15 | 16 | def get_model(): 17 | opts = Options.web() 18 | session = tf.Session() 19 | return Word2Vec(opts, session) 20 | 21 | app = Flask(__name__) 22 | start_time = time.time() 23 | model = get_model() 24 | print("--- model load time: %.1f seconds ---" % (time.time() - start_time)) 25 | instagram = Instagram() 26 | 27 | if os.environ.get('MEMCACHEDCLOUD_SERVERS'): 28 | from cache import MemcachedCache 29 | cache = MemcachedCache() 30 | else: 31 | from werkzeug.contrib.cache import SimpleCache 32 | cache = SimpleCache() 33 | 34 | @app.route("/", methods=['GET']) 35 | def main(): 36 | q = request.args.get('q') or '' 37 | q = q.strip() 38 | 39 | if not q: 40 | data = {'vocab_size': model.get_vocab_size(), 'emb_dim': model.get_emb_dim() } 41 | return render_template('index.html', query='', data=data) 42 | _add_recent_queries(q) 43 | return query(q) 44 | 45 | def query(q): 46 | data = {} 47 | if q.startswith('!'): 48 | words = q[1:].strip().split() 49 | data['doesnt_match'] = model.get_doesnt_match(*words) 50 | else: 51 | words = q.split() 52 | count = len(words) 53 | m = re.search('([^\-]+)\-([^\+]+)\+(.+)', q) 54 | if m: 55 | words = map(lambda x: x.strip(), m.groups()) 56 | data['analogy'] = model.get_analogy(*words) 57 | elif count == 1 and not q.startswith('-'): 58 | data['no_words'] = model.get_no_words(words) 59 | if not data['no_words']: 60 | data['nearby'] = model.get_nearby([q], [], num=NEARBY_COUNT + count) 61 | data['tag'] = q 62 | else: 63 | negative_words = [word[1:] for word in words if word.startswith('-')] 64 | positive_words = [word for word in words if not word.startswith('-')] 65 | data['no_words'] = model.get_no_words(negative_words + positive_words) 66 | if not data['no_words']: 67 | data['nearby'] = model.get_nearby(positive_words, negative_words, num=NEARBY_COUNT + count) 68 | data['tag'] = data['nearby'][0][0] 69 | data['words'] = words 70 | return render_template('query.html', query=q, data=data) 71 | 72 | @app.route("/tags//media.js", methods=['GET']) 73 | def tag_media(tag_name): 74 | key = '/tags/%s/media.js' % tag_name 75 | data = cache.get(key) 76 | if not data: 77 | media = instagram.media(tag_name) 78 | media = {'media': media[:12]} 79 | data = json.dumps(media) 80 | cache.set(key, data, timeout=60*60) 81 | return Response(response=data, status=200, mimetype='application/json') 82 | 83 | @app.route("/tsne.js", methods=['GET']) 84 | def tsne_js(): 85 | return send_from_directory(model.get_save_path(), 'tsne.js') 86 | 87 | @app.route("/recent_queries", methods=['GET']) 88 | def recent_queries(): 89 | queries = _get_recent_queries() 90 | return render_template('recent_queries.html', queries=queries) 91 | 92 | MAX_RECENT_QUERIES_LENGTH = 500 93 | KEY_RECENT_QUERIES = 'recent_queries' 94 | 95 | def _add_recent_queries(q): 96 | recent_queries = cache.get(KEY_RECENT_QUERIES) or '' 97 | recent_queries += q + '\n' 98 | length = len(recent_queries) 99 | if length > MAX_RECENT_QUERIES_LENGTH: 100 | index = recent_queries.find('\n', length - MAX_RECENT_QUERIES_LENGTH) 101 | recent_queries = recent_queries[index+1] 102 | cache.set(KEY_RECENT_QUERIES, recent_queries) 103 | 104 | def _get_recent_queries(): 105 | return (cache.get(KEY_RECENT_QUERIES) or '').strip().split('\n') 106 | 107 | 108 | if __name__ == "__main__": 109 | app.debug = True 110 | app.run(host=os.getenv('IP', '0.0.0.0'),port=int(os.getenv('PORT', 8080))) 111 | -------------------------------------------------------------------------------- /flags.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import tensorflow as tf 4 | 5 | flags = tf.app.flags 6 | 7 | flags.DEFINE_string("save_path", None, "Directory to write the model.") 8 | flags.DEFINE_string( 9 | "train_data", None, 10 | "Training data. E.g., unzipped file http://mattmahoney.net/dc/text8.zip.") 11 | flags.DEFINE_string( 12 | "eval_data", None, "Analogy questions. " 13 | "https://word2vec.googlecode.com/svn/trunk/questions-words.txt.") 14 | flags.DEFINE_integer("embedding_size", 200, "The embedding dimension size.") 15 | flags.DEFINE_integer( 16 | "epochs_to_train", 15, 17 | "Number of epochs to train. Each epoch processes the training data once " 18 | "completely.") 19 | flags.DEFINE_float("learning_rate", 0.025, "Initial learning rate.") 20 | flags.DEFINE_integer("num_neg_samples", 25, 21 | "Negative samples per training example.") 22 | flags.DEFINE_integer("batch_size", 500, 23 | "Numbers of training examples each step processes " 24 | "(no minibatching).") 25 | flags.DEFINE_integer("concurrent_steps", 12, 26 | "The number of concurrent training steps.") 27 | flags.DEFINE_integer("window_size", 5, 28 | "The number of words to predict to the left and right " 29 | "of the target word.") 30 | flags.DEFINE_integer("min_count", 5, 31 | "The minimum number of word occurrences for it to be " 32 | "included in the vocabulary.") 33 | flags.DEFINE_float("subsample", 1e-3, 34 | "Subsample threshold for word occurrence. Words that appear " 35 | "with higher frequency will be randomly down-sampled. Set " 36 | "to 0 to disable.") 37 | flags.DEFINE_boolean( 38 | "interactive", False, 39 | "If true, enters an IPython interactive session to play with the trained " 40 | "model. E.g., try model.analogy(b'france', b'paris', b'russia') and " 41 | "model.nearby([b'proton', b'elephant', b'maxwell'])") 42 | flags.DEFINE_string("emb_data", None, "Intial vector data.") 43 | 44 | FLAGS = flags.FLAGS 45 | 46 | class Options(object): 47 | """Options used by our word2vec model.""" 48 | 49 | def __init__(self): 50 | # Model options. 51 | 52 | # Embedding dimension. 53 | self.emb_dim = FLAGS.embedding_size 54 | 55 | # Training options. 56 | 57 | # The training text file. 58 | self.train_data = FLAGS.train_data 59 | 60 | # Number of negative samples per example. 61 | self.num_samples = FLAGS.num_neg_samples 62 | 63 | # The initial learning rate. 64 | self.learning_rate = FLAGS.learning_rate 65 | 66 | # Number of epochs to train. After these many epochs, the learning 67 | # rate decays linearly to zero and the training stops. 68 | self.epochs_to_train = FLAGS.epochs_to_train 69 | 70 | # Concurrent training steps. 71 | self.concurrent_steps = FLAGS.concurrent_steps 72 | 73 | # Number of examples for one training step. 74 | self.batch_size = FLAGS.batch_size 75 | 76 | # The number of words to predict to the left and right of the target word. 77 | self.window_size = FLAGS.window_size 78 | 79 | # The minimum number of word occurrences for it to be included in the 80 | # vocabulary. 81 | self.min_count = FLAGS.min_count 82 | 83 | # Subsampling threshold for word occurrence. 84 | self.subsample = FLAGS.subsample 85 | 86 | # Where to write out summaries. 87 | self.save_path = FLAGS.save_path 88 | 89 | # initial word embed data 90 | self.emb_data = FLAGS.emb_data 91 | 92 | # Eval options. 93 | 94 | # The text file for eval. 95 | self.eval_data = FLAGS.eval_data 96 | 97 | self.interactive = FLAGS.interactive 98 | 99 | @classmethod 100 | def web(cls): 101 | opts = Options() 102 | opts.save_path = 'train' 103 | opts.emb_dim = 100 104 | opts.interactive = True 105 | 106 | emb_data = 'train/model.vec' 107 | if os.path.isfile(emb_data): 108 | opts.emb_data = emb_data 109 | else: 110 | opts.train_data = 'data/tags.txt' 111 | 112 | with open(os.devnull, 'w') as FNULL: 113 | if subprocess.call(['ls', opts.save_path], stdout=FNULL) != 0: 114 | if subprocess.call(['ls', opts.train_data], stdout=FNULL) == 0: 115 | subprocess.call(['mkdir', opts.save_path]) 116 | else: 117 | subprocess.call(['wget', 'https://muik-projects.firebaseapp.com/tf/tag2vec-train.tgz'], 118 | stdout=FNULL) 119 | subprocess.call(['tar', 'xvfz', 'tag2vec-train.tgz']) 120 | subprocess.call(['rm', 'tag2vec-train.tgz']) 121 | return opts 122 | 123 | @classmethod 124 | def train(cls): 125 | opts = Options() 126 | opts.train_data = 'data/tags.txt' 127 | opts.save_path = 'train' 128 | opts.eval_data = 'data/questions-tags.txt' 129 | opts.window_size = 5 130 | opts.min_count = 7 131 | opts.emb_dim = 100 132 | return opts 133 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /word2vec_optimized.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Multi-threaded word2vec unbatched skip-gram model. 17 | 18 | Trains the model described in: 19 | (Mikolov, et. al.) Efficient Estimation of Word Representations in Vector Space 20 | ICLR 2013. 21 | http://arxiv.org/abs/1301.3781 22 | This model does true SGD (i.e. no minibatching). To do this efficiently, custom 23 | ops are used to sequentially process data within a 'batch'. 24 | 25 | The key ops used are: 26 | * skipgram custom op that does input processing. 27 | * neg_train custom op that efficiently calculates and applies the gradient using 28 | true SGD. 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import os 35 | import sys 36 | import threading 37 | import time 38 | import random 39 | 40 | from six.moves import xrange # pylint: disable=redefined-builtin 41 | 42 | import numpy as np 43 | import tensorflow as tf 44 | import pandas as pd 45 | 46 | from tensorflow.models.embedding import gen_word2vec as word2vec 47 | from flags import FLAGS, Options 48 | 49 | class Word2Vec(object): 50 | """Word2Vec model (Skipgram).""" 51 | def __init__(self, options, session): 52 | self._options = options 53 | self._session = session 54 | self._word2id = {} 55 | self._id2word = [] 56 | if options.emb_data or options.interactive: 57 | self.load_emb() 58 | else: 59 | self.build_graph() 60 | self.build_eval_graph() 61 | if options.eval_data: 62 | self._read_analogies() 63 | if not options.emb_data and not options.interactive: 64 | self.save_vocab() 65 | if not options.emb_data and options.train_data and not options.interactive: 66 | self._load_corpus() 67 | 68 | def _read_analogies(self): 69 | """Reads through the analogy question file. 70 | 71 | Returns: 72 | questions: a [n, 4] numpy array containing the analogy question's 73 | word ids. 74 | questions_skipped: questions skipped due to unknown words. 75 | """ 76 | questions = [] 77 | questions_skipped = 0 78 | with open(self._options.eval_data, "rb") as analogy_f: 79 | for line in analogy_f: 80 | if line.startswith(b":"): # Skip comments. 81 | continue 82 | words = line.decode('utf-8').strip().lower().split(b" ") 83 | ids = [self._word2id.get(w.strip()) for w in words] 84 | if None in ids or len(ids) != 4: 85 | questions_skipped += 1 86 | else: 87 | questions.append(np.array(ids)) 88 | print("Eval analogy file: ", self._options.eval_data) 89 | print("Questions: ", len(questions)) 90 | print("Skipped: ", questions_skipped) 91 | self._analogy_questions = np.array(questions, dtype=np.int32) 92 | 93 | def get_no_words(self, words): 94 | return [word for word in words if word not in self._word2id] 95 | 96 | def get_vocab_size(self): 97 | return self._options.vocab_size 98 | 99 | def get_emb_dim(self): 100 | return self._options.emb_dim 101 | 102 | def load_emb(self): 103 | start_time = time.time() 104 | opts = self._options 105 | 106 | if opts.emb_data: 107 | with open(opts.emb_data) as f: 108 | opts.emb_dim = int(f.readline().split()[1]) 109 | self._id2word = pd.read_csv(opts.emb_data, delimiter=' ', 110 | skiprows=1, header=0, usecols=[0]).values 111 | self._id2word = np.transpose(self._id2word)[0] 112 | if self._id2word[0] == '': 113 | self._id2word[0] = 'UNK' 114 | else: 115 | self._id2word = np.loadtxt(os.path.join(opts.save_path, "vocab.txt"), 116 | 'str', unpack=True)[0] 117 | 118 | self._id2word = [str(x).decode('utf-8') for x in self._id2word] 119 | for i, w in enumerate(self._id2word): 120 | self._word2id[w] = i 121 | opts.vocab_size = len(self._id2word) 122 | 123 | if opts.emb_data: 124 | def initializer(shape, dtype): 125 | initial_value = pd.read_csv(opts.emb_data, delimiter=' ', 126 | skiprows=1, header=0, usecols=range(1, opts.emb_dim+1)).values 127 | if opts.save_path: 128 | path = os.path.join(opts.save_path, 'tsne.js') 129 | if not os.path.isfile(path): 130 | self._export_tsne(initial_value) 131 | return initial_value 132 | self._w_in = tf.get_variable('w_in', [opts.vocab_size, opts.emb_dim], 133 | initializer=initializer) 134 | else: 135 | self._w_in = tf.get_variable('w_in', [opts.vocab_size, opts.emb_dim]) 136 | print("--- embed data load time: %.1f seconds ---" % (time.time() - start_time)) 137 | 138 | def build_graph(self): 139 | """Build the model graph.""" 140 | opts = self._options 141 | 142 | # The training data. A text file. 143 | (words, counts, words_per_epoch, current_epoch, total_words_processed, 144 | examples, labels) = word2vec.skipgram(filename=opts.train_data, 145 | batch_size=opts.batch_size, 146 | window_size=opts.window_size, 147 | min_count=opts.min_count, 148 | subsample=opts.subsample) 149 | (opts.vocab_words, opts.vocab_counts, 150 | opts.words_per_epoch) = self._session.run([words, counts, words_per_epoch]) 151 | opts.vocab_size = len(opts.vocab_words) 152 | print("Data file: ", opts.train_data) 153 | print("Vocab size: ", opts.vocab_size - 1, " + UNK") 154 | print("Words per epoch: ", opts.words_per_epoch) 155 | 156 | opts.vocab_words = map(lambda x: x.decode('utf-8'), opts.vocab_words) 157 | self._id2word = opts.vocab_words 158 | for i, w in enumerate(self._id2word): 159 | self._word2id[w] = i 160 | 161 | # Declare all variables we need. 162 | # Input words embedding: [vocab_size, emb_dim] 163 | w_in = tf.Variable( 164 | tf.random_uniform( 165 | [opts.vocab_size, 166 | opts.emb_dim], -0.5 / opts.emb_dim, 0.5 / opts.emb_dim), 167 | name="w_in") 168 | 169 | # Global step: scalar, i.e., shape []. 170 | w_out = tf.Variable(tf.zeros([opts.vocab_size, opts.emb_dim]), name="w_out") 171 | 172 | # Global step: [] 173 | global_step = tf.Variable(0, name="global_step") 174 | 175 | # Linear learning rate decay. 176 | words_to_train = float(opts.words_per_epoch * opts.epochs_to_train) 177 | lr = opts.learning_rate * tf.maximum( 178 | 0.0001, 179 | 1.0 - tf.cast(total_words_processed, tf.float32) / words_to_train) 180 | 181 | examples = tf.placeholder(dtype=tf.int32) # [N] 182 | labels = tf.placeholder(dtype=tf.int32) # [N] 183 | 184 | # Training nodes. 185 | inc = global_step.assign_add(1) 186 | with tf.control_dependencies([inc]): 187 | train = word2vec.neg_train(w_in, 188 | w_out, 189 | examples, 190 | labels, 191 | lr, 192 | vocab_count=opts.vocab_counts.tolist(), 193 | num_negative_samples=opts.num_samples) 194 | 195 | self._w_in = w_in 196 | self._examples = examples 197 | self._labels = labels 198 | self._lr = lr 199 | self._train = train 200 | self.step = global_step 201 | self._epoch = current_epoch 202 | self._words = total_words_processed 203 | 204 | def save_vocab(self): 205 | """Save the vocabulary to a file so the model can be reloaded.""" 206 | opts = self._options 207 | with open(os.path.join(opts.save_path, "vocab.txt"), "w") as f: 208 | for i in xrange(opts.vocab_size): 209 | f.write("%s %d\n" % (tf.compat.as_text(opts.vocab_words[i]).encode('utf-8'), 210 | opts.vocab_counts[i])) 211 | 212 | def build_eval_graph(self): 213 | """Build the evaluation graph.""" 214 | # Eval graph 215 | opts = self._options 216 | 217 | # Each analogy task is to predict the 4th word (d) given three 218 | # words: a, b, c. E.g., a=italy, b=rome, c=france, we should 219 | # predict d=paris. 220 | 221 | # The eval feeds three vectors of word ids for a, b, c, each of 222 | # which is of size N, where N is the number of analogies we want to 223 | # evaluate in one batch. 224 | analogy_a = tf.placeholder(dtype=tf.int32) # [N] 225 | analogy_b = tf.placeholder(dtype=tf.int32) # [N] 226 | analogy_c = tf.placeholder(dtype=tf.int32) # [N] 227 | 228 | word_ids = tf.placeholder(dtype=tf.int32) # [N] 229 | negative_word_ids = tf.placeholder(dtype=tf.int32) # [N] 230 | 231 | # Normalized word embeddings of shape [vocab_size, emb_dim]. 232 | nemb = tf.nn.l2_normalize(self._w_in, 1) 233 | 234 | # Each row of a_emb, b_emb, c_emb is a word's embedding vector. 235 | # They all have the shape [N, emb_dim] 236 | a_emb = tf.gather(nemb, analogy_a) # a's embs 237 | b_emb = tf.gather(nemb, analogy_b) # b's embs 238 | c_emb = tf.gather(nemb, analogy_c) # c's embs 239 | 240 | words_emb = tf.nn.embedding_lookup(nemb, word_ids) 241 | negative_words_emb = tf.nn.embedding_lookup(nemb, negative_word_ids) 242 | 243 | # We expect that d's embedding vectors on the unit hyper-sphere is 244 | # near: c_emb + (b_emb - a_emb), which has the shape [N, emb_dim]. 245 | target = c_emb + (b_emb - a_emb) 246 | 247 | # Compute cosine distance between each pair of target and vocab. 248 | # dist has shape [N, vocab_size]. 249 | dist = tf.matmul(target, nemb, transpose_b=True) 250 | self._target = target 251 | self._dist = dist 252 | 253 | # For each question (row in dist), find the top 4 words. 254 | _, pred_idx = tf.nn.top_k(dist, 4) 255 | 256 | mean = tf.reduce_mean(words_emb, 0) 257 | mean = tf.reshape(mean, [-1, opts.emb_dim]) 258 | mean_dist = 1.0 - tf.matmul(mean, words_emb, transpose_b=True) 259 | _, self._mean_pred_idx = tf.nn.top_k(mean_dist, 1) 260 | 261 | joint_dist = tf.matmul(words_emb, nemb, transpose_b=True) 262 | n_joint_dist = tf.matmul(negative_words_emb, nemb, transpose_b=True) 263 | joint_dist = tf.reduce_sum(joint_dist, 0) - tf.reduce_sum(n_joint_dist, 0) 264 | self._joint_idx = tf.nn.top_k(joint_dist, min(1000, opts.vocab_size)) 265 | 266 | # Nodes in the construct graph which are used by training and 267 | # evaluation to run/feed/fetch. 268 | self._analogy_a = analogy_a 269 | self._analogy_b = analogy_b 270 | self._analogy_c = analogy_c 271 | self._word_ids = word_ids 272 | self._negative_word_ids = negative_word_ids 273 | self._analogy_pred_idx = pred_idx 274 | 275 | ckpt = None 276 | self.saver = tf.train.Saver() 277 | if not opts.emb_data: 278 | ckpt = tf.train.latest_checkpoint(os.path.join(opts.save_path)) 279 | if ckpt: 280 | self.saver.restore(self._session, ckpt) 281 | print('loaded %s' % ckpt) 282 | else: 283 | # Properly initialize all variables. 284 | self._session.run(tf.initialize_all_variables()) 285 | 286 | def _load_corpus(self): 287 | corpus = [] 288 | with open(self._options.train_data, 'r') as f: 289 | unk_id = self._word2id['UNK'] 290 | def word2id(w): 291 | return w in self._word2id and self._word2id[w] or unk_id 292 | while True: 293 | line = f.readline().decode('utf-8') 294 | if not line: 295 | break 296 | corpus.append([word2id(w) for w in line.split()]) 297 | self._corpus = corpus 298 | self._corpus_lines_count = len(corpus) 299 | 300 | def _batch_data(self): 301 | examples = [] 302 | labels = [] 303 | batch_size = self._options.batch_size 304 | window_size = self._options.window_size 305 | unk_id = self._word2id['UNK'] 306 | count = 0 307 | while True: 308 | line = self._corpus[random.randrange(0,self._corpus_lines_count)] 309 | words_count = len(line) 310 | for i, center_id in enumerate(line): 311 | if center_id == unk_id: 312 | continue 313 | start_index = max(0, i-window_size) 314 | end_index = min(words_count, i + 1 + window_size) 315 | outputs = line[start_index:end_index] 316 | outputs = filter(lambda x: x != unk_id and x != center_id, outputs) 317 | outputs_count = len(outputs) 318 | examples += [center_id] * outputs_count 319 | labels += outputs 320 | count += outputs_count 321 | if count >= batch_size: 322 | return examples[:batch_size], labels[:batch_size] 323 | 324 | def _train_thread_body(self): 325 | initial_epoch, = self._session.run([self._epoch]) 326 | while True: 327 | examples, labels = self._batch_data() 328 | _, epoch = self._session.run([self._train, self._epoch], { 329 | self._examples: examples, 330 | self._labels: labels 331 | }) 332 | if epoch != initial_epoch: 333 | break 334 | # time.sleep(0.02) # for preventing notebook noise 335 | 336 | def train(self): 337 | """Train the model.""" 338 | opts = self._options 339 | 340 | initial_epoch, initial_words = self._session.run([self._epoch, self._words]) 341 | 342 | workers = [] 343 | for _ in xrange(opts.concurrent_steps): 344 | t = threading.Thread(target=self._train_thread_body) 345 | t.start() 346 | workers.append(t) 347 | 348 | last_words, last_time = initial_words, time.time() 349 | while True: 350 | time.sleep(2) # Reports our progress once a while. 351 | (epoch, step, words, 352 | lr) = self._session.run([self._epoch, self.step, self._words, self._lr]) 353 | now = time.time() 354 | last_words, last_time, rate = words, now, (words - last_words) / ( 355 | now - last_time) 356 | print("Epoch %4d Step %8d: lr = %5.3f words/sec = %8.0f\r" % (epoch, step, 357 | lr, rate), 358 | end="") 359 | sys.stdout.flush() 360 | if epoch != initial_epoch: 361 | break 362 | 363 | for t in workers: 364 | t.join() 365 | 366 | def _predict(self, analogy): 367 | """Predict the top 4 answers for analogy questions.""" 368 | idx, = self._session.run([self._analogy_pred_idx], { 369 | self._analogy_a: analogy[:, 0], 370 | self._analogy_b: analogy[:, 1], 371 | self._analogy_c: analogy[:, 2] 372 | }) 373 | return idx 374 | 375 | def eval(self): 376 | """Evaluate analogy questions and reports accuracy.""" 377 | 378 | # How many questions we get right at precision@1. 379 | correct = 0 380 | 381 | total = self._analogy_questions.shape[0] 382 | start = 0 383 | while start < total: 384 | limit = start + 2500 385 | sub = self._analogy_questions[start:limit, :] 386 | idx = self._predict(sub) 387 | start = limit 388 | for question in xrange(sub.shape[0]): 389 | for j in xrange(4): 390 | if idx[question, j] == sub[question, 3]: 391 | # Bingo! We predicted correctly. E.g., [italy, rome, france, paris]. 392 | correct += 1 393 | break 394 | elif idx[question, j] in sub[question, :3]: 395 | # We need to skip words already in the question. 396 | continue 397 | else: 398 | # The correct label is not the precision@1 399 | break 400 | accuracy = correct * 100.0 / total 401 | print() 402 | print("Eval %4d/%d accuracy = %4.1f%%" % (correct, total, accuracy)) 403 | return accuracy 404 | 405 | def get_nearby(self, words, negative_words, num=20): 406 | wids = [self._word2id.get(w, 0) for w in words] 407 | n_wids = [self._word2id.get(w, 0) for w in negative_words] 408 | idx = self._session.run(self._joint_idx, { 409 | self._word_ids: wids, 410 | self._negative_word_ids: n_wids 411 | }) 412 | results = [] 413 | for distance, i in zip(idx[0][:num], idx[1][:num]): 414 | if i in wids: 415 | continue 416 | results.append((self._id2word[i], distance)) 417 | return results 418 | 419 | def doesnt_match(self, *words): 420 | wids = [self._word2id.get(w, 0) for w in words] 421 | idx, = self._session.run(self._mean_pred_idx, { 422 | self._word_ids: wids 423 | }) 424 | print(words[idx[0]]) 425 | return 426 | 427 | def get_doesnt_match(self, *words): 428 | wids = [self._word2id.get(w, 0) for w in words] 429 | idx, = self._session.run(self._mean_pred_idx, { 430 | self._word_ids: wids 431 | }) 432 | return words[idx[0]] 433 | 434 | def get_analogy(self, w0, w1, w2): 435 | """Predict word w3 as in w0:w1 vs w2:w3.""" 436 | wid = np.array([[self._word2id.get(w, 0) for w in [w0, w1, w2]]]) 437 | idx = self._predict(wid) 438 | for c in [self._id2word[i] for i in idx[0, :]]: 439 | if c not in [w0, w1, w2, 'UNK']: 440 | return c 441 | return 442 | 443 | def save(self): 444 | opts = self._options 445 | self.saver.save(self._session, os.path.join(opts.save_path, "model.ckpt")) 446 | all_embs = self._session.run(self._w_in) 447 | self._export_tsne(all_embs) 448 | print('Saved') 449 | 450 | def _export_tsne(self, all_embs): 451 | from sklearn.manifold import TSNE 452 | import json 453 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) 454 | plot_only = min(500, all_embs.shape[0]) 455 | low_dim_embs = tsne.fit_transform(all_embs[:plot_only,:]) 456 | labels = [self._id2word[i] for i in xrange(plot_only)] 457 | embs = [list(e) for e in low_dim_embs] 458 | json_data = json.dumps({'embs': embs, 'labels': labels}) 459 | path = os.path.join(self._options.save_path, 'tsne.js') 460 | with open(path, 'w') as f: 461 | f.write(json_data) 462 | print('%s exported' % path) 463 | 464 | def get_save_path(self): 465 | return self._options.save_path 466 | 467 | 468 | def main(_): 469 | """Train a word2vec model.""" 470 | opts = Options() 471 | if not opts.train_data and opts.eval_data: 472 | with tf.Graph().as_default(), tf.Session() as session: 473 | model = Word2Vec(opts, session) 474 | model.eval() # Eval analogies. 475 | return 476 | 477 | if not opts.train_data or not opts.save_path or not opts.eval_data: 478 | print("--train_data --eval_data and --save_path must be specified.") 479 | sys.exit(1) 480 | 481 | with tf.Graph().as_default(), tf.Session() as session: 482 | model = Word2Vec(opts, session) 483 | for i in xrange(opts.epochs_to_train): 484 | model.train() # Process one epoch 485 | accuracy = model.eval() # Eval analogies. 486 | if (i+1) % 5 == 0: 487 | model.save() 488 | if opts.epochs_to_train % 5 != 0: 489 | model.save() 490 | 491 | 492 | if __name__ == "__main__": 493 | tf.app.run() 494 | --------------------------------------------------------------------------------