├── .gitignore ├── MANIFEST.in ├── README.md ├── configs └── template.ini ├── download_model.sh ├── examples └── ja_example.py ├── requirements.txt ├── setup.py ├── test-requirements.txt ├── tests ├── __init__.py ├── test_interface.py ├── test_load_entity_model.py └── test_make_lattice.py ├── tox.ini └── word2vec_wikification_py ├── __init__.py ├── init_logger.py ├── initialize_mysql_connector.py ├── interface.py ├── load_entity_model.py ├── make_lattice.py ├── models.py └── search_wiki_pages.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | .venv/ 84 | venv/ 85 | ENV/ 86 | 87 | # Spyder project settings 88 | .spyderproject 89 | 90 | # Rope project settings 91 | .ropeproject 92 | 93 | # appendix 94 | bin/* 95 | .idea/ 96 | configs/development.ini 97 | *~ 98 | .DS_Store 99 | tests/resources/jawiki-latest-page.sql 100 | tests/resources/jawiki-latest-redirect.sql -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | recursive-include examples * 4 | recursive-include tests * 5 | recursive-include word2vec_wikification_py * -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | wiki_node_disambiguation 2 | - - - 3 | 4 | # What's this ? 5 | 6 | - You can run "Wikification" as easy as possible. 7 | - According to wikipedia, [Wikification](https://en.wikipedia.org/wiki/Wikification) is `in computer science, entity linking with Wikipedia as the target knowledge base` 8 | - You can get disambiguated result with its score. 9 | 10 | Please visit [Github page](https://github.com/Kensuke-Mitsuzawa/word2vec_wikification_py) also. 11 | If you find any bugs and you report it to github issue, I'm glad. 12 | Any pull-requests are welcomed. 13 | 14 | 15 | # Requirement 16 | 17 | - Python3.x (checked under ) 18 | - I recommend to use "Anaconda" distribution. 19 | 20 | # Setup 21 | 22 | `python setup.py install` 23 | 24 | ## Get wikipedia entity vector model 25 | 26 | Go to [this page](http://www.cl.ecei.tohoku.ac.jp/~m-suzuki/jawiki_vector/) and download model file from [here](http://www.cl.ecei.tohoku.ac.jp/~m-suzuki/jawiki_vector/entity_vector.tar.bz2). 27 | Or run `download_model.sh` 28 | 29 | ## To those who uses interface.predict_japanese_wiki_names() 30 | 31 | You're supposed to have mysql somewhere. 32 | 33 | The step until using it. 34 | 35 | 1. start mysql server somewhere 36 | 2. download latest mysql dump files 37 | 3. initialize wikipedia database with mysql 38 | 39 | 40 | To download wikipedia dump files, execute following commands 41 | 42 | ``` 43 | wget https://dumps.wikimedia.org/jawiki/latest/jawiki-latest-redirect.sql.gz 44 | wget https://dumps.wikimedia.org/jawiki/latest/jawiki-latest-page.sql.gz 45 | gunzip jawiki-latest-redirect.sql.gz 46 | gunzip jawiki-latest-page.sql.gz 47 | ``` 48 | 49 | To initialize wikipedia database with mysql, 50 | 51 | ``` 52 | % CREATE DATABASE wikipedia; 53 | % mysql -u [user_name] -p[password] wikipedia < jawiki-latest-redirect.sql 54 | % mysql -u [user_name] -p[password] wikipedia < jawiki-latest-page.sql 55 | ``` 56 | 57 | # Change logs 58 | 59 | - version0.1 60 | - released 61 | - It supports only Japanese wikipedia 62 | -------------------------------------------------------------------------------- /configs/template.ini: -------------------------------------------------------------------------------- 1 | [Mysql] 2 | host=localhost 3 | port=3306 4 | user_name= 5 | password= 6 | database_name=wikipedia 7 | table_name_category=category 8 | table_name_categorylinks=categorylinks 9 | table_name_page=page 10 | table_name_redirect=redirect -------------------------------------------------------------------------------- /download_model.sh: -------------------------------------------------------------------------------- 1 | mkdir ./bin/ 2 | wget http://www.cl.ecei.tohoku.ac.jp/~m-suzuki/jawiki_vector/entity_vector.tar.bz2 -O ./bin/entity_vector.tar.bz2 3 | bzip2 -dc ./bin/entity_vector.tar.bz2 | tar xvf - 4 | mv entity_vector ./bin 5 | 6 | -------------------------------------------------------------------------------- /examples/ja_example.py: -------------------------------------------------------------------------------- 1 | from gensim.models import Word2Vec 2 | from word2vec_wikification_py.interface import load_entity_model, predict_japanese_wiki_names_with_wikidump 3 | from word2vec_wikification_py.initialize_mysql_connector import initialize_pymysql_connector 4 | from word2vec_wikification_py import init_logger 5 | # You're supposed to install "JapaneseTokenizer" pakcage beforehand 6 | from JapaneseTokenizer import MecabWrapper 7 | import logging 8 | logger = logging.getLogger(init_logger.LOGGER_NAME) 9 | logger.level = logging.INFO 10 | 11 | """In this example, you see how to get wikipedia-liked information from Japanese sentence 12 | """ 13 | 14 | # ------------------------------------------------------------ 15 | # PARAMETERS 16 | path_model_file = '../bin/entity_vector/entity_vector.model.bin' 17 | dict_type = 'neologd' 18 | path_mecab_config = '/usr/local/bin/' 19 | pos_condition = [('名詞', )] 20 | mysql_username = 'your-mysql-user-name-here' 21 | mysql_hostname = 'localhost' 22 | mysql_password = 'your-mysql-password-here' 23 | mysql_db_name = 'wikipedia' 24 | # ------------------------------------------------------------ 25 | entity_linking_model = load_entity_model(path_model_file) 26 | mecab_tokenizer = MecabWrapper(dict_type, path_mecab_config=path_mecab_config) 27 | model_object = load_entity_model(path_entity_model=path_model_file, is_use_cache=True) # type: Word2Vec 28 | mysql_connector = initialize_pymysql_connector(hostname=mysql_hostname, 29 | user_name=mysql_username, 30 | password=mysql_password, 31 | dbname=mysql_db_name) 32 | 33 | input_sentence = "かつてはイルモア、WCMといったプライベーターがオリジナルマシンで参戦していたほか、カワサキがワークス・チームを送り込んでいたが、2016年現在出場しているのはヤマハ、ホンダ、スズキ、ドゥカティ、アプリリアの5メーカーと、ワークスマシンの貸与等を受けられるサテライトチームとなっている。" 34 | filtered_nouns = mecab_tokenizer.filter( 35 | parsed_sentence=mecab_tokenizer.tokenize(sentence=input_sentence,return_list=False), 36 | pos_condition=pos_condition).convert_list_object() 37 | 38 | sequence_score_ojects = predict_japanese_wiki_names_with_wikidump(input_tokens=filtered_nouns, 39 | wikipedia_db_connector=mysql_connector, 40 | entity_vector_model=entity_linking_model, 41 | is_use_cache=True, 42 | is_sort_object=True) 43 | for rank, sequence_obj in enumerate(sequence_score_ojects): 44 | print('Rank-{} with score={}'.format(rank, sequence_obj.sequence_score)) 45 | print(sequence_obj.get_tokens()) 46 | print('-'*30) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nose 2 | tox 3 | typing -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from setuptools import setup, find_packages 4 | 5 | name='word2vec_wikification_py' 6 | version='0.17' 7 | description='A package to run wikification' 8 | author='Kensuke Mitsuzawa' 9 | author_email='kensuke.mit@gmail.com' 10 | url='https://github.com/Kensuke-Mitsuzawa/word2vec_wikification_py' 11 | license_name='MIT' 12 | 13 | install_requires = [ 14 | 'gensim', 15 | 'pymysql', 16 | 'typing' 17 | ] 18 | 19 | 20 | dependency_links = [ 21 | ] 22 | 23 | short_description = '' 24 | 25 | try: 26 | import pypandoc 27 | long_description = pypandoc.convert('README.md', 'rst') 28 | except(IOError, ImportError): 29 | long_description = open('README.md').read() 30 | 31 | classifiers = [ 32 | "Development Status :: 5 - Production/Stable", 33 | "License :: OSI Approved :: MIT License", 34 | "Programming Language :: Python", 35 | "Natural Language :: Japanese", 36 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 37 | "Programming Language :: Python :: 3.5" 38 | ] 39 | 40 | setup( 41 | name=name, 42 | version=version, 43 | description=description, 44 | long_description=long_description, 45 | author=author, 46 | install_requires=install_requires, 47 | dependency_links=dependency_links, 48 | author_email=author_email, 49 | url=url, 50 | license=license_name, 51 | packages=find_packages(), 52 | classifiers=classifiers, 53 | test_suite='tests', 54 | include_package_data=True, 55 | zip_safe=False 56 | ) 57 | -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- 1 | typing -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kensuke-Mitsuzawa/word2vec-wikification-py/a56950edaef8a47ce76437104de0e7485946d857/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_interface.py: -------------------------------------------------------------------------------- 1 | from word2vec_wikification_py import load_entity_model, make_lattice, interface, initialize_mysql_connector 2 | from word2vec_wikification_py.models import WikipediaArticleObject, SequenceScore, LatticeObject, IndexDictionaryObject 3 | import configparser 4 | import unittest 5 | import os 6 | 7 | class TestInterface(unittest.TestCase): 8 | @classmethod 9 | def setUpClass(cls): 10 | # procedures before tests are started. This code block is executed only once 11 | cls.path_model_file = '../bin/entity_vector/entity_vector.model.bin' 12 | if not os.path.exists(cls.path_model_file): 13 | cls.path_model_file = cls.path_model_file.replace('../', '') 14 | cls.model_object = load_entity_model.load_entity_model(path_entity_model=cls.path_model_file, is_use_cache=True) 15 | 16 | cls.path_config_file = '../configs/development.ini' 17 | if not os.path.exists(cls.path_config_file): 18 | cls.path_config_file = cls.path_config_file.replace('../', '') 19 | 20 | if not os.path.exists(cls.path_config_file): 21 | raise FileExistsError() 22 | 23 | cls.config_obj = configparser.ConfigParser(allow_no_value=True) 24 | cls.config_obj.read(cls.path_config_file) 25 | 26 | 27 | @classmethod 28 | def tearDownClass(cls): 29 | # procedures after tests are finished. This code block is executed only once 30 | pass 31 | 32 | def setUp(self): 33 | # procedures before every tests are started. This code block is executed every time 34 | pass 35 | 36 | def tearDown(self): 37 | # procedures after every tests are finished. This code block is executed every time 38 | pass 39 | 40 | 41 | def test_compute_wiki_node_probability_test1(self): 42 | seq_wikipedia_article_object = [ 43 | WikipediaArticleObject(page_title='ヤマハ', candidate_article_name=['[ヤマハ]', '[ヤマハ発動機]']), 44 | WikipediaArticleObject(page_title='スズキ', candidate_article_name=['[スズキ_(企業)]', '[スズキ_(魚)]']), 45 | WikipediaArticleObject(page_title='ドゥカティ', candidate_article_name=['[ドゥカティ]']) 46 | ] 47 | 48 | sequence_score_objects = interface.compute_wiki_node_probability( 49 | seq_wiki_article_name=seq_wikipedia_article_object, 50 | entity_vector_model=self.model_object, 51 | is_use_cache=True 52 | ) 53 | self.assertTrue(isinstance(sequence_score_objects, list)) 54 | for seq_obj in sequence_score_objects: 55 | self.assertTrue(isinstance(seq_obj, SequenceScore)) 56 | 57 | def test_compute_wiki_node_probability_test2(self): 58 | seq_wikipedia_article_object = [ 59 | WikipediaArticleObject(page_title='お笑いタレント', candidate_article_name=['[お笑いタレント]']), 60 | WikipediaArticleObject(page_title='ロバート', candidate_article_name=['[ロバート_(お笑いトリオ)]', '[ロバート]']), 61 | WikipediaArticleObject(page_title='山本博', candidate_article_name=['[山本博_(お笑い芸人)]', '[山本博_(アーチェリー選手)]', '山本博_(弁護士)', '[山本博_(柔道家)]']), 62 | WikipediaArticleObject(page_title='エンタの神様', candidate_article_name=['[エンタの神様]']), 63 | WikipediaArticleObject(page_title='日テレ', candidate_article_name=['[日本テレビ放送網]']), 64 | WikipediaArticleObject(page_title='さんま御殿', candidate_article_name=['[踊る!さんま御殿!!]']), 65 | ] 66 | 67 | sequence_score_objects = interface.compute_wiki_node_probability( 68 | seq_wiki_article_name=seq_wikipedia_article_object, 69 | entity_vector_model=self.model_object, 70 | is_use_cache=True 71 | ) 72 | self.assertTrue(isinstance(sequence_score_objects, list)) 73 | for seq_obj in sequence_score_objects: 74 | self.assertTrue(isinstance(seq_obj, SequenceScore)) 75 | import pprint 76 | pprint.pprint(seq_obj.__dict__()) 77 | 78 | def test_predict_japanese_wiki_names_partial(self): 79 | connector = initialize_mysql_connector.initialize_pymysql_connector( 80 | hostname=self.config_obj.get('Mysql', 'host'), 81 | user_name=self.config_obj.get('Mysql', 'user_name'), 82 | password=self.config_obj.get('Mysql', 'password'), 83 | dbname=self.config_obj.get('Mysql', 'database_name') 84 | ) 85 | 86 | test_input = ['お笑い', 'タレント', 'ロバート', '山本博', 'エンタ', 'の', '神様', '日テレ', '踊る!さんま御殿!!'] 87 | sequence_score_objects = interface.predict_japanese_wiki_names_with_wikidump( 88 | input_tokens=test_input, 89 | wikipedia_db_connector=connector, 90 | entity_vector_model=self.model_object, 91 | is_use_cache=True, 92 | is_sort_object=True 93 | ) 94 | self.assertTrue(isinstance(sequence_score_objects, list)) 95 | for seq_obj in sequence_score_objects: 96 | self.assertTrue(isinstance(seq_obj, SequenceScore)) 97 | import pprint 98 | pprint.pprint(seq_obj.__dict__()) 99 | 100 | 101 | def test_predict_japanese_wiki_names_complete(self): 102 | connector = initialize_mysql_connector.initialize_pymysql_connector( 103 | hostname=self.config_obj.get('Mysql', 'host'), 104 | user_name=self.config_obj.get('Mysql', 'user_name'), 105 | password=self.config_obj.get('Mysql', 'password'), 106 | dbname=self.config_obj.get('Mysql', 'database_name') 107 | ) 108 | 109 | #test_input = ['お笑いタレント', 'ロバート', '山本博', 'エンタの神様', '日テレ', '踊る!さんま御殿!!'] 110 | test_input = ['ヤマハ', 'バイク', 'スズキ', 'オーバーテイク', 'ホンダ', '優勝'] 111 | sequence_score_objects = interface.predict_japanese_wiki_names_with_wikidump( 112 | input_tokens=test_input, 113 | wikipedia_db_connector=connector, 114 | entity_vector_model=self.model_object, 115 | is_use_cache=True, 116 | is_sort_object=True, 117 | search_method='complete' 118 | ) 119 | self.assertTrue(isinstance(sequence_score_objects, list)) 120 | for seq_obj in sequence_score_objects: 121 | self.assertTrue(isinstance(seq_obj, SequenceScore)) 122 | import pprint 123 | pprint.pprint(seq_obj.__dict__()) 124 | 125 | 126 | if __name__ == '__main__': 127 | unittest.main() -------------------------------------------------------------------------------- /tests/test_load_entity_model.py: -------------------------------------------------------------------------------- 1 | from word2vec_wikification_py import load_entity_model 2 | import unittest 3 | import os 4 | 5 | class TestLoadEntityModel(unittest.TestCase): 6 | @classmethod 7 | def setUpClass(cls): 8 | # procedures before tests are started. This code block is executed only once 9 | cls.path_model_file = '../bin/entity_vector/entity_vector.model.bin' 10 | if not os.path.exists(cls.path_model_file): 11 | cls.path_model_file = cls.path_model_file.replace('../', '') 12 | 13 | 14 | @classmethod 15 | def tearDownClass(cls): 16 | # procedures after tests are finished. This code block is executed only once 17 | pass 18 | 19 | def setUp(self): 20 | # procedures before every tests are started. This code block is executed every time 21 | pass 22 | 23 | def tearDown(self): 24 | # procedures after every tests are finished. This code block is executed every time 25 | pass 26 | 27 | def test_load_entity_model(self): 28 | model_object = load_entity_model.load_entity_model(path_entity_model=self.path_model_file, 29 | is_use_cache=True) 30 | print(model_object.most_similar('[エン・ジャパン]')) 31 | 32 | if __name__ == '__main__': 33 | unittest.main() -------------------------------------------------------------------------------- /tests/test_make_lattice.py: -------------------------------------------------------------------------------- 1 | from word2vec_wikification_py import load_entity_model, make_lattice 2 | from word2vec_wikification_py.models import WikipediaArticleObject, LatticeObject, IndexDictionaryObject 3 | import unittest 4 | import os 5 | 6 | class TestMakeLatice(unittest.TestCase): 7 | @classmethod 8 | def setUpClass(cls): 9 | # procedures before tests are started. This code block is executed only once 10 | cls.path_model_file = '../bin/entity_vector/entity_vector.model.bin' 11 | if not os.path.exists(cls.path_model_file): 12 | cls.path_model_file = cls.path_model_file.replace('../', '') 13 | cls.model_object = load_entity_model.load_entity_model(path_entity_model=cls.path_model_file, is_use_cache=True) 14 | 15 | cls.seq_wikipedia_article_object = [ 16 | WikipediaArticleObject(page_title='ヤマハ', candidate_article_name=['[ヤマハ]', '[ヤマハ発動機]']), 17 | WikipediaArticleObject(page_title='スズキ', candidate_article_name=['[スズキ_(企業)]', '[スズキ_(魚)]']), 18 | WikipediaArticleObject(page_title='ドゥカティ', candidate_article_name=['[ドゥカティ]']), 19 | ] 20 | 21 | @classmethod 22 | def tearDownClass(cls): 23 | # procedures after tests are finished. This code block is executed only once 24 | pass 25 | 26 | def setUp(self): 27 | # procedures before every tests are started. This code block is executed every time 28 | pass 29 | 30 | def tearDown(self): 31 | # procedures after every tests are finished. This code block is executed every time 32 | pass 33 | 34 | def test_make_state_transition_matrix(self): 35 | """状態tから状態t+1への遷移行列を作成するテスト 36 | """ 37 | state2index_obj = IndexDictionaryObject(state2index={'row2index': {}, 'column2index': {}}, 38 | index2state={}) 39 | 40 | transition_edge = make_lattice.make_state_transition_edge( 41 | state_t_word_tuple=(0,'[ヤマハ]'), 42 | state_t_plus_word_tuple=(1, '[河合楽器製作所]'), 43 | state2index_obj=state2index_obj, 44 | entity_vector=self.model_object 45 | ) 46 | self.assertTrue(isinstance(transition_edge, tuple)) 47 | self.assertEqual(transition_edge[0].transition_score, self.model_object.similarity('[ヤマハ]', '[河合楽器製作所]')) 48 | 49 | def test_make_state_transition_sequence(self): 50 | """ 51 | """ 52 | state2index_obj = IndexDictionaryObject(state2index={'row2index': {}, 'column2index': {}}, 53 | index2state={}) 54 | 55 | make_lattice.make_state_transition_sequence( 56 | seq_wiki_article_name=self.seq_wikipedia_article_object, 57 | entity_vector_model=self.model_object, 58 | state2index_obj=state2index_obj, 59 | ) 60 | 61 | def test_make_lattice_object(self): 62 | lattice_object = make_lattice.make_lattice_object( 63 | seq_wiki_article_name=self.seq_wikipedia_article_object, 64 | entity_vector_model=self.model_object, 65 | is_use_cache=True 66 | ) 67 | self.assertTrue(isinstance(lattice_object, LatticeObject)) 68 | 69 | 70 | if __name__ == '__main__': 71 | unittest.main() -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py34,py35,py351,py352 3 | 4 | [testenv] 5 | deps = 6 | -U 7 | -r{toxinidir}/requirements.txt 8 | -r{toxinidir}/test-requirements.txt 9 | commands = 10 | nosetests -v -------------------------------------------------------------------------------- /word2vec_wikification_py/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kensuke-Mitsuzawa/word2vec-wikification-py/a56950edaef8a47ce76437104de0e7485946d857/word2vec_wikification_py/__init__.py -------------------------------------------------------------------------------- /word2vec_wikification_py/init_logger.py: -------------------------------------------------------------------------------- 1 | LOGGER_NAME = 'word2vec_wikification_py' 2 | 3 | import logging 4 | import os 5 | import sys 6 | from logging import getLogger, Formatter, Logger, StreamHandler 7 | from logging.handlers import SMTPHandler, RotatingFileHandler, TimedRotatingFileHandler 8 | 9 | # Formatter 10 | custmoFormatter = Formatter( 11 | fmt='[%(asctime)s]%(levelname)s - %(filename)s#%(funcName)s:%(lineno)d: %(message)s', 12 | datefmt='Y/%m/%d %H:%M:%S' 13 | ) 14 | 15 | # StreamHandler 16 | STREAM_LEVEL = logging.DEBUG 17 | STREAM_FORMATTER = custmoFormatter 18 | STREAM = sys.stderr 19 | 20 | st_handler = StreamHandler(stream=STREAM) 21 | st_handler.setLevel(STREAM_LEVEL) 22 | st_handler.setFormatter(STREAM_FORMATTER) 23 | 24 | 25 | def init_logger(logger:logging.Logger)->logging.Logger: 26 | logger.addHandler(st_handler) 27 | logger.propagate = False 28 | 29 | return logger -------------------------------------------------------------------------------- /word2vec_wikification_py/initialize_mysql_connector.py: -------------------------------------------------------------------------------- 1 | def initialize_mysql_connector(hostname:str, 2 | user_name:str, 3 | password:str, 4 | dbname:str): 5 | """ 6 | """ 7 | import MySQLdb 8 | conn = MySQLdb.connect(hostname, user_name, password, dbname) 9 | return conn 10 | 11 | 12 | def initialize_pymysql_connector(hostname:str, 13 | user_name:str, 14 | password:str, 15 | dbname:str): 16 | import pymysql.cursors 17 | # MySQLに接続する 18 | connection = pymysql.connect(host=hostname, 19 | user=user_name, 20 | password=password, 21 | db=dbname, 22 | charset='utf8') 23 | return connection 24 | -------------------------------------------------------------------------------- /word2vec_wikification_py/interface.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | try: 3 | from gensim.models import KeyedVectors 4 | from gensim.models import Word2Vec 5 | except ImportError: 6 | from gensim.models import Word2Vec 7 | from word2vec_wikification_py.models import WikipediaArticleObject, LatticeObject, SequenceScore 8 | from word2vec_wikification_py.make_lattice import make_lattice_object 9 | from word2vec_wikification_py.load_entity_model import load_entity_model 10 | from word2vec_wikification_py import search_wiki_pages 11 | from typing import List, Any, Union 12 | from functools import partial 13 | 14 | 15 | def string_normalization_function(input_str:str)->str: 16 | return input_str 17 | 18 | 19 | def add_article_symbol(input_str:str)->str: 20 | return '[{}]'.format(input_str) 21 | 22 | 23 | def predict_japanese_wiki_names_with_wikidump(input_tokens, 24 | wikipedia_db_connector, 25 | entity_vector_model, 26 | is_use_cache=True, 27 | is_sort_object=True, 28 | page_table_name='page', 29 | page_table_redirect='redirect', 30 | search_method='complete') -> List[SequenceScore]: 31 | """* What you can do 32 | - You can run "Wikification" over your tokenized text 33 | 34 | * Params 35 | - input_tokens: list of tokens 36 | - wikipedia_db_connector: mysql connector into wikipedia-dump database 37 | - entity_vector_model: wikipedia entity vector of word2vec model 38 | - is_use_cache: a boolean flag for keeping huge-object on disk 39 | - is_sort_object: a boolean flag for sorting SequenceScore object 40 | - page_table_name: the name of "page" table of wikipedia-dump database 41 | - page_table_redirect: the name of "redirect" table of wikipedia-dump database 42 | - search_method: a way to find candidates of wikipedia article name. 43 | - partial: It tries to find wikipedia article name by concatenating tokens 44 | - complete: It trusts the result of tokenizer. 45 | """ 46 | # type: (List[str],Any,Union[Word2Vec,KeyedVectors],bool,bool,str,str,str)->List[SequenceScore] 47 | if search_method=='partial': 48 | search_function = partial(search_wiki_pages.search_function_from_wikipedia_database, 49 | wikipedia_db_connector=wikipedia_db_connector, 50 | page_table_name=page_table_name, 51 | page_table_redirect=page_table_redirect) 52 | 53 | search_result = search_wiki_pages.search_from_dictionary(target_tokens=input_tokens, 54 | string_normalization_function=string_normalization_function, 55 | partially_param_given_function=search_function) 56 | seq_wiki_article_name = [ 57 | WikipediaArticleObject(page_title=token_name, candidate_article_name=[add_article_symbol(string) for string in results]) 58 | for token_name, results in search_result.items() if not results == []] 59 | return compute_wiki_node_probability(seq_wiki_article_name=seq_wiki_article_name, 60 | entity_vector_model=entity_vector_model, 61 | is_use_cache=is_use_cache, 62 | is_sort_object=is_sort_object) 63 | elif search_method == 'complete': 64 | search_result = [ 65 | search_wiki_pages.search_function_from_wikipedia_database( 66 | token=token, 67 | wikipedia_db_connector=wikipedia_db_connector, 68 | page_table_name=page_table_name, 69 | page_table_redirect=page_table_redirect 70 | ) for token in input_tokens] 71 | seq_wiki_article_name = [ 72 | WikipediaArticleObject(page_title=token, candidate_article_name=[add_article_symbol(string) for string in results]) 73 | for token, results in zip(input_tokens, search_result) if not results == []] 74 | 75 | return compute_wiki_node_probability(seq_wiki_article_name=seq_wiki_article_name, 76 | entity_vector_model=entity_vector_model, 77 | is_use_cache=is_use_cache, 78 | is_sort_object=is_sort_object) 79 | else: 80 | raise Exception('There is no search method named {}'.format(search_method)) 81 | 82 | 83 | def compute_wiki_node_probability(seq_wiki_article_name, 84 | entity_vector_model, 85 | is_use_cache=True, 86 | is_sort_object=True): 87 | """* What you can do 88 | - You can get sequence of wikipedia-article-names with its sequence-score 89 | 90 | * Params 91 | - is_use_cache: a boolean flag for keeping huge-object on disk 92 | - is_sort_object: a boolean flag for sorting SequenceScore object 93 | 94 | * Caution 95 | - You must proper wikipedia-article-name on WikipediaArticleObject.candidate_article_name attribute 96 | """ 97 | # type: (List[WikipediaArticleObject],Union[Word2Vec,KeyedVectors],bool,bool)->List[SequenceScore] 98 | 99 | # step1 it constructs array of transition-matrix(from state-t until state-t+1) 100 | lattice_object = make_lattice_object( 101 | seq_wiki_article_name=seq_wiki_article_name, 102 | entity_vector_model=entity_vector_model, 103 | is_use_cache=is_use_cache 104 | ) # type: LatticeObject 105 | # step2 compute route-score on Lattice network 106 | sequence_score_objects = lattice_object.get_score_routes() 107 | if is_sort_object: sequence_score_objects.sort(key=lambda obj: obj.sequence_score, reverse=True) 108 | 109 | return sequence_score_objects -------------------------------------------------------------------------------- /word2vec_wikification_py/load_entity_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | try: 3 | from gensim.models import KeyedVectors 4 | from gensim.models import Word2Vec 5 | except: 6 | # to meet api interface of old gensim version 7 | from gensim.models import Word2Vec 8 | from word2vec_wikification_py import init_logger 9 | from word2vec_wikification_py.models import PersistentDict 10 | from tempfile import mkdtemp 11 | from typing import Union 12 | import os 13 | 14 | 15 | 16 | def load_entity_model(path_entity_model:str, 17 | is_binary_file:bool=True, 18 | is_use_cache:bool=False, 19 | path_working_dir:str=None)->Union[KeyedVectors, Word2Vec]: 20 | """* What you can do 21 | - You load entity mode on memory. 22 | """ 23 | if not os.path.exists(path_entity_model): 24 | raise FileExistsError('There is no model file at {}'.format(path_entity_model)) 25 | if path_working_dir is None: path_working_dir = mkdtemp() 26 | 27 | try: 28 | if is_binary_file: 29 | model = Word2Vec.load_word2vec_format(path_entity_model, binary=True) 30 | else: 31 | model = Word2Vec.load_word2vec_format(path_entity_model, binary=False) 32 | except DeprecationWarning: 33 | if is_binary_file: 34 | model = KeyedVectors.load_word2vec_format(path_entity_model, binary=True) 35 | else: 36 | model = KeyedVectors.load_word2vec_format(path_entity_model, binary=False) 37 | 38 | 39 | if is_use_cache: 40 | cache_obj = PersistentDict(os.path.join(path_working_dir, 'entity_model'), flag='c', format='pickle') 41 | cache_obj = model 42 | return cache_obj 43 | else: 44 | return model -------------------------------------------------------------------------------- /word2vec_wikification_py/make_lattice.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # gensim 3 | try: 4 | from gensim.models import Word2Vec, KeyedVectors 5 | except ImportError: 6 | # for gensim older version 7 | from gensim.models import Word2Vec 8 | # matrix object 9 | from numpy import ndarray 10 | from word2vec_wikification_py import init_logger 11 | from word2vec_wikification_py.models import WikipediaArticleObject, PersistentDict, LatticeObject, IndexDictionaryObject, EdgeObject 12 | from typing import List, Tuple, Union, Any, Dict, Set 13 | from tempfile import mkdtemp 14 | from scipy.sparse import csr_matrix 15 | import os 16 | import logging 17 | logger = logging.getLogger(name=init_logger.LOGGER_NAME) 18 | 19 | 20 | class TransitionEdgeObject(object): 21 | __slots__ = ['row_index', 'column_index', 'transition_score'] 22 | 23 | def __init__(self, 24 | row_index:int, 25 | column_index:int, 26 | transition_score:float): 27 | self.row_index = row_index 28 | self.column_index = column_index 29 | self.transition_score = transition_score 30 | 31 | 32 | def __update_index_dictionary(key:Tuple[int,str], index_dictionary:Dict[Tuple[int,str],int])->Dict[Tuple[int,str],int]: 33 | """ 34 | """ 35 | if key in index_dictionary: 36 | raise Exception('The key is already existing in index_dictionary. key={}'.format(key)) 37 | else: 38 | if len(index_dictionary)==0: 39 | index_dictionary[key] = 0 40 | else: 41 | latest_index_number = max(index_dictionary.values()) 42 | index_dictionary[key] = latest_index_number + 1 43 | 44 | return index_dictionary 45 | 46 | 47 | def make_state_transition_edge(state_t_word_tuple, 48 | state_t_plus_word_tuple, 49 | state2index_obj, 50 | entity_vector): 51 | """* What you can do 52 | - tの単語xからt+1の単語x'への遷移スコアを計算する 53 | 54 | * Output 55 | - tuple object whose element is (transition_element, row2index, column2index) 56 | - transition_element is (row_index, column_index, transition_score) 57 | """ 58 | # type: (Tuple[int,str],Tuple[int,str],IndexDictionaryObject,Union[Word2Vec,KeyedVectors])->Tuple[TransitionEdgeObject, IndexDictionaryObject] 59 | if isinstance(entity_vector, Word2Vec): 60 | if not state_t_word_tuple[1] in entity_vector.wv.vocab: 61 | raise Exception('Element does not exist in entity_voctor model. element={}'.format(state_t_word_tuple)) 62 | if not state_t_plus_word_tuple[1] in entity_vector.wv.vocab: 63 | raise Exception('Element does not exist in entity_voctor model. element={}'.format(state_t_plus_word_tuple)) 64 | elif isinstance(entity_vector, KeyedVectors): 65 | if not state_t_word_tuple[1] in entity_vector.vocab: 66 | raise Exception('Element does not exist in entity_voctor model. element={}'.format(state_t_word_tuple)) 67 | if not state_t_plus_word_tuple[1] in entity_vector.vocab: 68 | raise Exception('Element does not exist in entity_voctor model. element={}'.format(state_t_plus_word_tuple)) 69 | else: 70 | raise Exception() 71 | 72 | transition_score = entity_vector.similarity(state_t_word_tuple[1], state_t_plus_word_tuple[1]) # type: float 73 | 74 | if state_t_word_tuple in state2index_obj.state2index['row2index']: 75 | row_index = state2index_obj.state2index['row2index'][state_t_word_tuple] # type: int 76 | else: 77 | state2index_obj.state2index['row2index'] = __update_index_dictionary(state_t_word_tuple, state2index_obj.state2index['row2index']) 78 | row_index = state2index_obj.state2index['row2index'][state_t_word_tuple] # type: int 79 | 80 | if state_t_plus_word_tuple in state2index_obj.state2index['column2index']: 81 | column_index = state2index_obj.state2index['column2index'][state_t_plus_word_tuple] # type: int 82 | else: 83 | state2index_obj.state2index['column2index'] = __update_index_dictionary(state_t_plus_word_tuple, state2index_obj.state2index['column2index']) 84 | column_index = state2index_obj.state2index['column2index'][state_t_plus_word_tuple] # type: int 85 | 86 | transition_edge_obj = TransitionEdgeObject(row_index=row_index, 87 | column_index=column_index, 88 | transition_score=transition_score) 89 | 90 | return (transition_edge_obj, state2index_obj) 91 | 92 | 93 | def make_state_transition(index:int, 94 | seq_wiki_article_name:List[WikipediaArticleObject], 95 | state2index_obj:IndexDictionaryObject, 96 | entity_vector_model:Word2Vec)->Tuple[List[EdgeObject], List[TransitionEdgeObject]]: 97 | """* What you can do 98 | - You make all state-information between state_index and state_index_plus_1 99 | """ 100 | edge_group = [] # type: List[EdgeObject] 101 | seq_transition_element = [] # type: List[TransitionEdgeObject] 102 | 103 | for candidate_wikipedia_article_name in seq_wiki_article_name[index].candidate_article_name: 104 | for candiate_wikipedia_article_name_state_plus in seq_wiki_article_name[index + 1].candidate_article_name: 105 | state_t_word_tuple = (index, candidate_wikipedia_article_name) 106 | state_t_plus_word_tuple = (index + 1, candiate_wikipedia_article_name_state_plus) 107 | transition_element, state2index_obj = make_state_transition_edge( 108 | state_t_word_tuple=state_t_word_tuple, 109 | state_t_plus_word_tuple=state_t_plus_word_tuple, 110 | state2index_obj=state2index_obj, 111 | entity_vector=entity_vector_model 112 | ) 113 | seq_transition_element.append(transition_element) 114 | edge_group.append( EdgeObject(state2index_obj.state2index['row2index'][state_t_word_tuple], 115 | state2index_obj.state2index['column2index'][state_t_plus_word_tuple]) 116 | ) 117 | 118 | return (edge_group, seq_transition_element) 119 | 120 | 121 | def make_state_transition_sequence(seq_wiki_article_name:List[WikipediaArticleObject], 122 | entity_vector_model:Word2Vec, 123 | state2index_obj: IndexDictionaryObject)->Tuple[IndexDictionaryObject, 124 | List[List[EdgeObject]], 125 | csr_matrix]: 126 | """系列での遷移行列を作成する 127 | """ 128 | # TODO 関数ごとcython化を検討 129 | # TODO sequence系列も作成する 130 | seq_transition_element = [] # type: List[TransitionEdgeObject] 131 | seq_edge_group = [] 132 | for index in range(0, len(seq_wiki_article_name)-1): 133 | edge_group, seq_transition_edge_object = make_state_transition( 134 | index=index, 135 | seq_wiki_article_name=seq_wiki_article_name, 136 | state2index_obj=state2index_obj, 137 | entity_vector_model=entity_vector_model) 138 | seq_edge_group.append(edge_group) 139 | seq_transition_element += seq_transition_edge_object 140 | 141 | # TODO cythonの場合は、numpyのまま処理してしまう 142 | data = [transition_tuple.transition_score for transition_tuple in seq_transition_element] 143 | row = [transition_tuple.row_index for transition_tuple in seq_transition_element] 144 | column = [transition_tuple.column_index for transition_tuple in seq_transition_element] 145 | 146 | transition_matrix = csr_matrix( 147 | (data, (row, column)), 148 | shape=(len(state2index_obj.state2index['row2index']), len(state2index_obj.state2index['column2index'])) 149 | ) 150 | return (state2index_obj, seq_edge_group, transition_matrix) 151 | 152 | 153 | def filter_out_of_vocabulary_word(wikipedia_article_obj: WikipediaArticleObject, vocabulary_words:Set)->Union[bool, WikipediaArticleObject]: 154 | """* What you can do 155 | - You remove out-of-vocabulary word from wikipedia_article_obj.candidate_article_name 156 | """ 157 | filtered_article_name = [] 158 | for article_name in wikipedia_article_obj.candidate_article_name: 159 | if article_name in vocabulary_words: 160 | filtered_article_name.append(article_name) 161 | else: 162 | logger.warning(msg='Out of vocabulary word. It removes. word = {}'.format(article_name)) 163 | 164 | if len(filtered_article_name)==0: 165 | return False 166 | else: 167 | wikipedia_article_obj.candidate_article_name = filtered_article_name 168 | return wikipedia_article_obj 169 | 170 | 171 | def make_lattice_object(seq_wiki_article_name, 172 | entity_vector_model, 173 | path_wordking_dir=None, 174 | is_use_cache=True): 175 | """* What you can do 176 | 177 | """ 178 | # type: (List[WikipediaArticleObject],Union[Word2Vec,KeyedVectors],str,bool)->LatticeObject 179 | if path_wordking_dir is None: path_wordking_dir = mkdtemp() 180 | if is_use_cache: 181 | persistent_state2index = PersistentDict(os.path.join(path_wordking_dir, 'column2index.json'), flag='c', format='json') 182 | persistent_state2index['row2index'] = {} 183 | persistent_state2index['column2index'] = {} 184 | else: 185 | persistent_state2index = {} 186 | persistent_state2index['row2index'] = {} 187 | persistent_state2index['column2index'] = {} 188 | 189 | state2dict_obj = IndexDictionaryObject( 190 | state2index=persistent_state2index, 191 | index2state={}) 192 | 193 | if isinstance(entity_vector_model, Word2Vec): 194 | vocabulary_words = set(entity_vector_model.wv.vocab.keys()) 195 | elif isinstance(entity_vector_model, KeyedVectors): 196 | vocabulary_words = set(entity_vector_model.vocab.keys()) 197 | else: 198 | raise Exception() 199 | 200 | seq_wiki_article_name = [ 201 | wiki_article_name 202 | for wiki_article_name in seq_wiki_article_name 203 | if not filter_out_of_vocabulary_word(wiki_article_name, vocabulary_words) is False] 204 | 205 | updated_state2dict_obj, seq_edge_group, transition_matrix = make_state_transition_sequence( 206 | seq_wiki_article_name=seq_wiki_article_name, 207 | entity_vector_model=entity_vector_model, 208 | state2index_obj=state2dict_obj 209 | ) 210 | 211 | if is_use_cache: 212 | """If is_use_cache is True, use disk-drive for keeping object 213 | """ 214 | index2state = PersistentDict(os.path.join(path_wordking_dir, 'index2row.json'), flag='c', format='json') 215 | updated_state2dict_obj.index2state = index2state 216 | else: 217 | updated_state2dict_obj.index2state = {} 218 | 219 | updated_state2dict_obj.index2state['index2row'] = {value: key for key, value in updated_state2dict_obj.state2index['row2index'].items()} 220 | updated_state2dict_obj.index2state['index2column'] = {value: key for key, value in updated_state2dict_obj.state2index['column2index'].items()} 221 | 222 | return LatticeObject( 223 | transition_matrix=transition_matrix, 224 | index_dictionary_obj=updated_state2dict_obj, 225 | seq_edge_groups=seq_edge_group, 226 | seq_wiki_article_name=seq_wiki_article_name 227 | ) -------------------------------------------------------------------------------- /word2vec_wikification_py/models.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Any, Union, Dict 2 | from numpy.core import ndarray 3 | from scipy.sparse import csr_matrix 4 | from itertools import product 5 | import pickle, json, csv, os, shutil 6 | import copy 7 | import itertools 8 | 9 | # this class is from https://code.activestate.com/recipes/576642/ 10 | class PersistentDict(dict): 11 | ''' Persistent dictionary with an API compatible with shelve and anydbm. 12 | The dict is kept in memory, so the dictionary operations run as fast as 13 | a regular dictionary. 14 | Write to disk is delayed until close or sync (similar to gdbm's fast mode). 15 | Input file format is automatically discovered. 16 | Output file format is selectable between pickle, json, and csv. 17 | All three serialization formats are backed by fast C implementations. 18 | ''' 19 | 20 | def __init__(self, filename, flag='c', mode=None, format='pickle', *args, **kwds): 21 | self.flag = flag # r=readonly, c=create, or n=new 22 | self.mode = mode # None or an octal triple like 0644 23 | self.format = format # 'csv', 'json', or 'pickle' 24 | self.filename = filename 25 | if flag != 'n' and os.access(filename, os.R_OK): 26 | fileobj = open(filename, 'rb' if format=='pickle' else 'r') 27 | with fileobj: 28 | self.load(fileobj) 29 | dict.__init__(self, *args, **kwds) 30 | 31 | def sync(self): 32 | 'Write dict to disk' 33 | if self.flag == 'r': 34 | return 35 | filename = self.filename 36 | tempname = filename + '.tmp' 37 | fileobj = open(tempname, 'wb' if self.format=='pickle' else 'w') 38 | try: 39 | self.dump(fileobj) 40 | except Exception: 41 | os.remove(tempname) 42 | raise 43 | finally: 44 | fileobj.close() 45 | shutil.move(tempname, self.filename) # atomic commit 46 | if self.mode is not None: 47 | os.chmod(self.filename, self.mode) 48 | 49 | def close(self): 50 | self.sync() 51 | 52 | def __enter__(self): 53 | return self 54 | 55 | def __exit__(self, *exc_info): 56 | self.close() 57 | 58 | def dump(self, fileobj): 59 | if self.format == 'csv': 60 | csv.writer(fileobj).writerows(self.items()) 61 | elif self.format == 'json': 62 | json.dump(self, fileobj, separators=(',', ':')) 63 | elif self.format == 'pickle': 64 | pickle.dump(dict(self), fileobj, 2) 65 | else: 66 | raise NotImplementedError('Unknown format: ' + repr(self.format)) 67 | 68 | def load(self, fileobj): 69 | # try formats from most restrictive to least restrictive 70 | for loader in (pickle.load, json.load, csv.reader): 71 | fileobj.seek(0) 72 | try: 73 | return self.update(loader(fileobj)) 74 | except Exception: 75 | pass 76 | 77 | 78 | 79 | class WikipediaArticleObject(object): 80 | """Wikipediaの記事情報を記述するためのクラス 81 | """ 82 | __slots__ = ['page_title', 'candidate_article_name', 'article_name'] 83 | 84 | def __init__(self, 85 | page_title:str, 86 | candidate_article_name:List[str], 87 | article_name:str=None): 88 | self.page_title = page_title 89 | self.candidate_article_name = candidate_article_name 90 | self.article_name = article_name 91 | 92 | def __str__(self)->str: 93 | return self.page_title 94 | 95 | def __dict__(self)->Dict[str,Any]: 96 | return { 97 | 'page_title': self.page_title, 98 | 'candidate_article_name': self.candidate_article_name, 99 | 'article_name': self.article_name 100 | } 101 | 102 | @classmethod 103 | def from_dict(cls, dict_object:Dict[str,Any]): 104 | if 'article_name' in dict_object: 105 | article_name = dict_object['article_name'] 106 | else: 107 | article_name = None 108 | 109 | return WikipediaArticleObject( 110 | page_title=dict_object['page_title'], 111 | candidate_article_name=dict_object['candidate_article_name'], 112 | article_name=article_name 113 | ) 114 | 115 | 116 | class SequenceScore(object): 117 | """計算された記事系列のスコアを保持するオブジェクト 118 | 119 | """ 120 | def __init__(self, 121 | seq_words:List[WikipediaArticleObject], 122 | seq_transition_score:List[Tuple[str, str, float]], 123 | sequence_score:float): 124 | self.seq_words = seq_words 125 | self.seq_transition_score = seq_transition_score 126 | self.sequence_score = sequence_score 127 | 128 | def __dict__(self): 129 | return { 130 | 'seq_words': [wikipedia_obj.__dict__() for wikipedia_obj in self.seq_words], 131 | 'seq_transition_score': self.seq_transition_score, 132 | 'sequence_score': self.sequence_score 133 | } 134 | 135 | def __str__(self): 136 | return """SequenceScore object with score={}""".format(self.sequence_score) 137 | 138 | def __generate_label_sequence(self, seq_score_tuple:List[Tuple[str, str, float]])->List[str]: 139 | """* What you can do 140 | - You generate list of label 141 | """ 142 | seq_label = [] 143 | for index in range(0, len(seq_score_tuple)): 144 | if index == 0: 145 | seq_label.append(seq_score_tuple[index][0]) 146 | elif index+1 == len(seq_score_tuple): 147 | seq_label.append(seq_score_tuple[index][0]) 148 | seq_label.append(seq_score_tuple[index][1]) 149 | else: 150 | seq_label.append(seq_score_tuple[index][0]) 151 | 152 | return seq_label 153 | 154 | def get_tokens(self)->List[str]: 155 | return self.__generate_label_sequence(seq_score_tuple=self.seq_transition_score) 156 | 157 | @classmethod 158 | def from_dict(cls, dict_object:Dict[str,Any]): 159 | seq_words = [ 160 | WikipediaArticleObject.from_dict(wikipedia_dict_obj) 161 | for wikipedia_dict_obj in dict_object['seq_words']] 162 | return SequenceScore( 163 | seq_words=seq_words, 164 | seq_transition_score=dict_object['seq_transition_score'], 165 | sequence_score=dict_object['sequence_score'] 166 | ) 167 | 168 | 169 | class EdgeObject(object): 170 | __slots__ = ['index_at_t', 'index_at_t_plus'] 171 | 172 | def __init__(self, index_at_t:int, index_at_t_plus:int): 173 | self.index_at_t = index_at_t 174 | self.index_at_t_plus = index_at_t_plus 175 | 176 | def to_tuple(self): 177 | return (self.index_at_t, self.index_at_t_plus) 178 | 179 | 180 | class IndexDictionaryObject(object): 181 | """Class object for keeping a relation of state_name and index. 182 | state2index attribute must have 2 key names. 183 | - row2index 184 | - column2index 185 | 186 | index2state attribute must have 2 key names. 187 | - index2row 188 | - index2column 189 | 190 | """ 191 | __slots__ = ['state2index', 'index2state'] 192 | 193 | def __init__(self, 194 | state2index:Union[Dict[str, Dict], PersistentDict], 195 | index2state:Union[Dict[str, Dict], PersistentDict]): 196 | self.state2index = state2index 197 | self.index2state = index2state 198 | 199 | 200 | class LatticeObject(object): 201 | def __init__(self, 202 | transition_matrix:Union[csr_matrix, ndarray], 203 | index_dictionary_obj:IndexDictionaryObject, 204 | seq_edge_groups:List[List[EdgeObject]], 205 | seq_wiki_article_name: List[WikipediaArticleObject]=None): 206 | """* 207 | """ 208 | self.transition_matrix = transition_matrix 209 | self.index_dictionary_obj = index_dictionary_obj 210 | self.seq_edge_groups = seq_edge_groups 211 | self.index_tuple_route = self.__generate_edge_routes() 212 | self.seq_wiki_article_name = seq_wiki_article_name 213 | if not seq_wiki_article_name is None: 214 | ## It constructs dict of wiki-article-name <-> (index-in-list, wiki-article-object) ## 215 | function_key = lambda tuple_wikilabel_wikiobj: tuple_wikilabel_wikiobj[0] 216 | seq_tuple_wikilabel_wikiobj = [(wiki_article_name, wiki_obj_index, wiki_article_obj) 217 | for wiki_obj_index, wiki_article_obj in enumerate(self.seq_wiki_article_name) 218 | for wiki_article_name in wiki_article_obj.candidate_article_name] 219 | self.label2WikiArticleObj = {} # type: Dict[str,List[Tuple[int, WikipediaArticleObject]]] 220 | for wiki_label_name, g_obj in itertools.groupby(sorted(seq_tuple_wikilabel_wikiobj, key=function_key), key=function_key): 221 | self.label2WikiArticleObj[wiki_label_name] = [(tuple_wikilabel_wiki_obj[1], tuple_wikilabel_wiki_obj[2]) 222 | for tuple_wikilabel_wiki_obj in g_obj] 223 | else: 224 | self.label2WikiArticleObj = None 225 | 226 | 227 | def __generate_edge_routes(self)->List[Tuple[Tuple[int,int]]]: 228 | """* What you can do 229 | - You can generate route over lattice graph. 230 | 231 | * Output 232 | - [( (row_index_matrix, column_index_matrix) )] 233 | """ 234 | def judge_proper_route(index_tuple_route:Tuple[Tuple[int,int]])->bool: 235 | """It picks up only sequence whose states meet condition state_t == state_t_plus_1 236 | """ 237 | judge_flag = True 238 | for edge_index in range(0, len(index_tuple_route)-1): 239 | state_name_t_plus_at_now = self.index_dictionary_obj.index2state['index2column'][index_tuple_route[edge_index][1]] 240 | state_name_t_at_next = self.index_dictionary_obj.index2state['index2row'][index_tuple_route[edge_index+1][0]] 241 | if state_name_t_plus_at_now == state_name_t_at_next: 242 | pass 243 | else: 244 | judge_flag = False 245 | 246 | return judge_flag 247 | 248 | index_tuple_of_edge = [[edge_obj.to_tuple() for edge_obj in list_edge_candidate] 249 | for list_edge_candidate in self.seq_edge_groups] 250 | index_tuple_of_route_candidates = product(*index_tuple_of_edge) 251 | # select only a route where state_t_plus == state_t_next 252 | index_tuple_of_route = list(filter(judge_proper_route, index_tuple_of_route_candidates)) 253 | return index_tuple_of_route 254 | 255 | def __get_score(self, row:int, column:int)->float: 256 | return self.transition_matrix[row, column] 257 | 258 | def __compute_route_score(self, index_tuple_route:Tuple[Tuple[int,int]])->float: 259 | """* What you can do 260 | - You get score of a route 261 | """ 262 | seq_score = [self.__get_score(index_tuple[0], index_tuple[1]) for index_tuple in index_tuple_route] 263 | return sum(seq_score) 264 | 265 | def __generate_state_name_sequence(self, index_tuple_route:Tuple[Tuple[int,int]])->List[Tuple[str, str, float]]: 266 | """* What you can do 267 | - You get sequence of label & score tuple (label_t, label_t_plus_1, score) 268 | """ 269 | seq_state_name_score = [ 270 | (self.index_dictionary_obj.index2state['index2row'][index_tuple[0]][1], 271 | self.index_dictionary_obj.index2state['index2column'][index_tuple[1]][1], 272 | self.__get_score(index_tuple[0], index_tuple[1])) 273 | for index_tuple in index_tuple_route] 274 | return seq_state_name_score 275 | 276 | def __generate_label_sequence(self, seq_score_tuple:List[Tuple[str, str, float]])->List[str]: 277 | """* What you can do 278 | - You generate list of label 279 | """ 280 | seq_label = [] 281 | for index in range(0, len(seq_score_tuple)): 282 | if index == 0: 283 | seq_label.append(seq_score_tuple[index][0]) 284 | elif index+1 == len(seq_score_tuple): 285 | seq_label.append(seq_score_tuple[index][0]) 286 | seq_label.append(seq_score_tuple[index][1]) 287 | else: 288 | seq_label.append(seq_score_tuple[index][0]) 289 | 290 | return seq_label 291 | 292 | def __generate_wiki_article_object_sequence(self, seq_label_name:List[str])->List[WikipediaArticleObject]: 293 | """* What you can do 294 | - You generate list of WikipediaArticleObject. They are already disambiguated. 295 | """ 296 | seq_wiki_article_obj = [None] * len(seq_label_name) 297 | for l_index, label in enumerate(seq_label_name): 298 | seq_tuple_index_wikiobj = copy.deepcopy(self.label2WikiArticleObj[label]) 299 | wiki_article_obj_in_index = [tuple_index_wikiobj for tuple_index_wikiobj in seq_tuple_index_wikiobj if tuple_index_wikiobj[0]==l_index][0][1] # type: WikipediaArticleObject 300 | wiki_article_obj_in_index.article_name = label 301 | seq_wiki_article_obj[l_index] = wiki_article_obj_in_index 302 | 303 | return list(filter(lambda element: True if not element is None else False, seq_wiki_article_obj)) 304 | 305 | def get_score_routes(self)->List[SequenceScore]: 306 | """* What you can do 307 | - You generate list of SequenceScore. 308 | - Each SequenceScore has information of one-route and its score. 309 | """ 310 | ### make list beforehand to make this process faster ### 311 | sequence_score_objects = [None] * len(self.index_tuple_route) 312 | for l_index, route in enumerate(self.index_tuple_route): 313 | route_score = self.__compute_route_score(route) 314 | seq_score_tuple = self.__generate_state_name_sequence(route) 315 | seq_label_name = self.__generate_label_sequence(seq_score_tuple=seq_score_tuple) 316 | 317 | if not self.seq_wiki_article_name is None: 318 | label_object = self.__generate_wiki_article_object_sequence(seq_label_name) 319 | else: 320 | label_object = seq_label_name 321 | 322 | sequence_score_objects[l_index] = SequenceScore(seq_words=label_object, 323 | seq_transition_score=seq_score_tuple, 324 | sequence_score=route_score) 325 | 326 | seq_result_score_object = list(filter(lambda element_obj: True if not element_obj is None else False, sequence_score_objects)) 327 | return seq_result_score_object -------------------------------------------------------------------------------- /word2vec_wikification_py/search_wiki_pages.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Dict 2 | from pymysql import Connection, cursors 3 | from typing import Tuple 4 | 5 | def __generate_window_size(target_token:List[str])->List[int]: 6 | return [i for i in range(1, len(target_token)+1)] 7 | 8 | 9 | def __generate_index_range(list_index:List[int], window_size:int)->List[List[int]]: 10 | """generate set of list, which describes range-index of candidate tokens 11 | """ 12 | search_index = [] 13 | for start_i in range(0, len(list_index)): 14 | end_i = start_i + window_size 15 | if end_i <= len(list_index): 16 | search_index.append(list(range(start_i, end_i))) 17 | return search_index 18 | 19 | 20 | def search_function_from_wikipedia_database(token: str, 21 | wikipedia_db_connector: Connection, 22 | page_table_name: str = 'page', 23 | page_table_redirect: str = 'redirect') -> List[str]: 24 | """* 25 | 部分文字検索をするときに使う 26 | """ 27 | def decode_string(string): 28 | try: 29 | unicode_string = string.decode('utf-8') 30 | return unicode_string 31 | except: 32 | return None 33 | 34 | 35 | # It searches article name with exact same name as token 36 | cursor = wikipedia_db_connector.cursor() # type: cursors 37 | page_query = """SELECT page_id, page_title, page_is_redirect FROM {} WHERE (page_title = %s OR page_title LIKE %s) AND page_namespace = 0""".format(page_table_name) 38 | cursor.execute(page_query, (token, '{}\_(%)'.format(token))) 39 | fetched_records = list(cursor.fetchall()) 40 | page_names = [page_id_title[1] for page_id_title in fetched_records if page_id_title[2]==0] 41 | redirect_names = [page_id_title[0] for page_id_title in fetched_records if page_id_title[2]==1] 42 | cursor.close() 43 | 44 | if not redirect_names == []: 45 | cursor = wikipedia_db_connector.cursor() # type: cursors 46 | select_query = """SELECT rd_title FROM {} WHERE rd_from IN %s""".format(page_table_redirect) 47 | cursor.execute(select_query, (redirect_names,)) 48 | article_page_names = [page_id_title[0] for page_id_title in cursor.fetchall()] 49 | cursor.close() 50 | else: 51 | article_page_names = [] 52 | 53 | article_name_string = list(set([decode_string(article_name) for article_name in page_names+article_page_names 54 | if not decode_string(article_name) is None])) 55 | 56 | return article_name_string 57 | 58 | 59 | def search_from_dictionary(target_tokens:List[str], 60 | string_normalization_function, 61 | partially_param_given_function)->Dict[str,Any]: 62 | """* 63 | """ 64 | list_window_size = __generate_window_size(target_tokens) 65 | found_index = [] # type: List[int] 66 | found_token_tuple_object = {} 67 | 68 | all_index = [i for i in range(0, len(target_tokens))] 69 | 70 | for w_size in list_window_size[::-1]: 71 | list_index = list(range(0, len(target_tokens))) 72 | candidate_index_list = __generate_index_range(list_index, window_size=w_size) 73 | for candidate_indices in candidate_index_list: 74 | if len(set(candidate_indices).intersection(set(found_index)))>=1: continue 75 | 76 | start_index = candidate_indices[0] 77 | end_index = candidate_indices[-1] + 1 78 | # candidate token 79 | search_token = ''.join(target_tokens[start_index:end_index]) 80 | normalized_search_token = string_normalization_function(search_token) 81 | # search token from ontology 82 | search_result = partially_param_given_function(normalized_search_token) 83 | # search result check 84 | # if length is more than 0, this is true 85 | if len(search_result)>0: 86 | # 見つかったオブジェクトを追加する 87 | found_token_tuple_object.update({normalized_search_token: search_result}) 88 | if all_index[-1] in candidate_indices: 89 | found_index += candidate_indices 90 | else: 91 | if len(candidate_indices)==1: 92 | found_index += candidate_indices 93 | else: 94 | found_index += candidate_indices 95 | # end condition 96 | if set(found_index) == set(all_index): break 97 | 98 | return found_token_tuple_object 99 | 100 | 101 | def complete_search(): 102 | pass --------------------------------------------------------------------------------