├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── configs
└── config.ini
├── evaluation
├── __init__.py
├── evaluate_el.py
├── evaluate_inference.py
└── evaluate_types.py
├── models
├── __init__.py
├── base.py
├── batch_normalizer.py
└── figer_model
│ ├── __init__.py
│ ├── coherence_model.py
│ ├── coldStart.py
│ ├── context_encoder.py
│ ├── el_model.py
│ ├── entity_posterior.py
│ ├── joint_context.py
│ ├── labeling_model.py
│ ├── loss_optim.py
│ └── wiki_desc.py
├── neuralel.py
├── neuralel_jsonl.py
├── neuralel_tadir.py
├── overview.png
├── readers
├── Mention.py
├── __init__.py
├── config.py
├── crosswikis_test.py
├── inference_reader.py
├── test_reader.py
├── textanno_test_reader.py
├── utils.py
└── vocabloader.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | logs/
2 | deprecated/
3 | test.py
4 | readers/crosswikis_test.py
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | env/
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *,cover
51 | .hypothesis/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # IPython Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # dotenv
84 | .env
85 |
86 | # virtualenv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 |
93 | # Rope project settings
94 | .ropeproject
95 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Neural Entity Linking
2 | =====================
3 | Code for paper
4 | "[Entity Linking via Joint Encoding of Types, Descriptions, and Context](http://cogcomp.org/page/publication_view/817)", EMNLP '17
5 |
6 |
7 |
8 | ## Abstract
9 | For accurate entity linking, we need to capture the various information aspects of an entity, such as its description in a KB, contexts in which it is mentioned, and structured knowledge. Further, a linking system should work on texts from different domains without requiring domain-specific training data or hand-engineered features.
10 | In this work we present a neural, modular entity linking system that learns a unified dense representation for each entity using multiple sources of information, such as its description, contexts around its mentions, and fine-grained types. We show that the resulting entity linking system is effective at combining these sources, and performs competitively, sometimes out-performing current state-of-art-systems across datasets, without requiring any domain-specific training data or hand-engineered features. We also show that our model can effectively "embed" entities that are new to the KB, and is able to link its mentions accurately.
11 |
12 | ### Requirements
13 | * Python 3.4
14 | * Tensorflow 0.11 / 0.12
15 | * numpy
16 | * [CogComp-NLPy](https://github.com/CogComp/cogcomp-nlpy)
17 | * [Resources](https://drive.google.com/open?id=0Bz-t37BfgoTuSEtXOTI1SEF3VnM) - Pretrained models, vectors, etc.
18 |
19 | ### How to run inference
20 | 1. Clone the [code repository](https://github.com/nitishgupta/neural-el/)
21 | 1. Download the [resources folder](https://drive.google.com/open?id=0Bz-t37BfgoTuSEtXOTI1SEF3VnM).
22 | 2. In `config/config.ini` set the correct path to the resources folder you just downloaded
23 | 3. Run using:
24 | ```
25 | python3 neuralel.py --config=configs/config.ini --model_path=PATH_TO_MODEL_IN_RESOURCES --mode=inference
26 | ```
27 | The file `sampletest.txt` in the resources folder contains the text to be entity-linked. Currently we only support linking for a single document. Make sure the text in `sampletest.txt` is a single doc in a single line.
28 |
29 | ### Installing cogcomp-nlpy
30 | [CogComp-NLPy](https://github.com/CogComp/cogcomp-nlpy) is needed to detect named-entity mentions using NER. To install:
31 | ```
32 | pip install cython
33 | pip install ccg_nlpy
34 | ```
35 |
36 | ### Installing Tensorflow (CPU Version)
37 | To install tensorflow 0.12:
38 | ```
39 | export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp34-cp34m-linux_x86_64.whl
40 | (Regular) pip install --upgrade $TF_BINARY_URL
41 | (Conda) pip install --ignore-installed --upgrade $TF_BINARY_URL
42 | ```
43 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nitishgupta/neural-el/8c7c278acefa66238a75e805511ff26a567fd4e0/__init__.py
--------------------------------------------------------------------------------
/configs/config.ini:
--------------------------------------------------------------------------------
1 | [DEFAULT]
2 | resources_dir: /shared/bronte/ngupta19/neural-el_resources
3 |
4 | vocab_dir: ${resources_dir}/vocab
5 |
6 | word_vocab_pkl:${vocab_dir}/word_vocab.pkl
7 | kwnwid_vocab_pkl:${vocab_dir}/knwn_wid_vocab.pkl
8 | label_vocab_pkl:${vocab_dir}/label_vocab.pkl
9 | cohstringG9_vocab_pkl:${vocab_dir}/cohstringG9_vocab.pkl
10 | widWiktitle_pkl:${vocab_dir}/wid2Wikititle.pkl
11 |
12 | # One CWIKIs PATH NEEDED
13 | crosswikis_pruned_pkl: ${resources_dir}/crosswikis.pruned.pkl
14 |
15 |
16 | glove_pkl: ${resources_dir}/glove.pkl
17 | glove_word_vocab_pkl:${vocab_dir}/glove_word_vocab.pkl
18 |
19 | # Should be removed
20 | test_file: ${resources_dir}/sampletest.txt
21 |
--------------------------------------------------------------------------------
/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nitishgupta/neural-el/8c7c278acefa66238a75e805511ff26a567fd4e0/evaluation/__init__.py
--------------------------------------------------------------------------------
/evaluation/evaluate_el.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 |
5 | coarsetypes = set(["location", "person", "organization", "event"])
6 | coarseTypeIds = set([1,5,10,25])
7 |
8 |
9 | def computeMaxPriorContextJointEntities(
10 | WIDS_list, wikiTitles_list, condProbs_list, contextProbs_list,
11 | condContextJointProbs_list, verbose):
12 |
13 | assert (len(wikiTitles_list) == len(condProbs_list) ==
14 | len(contextProbs_list) == len(condContextJointProbs_list))
15 | numMens = len(wikiTitles_list)
16 | numWithCorrectInCand = 0
17 | accCond = 0
18 | accCont = 0
19 | accJoint = 0
20 |
21 | # [[(trueWT, maxPrWT, maxContWT, maxJWT), (trueWID, maxPrWID, maxContWID, maxJWID)]]
22 | evaluationWikiTitles = []
23 |
24 | sortedContextWTs = []
25 |
26 | for (WIDS, wTs, cProbs, contProbs, jointProbs) in zip(WIDS_list,
27 | wikiTitles_list,
28 | condProbs_list,
29 | contextProbs_list,
30 | condContextJointProbs_list):
31 | if wTs[0] == "":
32 | evaluationWikiTitles.append([tuple([""]*4), tuple([""]*4)])
33 |
34 | else:
35 | numWithCorrectInCand += 1
36 | trueWID = WIDS[0]
37 | trueEntity = wTs[0]
38 | tCondProb = cProbs[0]
39 | tContProb = contProbs[0]
40 | tJointProb = jointProbs[0]
41 |
42 | maxCondEntity_idx = np.argmax(cProbs)
43 | maxCondWID = WIDS[maxCondEntity_idx]
44 | maxCondEntity = wTs[maxCondEntity_idx]
45 | maxCondProb = cProbs[maxCondEntity_idx]
46 | if trueEntity == maxCondEntity and maxCondProb!=0.0:
47 | accCond+= 1
48 |
49 | maxContEntity_idx = np.argmax(contProbs)
50 | maxContWID = WIDS[maxContEntity_idx]
51 | maxContEntity = wTs[maxContEntity_idx]
52 | maxContProb = contProbs[maxContEntity_idx]
53 | if maxContEntity == trueEntity and maxContProb!=0.0:
54 | accCont+= 1
55 |
56 | contProbs_sortIdxs = np.argsort(contProbs).tolist()[::-1]
57 | sortContWTs = [wTs[i] for i in contProbs_sortIdxs]
58 | sortedContextWTs.append(sortContWTs)
59 |
60 | maxJointEntity_idx = np.argmax(jointProbs)
61 | maxJointWID = WIDS[maxJointEntity_idx]
62 | maxJointEntity = wTs[maxJointEntity_idx]
63 | maxJointProb = jointProbs[maxJointEntity_idx]
64 | maxJointCprob = cProbs[maxJointEntity_idx]
65 | maxJointContP = contProbs[maxJointEntity_idx]
66 | if maxJointEntity == trueEntity and maxJointProb!=0:
67 | accJoint+= 1
68 |
69 | predWTs = (trueEntity, maxCondEntity, maxContEntity, maxJointEntity)
70 | predWIDs = (trueWID, maxCondWID, maxContWID, maxJointWID)
71 | evaluationWikiTitles.append([predWTs, predWIDs])
72 |
73 | if verbose:
74 | print("True: {} c:{:.3f} cont:{:.3f} J:{:.3f}".format(
75 | trueEntity, tCondProb, tContProb, tJointProb))
76 | print("Pred: {} c:{:.3f} cont:{:.3f} J:{:.3f}".format(
77 | maxJointEntity, maxJointCprob, maxJointContP, maxJointProb))
78 | print("maxPrior: {} p:{:.3f} maxCont:{} p:{:.3f}".format(
79 | maxCondEntity, maxCondProb, maxContEntity, maxContProb))
80 | #AllMentionsProcessed
81 | if numWithCorrectInCand != 0:
82 | accCond = accCond/float(numWithCorrectInCand)
83 | accCont = accCont/float(numWithCorrectInCand)
84 | accJoint = accJoint/float(numWithCorrectInCand)
85 | else:
86 | accCond = 0.0
87 | accCont = 0.0
88 | accJoint = 0.0
89 |
90 | print("Total Mentions : {} In Knwn Mentions : {}".format(
91 | numMens, numWithCorrectInCand))
92 | print("Priors Accuracy: {:.3f} Context Accuracy: {:.3f} Joint Accuracy: {:.3f}".format(
93 | (accCond), accCont, accJoint))
94 |
95 | assert len(evaluationWikiTitles) == numMens
96 |
97 | return (evaluationWikiTitles, sortedContextWTs)
98 |
99 |
100 | def convertWidIdxs2WikiTitlesAndWIDs(widIdxs_list, idx2knwid, wid2WikiTitle):
101 | wikiTitles_list = []
102 | WIDS_list = []
103 | for widIdxs in widIdxs_list:
104 | wids = [idx2knwid[wididx] for wididx in widIdxs]
105 | wikititles = [wid2WikiTitle[idx2knwid[wididx]] for wididx in widIdxs]
106 | WIDS_list.append(wids)
107 | wikiTitles_list.append(wikititles)
108 |
109 | return (WIDS_list, wikiTitles_list)
110 |
111 |
112 | def _normalizeProbList(probList):
113 | norm_probList = []
114 | for probs in probList:
115 | s = sum(probs)
116 | if s != 0.0:
117 | n_p = [p/s for p in probs]
118 | norm_probList.append(n_p)
119 | else:
120 | norm_probList.append(probs)
121 | return norm_probList
122 |
123 |
124 | def computeFinalEntityProbs(condProbs_list, contextProbs_list, alpha=0.5):
125 | condContextJointProbs_list = []
126 | condProbs_list = _normalizeProbList(condProbs_list)
127 | contextProbs_list = _normalizeProbList(contextProbs_list)
128 |
129 | for (cprobs, contprobs) in zip(condProbs_list, contextProbs_list):
130 | #condcontextprobs = [(alpha*x + (1-alpha)*y) for (x,y) in zip(cprobs, contprobs)]
131 | condcontextprobs = [(x + y - x*y) for (x,y) in zip(cprobs, contprobs)]
132 | sum_condcontextprobs = sum(condcontextprobs)
133 | if sum_condcontextprobs != 0.0:
134 | condcontextprobs = [float(x)/sum_condcontextprobs for x in condcontextprobs]
135 | condContextJointProbs_list.append(condcontextprobs)
136 | return condContextJointProbs_list
137 |
138 |
139 | def computeFinalEntityScores(condProbs_list, contextProbs_list, alpha=0.5):
140 | condContextJointProbs_list = []
141 | condProbs_list = _normalizeProbList(condProbs_list)
142 | #contextProbs_list = _normalizeProbList(contextProbs_list)
143 |
144 | for (cprobs, contprobs) in zip(condProbs_list, contextProbs_list):
145 | condcontextprobs = [(alpha*x + (1-alpha)*y) for (x,y) in zip(cprobs, contprobs)]
146 | sum_condcontextprobs = sum(condcontextprobs)
147 | if sum_condcontextprobs != 0.0:
148 | condcontextprobs = [float(x)/sum_condcontextprobs for x in condcontextprobs]
149 | condContextJointProbs_list.append(condcontextprobs)
150 | return condContextJointProbs_list
151 |
152 |
153 | ##############################################################################
154 |
155 | def evaluateEL(condProbs_list, widIdxs_list, contextProbs_list,
156 | idx2knwid, wid2WikiTitle, verbose=False):
157 | ''' Prior entity prob, True and candidate entity WIDs, Predicted ent. prob.
158 | using context for each of te 30 candidates. First element in the candidates is
159 | the true entity.
160 | Args:
161 | For each mention:
162 | condProbs_list: List of prior probs for 30 candidates.
163 | widIdxs_list: List of candidate widIdxs probs for 30 candidates.
164 | contextProbss_list: List of candidate prob. using context
165 | idx2knwid: Map for widIdx -> WID
166 | wid2WikiTitle: Map from WID -> WikiTitle
167 | wid2TypeLabels: Map from WID -> List of Types
168 | '''
169 | print("Evaluating E-Linking ... ")
170 | (WIDS_list, wikiTitles_list) = convertWidIdxs2WikiTitlesAndWIDs(
171 | widIdxs_list, idx2knwid, wid2WikiTitle)
172 |
173 | alpha = 0.5
174 | #for alpha in alpha_range:
175 | print("Alpha : {}".format(alpha))
176 | jointProbs_list = computeFinalEntityProbs(
177 | condProbs_list, contextProbs_list, alpha=alpha)
178 |
179 | # evaluationWikiTitles:
180 | # For each mention [(trWT, maxPWT, maxCWT, maxJWT), (trWID, ...)]
181 | (evaluationWikiTitles,
182 | sortedContextWTs) = computeMaxPriorContextJointEntities(
183 | WIDS_list, wikiTitles_list, condProbs_list, contextProbs_list,
184 | jointProbs_list, verbose)
185 |
186 |
187 | '''
188 | condContextJointScores_list = computeFinalEntityScores(
189 | condProbs_list, contextProbs_list, alpha=alpha)
190 |
191 | evaluationWikiTitles = computeMaxPriorContextJointEntities(
192 | WIDS_list, wikiTitles_list, condProbs_list, contextProbs_list,
193 | condContextJointScores_list, verbose)
194 | '''
195 |
196 | return (jointProbs_list, evaluationWikiTitles, sortedContextWTs)
197 |
198 | ##############################################################################
199 |
200 |
201 | def f1(p,r):
202 | if p == 0.0 and r == 0.0:
203 | return 0.0
204 | return (float(2*p*r))/(p + r)
205 |
206 |
207 | def strict_pred(true_label_batch, pred_score_batch):
208 | ''' Calculates strict precision/recall/f1 given truth and predicted scores
209 | args
210 | true_label_batch: Binary Numpy matrix of [num_instances, num_labels]
211 | pred_score_batch: Real [0,1] numpy matrix of [num_instances, num_labels]
212 |
213 | return:
214 | correct_preds: Number of correct strict preds
215 | precision : correct_preds / num_instances
216 | '''
217 | (true_labels, pred_labels) = types_convert_mat_to_sets(
218 | true_label_batch, pred_score_batch)
219 |
220 | num_instanes = len(true_labels)
221 | correct_preds = 0
222 | for i in range(0, num_instanes):
223 | if true_labels[i] == pred_labels[i]:
224 | correct_preds += 1
225 | #endfor
226 | precision = recall = float(correct_preds)/num_instanes
227 |
228 | return correct_preds, precision
229 |
230 |
231 | def correct_context_prediction(entity_posterior_scores, batch_size):
232 | bool_array = np.equal(np.argmax(entity_posterior_scores, axis=1),
233 | [0]*batch_size)
234 | correct_preds = np.sum(bool_array)
235 | return correct_preds
236 |
--------------------------------------------------------------------------------
/evaluation/evaluate_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 |
5 |
6 | def computeMaxPriorContextJointEntities(
7 | WIDS_list, wikiTitles_list, condProbs_list, contextProbs_list,
8 | condContextJointProbs_list, verbose):
9 |
10 | assert (len(wikiTitles_list) == len(condProbs_list) ==
11 | len(contextProbs_list) == len(condContextJointProbs_list))
12 | numMens = len(wikiTitles_list)
13 |
14 | evaluationWikiTitles = []
15 | sortedContextWTs = []
16 | for (WIDS, wTs,
17 | cProbs, contProbs, jointProbs) in zip(WIDS_list,
18 | wikiTitles_list,
19 | condProbs_list,
20 | contextProbs_list,
21 | condContextJointProbs_list):
22 | # if wTs[0] == "":
23 | # evaluationWikiTitles.append([tuple([""]*3),
24 | # tuple([""]*3)])
25 | # else:
26 | maxCondEntity_idx = np.argmax(cProbs)
27 | maxCondWID = WIDS[maxCondEntity_idx]
28 | maxCondEntity = wTs[maxCondEntity_idx]
29 | maxCondProb = cProbs[maxCondEntity_idx]
30 |
31 | maxContEntity_idx = np.argmax(contProbs)
32 | maxContWID = WIDS[maxContEntity_idx]
33 | maxContEntity = wTs[maxContEntity_idx]
34 | maxContProb = contProbs[maxContEntity_idx]
35 |
36 | contProbs_sortIdxs = np.argsort(contProbs).tolist()[::-1]
37 | sortContWTs = [wTs[i] for i in contProbs_sortIdxs]
38 | sortedContextWTs.append(sortContWTs)
39 |
40 | maxJointEntity_idx = np.argmax(jointProbs)
41 | maxJointWID = WIDS[maxJointEntity_idx]
42 | maxJointEntity = wTs[maxJointEntity_idx]
43 | maxJointProb = jointProbs[maxJointEntity_idx]
44 | maxJointCprob = cProbs[maxJointEntity_idx]
45 | maxJointContP = contProbs[maxJointEntity_idx]
46 |
47 | predWTs = (maxCondEntity, maxContEntity, maxJointEntity)
48 | predWIDs = (maxCondWID, maxContWID, maxJointWID)
49 | predProbs = (maxCondProb, maxContProb, maxJointProb)
50 | evaluationWikiTitles.append([predWTs, predWIDs, predProbs])
51 |
52 | assert len(evaluationWikiTitles) == numMens
53 |
54 | return (evaluationWikiTitles, sortedContextWTs)
55 |
56 | def convertWidIdxs2WikiTitlesAndWIDs(widIdxs_list, idx2knwid, wid2WikiTitle):
57 | wikiTitles_list = []
58 | WIDS_list = []
59 | for widIdxs in widIdxs_list:
60 | wids = [idx2knwid[wididx] for wididx in widIdxs]
61 | wikititles = [wid2WikiTitle[idx2knwid[wididx]] for wididx in widIdxs]
62 | WIDS_list.append(wids)
63 | wikiTitles_list.append(wikititles)
64 |
65 | return (WIDS_list, wikiTitles_list)
66 |
67 | def _normalizeProbList(probList):
68 | norm_probList = []
69 | for probs in probList:
70 | s = sum(probs)
71 | if s != 0.0:
72 | n_p = [p/s for p in probs]
73 | norm_probList.append(n_p)
74 | else:
75 | norm_probList.append(probs)
76 | return norm_probList
77 |
78 |
79 | def computeFinalEntityProbs(condProbs_list, contextProbs_list):
80 | condContextJointProbs_list = []
81 | condProbs_list = _normalizeProbList(condProbs_list)
82 | contextProbs_list = _normalizeProbList(contextProbs_list)
83 |
84 | for (cprobs, contprobs) in zip(condProbs_list, contextProbs_list):
85 | condcontextprobs = [(x + y - x*y) for (x,y) in zip(cprobs, contprobs)]
86 | sum_condcontextprobs = sum(condcontextprobs)
87 | if sum_condcontextprobs != 0.0:
88 | condcontextprobs = [float(x)/sum_condcontextprobs for x in condcontextprobs]
89 | condContextJointProbs_list.append(condcontextprobs)
90 | return condContextJointProbs_list
91 |
92 |
93 | def computeFinalEntityScores(condProbs_list, contextProbs_list, alpha=0.5):
94 | condContextJointProbs_list = []
95 | condProbs_list = _normalizeProbList(condProbs_list)
96 | #contextProbs_list = _normalizeProbList(contextProbs_list)
97 |
98 | for (cprobs, contprobs) in zip(condProbs_list, contextProbs_list):
99 | condcontextprobs = [(alpha*x + (1-alpha)*y) for (x,y) in zip(cprobs, contprobs)]
100 | sum_condcontextprobs = sum(condcontextprobs)
101 | if sum_condcontextprobs != 0.0:
102 | condcontextprobs = [float(x)/sum_condcontextprobs for x in condcontextprobs]
103 | condContextJointProbs_list.append(condcontextprobs)
104 | return condContextJointProbs_list
105 |
106 |
107 | #############################################################################
108 |
109 |
110 | def evaluateEL(condProbs_list, widIdxs_list, contextProbs_list,
111 | idx2knwid, wid2WikiTitle, verbose=False):
112 | ''' Prior entity prob, True and candidate entity WIDs, Predicted ent. prob.
113 | using context for each of te 30 candidates. First element in the candidates
114 | is the true entity.
115 | Args:
116 | For each mention:
117 | condProbs_list: List of prior probs for 30 candidates.
118 | widIdxs_list: List of candidate widIdxs probs for 30 candidates.
119 | contextProbss_list: List of candidate prob. using context
120 | idx2knwid: Map for widIdx -> WID
121 | wid2WikiTitle: Map from WID -> WikiTitle
122 | '''
123 | # print("Evaluating E-Linking ... ")
124 | (WIDS_list, wikiTitles_list) = convertWidIdxs2WikiTitlesAndWIDs(
125 | widIdxs_list, idx2knwid, wid2WikiTitle)
126 |
127 | jointProbs_list = computeFinalEntityProbs(condProbs_list,
128 | contextProbs_list)
129 |
130 | (evaluationWikiTitles,
131 | sortedContextWTs) = computeMaxPriorContextJointEntities(
132 | WIDS_list, wikiTitles_list, condProbs_list, contextProbs_list,
133 | jointProbs_list, verbose)
134 |
135 | return (jointProbs_list, evaluationWikiTitles, sortedContextWTs)
136 |
137 | ##############################################################################
138 |
139 |
140 | def f1(p,r):
141 | if p == 0.0 and r == 0.0:
142 | return 0.0
143 | return (float(2*p*r))/(p + r)
144 |
145 |
146 | def strict_pred(true_label_batch, pred_score_batch):
147 | ''' Calculates strict precision/recall/f1 given truth and predicted scores
148 | args
149 | true_label_batch: Binary Numpy matrix of [num_instances, num_labels]
150 | pred_score_batch: Real [0,1] numpy matrix of [num_instances, num_labels]
151 |
152 | return:
153 | correct_preds: Number of correct strict preds
154 | precision : correct_preds / num_instances
155 | '''
156 | (true_labels, pred_labels) = types_convert_mat_to_sets(
157 | true_label_batch, pred_score_batch)
158 |
159 | num_instanes = len(true_labels)
160 | correct_preds = 0
161 | for i in range(0, num_instanes):
162 | if true_labels[i] == pred_labels[i]:
163 | correct_preds += 1
164 | #endfor
165 | precision = recall = float(correct_preds)/num_instanes
166 |
167 | return correct_preds, precision
168 |
169 |
170 | def correct_context_prediction(entity_posterior_scores, batch_size):
171 | bool_array = np.equal(np.argmax(entity_posterior_scores, axis=1),
172 | [0]*batch_size)
173 | correct_preds = np.sum(bool_array)
174 | return correct_preds
175 |
--------------------------------------------------------------------------------
/evaluation/evaluate_types.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 |
5 | coarsetypes = set(["location", "person", "organization", "event"])
6 | coarseTypeIds = set([1,5,10,25])
7 |
8 | def _convertTypeMatToTypeSets(typesscore_mat, idx2label, threshold):
9 | ''' Gets true labels and pred scores in numpy matrix and converts to list
10 | args
11 | true_label_batch: Binary Numpy matrix of [num_instances, num_labels]
12 | pred_score_batch: Real [0,1] numpy matrix of [num_instances, num_labels]
13 |
14 | return:
15 | true_labels: List of list of true label (indices) for batch of instances
16 | pred_labels : List of list of pred label (indices) for batch of instances
17 | (threshold = 0.5)
18 | '''
19 | labels = []
20 | for i in typesscore_mat:
21 | # i in array of label_vals for i-th example
22 | labels_i = []
23 | max_idx = -1
24 | max_val = -1
25 | for (label_idx, val) in enumerate(i):
26 | if val >= threshold:
27 | labels_i.append(idx2label[label_idx])
28 | if val > max_val:
29 | max_idx = label_idx
30 | max_val = val
31 | if len(labels_i) == 0:
32 | labels_i.append(idx2label[max_idx])
33 | labels.append(set(labels_i))
34 |
35 | '''
36 | assert 0.0 < threshold <= 1.0
37 | boolmat = typesscore_mat >= threshold
38 | boollist = boolmat.tolist()
39 | num_instanes = len(boollist)
40 | labels = []
41 | for i in range(0, num_instanes):
42 | labels_i = [idx2label[i] for i, x in enumerate(boollist[i]) if x]
43 | labels.append(set(labels_i))
44 | ##
45 | '''
46 | return labels
47 |
48 | def convertTypesScoreMatLists_TypeSets(typeScoreMat_list, idx2label, threshold):
49 | '''
50 | Take list of type scores numpy mat (per batch) as ouput from Tensorflow.
51 | Convert into list of type sets for each mention based on the thresold
52 |
53 | Return:
54 | typeSets_list: Size=num_instances. Each instance is set of type labels for mention
55 | '''
56 |
57 | typeSets_list = []
58 | for typeScoreMat in typeScoreMat_list:
59 | typeLabels_list = _convertTypeMatToTypeSets(typeScoreMat,
60 | idx2label, threshold)
61 | typeSets_list.extend(typeLabels_list)
62 | return typeSets_list
63 |
64 |
65 | def typesPredictionStats(pred_labels, true_labels):
66 | '''
67 | args
68 | true_label_batch: Binary Numpy matrix of [num_instances, num_labels]
69 | pred_score_batch: Real [0,1] numpy matrix of [num_instances, num_labels]
70 | '''
71 |
72 | # t_hat \interesect t
73 | t_intersect = 0
74 | t_hat_count = 0
75 | t_count = 0
76 | t_t_hat_exact = 0
77 | loose_macro_p = 0.0
78 | loose_macro_r = 0.0
79 | num_instances = len(true_labels)
80 | for i in range(0, num_instances):
81 | intersect = len(true_labels[i].intersection(pred_labels[i]))
82 | t_h_c = len(pred_labels[i])
83 | t_c = len(true_labels[i])
84 | t_intersect += intersect
85 | t_hat_count += t_h_c
86 | t_count += t_c
87 | exact = 1 if (true_labels[i] == pred_labels[i]) else 0
88 | t_t_hat_exact += exact
89 | if len(pred_labels[i]) > 0:
90 | loose_macro_p += intersect / float(t_h_c)
91 | if len(true_labels[i]) > 0:
92 | loose_macro_r += intersect / float(t_c)
93 |
94 | return (t_intersect, t_t_hat_exact, t_hat_count, t_count,
95 | loose_macro_p, loose_macro_r)
96 |
97 | def typesEvaluationMetrics(pred_TypeSetsList, true_TypeSetsList):
98 | num_instances = len(true_TypeSetsList)
99 | (t_i, t_th_exact, t_h_c, t_c, l_m_p, l_m_r) = typesPredictionStats(
100 | pred_labels=pred_TypeSetsList, true_labels=true_TypeSetsList)
101 | strict = float(t_th_exact)/float(num_instances)
102 | loose_macro_p = l_m_p / float(num_instances)
103 | loose_macro_r = l_m_r / float(num_instances)
104 | loose_macro_f = f1(loose_macro_p, loose_macro_r)
105 | if t_h_c > 0:
106 | loose_micro_p = float(t_i)/float(t_h_c)
107 | else:
108 | loose_micro_p = 0
109 | if t_c > 0:
110 | loose_micro_r = float(t_i)/float(t_c)
111 | else:
112 | loose_micro_r = 0
113 | loose_micro_f = f1(loose_micro_p, loose_micro_r)
114 |
115 | return (strict, loose_macro_p, loose_macro_r, loose_macro_f, loose_micro_p,
116 | loose_micro_r, loose_micro_f)
117 |
118 |
119 | def performTypingEvaluation(predLabelScoresnumpymat_list, idx2label):
120 | '''
121 | Args: List of numpy mat, one for ech batch, for true and pred type scores
122 | trueLabelScoresnumpymat_list: List of score matrices output by tensorflow
123 | predLabelScoresnumpymat_list: List of score matrices output by tensorflow
124 | '''
125 | pred_TypeSetsList = convertTypesScoreMatLists_TypeSets(
126 | typeScoreMat_list=predLabelScoresnumpymat_list, idx2label=idx2label,
127 | threshold=0.75)
128 |
129 | return pred_TypeSetsList
130 |
131 |
132 | def evaluate(predLabelScoresnumpymat_list, idx2label):
133 | # print("Evaluating Typing ... ")
134 | pred_TypeSetsList = convertTypesScoreMatLists_TypeSets(
135 | typeScoreMat_list=predLabelScoresnumpymat_list, idx2label=idx2label,
136 | threshold=0.75)
137 |
138 | return pred_TypeSetsList
139 |
140 |
141 | def f1(p,r):
142 | if p == 0.0 and r == 0.0:
143 | return 0.0
144 | return (float(2*p*r))/(p + r)
145 |
146 |
147 | def strict_pred(true_label_batch, pred_score_batch):
148 | ''' Calculates strict precision/recall/f1 given truth and predicted scores
149 | args
150 | true_label_batch: Binary Numpy matrix of [num_instances, num_labels]
151 | pred_score_batch: Real [0,1] numpy matrix of [num_instances, num_labels]
152 |
153 | return:
154 | correct_preds: Number of correct strict preds
155 | precision : correct_preds / num_instances
156 | '''
157 | (true_labels, pred_labels) = types_convert_mat_to_sets(
158 | true_label_batch, pred_score_batch)
159 |
160 | num_instanes = len(true_labels)
161 | correct_preds = 0
162 | for i in range(0, num_instanes):
163 | if true_labels[i] == pred_labels[i]:
164 | correct_preds += 1
165 | #endfor
166 | precision = recall = float(correct_preds)/num_instanes
167 |
168 | return correct_preds, precision
169 |
170 | def correct_context_prediction(entity_posterior_scores, batch_size):
171 | bool_array = np.equal(np.argmax(entity_posterior_scores, axis=1),
172 | [0]*batch_size)
173 | correct_preds = np.sum(bool_array)
174 | return correct_preds
175 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nitishgupta/neural-el/8c7c278acefa66238a75e805511ff26a567fd4e0/models/__init__.py
--------------------------------------------------------------------------------
/models/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import tensorflow as tf
4 |
5 | class Model(object):
6 | """Abstract object representing an Reader model."""
7 | def __init__(self):
8 | pass
9 |
10 | # def get_model_dir(self):
11 | # model_dir = self.dataset
12 | # for attr in self._attrs:
13 | # if hasattr(self, attr):
14 | # model_dir += "/%s=%s" % (attr, getattr(self, attr))
15 | # return model_dir
16 |
17 | def get_model_dir(self, attrs=None):
18 | model_dir = self.dataset
19 | if attrs == None:
20 | attrs = self._attrs
21 | for attr in attrs:
22 | if hasattr(self, attr):
23 | model_dir += "/%s=%s" % (attr, getattr(self, attr))
24 | return model_dir
25 |
26 | def get_log_dir(self, root_log_dir, attrs=None):
27 | model_dir = self.get_model_dir(attrs=attrs)
28 | log_dir = os.path.join(root_log_dir, model_dir)
29 | if not os.path.exists(log_dir):
30 | os.makedirs(log_dir)
31 | return log_dir
32 |
33 | def save(self, saver, checkpoint_dir, attrs=None, global_step=None):
34 | print(" [*] Saving checkpoints...")
35 | model_name = type(self).__name__
36 | model_dir = self.get_model_dir(attrs=attrs)
37 |
38 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
39 | if not os.path.exists(checkpoint_dir):
40 | os.makedirs(checkpoint_dir)
41 | saver.save(self.sess, os.path.join(checkpoint_dir, model_name),
42 | global_step=global_step)
43 | print(" [*] Saving done...")
44 |
45 | def initialize(self, log_dir="./logs"):
46 | self.merged_sum = tf.merge_all_summaries()
47 | self.writer = tf.train.SummaryWriter(log_dir, self.sess.graph_def)
48 |
49 | tf.initialize_all_variables().run()
50 | self.load(self.checkpoint_dir)
51 |
52 | start_iter = self.step.eval()
53 |
54 | def load(self, saver, checkpoint_dir, attrs=None):
55 | print(" [*] Loading checkpoints...")
56 | model_dir = self.get_model_dir(attrs=attrs)
57 | # /checkpointdir/attrs=values/
58 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
59 | print(" [#] Checkpoint Dir : {}".format(checkpoint_dir))
60 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
61 | if ckpt and ckpt.model_checkpoint_path:
62 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
63 | print("ckpt_name: {}".format(ckpt_name))
64 | saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
65 | print(" [*] Load SUCCESS")
66 | return True
67 | else:
68 | print(" [!] Load failed...")
69 | return False
70 |
71 | def loadCKPTPath(self, saver, ckptPath=None):
72 | assert ckptPath != None
73 | print(" [#] CKPT Path : {}".format(ckptPath))
74 | if os.path.exists(ckptPath):
75 | saver.restore(self.sess, ckptPath)
76 | print(" [*] Load SUCCESS")
77 | return True
78 | else:
79 | print(" [*] CKPT Path doesn't exist")
80 | return False
81 |
82 | def loadSpecificCKPT(self, saver, checkpoint_dir, ckptName=None, attrs=None):
83 | assert ckptName != None
84 | model_dir = self.get_model_dir(attrs=attrs)
85 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
86 | checkpoint_path = os.path.join(checkpoint_dir, ckptName)
87 | print(" [#] CKPT Path : {}".format(checkpoint_path))
88 | if os.path.exists(checkpoint_path):
89 | saver.restore(self.sess, checkpoint_path)
90 |
91 | print(" [*] Load SUCCESS")
92 | return True
93 | else:
94 | print(" [*] CKPT Path doesn't exist")
95 | return False
96 |
97 |
98 |
99 | def collect_scope(self, scope_name, graph=None, var_type=tf.GraphKeys.VARIABLES):
100 | if graph == None:
101 | graph = tf.get_default_graph()
102 |
103 | var_list = graph.get_collection(var_type, scope=scope_name)
104 |
105 | assert_str = "No variable exists with name_scope '{}'".format(scope_name)
106 | assert len(var_list) != 0, assert_str
107 |
108 | return var_list
109 |
110 | def get_scope_var_name_set(self, var_name):
111 | clean_var_num = var_name.split(":")[0]
112 | scopes_names = clean_var_num.split("/")
113 | return set(scopes_names)
114 |
115 |
116 | def scope_vars_list(self, scope_name, var_list):
117 | scope_var_list = []
118 | for var in var_list:
119 | scope_var_name = self.get_scope_var_name_set(var.name)
120 | if scope_name in scope_var_name:
121 | scope_var_list.append(var)
122 | return scope_var_list
123 |
--------------------------------------------------------------------------------
/models/batch_normalizer.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | #from tensorflow.python import control_flow_ops
3 | from tensorflow.python.ops import control_flow_ops
4 |
5 | class BatchNorm():
6 | def __init__(self,
7 | input,
8 | training,
9 | decay=0.95,
10 | epsilon=1e-4,
11 | name='bn',
12 | reuse_vars=False):
13 |
14 | self.decay = decay
15 | self.epsilon = epsilon
16 | self.batchnorm(input, training, name, reuse_vars)
17 |
18 | def batchnorm(self, input, training, name, reuse_vars):
19 | with tf.variable_scope(name, reuse=reuse_vars) as bn:
20 | rank = len(input.get_shape().as_list())
21 | in_dim = input.get_shape().as_list()[-1]
22 |
23 | if rank == 2:
24 | self.axes = [0]
25 | elif rank == 4:
26 | self.axes = [0, 1, 2]
27 | else:
28 | raise ValueError('Input tensor must have rank 2 or 4.')
29 |
30 | self.offset = tf.get_variable(
31 | 'offset',
32 | shape=[in_dim],
33 | initializer=tf.constant_initializer(0.0))
34 |
35 | self.scale = tf.get_variable(
36 | 'scale',
37 | shape=[in_dim],
38 | initializer=tf.constant_initializer(1.0))
39 |
40 | self.ema = tf.train.ExponentialMovingAverage(decay=self.decay)
41 |
42 | self.output = tf.cond(training,
43 | lambda: self.get_normalizer(input, True),
44 | lambda: self.get_normalizer(input, False))
45 |
46 | def get_normalizer(self, input, train_flag):
47 | if train_flag:
48 | self.mean, self.variance = tf.nn.moments(input, self.axes)
49 | # Fixes numerical instability if variance ~= 0, and it goes negative
50 | v = tf.nn.relu(self.variance)
51 | ema_apply_op = self.ema.apply([self.mean, self.variance])
52 | with tf.control_dependencies([ema_apply_op]):
53 | self.output_training = tf.nn.batch_normalization(
54 | input, self.mean, v, self.offset, self.scale,
55 | self.epsilon, 'normalizer_train'),
56 | return self.output_training
57 | else:
58 | self.output_test = tf.nn.batch_normalization(
59 | input, self.ema.average(self.mean),
60 | self.ema.average(self.variance), self.offset, self.scale,
61 | self.epsilon, 'normalizer_test')
62 | return self.output_test
63 |
64 | def get_batch_moments(self):
65 | return self.mean, self.variance
66 |
67 | def get_ema_moments(self):
68 | return self.ema.average(self.mean), self.ema.average(self.variance)
69 |
70 | def get_offset_scale(self):
71 | return self.offset, self.scale
72 |
--------------------------------------------------------------------------------
/models/figer_model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nitishgupta/neural-el/8c7c278acefa66238a75e805511ff26a567fd4e0/models/figer_model/__init__.py
--------------------------------------------------------------------------------
/models/figer_model/coherence_model.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | from models.base import Model
6 |
7 | class CoherenceModel(Model):
8 | '''
9 | Input is sparse tensor of mention strings in mention's document.
10 | Pass through feed forward and get a coherence representation
11 | (keep same as context_encoded_dim)
12 | '''
13 |
14 | def __init__(self, num_layers, batch_size, input_size,
15 | coherence_indices, coherence_values, coherence_matshape,
16 | context_encoded_dim, scope_name, device,
17 | dropout_keep_prob=1.0):
18 |
19 | # Num of layers in the encoder and decoder network
20 | self.num_layers = num_layers
21 | self.input_size = input_size
22 | self.context_encoded_dim = context_encoded_dim
23 | self.dropout_keep_prob = dropout_keep_prob
24 | self.batch_size = batch_size
25 |
26 | with tf.variable_scope(scope_name) as s, tf.device(device) as d:
27 | coherence_inp_tensor = tf.SparseTensor(coherence_indices,
28 | coherence_values,
29 | coherence_matshape)
30 |
31 | # Feed-forward Net for coherence_representation
32 | # Layer 1
33 | self.trans_weights = tf.get_variable(
34 | name="coherence_layer_0",
35 | shape=[self.input_size, self.context_encoded_dim],
36 | initializer=tf.random_normal_initializer(
37 | mean=0.0,
38 | stddev=1.0/(100.0)))
39 |
40 | # [B, context_encoded_dim]
41 | coherence_encoded = tf.sparse_tensor_dense_matmul(
42 | coherence_inp_tensor, self.trans_weights)
43 | coherence_encoded = tf.nn.relu(coherence_encoded)
44 |
45 | # Hidden Layers. NumLayers >= 2
46 | self.hidden_layers = []
47 | for i in range(1, self.num_layers):
48 | weight_matrix = tf.get_variable(
49 | name="coherence_layer_"+str(i),
50 | shape=[self.context_encoded_dim, self.context_encoded_dim],
51 | initializer=tf.random_normal_initializer(
52 | mean=0.0,
53 | stddev=1.0/(100.0)))
54 | self.hidden_layers.append(weight_matrix)
55 |
56 | for i in range(1, self.num_layers):
57 | coherence_encoded = tf.nn.dropout(
58 | coherence_encoded, keep_prob=self.dropout_keep_prob)
59 | coherence_encoded = tf.matmul(coherence_encoded,
60 | self.hidden_layers[i-1])
61 | coherence_encoded = tf.nn.relu(coherence_encoded)
62 |
63 | self.coherence_encoded = tf.nn.dropout(
64 | coherence_encoded, keep_prob=self.dropout_keep_prob)
65 |
--------------------------------------------------------------------------------
/models/figer_model/coldStart.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import tensorflow as tf
4 | import numpy as np
5 |
6 | import readers.utils as utils
7 | from evaluation import evaluate
8 | from evaluation import evaluate_el
9 | from evaluation import evaluate_types
10 | from models.base import Model
11 | from models.figer_model.context_encoder import ContextEncoderModel
12 | from models.figer_model.coherence_model import CoherenceModel
13 | from models.figer_model.wiki_desc import WikiDescModel
14 | from models.figer_model.joint_context import JointContextModel
15 | from models.figer_model.labeling_model import LabelingModel
16 | from models.figer_model.entity_posterior import EntityPosterior
17 | from models.figer_model.loss_optim import LossOptim
18 |
19 |
20 | class ColdStart(object):
21 | def __init__(self, figermodel):
22 | print("###### ENTERED THE COLD WORLD OF THE UNKNOWN ##############")
23 | # Object of the WikiELModel Class
24 | self.fm = figermodel
25 | self.coldDir = self.fm.reader.coldDir
26 | coldWid2DescVecs_pkl = os.path.join(self.coldDir, "coldwid2descvecs.pkl")
27 | self.coldWid2DescVecs = utils.load(coldWid2DescVecs_pkl)
28 | self.num_cold_entities = self.fm.reader.num_cold_entities
29 | self.batch_size = self.fm.batch_size
30 | (self.coldwid2idx,
31 | self.idx2coldwid) = (self.fm.reader.coldwid2idx, self.fm.reader.idx2coldwid)
32 |
33 | def _makeDescLossGraph(self):
34 | with tf.variable_scope("cold") as s:
35 | with tf.device(self.fm.device_placements['gpu']) as d:
36 | tf.set_random_seed(1)
37 |
38 | self.coldEnEmbsToAssign = tf.placeholder(
39 | tf.float32, [self.num_cold_entities, 200], name="coldEmbsAssignment")
40 |
41 | self.coldEnEmbs = tf.get_variable(
42 | name="cold_entity_embeddings",
43 | shape=[self.num_cold_entities, 200],
44 | initializer=tf.random_normal_initializer(mean=-0.25,
45 | stddev=1.0/(100.0)))
46 |
47 | self.assignColdEmbs = self.coldEnEmbs.assign(self.coldEnEmbsToAssign)
48 |
49 | self.trueColdEnIds = tf.placeholder(
50 | tf.int32, [self.batch_size], name="true_entities_idxs")
51 |
52 | # Should be a list of zeros
53 | self.softTrueIdxs = tf.placeholder(
54 | tf.int32, [self.batch_size], name="softmaxTrueEnsIdxs")
55 |
56 | # [B, D]
57 | self.trueColdEmb = tf.nn.embedding_lookup(
58 | self.coldEnEmbs, self.trueColdEnIds)
59 | # [B, 1, D]
60 | self.trueColdEmb_exp = tf.expand_dims(
61 | input=self.trueColdEmb, dim=1)
62 |
63 | self.label_scores = tf.matmul(self.trueColdEmb,
64 | self.fm.labeling_model.label_weights)
65 |
66 | self.labeling_losses = tf.nn.sigmoid_cross_entropy_with_logits(
67 | logits=self.label_scores,
68 | targets=self.fm.labels_batch,
69 | name="labeling_loss")
70 |
71 | self.labelingLoss = tf.reduce_sum(
72 | self.labeling_losses) / tf.to_float(self.batch_size)
73 |
74 | # [B, D]
75 | self.descEncoded = self.fm.wikidescmodel.desc_encoded
76 |
77 | ## Maximize sigmoid of dot-prod between true emb. and desc encoding
78 | descLosses = -tf.sigmoid(tf.reduce_sum(tf.mul(self.trueColdEmb, self.descEncoded), 1))
79 | self.descLoss = tf.reduce_sum(descLosses)/tf.to_float(self.batch_size)
80 |
81 |
82 | # L-2 Norm Loss
83 | self.trueEmbNormLoss = tf.reduce_sum(
84 | tf.square(self.trueColdEmb))/(tf.to_float(self.batch_size))
85 |
86 |
87 | ''' Concat trueColdEmb_exp to negKnownEmbs so that 0 is the true entity.
88 | Dotprod this emb matrix with descEncoded to get scores and apply softmax
89 | '''
90 |
91 | self.trcoldvars = self.fm.scope_vars_list(scope_name="cold",
92 | var_list=tf.trainable_variables())
93 |
94 | print("Vars in Training")
95 | for var in self.trcoldvars:
96 | print(var.name)
97 |
98 |
99 | self.optimizer = tf.train.AdamOptimizer(
100 | learning_rate=self.fm.learning_rate,
101 | name='AdamCold_')
102 |
103 | self.total_labeling_loss = self.labelingLoss + self.trueEmbNormLoss
104 | self.label_gvs = self.optimizer.compute_gradients(
105 | loss=self.total_labeling_loss, var_list=self.trcoldvars)
106 | self.labeling_optim_op = self.optimizer.apply_gradients(self.label_gvs)
107 |
108 | self.total_loss = self.labelingLoss + 100*self.descLoss + self.trueEmbNormLoss
109 | self.comb_gvs = self.optimizer.compute_gradients(
110 | loss=self.total_loss, var_list=self.trcoldvars)
111 | self.combined_optim_op = self.optimizer.apply_gradients(self.comb_gvs)
112 |
113 |
114 | self.allcoldvars = self.fm.scope_vars_list(scope_name="cold",
115 | var_list=tf.all_variables())
116 |
117 | print("All Vars in Cold")
118 | for var in self.allcoldvars:
119 | print(var.name)
120 |
121 | print("Loaded and graph made")
122 | ### GRAPH COMPLETE ###
123 |
124 | #############################################################################
125 | def _trainColdEmbFromTypes(self, epochsToTrain=5):
126 | print("Training Cold Entity Embeddings from Typing Info")
127 |
128 | epochsDone = self.fm.reader.val_epochs
129 |
130 | while self.fm.reader.val_epochs < epochsToTrain:
131 | (left_batch, left_lengths,
132 | right_batch, right_lengths,
133 | wids_batch,
134 | labels_batch, coherence_batch,
135 | wid_idxs_batch, wid_cprobs_batch) = self.fm.reader._next_padded_batch(data_type=1)
136 |
137 | trueColdWidIdxsBatch = []
138 | trueColdWidDescWordVecBatch = []
139 | for wid in wids_batch:
140 | trueColdWidIdxsBatch.append(self.coldwid2idx[wid])
141 | trueColdWidDescWordVecBatch.append(self.coldWid2DescVecs[wid])
142 |
143 | feed_dict = {self.trueColdEnIds: trueColdWidIdxsBatch,
144 | self.fm.labels_batch: labels_batch}
145 |
146 | fetch_tensor = [self.labelingLoss, self.trueEmbNormLoss]
147 |
148 | (fetches, _) = self.fm.sess.run([fetch_tensor,
149 | self.labeling_optim_op],
150 | feed_dict=feed_dict)
151 |
152 | labelingLoss = fetches[0]
153 | trueEmbNormLoss = fetches[1]
154 |
155 | print("LL : {} NormLoss : {}".format(labelingLoss, trueEmbNormLoss))
156 |
157 | newedone = self.fm.reader.val_epochs
158 | if newedone > epochsDone:
159 | print("Epochs : {}".format(newedone))
160 | epochsDone = newedone
161 |
162 | #############################################################################
163 | def _trainColdEmbFromTypesAndDesc(self, epochsToTrain=5):
164 | print("Training Cold Entity Embeddings from Typing Info")
165 |
166 | epochsDone = self.fm.reader.val_epochs
167 |
168 | while self.fm.reader.val_epochs < epochsToTrain:
169 | (left_batch, left_lengths,
170 | right_batch, right_lengths,
171 | wids_batch,
172 | labels_batch, coherence_batch,
173 | wid_idxs_batch, wid_cprobs_batch) = self.fm.reader._next_padded_batch(data_type=1)
174 |
175 | trueColdWidIdxsBatch = []
176 | trueColdWidDescWordVecBatch = []
177 | for wid in wids_batch:
178 | trueColdWidIdxsBatch.append(self.coldwid2idx[wid])
179 | trueColdWidDescWordVecBatch.append(self.coldWid2DescVecs[wid])
180 |
181 | feed_dict = {self.fm.wikidesc_batch: trueColdWidDescWordVecBatch,
182 | self.trueColdEnIds: trueColdWidIdxsBatch,
183 | self.fm.labels_batch: labels_batch}
184 |
185 | fetch_tensor = [self.labelingLoss, self.descLoss, self.trueEmbNormLoss]
186 |
187 | (fetches,_) = self.fm.sess.run([fetch_tensor,
188 | self.combined_optim_op],
189 | feed_dict=feed_dict)
190 |
191 | labelingLoss = fetches[0]
192 | descLoss = fetches[1]
193 | normLoss = fetches[2]
194 |
195 | print("L : {} D : {} NormLoss : {}".format(labelingLoss, descLoss, normLoss))
196 |
197 | newedone = self.fm.reader.val_epochs
198 | if newedone > epochsDone:
199 | print("Epochs : {}".format(newedone))
200 | epochsDone = newedone
201 |
202 | #############################################################################
203 |
204 | def runEval(self):
205 | print("Running Evaluations")
206 | self.fm.reader.reset_validation()
207 | correct = 0
208 | total = 0
209 | totnew = 0
210 | correctnew = 0
211 | while self.fm.reader.val_epochs < 1:
212 | (left_batch, left_lengths,
213 | right_batch, right_lengths,
214 | wids_batch,
215 | labels_batch, coherence_batch,
216 | wid_idxs_batch,
217 | wid_cprobs_batch) = self.fm.reader._next_padded_batch(data_type=1)
218 |
219 | trueColdWidIdxsBatch = []
220 |
221 | for wid in wids_batch:
222 | trueColdWidIdxsBatch.append(self.coldwid2idx[wid])
223 |
224 | feed_dict = {self.fm.sampled_entity_ids: wid_idxs_batch,
225 | self.fm.left_context_embeddings: left_batch,
226 | self.fm.right_context_embeddings: right_batch,
227 | self.fm.left_lengths: left_lengths,
228 | self.fm.right_lengths: right_lengths,
229 | self.fm.coherence_indices: coherence_batch[0],
230 | self.fm.coherence_values: coherence_batch[1],
231 | self.fm.coherence_matshape: coherence_batch[2],
232 | self.trueColdEnIds: trueColdWidIdxsBatch}
233 |
234 | fetch_tensor = [self.trueColdEmb,
235 | self.fm.joint_context_encoded,
236 | self.fm.posterior_model.sampled_entity_embeddings,
237 | self.fm.posterior_model.entity_scores]
238 |
239 | fetched_vals = self.fm.sess.run(fetch_tensor, feed_dict=feed_dict)
240 | [trueColdEmbs, # [B, D]
241 | context_encoded, # [B, D]
242 | neg_entity_embeddings, # [B, N, D]
243 | neg_entity_scores] = fetched_vals # [B, N]
244 |
245 | # [B]
246 | trueColdWidScores = np.sum(trueColdEmbs*context_encoded, axis=1)
247 | entity_scores = neg_entity_scores
248 | entity_scores[:,0] = trueColdWidScores
249 | context_entity_scores = np.exp(entity_scores)/np.sum(np.exp(entity_scores))
250 |
251 | maxIdxs = np.argmax(context_entity_scores, axis=1)
252 | for i in range(0, self.batch_size):
253 | total += 1
254 | if maxIdxs[i] == 0:
255 | correct += 1
256 |
257 | scores_withpriors = context_entity_scores + wid_cprobs_batch
258 |
259 | maxIdxs = np.argmax(scores_withpriors, axis=1)
260 | for i in range(0, self.batch_size):
261 | totnew += 1
262 | if maxIdxs[i] == 0:
263 | correctnew += 1
264 |
265 | print("Context T : {} C : {}".format(total, correct))
266 | print("WPriors T : {} C : {}".format(totnew, correctnew))
267 |
268 | ##############################################################################
269 |
270 | def typeBasedColdEmbExp(self, ckptName="FigerModel-20001"):
271 | ''' Train cold embeddings using wiki desc loss
272 | '''
273 | saver = tf.train.Saver(var_list=tf.all_variables())
274 |
275 | print("Loading Model ... ")
276 | if ckptName == None:
277 | print("Given CKPT Name")
278 | sys.exit()
279 | else:
280 | load_status = self.fm.loadSpecificCKPT(
281 | saver=saver, checkpoint_dir=self.fm.checkpoint_dir,
282 | ckptName=ckptName, attrs=self.fm._attrs)
283 | if not load_status:
284 | print("No model to load. Exiting")
285 | sys.exit(0)
286 |
287 | self._makeDescLossGraph()
288 | self.fm.sess.run(tf.initialize_variables(self.allcoldvars))
289 | self._trainColdEmbFromTypes(epochsToTrain=5)
290 |
291 | self.runEval()
292 |
293 | ##############################################################################
294 |
295 | def typeAndWikiDescBasedColdEmbExp(self, ckptName="FigerModel-20001"):
296 | ''' Train cold embeddings using wiki desc loss
297 | '''
298 | saver = tf.train.Saver(var_list=tf.all_variables())
299 |
300 | print("Loading Model ... ")
301 | if ckptName == None:
302 | print("Given CKPT Name")
303 | sys.exit()
304 | else:
305 | load_status = self.fm.loadSpecificCKPT(
306 | saver=saver, checkpoint_dir=self.fm.checkpoint_dir,
307 | ckptName=ckptName, attrs=self.fm._attrs)
308 | if not load_status:
309 | print("No model to load. Exiting")
310 | sys.exit(0)
311 |
312 | self._makeDescLossGraph()
313 | self.fm.sess.run(tf.initialize_variables(self.allcoldvars))
314 | self._trainColdEmbFromTypesAndDesc(epochsToTrain=5)
315 |
316 | self.runEval()
317 |
318 | # EVALUATION FOR COLD START WHEN INITIALIZING COLD EMB FROM WIKI DESC ENCODING
319 | def wikiDescColdEmbExp(self, ckptName="FigerModel-20001"):
320 | ''' Assign cold entity embeddings as wiki desc encoding
321 | '''
322 | assert self.batch_size == 1
323 | print("Loaded Cold Start Class. ")
324 | print("Size of cold entities : {}".format(len(self.coldWid2DescVecs)))
325 |
326 | saver = tf.train.Saver(var_list=tf.all_variables(), max_to_keep=5)
327 |
328 | print("Loading Model ... ")
329 | if ckptName == None:
330 | print("Given CKPT Name")
331 | sys.exit()
332 | else:
333 | load_status = self.fm.loadSpecificCKPT(
334 | saver=saver, checkpoint_dir=self.fm.checkpoint_dir,
335 | ckptName=ckptName, attrs=self.fm._attrs)
336 | if not load_status:
337 | print("No model to load. Exiting")
338 | sys.exit(0)
339 |
340 | iter_done = self.fm.global_step.eval()
341 | print("[#] Model loaded with iterations done: %d" % iter_done)
342 |
343 | self._makeDescLossGraph()
344 | self.fm.sess.run(tf.initialize_variables(self.allcoldvars))
345 |
346 | # Fill with encoded desc. in order of idx2coldwid
347 | print("Getting Encoded Description Vectors")
348 | descEncodedMatrix = []
349 | for idx in range(0, len(self.idx2coldwid)):
350 | wid = self.idx2coldwid[idx]
351 | desc_vec = self.coldWid2DescVecs[wid]
352 | feed_dict = {self.fm.wikidesc_batch: [desc_vec]}
353 | desc_encoded = self.fm.sess.run(self.fm.wikidescmodel.desc_encoded,
354 | feed_dict=feed_dict)
355 | descEncodedMatrix.append(desc_encoded[0])
356 |
357 | print("Initialization Experiment")
358 | self.runEval()
359 |
360 | print("Assigning Cold Embeddings from Wiki Desc Encoder ...")
361 | self.fm.sess.run(self.assignColdEmbs,
362 | feed_dict={self.coldEnEmbsToAssign:descEncodedMatrix})
363 |
364 | print("After assigning based on Wiki Encoder")
365 | self.runEval()
366 |
367 | ##############################################################################
368 |
--------------------------------------------------------------------------------
/models/figer_model/context_encoder.py:
--------------------------------------------------------------------------------
1 | import time
2 | import tensorflow as tf
3 | import numpy as np
4 |
5 | from models.base import Model
6 |
7 | class ContextEncoderModel(Model):
8 | """Run Forward and Backward LSTM and concatenate last outputs to get
9 | context representation"""
10 |
11 | def __init__(self, num_layers, batch_size, lstm_size,
12 | left_embed_batch, left_lengths, right_embed_batch, right_lengths,
13 | context_encoded_dim, scope_name, device, dropout_keep_prob=1.0):
14 |
15 | self.num_layers = num_layers # Num of layers in the encoder and decoder network
16 | self.num_lstm_layers = 1
17 |
18 | # Left / Right Context Dim.
19 | # Context Representation Dim : 2*lstm_size
20 | self.lstm_size = lstm_size
21 | self.dropout_keep_prob = dropout_keep_prob
22 | self.batch_size = batch_size
23 | self.context_encoded_dim = context_encoded_dim
24 | self.left_context_embeddings = left_embed_batch
25 | self.right_context_embeddings = right_embed_batch
26 |
27 | with tf.variable_scope(scope_name) as sc, tf.device(device) as d:
28 | with tf.variable_scope("left_encoder") as s:
29 | l_encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(
30 | self.lstm_size, state_is_tuple=True)
31 |
32 | l_dropout_cell = tf.nn.rnn_cell.DropoutWrapper(
33 | cell=l_encoder_cell,
34 | input_keep_prob=self.dropout_keep_prob,
35 | output_keep_prob=self.dropout_keep_prob)
36 |
37 | self.left_encoder = tf.nn.rnn_cell.MultiRNNCell(
38 | [l_dropout_cell] * self.num_lstm_layers, state_is_tuple=True)
39 |
40 | self.left_outputs, self.left_states = tf.nn.dynamic_rnn(
41 | cell=self.left_encoder, inputs=self.left_context_embeddings,
42 | sequence_length=left_lengths, dtype=tf.float32)
43 |
44 | with tf.variable_scope("right_encoder") as s:
45 | r_encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(
46 | self.lstm_size, state_is_tuple=True)
47 |
48 | r_dropout_cell = tf.nn.rnn_cell.DropoutWrapper(
49 | cell=r_encoder_cell,
50 | input_keep_prob=self.dropout_keep_prob,
51 | output_keep_prob=self.dropout_keep_prob)
52 |
53 | self.right_encoder = tf.nn.rnn_cell.MultiRNNCell(
54 | [r_dropout_cell] * self.num_lstm_layers, state_is_tuple=True)
55 |
56 | self.right_outputs, self.right_states = tf.nn.dynamic_rnn(
57 | cell=self.right_encoder,
58 | inputs=self.right_context_embeddings,
59 | sequence_length=right_lengths, dtype=tf.float32)
60 |
61 | # Left Context Encoded
62 | # [B, LSTM_DIM]
63 | self.left_last_output = self.get_last_output(
64 | outputs=self.left_outputs, lengths=left_lengths,
65 | name="left_context_encoded")
66 |
67 | # Right Context Encoded
68 | # [B, LSTM_DIM]
69 | self.right_last_output = self.get_last_output(
70 | outputs=self.right_outputs, lengths=right_lengths,
71 | name="right_context_encoded")
72 |
73 | # Context Encoded Vector
74 | self.context_lstm_encoded = tf.concat(
75 | 1, [self.left_last_output, self.right_last_output],
76 | name='context_lstm_encoded')
77 |
78 | # Linear Transformation to get context_encoded_dim
79 | # Layer 1
80 | self.trans_weights = tf.get_variable(
81 | name="context_trans_weights",
82 | shape=[2*self.lstm_size, self.context_encoded_dim],
83 | initializer=tf.random_normal_initializer(
84 | mean=0.0,
85 | stddev=1.0/(100.0)))
86 |
87 | # [B, context_encoded_dim]
88 | context_encoded = tf.matmul(self.context_lstm_encoded,
89 | self.trans_weights)
90 | context_encoded = tf.nn.relu(context_encoded)
91 |
92 | self.hidden_layers = []
93 | for i in range(1, self.num_layers):
94 | weight_matrix = tf.get_variable(
95 | name="context_hlayer_"+str(i),
96 | shape=[self.context_encoded_dim, self.context_encoded_dim],
97 | initializer=tf.random_normal_initializer(
98 | mean=0.0,
99 | stddev=1.0/(100.0)))
100 | self.hidden_layers.append(weight_matrix)
101 |
102 | for i in range(1, self.num_layers):
103 | context_encoded = tf.nn.dropout(
104 | context_encoded, keep_prob=self.dropout_keep_prob)
105 | context_encoded = tf.matmul(context_encoded,
106 | self.hidden_layers[i-1])
107 | context_encoded = tf.nn.relu(context_encoded)
108 |
109 | self.context_encoded = tf.nn.dropout(
110 | context_encoded, keep_prob=self.dropout_keep_prob)
111 |
112 | def get_last_output(self, outputs, lengths, name):
113 | reverse_output = tf.reverse_sequence(input=outputs,
114 | seq_lengths=tf.to_int64(lengths),
115 | seq_dim=1,
116 | batch_dim=0)
117 | en_last_output = tf.slice(input_=reverse_output,
118 | begin=[0,0,0],
119 | size=[self.batch_size, 1, -1])
120 | # [batch_size, h_dim]
121 | encoder_last_output = tf.reshape(en_last_output,
122 | shape=[self.batch_size, -1],
123 | name=name)
124 |
125 | return encoder_last_output
126 |
--------------------------------------------------------------------------------
/models/figer_model/entity_posterior.py:
--------------------------------------------------------------------------------
1 | import time
2 | import tensorflow as tf
3 | import numpy as np
4 |
5 | from models.base import Model
6 |
7 | class EntityPosterior(Model):
8 | """Entity Embeddings and Posterior Calculation"""
9 |
10 | def __init__(self, batch_size, num_knwn_entities, context_encoded_dim,
11 | context_encoded, entity_ids, scope_name,
12 | device_embeds, device_gpu):
13 |
14 | ''' Defines the entity posterior estimation graph.
15 | Makes entity embeddings variable, gets encoded context and candidate
16 | entities and gets entity scores using dot-prod.
17 |
18 | Input :
19 | context_encoded_dim: D - dims in which context is encoded
20 | context_encoded: [B, D]
21 | entity_ids: [B, N]. If supervised, first is the correct entity, and
22 | rest N-1 the candidates. For unsupervised, all N are candidates.
23 | num_knwn_entities: Number of entities with supervised data.
24 | Defines:
25 | entity_embeddings: [num_knwn_entities, D]
26 | Output:
27 | entity_scores: [B,N] matrix of context scores
28 | '''
29 |
30 | self.batch_size = batch_size
31 | self.num_knwn_entities = num_knwn_entities
32 |
33 | with tf.variable_scope(scope_name) as s:
34 | with tf.device(device_embeds) as d:
35 | self.knwn_entity_embeddings = tf.get_variable(
36 | name="known_entity_embeddings",
37 | shape=[self.num_knwn_entities, context_encoded_dim],
38 | initializer=tf.random_normal_initializer(mean=0.0,
39 | stddev=1.0/(100.0)))
40 | with tf.device(device_gpu) as g:
41 | # [B, N, D]
42 | self.sampled_entity_embeddings = tf.nn.embedding_lookup(
43 | self.knwn_entity_embeddings, entity_ids)
44 |
45 | # # Negative Samples for Description CNN - [B, D]
46 | trueentity_embeddings = tf.slice(
47 | self.sampled_entity_embeddings, [0,0,0],
48 | [self.batch_size, 1, context_encoded_dim])
49 | self.trueentity_embeddings = tf.reshape(
50 | trueentity_embeddings, [self.batch_size,
51 | context_encoded_dim])
52 |
53 | # Negative Samples for Description CNN
54 | negentity_embeddings = tf.slice(
55 | self.sampled_entity_embeddings, [0,1,0],
56 | [self.batch_size, 1, context_encoded_dim])
57 | self.negentity_embeddings = tf.reshape(
58 | negentity_embeddings, [self.batch_size,
59 | context_encoded_dim])
60 |
61 | # [B, 1, D]
62 | context_encoded_expanded = tf.expand_dims(
63 | input=context_encoded, dim=1)
64 |
65 | # [B, N]
66 | self.entity_scores = tf.reduce_sum(tf.mul(
67 | self.sampled_entity_embeddings, context_encoded_expanded), 2)
68 |
69 | # SOFTMAX
70 | # [B, N]
71 | self.entity_posteriors = tf.nn.softmax(
72 | self.entity_scores, name="entity_post_softmax")
73 |
74 | def loss_graph(self, true_entity_ids, scope_name, device_gpu):
75 | ''' true_entity_ids : [B] is the true ids in the sampled [B,N] matrix
76 | In entity_ids, [?, 0] is the true entity therefore this should be
77 | a vector of zeros
78 | '''
79 | with tf.variable_scope(scope_name) as s, tf.device(device_gpu) as d:
80 | # CROSS ENTROPY LOSS
81 | self.crossentropy_losses = \
82 | tf.nn.sparse_softmax_cross_entropy_with_logits(
83 | logits=self.entity_scores,
84 | labels=true_entity_ids,
85 | name="entity_posterior_loss")
86 |
87 | self.posterior_loss = tf.reduce_sum(
88 | self.crossentropy_losses) / tf.to_float(self.batch_size)
89 |
--------------------------------------------------------------------------------
/models/figer_model/joint_context.py:
--------------------------------------------------------------------------------
1 | import time
2 | import tensorflow as tf
3 | import numpy as np
4 |
5 | from models.base import Model
6 |
7 | class JointContextModel(Model):
8 | """Entity Embeddings and Posterior Calculation"""
9 |
10 | def __init__(self, num_layers, context_encoded_dim, text_encoded,
11 | coherence_encoded, scope_name, device, dropout_keep_prob):
12 |
13 | ''' Get context text and coherence encoded and combine into one repr.
14 | Input:
15 | text_encoded: Encoded vector for bi-LSTM. [context_encoded_dim]
16 | coherence_encoded: Encoded vector from sparse coherence FF [context_encoded_dim]
17 |
18 | Output:
19 | joint_encoded_vector: [context_encoded_dim]
20 | '''
21 | self.num_layers = num_layers
22 | self.dropout_keep_prob = dropout_keep_prob
23 | with tf.variable_scope(scope_name) as s:
24 | with tf.device(device) as d:
25 | self.joint_weights = tf.get_variable(
26 | name="joint_context_layer",
27 | shape=[2*context_encoded_dim, context_encoded_dim],
28 | initializer=tf.random_normal_initializer(mean=0.0,
29 | stddev=1.0/(100.0)))
30 |
31 | self.text_coh_concat = tf.concat(
32 | 1, [text_encoded, coherence_encoded], name='text_coh_concat')
33 |
34 | context_encoded = tf.matmul(self.text_coh_concat, self.joint_weights)
35 | context_encoded = tf.nn.relu(context_encoded)
36 |
37 | self.hidden_layers = []
38 | for i in range(1, self.num_layers):
39 | weight_matrix = tf.get_variable(
40 | name="joint_context_hlayer_"+str(i),
41 | shape=[context_encoded_dim, context_encoded_dim],
42 | initializer=tf.random_normal_initializer(
43 | mean=0.0,
44 | stddev=1.0/(100.0)))
45 | self.hidden_layers.append(weight_matrix)
46 |
47 | for i in range(1, self.num_layers):
48 | context_encoded = tf.nn.dropout(context_encoded, keep_prob=self.dropout_keep_prob)
49 | context_encoded = tf.matmul(context_encoded, self.hidden_layers[i-1])
50 | context_encoded = tf.nn.relu(context_encoded)
51 |
52 | self.context_encoded = tf.nn.dropout(context_encoded, keep_prob=self.dropout_keep_prob)
53 |
--------------------------------------------------------------------------------
/models/figer_model/labeling_model.py:
--------------------------------------------------------------------------------
1 | import time
2 | import tensorflow as tf
3 | import numpy as np
4 |
5 | from models.base import Model
6 |
7 | class LabelingModel(Model):
8 | """Unsupervised Clustering using Discrete-State VAE"""
9 |
10 | def __init__(self, batch_size, num_labels, context_encoded_dim,
11 | true_entity_embeddings,
12 | word_embed_dim, context_encoded, mention_embed, scope_name, device):
13 |
14 | self.batch_size = batch_size
15 | self.num_labels = num_labels
16 | self.word_embed_dim = word_embed_dim
17 |
18 | with tf.variable_scope(scope_name) as s, tf.device(device) as d:
19 | if mention_embed == None:
20 | self.label_weights = tf.get_variable(
21 | name="label_weights",
22 | shape=[context_encoded_dim, num_labels],
23 | initializer=tf.random_normal_initializer(mean=0.0,
24 | stddev=1.0/(100.0)))
25 | else:
26 | context_encoded = tf.concat(
27 | 1, [context_encoded, mention_embed], name='con_ment_repr')
28 | self.label_weights = tf.get_variable(
29 | name="label_weights",
30 | shape=[context_encoded_dim+word_embed_dim, num_labels],
31 | initializer=tf.random_normal_initializer(mean=0.0,
32 | stddev=1.0/(100.0)))
33 |
34 | # [B, L]
35 | self.label_scores = tf.matmul(context_encoded, self.label_weights)
36 | self.label_probs = tf.sigmoid(self.label_scores)
37 |
38 | ### PREDICT TYPES FROM ENTITIES
39 | #true_entity_embeddings = tf.nn.dropout(true_entity_embeddings, keep_prob=0.5)
40 | self.entity_label_scores = tf.matmul(true_entity_embeddings, self.label_weights)
41 | self.entity_label_probs = tf.sigmoid(self.label_scores)
42 |
43 |
44 | def loss_graph(self, true_label_ids, scope_name, device_gpu):
45 | with tf.variable_scope(scope_name) as s, tf.device(device_gpu) as d:
46 | # [B, L]
47 | self.cross_entropy_losses = tf.nn.sigmoid_cross_entropy_with_logits(
48 | logits=self.label_scores,
49 | targets=true_label_ids,
50 | name="labeling_loss")
51 |
52 | self.labeling_loss = tf.reduce_sum(
53 | self.cross_entropy_losses) / tf.to_float(self.batch_size)
54 |
55 |
56 | self.enlabel_cross_entropy_losses = tf.nn.sigmoid_cross_entropy_with_logits(
57 | logits=self.entity_label_scores,
58 | targets=true_label_ids,
59 | name="entity_labeling_loss")
60 |
61 | self.entity_labeling_loss = tf.reduce_sum(
62 | self.enlabel_cross_entropy_losses) / tf.to_float(self.batch_size)
63 |
--------------------------------------------------------------------------------
/models/figer_model/loss_optim.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | from models.base import Model
6 |
7 |
8 | class LossOptim(object):
9 | def __init__(self, figermodel):
10 | ''' Houses utility functions to facilitate training/pre-training'''
11 |
12 | # Object of the WikiELModel Class
13 | self.figermodel = figermodel
14 |
15 | def make_loss_graph(self):
16 | self.figermodel.labeling_model.loss_graph(
17 | true_label_ids=self.figermodel.labels_batch,
18 | scope_name=self.figermodel.labeling_loss_scope,
19 | device_gpu=self.figermodel.device_placements['gpu'])
20 |
21 | self.figermodel.posterior_model.loss_graph(
22 | true_entity_ids=self.figermodel.true_entity_ids,
23 | scope_name=self.figermodel.posterior_loss_scope,
24 | device_gpu=self.figermodel.device_placements['gpu'])
25 |
26 | if self.figermodel.useCNN:
27 | self.figermodel.wikidescmodel.loss_graph(
28 | true_entity_ids=self.figermodel.true_entity_ids,
29 | scope_name=self.figermodel.wikidesc_loss_scope,
30 | device_gpu=self.figermodel.device_placements['gpu'])
31 |
32 | def optimizer(self, optimizer_name, name):
33 | if optimizer_name == 'adam':
34 | optimizer = tf.train.AdamOptimizer(
35 | learning_rate=self.figermodel.learning_rate,
36 | name='Adam_'+name)
37 | elif optimizer_name == 'adagrad':
38 | optimizer = tf.train.AdagradOptimizer(
39 | learning_rate=self.figermodel.learning_rate,
40 | name='Adagrad_'+name)
41 | elif optimizer_name == 'adadelta':
42 | optimizer = tf.train.AdadeltaOptimizer(
43 | learning_rate=self.figermodel.learning_rate,
44 | name='Adadelta_'+name)
45 | elif optimizer_name == 'sgd':
46 | optimizer = tf.train.GradientDescentOptimizer(
47 | learning_rate=self.figermodel.learning_rate,
48 | name='SGD_'+name)
49 | elif optimizer_name == 'momentum':
50 | optimizer = tf.train.MomentumOptimizer(
51 | learning_rate=self.figermodel.learning_rate,
52 | momentum=0.9,
53 | name='Momentum_'+name)
54 | else:
55 | print("OPTIMIZER WRONG. HOW DID YOU GET HERE!!")
56 | sys.exit(0)
57 | return optimizer
58 |
59 | def weight_regularization(self, trainable_vars):
60 | vars_to_regularize = []
61 | regularization_loss = 0
62 | for var in trainable_vars:
63 | if "_weights" in var.name:
64 | regularization_loss += tf.nn.l2_loss(var)
65 | vars_to_regularize.append(var)
66 |
67 | print("L2 - Regularization for Variables:")
68 | self.figermodel.print_variables_in_collection(vars_to_regularize)
69 | return regularization_loss
70 |
71 | def label_optimization(self, trainable_vars, optim_scope):
72 | # Typing Loss
73 | if self.figermodel.typing:
74 | self.labeling_loss = self.figermodel.labeling_model.labeling_loss
75 | else:
76 | self.labeling_loss = tf.constant(0.0)
77 |
78 | if self.figermodel.entyping:
79 | self.entity_labeling_loss = \
80 | self.figermodel.labeling_model.entity_labeling_loss
81 | else:
82 | self.entity_labeling_loss = tf.constant(0.0)
83 |
84 | # Posterior Loss
85 | if self.figermodel.el:
86 | self.posterior_loss = \
87 | self.figermodel.posterior_model.posterior_loss
88 | else:
89 | self.posterior_loss = tf.constant(0.0)
90 |
91 | if self.figermodel.useCNN:
92 | self.wikidesc_loss = self.figermodel.wikidescmodel.wikiDescLoss
93 | else:
94 | self.wikidesc_loss = tf.constant(0.0)
95 |
96 | # _ = tf.scalar_summary("loss_typing", self.labeling_loss)
97 | # _ = tf.scalar_summary("loss_posterior", self.posterior_loss)
98 | # _ = tf.scalar_summary("loss_wikidesc", self.wikidesc_loss)
99 |
100 | self.total_loss = (self.labeling_loss + self.posterior_loss +
101 | self.wikidesc_loss + self.entity_labeling_loss)
102 |
103 | # Weight Regularization
104 | # self.regularization_loss = self.weight_regularization(
105 | # trainable_vars)
106 | # self.total_loss += (self.figermodel.reg_constant *
107 | # self.regularization_loss)
108 |
109 | # Scalar Summaries
110 | # _ = tf.scalar_summary("loss_regularized", self.total_loss)
111 | # _ = tf.scalar_summary("loss_labeling", self.labeling_loss)
112 |
113 | with tf.variable_scope(optim_scope) as s, \
114 | tf.device(self.figermodel.device_placements['gpu']) as d:
115 | self.optimizer = self.optimizer(
116 | optimizer_name=self.figermodel.optimizer, name="opt")
117 | self.gvs = self.optimizer.compute_gradients(
118 | loss=self.total_loss, var_list=trainable_vars)
119 | # self.clipped_gvs = self.clip_gradients(self.gvs)
120 | self.optim_op = self.optimizer.apply_gradients(self.gvs)
121 |
122 | def clip_gradients(self, gvs):
123 | clipped_gvs = []
124 | for (g,v) in gvs:
125 | if self.figermodel.embeddings_scope in v.name:
126 | clipped_gvalues = tf.clip_by_norm(g.values, 30)
127 | clipped_index_slices = tf.IndexedSlices(
128 | values=clipped_gvalues,
129 | indices=g.indices)
130 | clipped_gvs.append((clipped_index_slices, v))
131 | else:
132 | clipped_gvs.append((tf.clip_by_norm(g, 1), v))
133 | return clipped_gvs
134 |
--------------------------------------------------------------------------------
/models/figer_model/wiki_desc.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | from models.base import Model
6 |
7 | class WikiDescModel(Model):
8 | '''
9 | Input is sparse tensor of mention strings in mention's document.
10 | Pass through feed forward and get a coherence representation
11 | (keep same as context_encoded_dim)
12 | '''
13 |
14 | def __init__(self, desc_batch, trueentity_embs, negentity_embs, allentity_embs,
15 | batch_size, doclength, wordembeddim, filtersize, desc_encoded_dim,
16 | scope_name, device, dropout_keep_prob=1.0):
17 |
18 | # [B, doclength, wordembeddim]
19 | self.desc_batch = desc_batch
20 | self.batch_size = batch_size
21 | self.doclength = doclength
22 | self.wordembeddim = wordembeddim
23 | self.filtersize = filtersize
24 | self.desc_encoded_dim = desc_encoded_dim # Output dim of desc
25 | self.dropout_keep_prob = dropout_keep_prob
26 |
27 | # [B, K] - target of the CNN network and Negative sampled Entities
28 | self.trueentity_embs = trueentity_embs
29 | self.negentity_embs = negentity_embs
30 | self.allentity_embs = allentity_embs
31 |
32 | # [B, DL, WD, 1] - 1 to specify one channel
33 | self.desc_batch_expanded = tf.expand_dims(self.desc_batch, -1)
34 | # [F, WD, 1, K]
35 | self.filter_shape = [self.filtersize, self.wordembeddim, 1, self.desc_encoded_dim]
36 |
37 |
38 | with tf.variable_scope(scope_name) as scope, tf.device(device) as device:
39 | W = tf.Variable(tf.truncated_normal(self.filter_shape, stddev=0.1), name="W_conv")
40 | conv = tf.nn.conv2d(self.desc_batch_expanded,
41 | W,
42 | strides=[1, 1, 1, 1],
43 | padding="VALID",
44 | name="desc_conv")
45 |
46 | conv = tf.nn.relu(conv, name="conv_relu")
47 | conv = tf.nn.dropout(conv, keep_prob=self.dropout_keep_prob)
48 |
49 | # [B, (doclength-F+1), 1, K]
50 | # [B,K] - Global Average Pooling
51 | self.desc_encoded = tf.reduce_mean(conv, reduction_indices=[1,2])
52 |
53 | # [B, 1, K]
54 | self.desc_encoded_expand = tf.expand_dims(
55 | input=self.desc_encoded, dim=1)
56 |
57 | # [B, N]
58 | self.desc_scores = tf.reduce_sum(tf.mul(
59 | self.allentity_embs, self.desc_encoded_expand), 2)
60 |
61 | self.desc_posteriors = tf.nn.softmax(self.desc_scores,
62 | name="entity_post_softmax")
63 |
64 | ########### end def __init__ ##########################################
65 |
66 | def loss_graph(self, true_entity_ids, scope_name, device_gpu):
67 |
68 | with tf.variable_scope(scope_name) as s, tf.device(device_gpu) as d:
69 | self.crossentropy_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
70 | logits=self.desc_scores,
71 | labels=true_entity_ids,
72 | name="desc_posterior_loss")
73 |
74 | self.wikiDescLoss = tf.reduce_sum(
75 | self.crossentropy_losses) / tf.to_float(self.batch_size)
76 |
77 |
78 | '''
79 | # Maximize cosine distance between true_entity_embeddings and encoded_description
80 | # max CosineDis(self.trueentity_embs, self.desc_encoded)
81 |
82 | # [B, 1] - NOT [B] Due to keep_dims
83 | trueen_emb_norm = tf.sqrt(
84 | tf.reduce_sum(tf.square(self.trueentity_embs), 1, keep_dims=True))
85 | # [B, K]
86 | true_emb_normalized = self.trueentity_embs / trueen_emb_norm
87 |
88 | # [B, 1] - NOT [B] Due to keep_dims
89 | negen_emb_norm = tf.sqrt(
90 | tf.reduce_sum(tf.square(self.negentity_embs), 1, keep_dims=True))
91 | # [B, K]
92 | neg_emb_normalized = self.negentity_embs / negen_emb_norm
93 |
94 |
95 | # [B, 1] - NOT [B] Due to keep_dims
96 | desc_enc_norm = tf.sqrt(
97 | tf.reduce_sum(tf.square(self.desc_encoded), 1, keep_dims=True))
98 | # [B, K]
99 | desc_end_normalized = self.desc_encoded / desc_enc_norm
100 |
101 | # [B]
102 | self.true_cosDist = tf.reduce_mean(tf.mul(true_emb_normalized, desc_end_normalized))
103 |
104 | self.neg_cosDist = tf.reduce_mean(tf.mul(neg_emb_normalized, desc_end_normalized))
105 |
106 | # Loss = -ve dot_prod because distance has to be decreased
107 | #self.wikiDescLoss = self.neg_cosDist - self.true_cosDist
108 | '''
109 |
--------------------------------------------------------------------------------
/neuralel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import pprint
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | from readers.inference_reader import InferenceReader
9 | from readers.test_reader import TestDataReader
10 | from models.figer_model.el_model import ELModel
11 | from readers.config import Config
12 | from readers.vocabloader import VocabLoader
13 |
14 | np.set_printoptions(threshold=np.inf)
15 | np.set_printoptions(precision=7)
16 |
17 | pp = pprint.PrettyPrinter()
18 |
19 | flags = tf.app.flags
20 | flags.DEFINE_integer("max_steps", 32000, "Maximum of iteration [450000]")
21 | flags.DEFINE_integer("pretraining_steps", 32000, "Number of steps to run pretraining")
22 | flags.DEFINE_float("learning_rate", 0.005, "Learning rate of adam optimizer [0.001]")
23 | flags.DEFINE_string("model_path", "", "Path to trained model")
24 | flags.DEFINE_string("dataset", "el-figer", "The name of dataset [ptb]")
25 | flags.DEFINE_string("checkpoint_dir", "/tmp",
26 | "Directory name to save the checkpoints [checkpoints]")
27 | flags.DEFINE_integer("batch_size", 1, "Batch Size for training and testing")
28 | flags.DEFINE_integer("word_embed_dim", 300, "Word Embedding Size")
29 | flags.DEFINE_integer("context_encoded_dim", 100, "Context Encoded Dim")
30 | flags.DEFINE_integer("context_encoder_num_layers", 1, "Num of Layers in context encoder network")
31 | flags.DEFINE_integer("context_encoder_lstmsize", 100, "Size of context encoder hidden layer")
32 | flags.DEFINE_integer("coherence_numlayers", 1, "Number of layers in the Coherence FF")
33 | flags.DEFINE_integer("jointff_numlayers", 1, "Number of layers in the Coherence FF")
34 | flags.DEFINE_integer("num_cand_entities", 30, "Num CrossWikis entity candidates")
35 | flags.DEFINE_float("reg_constant", 0.00, "Regularization constant for NN weight regularization")
36 | flags.DEFINE_float("dropout_keep_prob", 0.6, "Dropout Keep Probability")
37 | flags.DEFINE_float("wordDropoutKeep", 0.6, "Word Dropout Keep Probability")
38 | flags.DEFINE_float("cohDropoutKeep", 0.4, "Coherence Dropout Keep Probability")
39 | flags.DEFINE_boolean("decoder_bool", True, "Decoder bool")
40 | flags.DEFINE_string("mode", 'inference', "Mode to run")
41 | flags.DEFINE_boolean("strict_context", False, "Strict Context exludes mention surface")
42 | flags.DEFINE_boolean("pretrain_wordembed", True, "Use Word2Vec Embeddings")
43 | flags.DEFINE_boolean("coherence", True, "Use Coherence")
44 | flags.DEFINE_boolean("typing", True, "Perform joint typing")
45 | flags.DEFINE_boolean("el", True, "Perform joint typing")
46 | flags.DEFINE_boolean("textcontext", True, "Use text context from LSTM")
47 | flags.DEFINE_boolean("useCNN", False, "Use wiki descp. CNN")
48 | flags.DEFINE_boolean("glove", True, "Use Glove Embeddings")
49 | flags.DEFINE_boolean("entyping", False, "Use Entity Type Prediction")
50 | flags.DEFINE_integer("WDLength", 100, "Length of wiki description")
51 | flags.DEFINE_integer("Fsize", 5, "For CNN filter size")
52 |
53 | flags.DEFINE_string("optimizer", 'adam', "Optimizer to use. adagrad, adadelta or adam")
54 |
55 | flags.DEFINE_string("config", 'configs/config.ini',
56 | "VocabConfig Filepath")
57 | flags.DEFINE_string("test_out_fp", "", "Write Test Prediction Data")
58 |
59 | FLAGS = flags.FLAGS
60 |
61 |
62 | def FLAGS_check(FLAGS):
63 | if not (FLAGS.textcontext and FLAGS.coherence):
64 | print("*** Local and Document context required ***")
65 | sys.exit(0)
66 | assert os.path.exists(FLAGS.model_path), "Model path doesn't exist."
67 |
68 |
69 | def main(_):
70 | pp.pprint(flags.FLAGS.__flags)
71 |
72 | FLAGS_check(FLAGS)
73 |
74 | config = Config(FLAGS.config, verbose=False)
75 | vocabloader = VocabLoader(config)
76 |
77 | if FLAGS.mode == 'inference':
78 | FLAGS.dropout_keep_prob = 1.0
79 | FLAGS.wordDropoutKeep = 1.0
80 | FLAGS.cohDropoutKeep = 1.0
81 |
82 | reader = InferenceReader(config=config,
83 | vocabloader=vocabloader,
84 | test_mens_file=config.test_file,
85 | num_cands=FLAGS.num_cand_entities,
86 | batch_size=FLAGS.batch_size,
87 | strict_context=FLAGS.strict_context,
88 | pretrain_wordembed=FLAGS.pretrain_wordembed,
89 | coherence=FLAGS.coherence)
90 | docta = reader.ccgdoc
91 | model_mode = 'inference'
92 |
93 | elif FLAGS.mode == 'test':
94 | FLAGS.dropout_keep_prob = 1.0
95 | FLAGS.wordDropoutKeep = 1.0
96 | FLAGS.cohDropoutKeep = 1.0
97 |
98 | reader = TestDataReader(config=config,
99 | vocabloader=vocabloader,
100 | test_mens_file=config.test_file,
101 | num_cands=30,
102 | batch_size=FLAGS.batch_size,
103 | strict_context=FLAGS.strict_context,
104 | pretrain_wordembed=FLAGS.pretrain_wordembed,
105 | coherence=FLAGS.coherence)
106 | model_mode = 'test'
107 |
108 | else:
109 | print("MODE in FLAGS is incorrect : {}".format(FLAGS.mode))
110 | sys.exit()
111 |
112 | config_proto = tf.ConfigProto()
113 | config_proto.allow_soft_placement = True
114 | config_proto.gpu_options.allow_growth=True
115 | sess = tf.Session(config=config_proto)
116 |
117 |
118 | with sess.as_default():
119 | model = ELModel(
120 | sess=sess, reader=reader, dataset=FLAGS.dataset,
121 | max_steps=FLAGS.max_steps,
122 | pretrain_max_steps=FLAGS.pretraining_steps,
123 | word_embed_dim=FLAGS.word_embed_dim,
124 | context_encoded_dim=FLAGS.context_encoded_dim,
125 | context_encoder_num_layers=FLAGS.context_encoder_num_layers,
126 | context_encoder_lstmsize=FLAGS.context_encoder_lstmsize,
127 | coherence_numlayers=FLAGS.coherence_numlayers,
128 | jointff_numlayers=FLAGS.jointff_numlayers,
129 | learning_rate=FLAGS.learning_rate,
130 | dropout_keep_prob=FLAGS.dropout_keep_prob,
131 | reg_constant=FLAGS.reg_constant,
132 | checkpoint_dir=FLAGS.checkpoint_dir,
133 | optimizer=FLAGS.optimizer,
134 | mode=model_mode,
135 | strict=FLAGS.strict_context,
136 | pretrain_word_embed=FLAGS.pretrain_wordembed,
137 | typing=FLAGS.typing,
138 | el=FLAGS.el,
139 | coherence=FLAGS.coherence,
140 | textcontext=FLAGS.textcontext,
141 | useCNN=FLAGS.useCNN,
142 | WDLength=FLAGS.WDLength,
143 | Fsize=FLAGS.Fsize,
144 | entyping=FLAGS.entyping)
145 |
146 | if FLAGS.mode == 'inference':
147 | print("Doing inference")
148 | (predTypScNPmat_list,
149 | widIdxs_list,
150 | priorProbs_list,
151 | textProbs_list,
152 | jointProbs_list,
153 | evWTs_list,
154 | pred_TypeSetsList) = model.inference(ckptpath=FLAGS.model_path)
155 |
156 | numMentionsInference = len(widIdxs_list)
157 | numMentionsReader = 0
158 | for sent_idx in reader.sentidx2ners:
159 | numMentionsReader += len(reader.sentidx2ners[sent_idx])
160 | assert numMentionsInference == numMentionsReader
161 |
162 | mentionnum = 0
163 | entityTitleList = []
164 | for sent_idx in reader.sentidx2ners:
165 | nerDicts = reader.sentidx2ners[sent_idx]
166 | sentence = ' '.join(reader.sentences_tokenized[sent_idx])
167 | for s, ner in nerDicts:
168 | [evWTs, evWIDS, evProbs] = evWTs_list[mentionnum]
169 | predTypes = pred_TypeSetsList[mentionnum]
170 | print(reader.bracketMentionInSentence(sentence, ner))
171 | print("Prior: {} {}, Context: {} {}, Joint: {} {}".format(
172 | evWTs[0], evProbs[0], evWTs[1], evProbs[1],
173 | evWTs[2], evProbs[2]))
174 |
175 | entityTitleList.append(evWTs[2])
176 | print("Predicted Entity Types : {}".format(predTypes))
177 | print("\n")
178 | mentionnum += 1
179 |
180 | elview = copy.deepcopy(docta.view_dictionary['NER_CONLL'])
181 | elview.view_name = 'ENG_NEURAL_EL'
182 | for i, cons in enumerate(elview.cons_list):
183 | cons['label'] = entityTitleList[i]
184 |
185 | docta.view_dictionary['ENG_NEURAL_EL'] = elview
186 |
187 | print("elview.cons_list")
188 | print(elview.cons_list)
189 | print("\n")
190 |
191 | for v in docta.as_json['views']:
192 | print(v)
193 | print("\n")
194 |
195 | elif FLAGS.mode == 'test':
196 | print("Testing on Data ")
197 | (widIdxs_list, condProbs_list, contextProbs_list,
198 | condContextJointProbs_list, evWTs,
199 | sortedContextWTs) = model.dataset_test(ckptpath=FLAGS.model_path)
200 |
201 | print(len(widIdxs_list))
202 | print(len(condProbs_list))
203 | print(len(contextProbs_list))
204 | print(len(condContextJointProbs_list))
205 | print(len(reader.mentions))
206 |
207 |
208 | print("Writing Test Predictions: {}".format(FLAGS.test_out_fp))
209 | with open(FLAGS.test_out_fp, 'w') as f:
210 | for (wididxs, pps, mps, jps) in zip(widIdxs_list,
211 | condProbs_list,
212 | contextProbs_list,
213 | condContextJointProbs_list):
214 |
215 | mentionPred = ""
216 |
217 | for (wididx, prp, mp, jp) in zip(wididxs, pps, mps, jps):
218 | wit = reader.widIdx2WikiTitle(wididx)
219 | mentionPred += wit + " " + str(prp) + " " + \
220 | str(mp) + " " + str(jp)
221 | mentionPred += "\t"
222 |
223 | mentionPred = mentionPred.strip() + "\n"
224 |
225 | f.write(mentionPred)
226 |
227 | print("Done writing. Can Exit.")
228 |
229 | else:
230 | print("WRONG MODE!")
231 | sys.exit(0)
232 |
233 |
234 |
235 |
236 |
237 | sys.exit()
238 |
239 | if __name__ == '__main__':
240 | tf.app.run()
241 |
--------------------------------------------------------------------------------
/neuralel_jsonl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import json
5 | import pprint
6 | import numpy as np
7 | import tensorflow as tf
8 | from os import listdir
9 | from os.path import isfile, join
10 |
11 | from ccg_nlpy.core.view import View
12 | from ccg_nlpy.core.text_annotation import TextAnnotation
13 | from ccg_nlpy.local_pipeline import LocalPipeline
14 | from readers.textanno_test_reader import TextAnnoTestReader
15 | from models.figer_model.el_model import ELModel
16 | from readers.config import Config
17 | from readers.vocabloader import VocabLoader
18 |
19 | np.set_printoptions(threshold=np.inf)
20 | np.set_printoptions(precision=7)
21 |
22 | pp = pprint.PrettyPrinter()
23 |
24 | flags = tf.app.flags
25 | flags.DEFINE_integer("max_steps", 32000, "Maximum of iteration [450000]")
26 | flags.DEFINE_integer("pretraining_steps", 32000, "Number of steps to run pretraining")
27 | flags.DEFINE_float("learning_rate", 0.005, "Learning rate of adam optimizer [0.001]")
28 | flags.DEFINE_string("model_path", "", "Path to trained model")
29 | flags.DEFINE_string("dataset", "el-figer", "The name of dataset [ptb]")
30 | flags.DEFINE_string("checkpoint_dir", "/tmp",
31 | "Directory name to save the checkpoints [checkpoints]")
32 | flags.DEFINE_integer("batch_size", 1, "Batch Size for training and testing")
33 | flags.DEFINE_integer("word_embed_dim", 300, "Word Embedding Size")
34 | flags.DEFINE_integer("context_encoded_dim", 100, "Context Encoded Dim")
35 | flags.DEFINE_integer("context_encoder_num_layers", 1, "Num of Layers in context encoder network")
36 | flags.DEFINE_integer("context_encoder_lstmsize", 100, "Size of context encoder hidden layer")
37 | flags.DEFINE_integer("coherence_numlayers", 1, "Number of layers in the Coherence FF")
38 | flags.DEFINE_integer("jointff_numlayers", 1, "Number of layers in the Coherence FF")
39 | flags.DEFINE_integer("num_cand_entities", 30, "Num CrossWikis entity candidates")
40 | flags.DEFINE_float("reg_constant", 0.00, "Regularization constant for NN weight regularization")
41 | flags.DEFINE_float("dropout_keep_prob", 0.6, "Dropout Keep Probability")
42 | flags.DEFINE_float("wordDropoutKeep", 0.6, "Word Dropout Keep Probability")
43 | flags.DEFINE_float("cohDropoutKeep", 0.4, "Coherence Dropout Keep Probability")
44 | flags.DEFINE_boolean("decoder_bool", True, "Decoder bool")
45 | flags.DEFINE_string("mode", 'inference', "Mode to run")
46 | flags.DEFINE_boolean("strict_context", False, "Strict Context exludes mention surface")
47 | flags.DEFINE_boolean("pretrain_wordembed", True, "Use Word2Vec Embeddings")
48 | flags.DEFINE_boolean("coherence", True, "Use Coherence")
49 | flags.DEFINE_boolean("typing", True, "Perform joint typing")
50 | flags.DEFINE_boolean("el", True, "Perform joint typing")
51 | flags.DEFINE_boolean("textcontext", True, "Use text context from LSTM")
52 | flags.DEFINE_boolean("useCNN", False, "Use wiki descp. CNN")
53 | flags.DEFINE_boolean("glove", True, "Use Glove Embeddings")
54 | flags.DEFINE_boolean("entyping", False, "Use Entity Type Prediction")
55 | flags.DEFINE_integer("WDLength", 100, "Length of wiki description")
56 | flags.DEFINE_integer("Fsize", 5, "For CNN filter size")
57 |
58 | flags.DEFINE_string("optimizer", 'adam', "Optimizer to use. adagrad, adadelta or adam")
59 |
60 | flags.DEFINE_string("config", 'configs/config.ini',
61 | "VocabConfig Filepath")
62 | flags.DEFINE_string("test_out_fp", "", "Write Test Prediction Data")
63 |
64 | flags.DEFINE_string("input_jsonl", "", "Input containing documents in jsonl")
65 | flags.DEFINE_string("output_jsonl", "", "Output in jsonl format")
66 | flags.DEFINE_string("doc_key", "", "Key in input_jsonl containing documents")
67 | flags.DEFINE_boolean("pretokenized", False, "Is the input text pretokenized")
68 |
69 | FLAGS = flags.FLAGS
70 |
71 | localpipeline = LocalPipeline()
72 |
73 |
74 | def FLAGS_check(FLAGS):
75 | if not (FLAGS.textcontext and FLAGS.coherence):
76 | print("*** Local and Document context required ***")
77 | sys.exit(0)
78 | assert os.path.exists(FLAGS.model_path), "Model path doesn't exist."
79 |
80 | assert(FLAGS.mode == 'ta'), "Only mode == ta allowed"
81 |
82 |
83 | def main(_):
84 | pp.pprint(flags.FLAGS.__flags)
85 |
86 | FLAGS_check(FLAGS)
87 |
88 | config = Config(FLAGS.config, verbose=False)
89 | vocabloader = VocabLoader(config)
90 |
91 | FLAGS.dropout_keep_prob = 1.0
92 | FLAGS.wordDropoutKeep = 1.0
93 | FLAGS.cohDropoutKeep = 1.0
94 |
95 | input_jsonl = FLAGS.input_jsonl
96 | output_jsonl = FLAGS.output_jsonl
97 | doc_key = FLAGS.doc_key
98 |
99 | reader = TextAnnoTestReader(
100 | config=config,
101 | vocabloader=vocabloader,
102 | num_cands=30,
103 | batch_size=FLAGS.batch_size,
104 | strict_context=FLAGS.strict_context,
105 | pretrain_wordembed=FLAGS.pretrain_wordembed,
106 | coherence=FLAGS.coherence)
107 | model_mode = 'test'
108 |
109 | config_proto = tf.ConfigProto()
110 | config_proto.allow_soft_placement = True
111 | config_proto.gpu_options.allow_growth=True
112 | sess = tf.Session(config=config_proto)
113 |
114 |
115 | with sess.as_default():
116 | model = ELModel(
117 | sess=sess, reader=reader, dataset=FLAGS.dataset,
118 | max_steps=FLAGS.max_steps,
119 | pretrain_max_steps=FLAGS.pretraining_steps,
120 | word_embed_dim=FLAGS.word_embed_dim,
121 | context_encoded_dim=FLAGS.context_encoded_dim,
122 | context_encoder_num_layers=FLAGS.context_encoder_num_layers,
123 | context_encoder_lstmsize=FLAGS.context_encoder_lstmsize,
124 | coherence_numlayers=FLAGS.coherence_numlayers,
125 | jointff_numlayers=FLAGS.jointff_numlayers,
126 | learning_rate=FLAGS.learning_rate,
127 | dropout_keep_prob=FLAGS.dropout_keep_prob,
128 | reg_constant=FLAGS.reg_constant,
129 | checkpoint_dir=FLAGS.checkpoint_dir,
130 | optimizer=FLAGS.optimizer,
131 | mode=model_mode,
132 | strict=FLAGS.strict_context,
133 | pretrain_word_embed=FLAGS.pretrain_wordembed,
134 | typing=FLAGS.typing,
135 | el=FLAGS.el,
136 | coherence=FLAGS.coherence,
137 | textcontext=FLAGS.textcontext,
138 | useCNN=FLAGS.useCNN,
139 | WDLength=FLAGS.WDLength,
140 | Fsize=FLAGS.Fsize,
141 | entyping=FLAGS.entyping)
142 |
143 | model.load_ckpt_model(ckptpath=FLAGS.model_path)
144 |
145 | erroneous_files = 0
146 |
147 | outf = open(output_jsonl, 'w')
148 | inpf = open(input_jsonl, 'r')
149 |
150 | for line in inpf:
151 | jsonobj = json.loads(line)
152 | doctext = jsonobj[doc_key]
153 | ta = localpipeline.doc(doctext, pretokenized=FLAGS.pretokenized)
154 | _ = ta.get_ner_conll
155 |
156 | # Make instances for this document
157 | reader.new_ta(ta)
158 |
159 | (predTypScNPmat_list,
160 | widIdxs_list,
161 | priorProbs_list,
162 | textProbs_list,
163 | jointProbs_list,
164 | evWTs_list,
165 | pred_TypeSetsList) = model.inference_run()
166 |
167 | wiki_view = copy.deepcopy(reader.textanno.get_view("NER_CONLL"))
168 | docta = reader.textanno
169 |
170 | el_cons_list = wiki_view.cons_list
171 | numMentionsInference = len(widIdxs_list)
172 |
173 | assert len(el_cons_list) == numMentionsInference
174 |
175 | out_dict = {doc_key: doctext}
176 | el_mentions = []
177 |
178 | mentionnum = 0
179 | for ner_cons in el_cons_list:
180 | # ner_cons is a dict
181 | mentiondict = {}
182 | mentiondict['tokens'] = ner_cons['tokens']
183 | mentiondict['end'] = ner_cons['end']
184 | mentiondict['start'] = ner_cons['start']
185 |
186 | priorScoreMap = {}
187 | contextScoreMap = {}
188 | jointScoreMap = {}
189 |
190 | (wididxs, pps, mps, jps) = (widIdxs_list[mentionnum],
191 | priorProbs_list[mentionnum],
192 | textProbs_list[mentionnum],
193 | jointProbs_list[mentionnum])
194 |
195 | maxJointProb = 0.0
196 | maxJointEntity = ""
197 | for (wididx, prp, mp, jp) in zip(wididxs, pps, mps, jps):
198 | wT = reader.widIdx2WikiTitle(wididx)
199 | priorScoreMap[wT] = prp
200 | contextScoreMap[wT] = mp
201 | jointScoreMap[wT] = jp
202 |
203 | if jp > maxJointProb:
204 | maxJointProb = jp
205 | maxJointEntity = wT
206 |
207 | mentiondict["jointScoreMap"] = jointScoreMap
208 | mentiondict["contextScoreMap"] = contextScoreMap
209 | mentiondict["priorScoreMap"] = priorScoreMap
210 |
211 | # add max scoring entity as label
212 | mentiondict["label"] = maxJointEntity
213 | mentiondict["score"] = maxJointProb
214 |
215 | mentionnum += 1
216 |
217 | el_mentions.append(mentiondict)
218 |
219 | out_dict['nel'] = el_mentions
220 | outstr = json.dumps(out_dict)
221 | outf.write(outstr)
222 | outf.write("\n")
223 |
224 | outf.close()
225 | inpf.close()
226 |
227 | print("Number of erroneous files: {}".format(erroneous_files))
228 | print("Annotation completed. Program can be exited safely.")
229 | sys.exit()
230 |
231 | if __name__ == '__main__':
232 | tf.app.run()
233 |
--------------------------------------------------------------------------------
/neuralel_tadir.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import json
5 | import pprint
6 | import numpy as np
7 | import tensorflow as tf
8 | from os import listdir
9 | from os.path import isfile, join
10 |
11 | from ccg_nlpy.core.view import View
12 | from ccg_nlpy.core.text_annotation import TextAnnotation
13 |
14 | from readers.inference_reader import InferenceReader
15 | from readers.test_reader import TestDataReader
16 | from readers.textanno_test_reader import TextAnnoTestReader
17 | from models.figer_model.el_model import ELModel
18 | from readers.config import Config
19 | from readers.vocabloader import VocabLoader
20 |
21 | np.set_printoptions(threshold=np.inf)
22 | np.set_printoptions(precision=7)
23 |
24 | pp = pprint.PrettyPrinter()
25 |
26 | flags = tf.app.flags
27 | flags.DEFINE_integer("max_steps", 32000, "Maximum of iteration [450000]")
28 | flags.DEFINE_integer("pretraining_steps", 32000, "Number of steps to run pretraining")
29 | flags.DEFINE_float("learning_rate", 0.005, "Learning rate of adam optimizer [0.001]")
30 | flags.DEFINE_string("model_path", "", "Path to trained model")
31 | flags.DEFINE_string("dataset", "el-figer", "The name of dataset [ptb]")
32 | flags.DEFINE_string("checkpoint_dir", "/tmp",
33 | "Directory name to save the checkpoints [checkpoints]")
34 | flags.DEFINE_integer("batch_size", 1, "Batch Size for training and testing")
35 | flags.DEFINE_integer("word_embed_dim", 300, "Word Embedding Size")
36 | flags.DEFINE_integer("context_encoded_dim", 100, "Context Encoded Dim")
37 | flags.DEFINE_integer("context_encoder_num_layers", 1, "Num of Layers in context encoder network")
38 | flags.DEFINE_integer("context_encoder_lstmsize", 100, "Size of context encoder hidden layer")
39 | flags.DEFINE_integer("coherence_numlayers", 1, "Number of layers in the Coherence FF")
40 | flags.DEFINE_integer("jointff_numlayers", 1, "Number of layers in the Coherence FF")
41 | flags.DEFINE_integer("num_cand_entities", 30, "Num CrossWikis entity candidates")
42 | flags.DEFINE_float("reg_constant", 0.00, "Regularization constant for NN weight regularization")
43 | flags.DEFINE_float("dropout_keep_prob", 0.6, "Dropout Keep Probability")
44 | flags.DEFINE_float("wordDropoutKeep", 0.6, "Word Dropout Keep Probability")
45 | flags.DEFINE_float("cohDropoutKeep", 0.4, "Coherence Dropout Keep Probability")
46 | flags.DEFINE_boolean("decoder_bool", True, "Decoder bool")
47 | flags.DEFINE_string("mode", 'inference', "Mode to run")
48 | flags.DEFINE_boolean("strict_context", False, "Strict Context exludes mention surface")
49 | flags.DEFINE_boolean("pretrain_wordembed", True, "Use Word2Vec Embeddings")
50 | flags.DEFINE_boolean("coherence", True, "Use Coherence")
51 | flags.DEFINE_boolean("typing", True, "Perform joint typing")
52 | flags.DEFINE_boolean("el", True, "Perform joint typing")
53 | flags.DEFINE_boolean("textcontext", True, "Use text context from LSTM")
54 | flags.DEFINE_boolean("useCNN", False, "Use wiki descp. CNN")
55 | flags.DEFINE_boolean("glove", True, "Use Glove Embeddings")
56 | flags.DEFINE_boolean("entyping", False, "Use Entity Type Prediction")
57 | flags.DEFINE_integer("WDLength", 100, "Length of wiki description")
58 | flags.DEFINE_integer("Fsize", 5, "For CNN filter size")
59 |
60 | flags.DEFINE_string("optimizer", 'adam', "Optimizer to use. adagrad, adadelta or adam")
61 |
62 | flags.DEFINE_string("config", 'configs/config.ini',
63 | "VocabConfig Filepath")
64 | flags.DEFINE_string("test_out_fp", "", "Write Test Prediction Data")
65 |
66 | flags.DEFINE_string("tadirpath", "", "Director containing all the text-annos")
67 | flags.DEFINE_string("taoutdirpath", "", "Director containing all the text-annos")
68 |
69 |
70 |
71 | FLAGS = flags.FLAGS
72 |
73 |
74 | def FLAGS_check(FLAGS):
75 | if not (FLAGS.textcontext and FLAGS.coherence):
76 | print("*** Local and Document context required ***")
77 | sys.exit(0)
78 | assert os.path.exists(FLAGS.model_path), "Model path doesn't exist."
79 |
80 | assert(FLAGS.mode == 'ta'), "Only mode == ta allowed"
81 |
82 |
83 | def getAllTAFilePaths(FLAGS):
84 | tadir = FLAGS.tadirpath
85 | taoutdirpath = FLAGS.taoutdirpath
86 | onlyfiles = [f for f in listdir(tadir) if isfile(join(tadir, f))]
87 | ta_files = [os.path.join(tadir, fname) for fname in onlyfiles]
88 |
89 | output_ta_files = [os.path.join(taoutdirpath, fname) for fname in onlyfiles]
90 |
91 | return (ta_files, output_ta_files)
92 |
93 |
94 | def main(_):
95 | pp.pprint(flags.FLAGS.__flags)
96 |
97 | FLAGS_check(FLAGS)
98 |
99 | config = Config(FLAGS.config, verbose=False)
100 | vocabloader = VocabLoader(config)
101 |
102 | FLAGS.dropout_keep_prob = 1.0
103 | FLAGS.wordDropoutKeep = 1.0
104 | FLAGS.cohDropoutKeep = 1.0
105 |
106 | (intput_ta_files, output_ta_files) = getAllTAFilePaths(FLAGS)
107 |
108 | print("TOTAL NUMBER OF TAS : {}".format(len(intput_ta_files)))
109 |
110 | reader = TextAnnoTestReader(
111 | config=config,
112 | vocabloader=vocabloader,
113 | num_cands=30,
114 | batch_size=FLAGS.batch_size,
115 | strict_context=FLAGS.strict_context,
116 | pretrain_wordembed=FLAGS.pretrain_wordembed,
117 | coherence=FLAGS.coherence,
118 | nerviewname="NER")
119 |
120 | model_mode = 'test'
121 |
122 | config_proto = tf.ConfigProto()
123 | config_proto.allow_soft_placement = True
124 | config_proto.gpu_options.allow_growth=True
125 | sess = tf.Session(config=config_proto)
126 |
127 |
128 | with sess.as_default():
129 | model = ELModel(
130 | sess=sess, reader=reader, dataset=FLAGS.dataset,
131 | max_steps=FLAGS.max_steps,
132 | pretrain_max_steps=FLAGS.pretraining_steps,
133 | word_embed_dim=FLAGS.word_embed_dim,
134 | context_encoded_dim=FLAGS.context_encoded_dim,
135 | context_encoder_num_layers=FLAGS.context_encoder_num_layers,
136 | context_encoder_lstmsize=FLAGS.context_encoder_lstmsize,
137 | coherence_numlayers=FLAGS.coherence_numlayers,
138 | jointff_numlayers=FLAGS.jointff_numlayers,
139 | learning_rate=FLAGS.learning_rate,
140 | dropout_keep_prob=FLAGS.dropout_keep_prob,
141 | reg_constant=FLAGS.reg_constant,
142 | checkpoint_dir=FLAGS.checkpoint_dir,
143 | optimizer=FLAGS.optimizer,
144 | mode=model_mode,
145 | strict=FLAGS.strict_context,
146 | pretrain_word_embed=FLAGS.pretrain_wordembed,
147 | typing=FLAGS.typing,
148 | el=FLAGS.el,
149 | coherence=FLAGS.coherence,
150 | textcontext=FLAGS.textcontext,
151 | useCNN=FLAGS.useCNN,
152 | WDLength=FLAGS.WDLength,
153 | Fsize=FLAGS.Fsize,
154 | entyping=FLAGS.entyping)
155 |
156 | model.load_ckpt_model(ckptpath=FLAGS.model_path)
157 |
158 | print("Total files: {}".format(len(output_ta_files)))
159 | erroneous_files = 0
160 | for in_ta_path, out_ta_path in zip(intput_ta_files, output_ta_files):
161 | # print("Running the inference for : {}".format(in_ta_path))
162 | try:
163 | reader.new_test_file(in_ta_path)
164 | except:
165 | print("Error reading : {}".format(in_ta_path))
166 | erroneous_files += 1
167 | continue
168 |
169 | (predTypScNPmat_list,
170 | widIdxs_list,
171 | priorProbs_list,
172 | textProbs_list,
173 | jointProbs_list,
174 | evWTs_list,
175 | pred_TypeSetsList) = model.inference_run()
176 |
177 | # model.inference(ckptpath=FLAGS.model_path)
178 |
179 | wiki_view = copy.deepcopy(reader.textanno.get_view("NER"))
180 | docta = reader.textanno
181 |
182 | el_cons_list = wiki_view.cons_list
183 | numMentionsInference = len(widIdxs_list)
184 |
185 | # print("Number of mentions in model: {}".format(len(widIdxs_list)))
186 | # print("Number of NER mention: {}".format(len(el_cons_list)))
187 |
188 | assert len(el_cons_list) == numMentionsInference
189 |
190 | mentionnum = 0
191 | for ner_cons in el_cons_list:
192 | priorScoreMap = {}
193 | contextScoreMap = {}
194 | jointScoreMap = {}
195 |
196 | (wididxs, pps, mps, jps) = (widIdxs_list[mentionnum],
197 | priorProbs_list[mentionnum],
198 | textProbs_list[mentionnum],
199 | jointProbs_list[mentionnum])
200 |
201 | maxJointProb = 0.0
202 | maxJointEntity = ""
203 | for (wididx, prp, mp, jp) in zip(wididxs, pps, mps, jps):
204 | wT = reader.widIdx2WikiTitle(wididx)
205 | priorScoreMap[wT] = prp
206 | contextScoreMap[wT] = mp
207 | jointScoreMap[wT] = jp
208 |
209 | if jp > maxJointProb:
210 | maxJointProb = jp
211 | maxJointEntity = wT
212 |
213 |
214 | ''' add labels2score map here '''
215 | ner_cons["jointScoreMap"] = jointScoreMap
216 | ner_cons["contextScoreMap"] = contextScoreMap
217 | ner_cons["priorScoreMap"] = priorScoreMap
218 |
219 | # add max scoring entity as label
220 | ner_cons["label"] = maxJointEntity
221 | ner_cons["score"] = maxJointProb
222 |
223 | mentionnum += 1
224 |
225 | wiki_view.view_name = "NEUREL"
226 | docta.view_dictionary["NEUREL"] = wiki_view
227 |
228 | docta_json = docta.as_json
229 | json.dump(docta_json, open(out_ta_path, "w"), indent=True)
230 |
231 | print("Number of erroneous files: {}".format(erroneous_files))
232 | print("Annotation completed. Program can be exited safely.")
233 | sys.exit()
234 |
235 | if __name__ == '__main__':
236 | tf.app.run()
237 |
--------------------------------------------------------------------------------
/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nitishgupta/neural-el/8c7c278acefa66238a75e805511ff26a567fd4e0/overview.png
--------------------------------------------------------------------------------
/readers/Mention.py:
--------------------------------------------------------------------------------
1 | start_word = ""
2 | end_word = ""
3 | unk_word = ""
4 |
5 |
6 | class Mention(object):
7 | def __init__(self, mention_line):
8 | ''' mention_line : Is the string line stored for each mention
9 | mid wid wikititle start_token end_token surface tokenized_sentence
10 | all_types
11 | '''
12 | mention_line = mention_line.strip()
13 | split = mention_line.split("\t")
14 | (self.mid, self.wid, self.wikititle) = split[0:3]
15 | self.start_token = int(split[3]) + 1 # Adding in the start
16 | self.end_token = int(split[4]) + 1
17 | self.surface = split[5]
18 | self.sent_tokens = [start_word]
19 | self.sent_tokens.extend(split[6].split(" "))
20 | self.sent_tokens.append(end_word)
21 | self.types = split[7].split(" ")
22 | if len(split) > 8: # If no mention surface words in coherence
23 | if split[8].strip() == "":
24 | self.coherence = [unk_word]
25 | else:
26 | self.coherence = split[8].split(" ")
27 | if len(split) == 10:
28 | self.docid = split[9]
29 |
30 | assert self.end_token <= (len(self.sent_tokens) - 1), "Line : %s" % mention_line
31 | #enddef
32 |
33 | def toString(self):
34 | outstr = self.wid + "\t"
35 | outstr += self.wikititle + "\t"
36 | for i in range(1, len(self.sent_tokens)):
37 | outstr += self.sent_tokens[i] + " "
38 |
39 | outstr = outstr.strip()
40 | return outstr
41 |
42 | #endclass
43 |
--------------------------------------------------------------------------------
/readers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nitishgupta/neural-el/8c7c278acefa66238a75e805511ff26a567fd4e0/readers/__init__.py
--------------------------------------------------------------------------------
/readers/config.py:
--------------------------------------------------------------------------------
1 | import pprint
2 | import configparser
3 |
4 | pp = pprint.PrettyPrinter()
5 |
6 | class Config(object):
7 | def __init__(self, paths_config, verbose=False):
8 | config = configparser.ConfigParser()
9 | config._interpolation = configparser.ExtendedInterpolation()
10 | config.read(paths_config)
11 | print(paths_config)
12 |
13 | c = config['DEFAULT']
14 |
15 | d = {}
16 | for k in c:
17 | d[k] = c[k]
18 |
19 | self.resources_dir = d['resources_dir']
20 |
21 | self.vocab_dir = d['vocab_dir']
22 |
23 | # Word2Vec Vocab to Idxs
24 | self.word_vocab_pkl = d['word_vocab_pkl']
25 | # Wid2Idx for Known Entities ~ 620K (readers.train.vocab.py)
26 | self.kwnwid_vocab_pkl = d['kwnwid_vocab_pkl']
27 | # FIGER Type label 2 idx (readers.train.vocab.py)
28 | self.label_vocab_pkl = d['label_vocab_pkl']
29 | # EntityWid: [FIGER Type Labels]
30 | # CoherenceStr2Idx at various thresholds (readers.train.vocab.py)
31 | self.cohstringG9_vocab_pkl = d['cohstringg9_vocab_pkl']
32 |
33 | # wid2Wikititle for whole KB ~ 3.18M (readers.train.vocab.py)
34 | self.widWiktitle_pkl = d['widwiktitle_pkl']
35 |
36 | self.crosswikis_pruned_pkl = d['crosswikis_pruned_pkl']
37 |
38 | self.glove_pkl = d['glove_pkl']
39 | self.glove_word_vocab_pkl = d['glove_word_vocab_pkl']
40 |
41 | self.test_file = d['test_file']
42 |
43 | if verbose:
44 | pp.pprint(d)
45 |
46 | #endinit
47 |
48 | if __name__=='__main__':
49 | c = Config("configs/allnew_mentions_config.ini", verbose=True)
50 |
--------------------------------------------------------------------------------
/readers/crosswikis_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | from readers.config import Config
5 | from readers.vocabloader import VocabLoader
6 | from readers import utils
7 |
8 |
9 |
10 | class CrosswikisTest(object):
11 | def __init__(self, config, vocabloader):
12 | print("Loading Crosswikis")
13 | # self.crosswikis = vocabloader.loadCrosswikis()
14 |
15 | stime = time.time()
16 | self.crosswikis = utils.load(config.crosswikis_pruned_pkl)
17 | ttime = time.time() - stime
18 | print("Crosswikis Loaded. Size : {}".format(len(self.crosswikis)))
19 | print("Time taken : {} secs".format(ttime))
20 |
21 | (self.knwid2idx, self.idx2knwid) = vocabloader.getKnwnWidVocab()
22 | print("Size of known wids : {}".format(len(self.knwid2idx)))
23 |
24 | def test(self):
25 | print("Test starting")
26 | maxCands = 0
27 | minCands = 0
28 | notKnownWid = False
29 | smallSurface = 0
30 | notSortedProbs = 0
31 | for surface, c_cprobs in self.crosswikis.items():
32 | notSorted = False
33 | numCands = len(c_cprobs)
34 | if numCands < minCands:
35 | minCands = numCands
36 | if numCands > maxCands:
37 | maxCands = numCands
38 |
39 | prob_prv = 10.0
40 | for (wid, prob) in c_cprobs:
41 | if wid not in self.knwid2idx:
42 | notKnownWid = True
43 | if prob_prv < prob:
44 | notSorted = True
45 | prob_prv = prob
46 |
47 | if notSorted:
48 | notSortedProbs += 1
49 |
50 | if len(surface) <= 1:
51 | smallSurface += 1
52 |
53 | print("Max Cands : {}".format(maxCands))
54 | print("min Cands : {}".format(minCands))
55 | print("Not Known Wid : {}".format(notKnownWid))
56 | print("small surfaces {}".format(smallSurface))
57 | print("Not Sorted Probs {}".format(notSortedProbs))
58 |
59 | def test_pruned(self):
60 | print("Test starting")
61 | maxCands = 0
62 | minCands = 30
63 | notKnownWid = False
64 | smallSurface = 0
65 | notSortedProbs = 0
66 | for surface, c_cprobs in self.crosswikis.items():
67 | notSorted = False
68 | (wids, probs) = c_cprobs
69 | numCands = len(wids)
70 | if numCands < minCands:
71 | minCands = numCands
72 | if numCands > maxCands:
73 | maxCands = numCands
74 |
75 | prob_prv = 10.0
76 | for (wid, prob) in zip(wids, probs):
77 | if wid not in self.knwid2idx:
78 | notKnownWid = True
79 | if prob_prv < prob:
80 | notSorted = True
81 | prob_prv = prob
82 |
83 | if notSorted:
84 | notSortedProbs += 1
85 |
86 | if len(surface) <= 1:
87 | smallSurface += 1
88 |
89 | print("Max Cands : {}".format(maxCands))
90 | print("min Cands : {}".format(minCands))
91 | print("Not Known Wid : {}".format(notKnownWid))
92 | print("small surfaces {}".format(smallSurface))
93 | print("Not Sorted Probs {}".format(notSortedProbs))
94 |
95 | def makeCWKnown(self, cwOutPath):
96 | cw = {}
97 | MAXCAND = 30
98 | surfacesProcessed = 0
99 | for surface, c_cprobs in self.crosswikis.items():
100 | surfacesProcessed += 1
101 | if surfacesProcessed % 1000000 == 0:
102 | print("Surfaces Processed : {}".format(surfacesProcessed))
103 |
104 | if len(c_cprobs) == 0:
105 | continue
106 | if len(surface) <= 1:
107 | continue
108 | candsAdded = 0
109 | c_probs = ([], [])
110 | # cw[surface] = ([], [])
111 | for (wid, prob) in c_cprobs:
112 | if candsAdded == 30:
113 | break
114 | if wid in self.knwid2idx:
115 | c_probs[0].append(wid)
116 | c_probs[1].append(prob)
117 | candsAdded += 1
118 | if candsAdded != 0:
119 | cw[surface] = c_probs
120 | print("Processed")
121 | print("Size of CW : {}".format(len(cw)))
122 | utils.save(cwOutPath, cw)
123 | print("Saved pruned CW")
124 |
125 |
126 |
127 | if __name__ == '__main__':
128 | configpath = "configs/config.ini"
129 | config = Config(configpath, verbose=False)
130 | vocabloader = VocabLoader(config)
131 | cwikistest = CrosswikisTest(config, vocabloader)
132 | cwikistest.test_pruned()
133 | # cwikistest.makeCWKnown(os.path.join(config.resources_dir,
134 | # "crosswikis.pruned.pkl"))
135 | import os
136 | import sys
137 | import time
138 | from readers.config import Config
139 | from readers.vocabloader import VocabLoader
140 | from readers import utils
141 |
142 |
143 |
144 | class CrosswikisTest(object):
145 | def __init__(self, config, vocabloader):
146 | print("Loading Crosswikis")
147 | # self.crosswikis = vocabloader.loadCrosswikis()
148 |
149 | stime = time.time()
150 | self.crosswikis = utils.load(config.crosswikis_pruned_pkl)
151 | ttime = time.time() - stime
152 | print("Crosswikis Loaded. Size : {}".format(len(self.crosswikis)))
153 | print("Time taken : {} secs".format(ttime))
154 |
155 | (self.knwid2idx, self.idx2knwid) = vocabloader.getKnwnWidVocab()
156 | print("Size of known wids : {}".format(len(self.knwid2idx)))
157 |
158 | def test(self):
159 | print("Test starting")
160 | maxCands = 0
161 | minCands = 0
162 | notKnownWid = False
163 | smallSurface = 0
164 | notSortedProbs = 0
165 | for surface, c_cprobs in self.crosswikis.items():
166 | notSorted = False
167 | numCands = len(c_cprobs)
168 | if numCands < minCands:
169 | minCands = numCands
170 | if numCands > maxCands:
171 | maxCands = numCands
172 |
173 | prob_prv = 10.0
174 | for (wid, prob) in c_cprobs:
175 | if wid not in self.knwid2idx:
176 | notKnownWid = True
177 | if prob_prv < prob:
178 | notSorted = True
179 | prob_prv = prob
180 |
181 | if notSorted:
182 | notSortedProbs += 1
183 |
184 | if len(surface) <= 1:
185 | smallSurface += 1
186 |
187 | print("Max Cands : {}".format(maxCands))
188 | print("min Cands : {}".format(minCands))
189 | print("Not Known Wid : {}".format(notKnownWid))
190 | print("small surfaces {}".format(smallSurface))
191 | print("Not Sorted Probs {}".format(notSortedProbs))
192 |
193 | def test_pruned(self):
194 | print("Test starting")
195 | maxCands = 0
196 | minCands = 30
197 | notKnownWid = False
198 | smallSurface = 0
199 | notSortedProbs = 0
200 | for surface, c_cprobs in self.crosswikis.items():
201 | notSorted = False
202 | (wids, probs) = c_cprobs
203 | numCands = len(wids)
204 | if numCands < minCands:
205 | minCands = numCands
206 | if numCands > maxCands:
207 | maxCands = numCands
208 |
209 | prob_prv = 10.0
210 | for (wid, prob) in zip(wids, probs):
211 | if wid not in self.knwid2idx:
212 | notKnownWid = True
213 | if prob_prv < prob:
214 | notSorted = True
215 | prob_prv = prob
216 |
217 | if notSorted:
218 | notSortedProbs += 1
219 |
220 | if len(surface) <= 1:
221 | smallSurface += 1
222 |
223 | print("Max Cands : {}".format(maxCands))
224 | print("min Cands : {}".format(minCands))
225 | print("Not Known Wid : {}".format(notKnownWid))
226 | print("small surfaces {}".format(smallSurface))
227 | print("Not Sorted Probs {}".format(notSortedProbs))
228 |
229 | def makeCWKnown(self, cwOutPath):
230 | cw = {}
231 | MAXCAND = 30
232 | surfacesProcessed = 0
233 | for surface, c_cprobs in self.crosswikis.items():
234 | surfacesProcessed += 1
235 | if surfacesProcessed % 1000000 == 0:
236 | print("Surfaces Processed : {}".format(surfacesProcessed))
237 |
238 | if len(c_cprobs) == 0:
239 | continue
240 | if len(surface) <= 1:
241 | continue
242 | candsAdded = 0
243 | c_probs = ([], [])
244 | # cw[surface] = ([], [])
245 | for (wid, prob) in c_cprobs:
246 | if candsAdded == 30:
247 | break
248 | if wid in self.knwid2idx:
249 | c_probs[0].append(wid)
250 | c_probs[1].append(prob)
251 | candsAdded += 1
252 | if candsAdded != 0:
253 | cw[surface] = c_probs
254 | print("Processed")
255 | print("Size of CW : {}".format(len(cw)))
256 | utils.save(cwOutPath, cw)
257 | print("Saved pruned CW")
258 |
259 |
260 |
261 | if __name__ == '__main__':
262 | configpath = "configs/config.ini"
263 | config = Config(configpath, verbose=False)
264 | vocabloader = VocabLoader(config)
265 | cwikistest = CrosswikisTest(config, vocabloader)
266 | cwikistest.test_pruned()
267 | # cwikistest.makeCWKnown(os.path.join(config.resources_dir,
268 | # "crosswikis.pruned.pkl"))
269 |
--------------------------------------------------------------------------------
/readers/inference_reader.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | import numpy as np
4 | import readers.utils as utils
5 | from readers.Mention import Mention
6 | from readers.config import Config
7 | from readers.vocabloader import VocabLoader
8 | from ccg_nlpy import remote_pipeline
9 |
10 | start_word = ""
11 | end_word = ""
12 |
13 | class InferenceReader(object):
14 | def __init__(self, config, vocabloader, test_mens_file,
15 | num_cands, batch_size, strict_context=True,
16 | pretrain_wordembed=True, coherence=True):
17 | self.pipeline = remote_pipeline.RemotePipeline(
18 | server_api='http://macniece.seas.upenn.edu:4001')
19 | self.typeOfReader = "inference"
20 | self.start_word = start_word
21 | self.end_word = end_word
22 | self.unk_word = 'unk' # In tune with word2vec
23 | self.unk_wid = ""
24 | self.tr_sup = 'tr_sup'
25 | self.tr_unsup = 'tr_unsup'
26 | self.pretrain_wordembed = pretrain_wordembed
27 | self.coherence = coherence
28 |
29 | # Word Vocab
30 | (self.word2idx, self.idx2word) = vocabloader.getGloveWordVocab()
31 | self.num_words = len(self.idx2word)
32 |
33 | # Label Vocab
34 | (self.label2idx, self.idx2label) = vocabloader.getLabelVocab()
35 | self.num_labels = len(self.idx2label)
36 |
37 | # Known WID Vocab
38 | (self.knwid2idx, self.idx2knwid) = vocabloader.getKnwnWidVocab()
39 | self.num_knwn_entities = len(self.idx2knwid)
40 |
41 | # Wid2Wikititle Map
42 | self.wid2WikiTitle = vocabloader.getWID2Wikititle()
43 |
44 | # Coherence String Vocab
45 | print("Loading Coherence Strings Dicts ... ")
46 | (self.cohG92idx, self.idx2cohG9) = utils.load(
47 | config.cohstringG9_vocab_pkl)
48 | self.num_cohstr = len(self.idx2cohG9)
49 |
50 | # Crosswikis
51 | print("Loading Crosswikis dict. (takes ~2 mins to load)")
52 | self.crosswikis = utils.load(config.crosswikis_pruned_pkl)
53 | print("Crosswikis loaded. Size: {}".format(len(self.crosswikis)))
54 |
55 | if self.pretrain_wordembed:
56 | stime = time.time()
57 | self.word2vec = vocabloader.loadGloveVectors()
58 | print("[#] Glove Vectors loaded!")
59 | ttime = (time.time() - stime)/float(60)
60 |
61 |
62 | print("[#] Test Mentions File : {}".format(test_mens_file))
63 |
64 | print("[#] Loading test file and preprocessing ... ")
65 | self.processTestDoc(test_mens_file)
66 | self.mention_lines = self.convertSent2NerToMentionLines()
67 | self.mentions = []
68 | for line in self.mention_lines:
69 | m = Mention(line)
70 | self.mentions.append(m)
71 |
72 | self.men_idx = 0
73 | self.num_mens = len(self.mentions)
74 | self.epochs = 0
75 | print( "[#] Test Mentions : {}".format(self.num_mens))
76 |
77 | self.batch_size = batch_size
78 | print("[#] Batch Size: %d" % self.batch_size)
79 | self.num_cands = num_cands
80 | self.strict_context = strict_context
81 |
82 | print("\n[#]LOADING COMPLETE")
83 | #******************* END __init__ *********************************
84 |
85 | def get_vector(self, word):
86 | if word in self.word2vec:
87 | return self.word2vec[word]
88 | else:
89 | return self.word2vec['unk']
90 |
91 | def reset_test(self):
92 | self.men_idx = 0
93 | self.epochs = 0
94 |
95 | def processTestDoc(self, test_mens_file):
96 | with open(test_mens_file, 'r') as f:
97 | lines = f.read().strip().split("\n")
98 | assert len(lines) == 1, "Only support inference for single doc"
99 | self.doctext = lines[0].strip()
100 | self.ccgdoc = self.pipeline.doc(self.doctext)
101 | # List of tokens
102 | self.doc_tokens = self.ccgdoc.get_tokens
103 | # sent_end_token_indices : contains index for the starting of the
104 | # next sentence.
105 | self.sent_end_token_indices = \
106 | self.ccgdoc.get_sentence_end_token_indices
107 | # List of tokenized sentences
108 | self.sentences_tokenized = []
109 | for i in range(0, len(self.sent_end_token_indices)):
110 | start = self.sent_end_token_indices[i-1] if i != 0 else 0
111 | end = self.sent_end_token_indices[i]
112 | sent_tokens = self.doc_tokens[start:end]
113 | self.sentences_tokenized.append(sent_tokens)
114 |
115 | # List of ner dicts from ccg pipeline
116 | self.ner_cons_list = []
117 | try:
118 | self.ner_cons_list = self.ccgdoc.get_ner_conll.cons_list
119 | except:
120 | print("NO NAMED ENTITIES IN THE DOC. EXITING")
121 |
122 | # SentIdx : [(tokenized_sent, ner_dict)]
123 | self.sentidx2ners = {}
124 | for ner in self.ner_cons_list:
125 | found = False
126 | # idx = sentIdx, j = sentEndTokenIdx
127 | for idx, j in enumerate(self.sent_end_token_indices):
128 | sent_start_token = self.sent_end_token_indices[idx-1] \
129 | if idx != 0 else 0
130 | # ner['end'] is the idx of the token after ner
131 | if ner['end'] < j:
132 | if idx not in self.sentidx2ners:
133 | self.sentidx2ners[idx] = []
134 | ner['start'] = ner['start'] - sent_start_token
135 | ner['end'] = ner['end'] - sent_start_token - 1
136 | self.sentidx2ners[idx].append(
137 | (self.sentences_tokenized[idx], ner))
138 | found = True
139 | if found:
140 | break
141 |
142 | def convertSent2NerToMentionLines(self):
143 | '''Convert NERs from document to list of mention strings'''
144 | mentions = []
145 | # Make Document Context String for whole document
146 | cohStr = ""
147 | for sent_idx, s_nerDicts in self.sentidx2ners.items():
148 | for s, ner in s_nerDicts:
149 | cohStr += ner['tokens'].replace(' ', '_') + ' '
150 |
151 | cohStr = cohStr.strip()
152 |
153 | for idx in range(0, len(self.sentences_tokenized)):
154 | if idx in self.sentidx2ners:
155 | sentence = ' '.join(self.sentences_tokenized[idx])
156 | s_nerDicts = self.sentidx2ners[idx]
157 | for s, ner in s_nerDicts:
158 | mention = "%s\t%s\t%s" % ("unk_mid", "unk_wid", "unkWT")
159 | mention = mention + str('\t') + str(ner['start'])
160 | mention = mention + '\t' + str(ner['end'])
161 | mention = mention + '\t' + str(ner['tokens'])
162 | mention = mention + '\t' + sentence
163 | mention = mention + '\t' + "UNK_TYPES"
164 | mention = mention + '\t' + cohStr
165 | mentions.append(mention)
166 | return mentions
167 |
168 | def bracketMentionInSentence(self, s, nerDict):
169 | tokens = s.split(" ")
170 | start = nerDict['start']
171 | end = nerDict['end']
172 | tokens.insert(start, '[[')
173 | tokens.insert(end + 2, ']]')
174 | return ' '.join(tokens)
175 |
176 | def _read_mention(self):
177 | mention = self.mentions[self.men_idx]
178 | self.men_idx += 1
179 | if self.men_idx == self.num_mens:
180 | self.men_idx = 0
181 | self.epochs += 1
182 | return mention
183 |
184 | def _next_batch(self):
185 | ''' Data : wikititle \t mid \t wid \t start \t end \t tokens \t labels
186 | start and end are inclusive
187 | '''
188 | # Sentence = s1 ... m1 ... mN, ... sN.
189 | # Left Batch = s1 ... m1 ... mN
190 | # Right Batch = sN ... mN ... m1
191 | (left_batch, right_batch) = ([], [])
192 |
193 | coh_indices = []
194 | coh_values = []
195 | if self.coherence:
196 | coh_matshape = [self.batch_size, self.num_cohstr]
197 | else:
198 | coh_matshape = []
199 |
200 | # Candidate WID idxs and their cprobs
201 | # First element is always true wid
202 | (wid_idxs_batch, wid_cprobs_batch) = ([], [])
203 |
204 | while len(left_batch) < self.batch_size:
205 | batch_el = len(left_batch)
206 | m = self._read_mention()
207 |
208 | # for label in m.types:
209 | # if label in self.label2idx:
210 | # labelidx = self.label2idx[label]
211 | # labels_batch[batch_el][labelidx] = 1.0
212 |
213 | cohFound = False # If no coherence mention is found, add unk
214 | if self.coherence:
215 | cohidxs = [] # Indexes in the [B, NumCoh] matrix
216 | cohvals = [] # 1.0 to indicate presence
217 | for cohstr in m.coherence:
218 | if cohstr in self.cohG92idx:
219 | cohidx = self.cohG92idx[cohstr]
220 | cohidxs.append([batch_el, cohidx])
221 | cohvals.append(1.0)
222 | cohFound = True
223 | if cohFound:
224 | coh_indices.extend(cohidxs)
225 | coh_values.extend(cohvals)
226 | else:
227 | cohidx = self.cohG92idx[self.unk_word]
228 | coh_indices.append([batch_el, cohidx])
229 | coh_values.append(1.0)
230 |
231 | # Left and Right context includes mention surface
232 | left_tokens = m.sent_tokens[0:m.end_token+1]
233 | right_tokens = m.sent_tokens[m.start_token:][::-1]
234 |
235 | # Strict left and right context
236 | if self.strict_context:
237 | left_tokens = m.sent_tokens[0:m.start_token]
238 | right_tokens = m.sent_tokens[m.end_token+1:][::-1]
239 | # Left and Right context includes mention surface
240 | else:
241 | left_tokens = m.sent_tokens[0:m.end_token+1]
242 | right_tokens = m.sent_tokens[m.start_token:][::-1]
243 |
244 | if not self.pretrain_wordembed:
245 | left_idxs = [self.convert_word2idx(word)
246 | for word in left_tokens]
247 | right_idxs = [self.convert_word2idx(word)
248 | for word in right_tokens]
249 | else:
250 | left_idxs = left_tokens
251 | right_idxs = right_tokens
252 |
253 | left_batch.append(left_idxs)
254 | right_batch.append(right_idxs)
255 |
256 | # wids : [true_knwn_idx, cand1_idx, cand2_idx, ..., unk_idx]
257 | # wid_cprobs : [cwikis probs or 0.0 for unks]
258 | (wid_idxs, wid_cprobs) = self.make_candidates_cprobs(m)
259 | wid_idxs_batch.append(wid_idxs)
260 | wid_cprobs_batch.append(wid_cprobs)
261 |
262 | coherence_batch = (coh_indices, coh_values, coh_matshape)
263 |
264 | return (left_batch, right_batch,
265 | coherence_batch, wid_idxs_batch, wid_cprobs_batch)
266 |
267 | def print_test_batch(self, mention, wid_idxs, wid_cprobs):
268 | print("Surface : {} WID : {} WT: {}".format(
269 | mention.surface, mention.wid, self.wid2WikiTitle[mention.wid]))
270 | print(mention.wid in self.knwid2idx)
271 | for (idx,cprob) in zip(wid_idxs, wid_cprobs):
272 | print("({} : {:0.5f})".format(
273 | self.wid2WikiTitle[self.idx2knwid[idx]], cprob), end=" ")
274 | print("\n")
275 |
276 | def make_candidates_cprobs(self, m):
277 | # Fill num_cands now
278 | surface = utils._getLnrm(m.surface)
279 | if surface in self.crosswikis:
280 | # Pruned crosswikis has only known wids and 30 cands at max
281 | candwids_cprobs = self.crosswikis[surface][0:self.num_cands-1]
282 | (wids, wid_cprobs) = candwids_cprobs
283 | wid_idxs = [self.knwid2idx[wid] for wid in wids]
284 |
285 | # All possible candidates added now. Pad with unks
286 | assert len(wid_idxs) == len(wid_cprobs)
287 | remain = self.num_cands - len(wid_idxs)
288 | wid_idxs.extend([0]*remain)
289 | wid_cprobs.extend([0.0]*remain)
290 |
291 | return (wid_idxs, wid_cprobs)
292 |
293 | def embed_batch(self, batch):
294 | ''' Input is a padded batch of left or right contexts containing words
295 | Dimensions should be [B, padded_length]
296 | Output:
297 | Embed the word idxs using pretrain word embedding
298 | '''
299 | output_batch = []
300 | for sent in batch:
301 | word_embeddings = [self.get_vector(word) for word in sent]
302 | output_batch.append(word_embeddings)
303 | return output_batch
304 |
305 | def embed_mentions_batch(self, mentions_batch):
306 | ''' Input is batch of mention tokens as a list of list of tokens.
307 | Output: For each mention, average word embeddings '''
308 | embedded_mentions_batch = []
309 | for m_tokens in mentions_batch:
310 | outvec = np.zeros(300, dtype=float)
311 | for word in m_tokens:
312 | outvec += self.get_vector(word)
313 | outvec = outvec / len(m_tokens)
314 | embedded_mentions_batch.append(outvec)
315 | return embedded_mentions_batch
316 |
317 | def pad_batch(self, batch):
318 | if not self.pretrain_wordembed:
319 | pad_unit = self.word2idx[self.unk_word]
320 | else:
321 | pad_unit = self.unk_word
322 |
323 | lengths = [len(i) for i in batch]
324 | max_length = max(lengths)
325 | for i in range(0, len(batch)):
326 | batch[i].extend([pad_unit]*(max_length - lengths[i]))
327 | return (batch, lengths)
328 |
329 | def _next_padded_batch(self):
330 | (left_batch, right_batch,
331 | coherence_batch,
332 | wid_idxs_batch, wid_cprobs_batch) = self._next_batch()
333 |
334 | (left_batch, left_lengths) = self.pad_batch(left_batch)
335 | (right_batch, right_lengths) = self.pad_batch(right_batch)
336 |
337 | if self.pretrain_wordembed:
338 | left_batch = self.embed_batch(left_batch)
339 | right_batch = self.embed_batch(right_batch)
340 |
341 | return (left_batch, left_lengths, right_batch, right_lengths,
342 | coherence_batch, wid_idxs_batch, wid_cprobs_batch)
343 |
344 | def convert_word2idx(self, word):
345 | if word in self.word2idx:
346 | return self.word2idx[word]
347 | else:
348 | return self.word2idx[self.unk_word]
349 |
350 | def next_test_batch(self):
351 | return self._next_padded_batch()
352 |
353 | def widIdx2WikiTitle(self, widIdx):
354 | wid = self.idx2knwid[widIdx]
355 | wikiTitle = self.wid2WikiTitle[wid]
356 | return wikiTitle
357 |
358 | if __name__ == '__main__':
359 | sttime = time.time()
360 | batch_size = 2
361 | num_batch = 1000
362 | configpath = "configs/all_mentions_config.ini"
363 | config = Config(configpath, verbose=False)
364 | vocabloader = VocabLoader(config)
365 | b = InferenceReader(config=config,
366 | vocabloader=vocabloader,
367 | test_mens_file=config.test_file,
368 | num_cands=30,
369 | batch_size=batch_size,
370 | strict_context=False,
371 | pretrain_wordembed=True,
372 | coherence=True)
373 |
374 | stime = time.time()
375 |
376 | i = 0
377 | total_instances = 0
378 | while b.epochs < 1:
379 | (left_batch, left_lengths, right_batch, right_lengths,
380 | coherence_batch, wid_idxs_batch,
381 | wid_cprobs_batch) = b.next_test_batch()
382 | if i % 100 == 0:
383 | etime = time.time()
384 | t=etime-stime
385 | print("{} done. Time taken : {} seconds".format(i, t))
386 | i += 1
387 | etime = time.time()
388 | t=etime-stime
389 | tt = etime - sttime
390 | print("Total Instances : {}".format(total_instances))
391 | print("Batching time (in secs) to make %d batches of size %d : %7.4f seconds" % (i, batch_size, t))
392 | print("Total time (in secs) to make %d batches of size %d : %7.4f seconds" % (i, batch_size, tt))
393 |
--------------------------------------------------------------------------------
/readers/test_reader.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import readers.utils as utils
4 | from readers.Mention import Mention
5 | from readers.config import Config
6 | from readers.vocabloader import VocabLoader
7 |
8 | start_word = ""
9 | end_word = ""
10 |
11 | class TestDataReader(object):
12 | def __init__(self, config, vocabloader, test_mens_file,
13 | num_cands, batch_size, strict_context=True,
14 | pretrain_wordembed=True, coherence=True,
15 | glove=True):
16 | print("Loading Test Reader: {}".format(test_mens_file))
17 | self.typeOfReader="test"
18 | self.start_word = start_word
19 | self.end_word = end_word
20 | self.unk_word = 'unk' # In tune with word2vec
21 | self.unk_wid = ""
22 | # self.useKnownEntitesOnly = True
23 | self.pretrain_wordembed = pretrain_wordembed
24 | self.coherence = coherence
25 |
26 | # Word Vocab
27 | (self.word2idx, self.idx2word) = vocabloader.getGloveWordVocab()
28 | self.num_words = len(self.idx2word)
29 | print(" [#] Word vocab loaded. Size of vocab : {}".format(
30 | self.num_words))
31 |
32 | # Label Vocab
33 | (self.label2idx, self.idx2label) = vocabloader.getLabelVocab()
34 | self.num_labels = len(self.idx2label)
35 | print(" [#] Label vocab loaded. Number of labels : {}".format(
36 | self.num_labels))
37 |
38 | # Known WID Vocab
39 | (self.knwid2idx, self.idx2knwid) = vocabloader.getKnwnWidVocab()
40 | self.num_knwn_entities = len(self.idx2knwid)
41 | print(" [#] Loaded. Num of known wids : {}".format(
42 | self.num_knwn_entities))
43 |
44 | # Wid2Wikititle Map
45 | self.wid2WikiTitle = vocabloader.getWID2Wikititle()
46 | print(" [#] Size of Wid2Wikititle: {}".format(len(
47 | self.wid2WikiTitle)))
48 |
49 | # # Wid2TypeLabels Map
50 | # self.wid2TypeLabels = vocabloader.getWID2TypeLabels()
51 | # print(" [#] Total number of Wids : {}".format(len(
52 | # self.wid2TypeLabels)))
53 |
54 | # Coherence String Vocab
55 | print("Loading Coherence Strings Dicts ... ")
56 | (self.cohG92idx, self.idx2cohG9) = utils.load(
57 | config.cohstringG9_vocab_pkl)
58 | self.num_cohstr = len(self.idx2cohG9)
59 | print(" [#] Number of Coherence Strings in Vocab : {}".format(
60 | self.num_cohstr))
61 |
62 | # Known WID Description Vectors
63 | # self.kwnwid2descvecs = vocabloader.loadKnownWIDDescVecs()
64 | # print(" [#] Size of kwn wid desc vecs dict : {}".format(
65 | # len(self.kwnwid2descvecs)))
66 |
67 | # # Crosswikis
68 | # print("[#] Loading training/val crosswikis dictionary ... ")
69 | # self.test_kwnen_cwikis = vocabloader.getTestKnwEnCwiki()
70 | # self.test_allen_cwikis = vocabloader.getTestAllEnCwiki()
71 |
72 | # Crosswikis
73 | print("Loading Crosswikis dict. (takes ~2 mins to load)")
74 | self.crosswikis = utils.load(config.crosswikis_pruned_pkl)
75 | # self.crosswikis = {}
76 | print("Crosswikis loaded. Size: {}".format(len(self.crosswikis)))
77 |
78 | if self.pretrain_wordembed:
79 | stime = time.time()
80 | self.word2vec = vocabloader.loadGloveVectors()
81 | print("[#] Glove Vectors loaded!")
82 | ttime = (time.time() - stime)/float(60)
83 | print("[#] Time to load vectors : {} mins".format(ttime))
84 |
85 | print("[#] Test Mentions File : {}".format(test_mens_file))
86 |
87 | print("[#] Pre-loading test mentions ... ")
88 | self.mentions = utils.make_mentions_from_file(test_mens_file)
89 | self.men_idx = 0
90 | self.num_mens = len(self.mentions)
91 | self.epochs = 0
92 | print( "[#] Test Mentions : {}".format(self.num_mens))
93 |
94 | self.batch_size = batch_size
95 | print("[#] Batch Size: %d" % self.batch_size)
96 | self.num_cands = num_cands
97 | self.strict_context = strict_context
98 |
99 | print("\n[#]LOADING COMPLETE")
100 | # ******************* END __init__ *******************************
101 |
102 | def get_vector(self, word):
103 | if word in self.word2vec:
104 | return self.word2vec[word]
105 | else:
106 | return self.word2vec['unk']
107 |
108 | def reset_test(self):
109 | self.men_idx = 0
110 | self.epochs = 0
111 |
112 | def _read_mention(self):
113 | mention = self.mentions[self.men_idx]
114 | self.men_idx += 1
115 | if self.men_idx == self.num_mens:
116 | self.men_idx = 0
117 | self.epochs += 1
118 | return mention
119 |
120 | def _next_batch(self):
121 | ''' Data : wikititle \t mid \t wid \t start \t end \t tokens \t labels
122 | start and end are inclusive
123 | '''
124 | # Sentence = s1 ... m1 ... mN, ... sN.
125 | # Left Batch = s1 ... m1 ... mN
126 | # Right Batch = sN ... mN ... m1
127 | (left_batch, right_batch) = ([], [])
128 |
129 | # Labels : Vector of 0s and 1s of size = number of labels = 113
130 | labels_batch = np.zeros([self.batch_size, self.num_labels])
131 |
132 | coh_indices = []
133 | coh_values = []
134 | if self.coherence:
135 | coh_matshape = [self.batch_size, self.num_cohstr]
136 | else:
137 | coh_matshape = []
138 |
139 | # Wiki Description: [B, N=100, D=300]
140 | # truewid_descvec_batch = []
141 |
142 | # Candidate WID idxs and their cprobs
143 | # First element is always true wid
144 | (wid_idxs_batch, wid_cprobs_batch) = ([], [])
145 |
146 | while len(left_batch) < self.batch_size:
147 | batch_el = len(left_batch)
148 | m = self._read_mention()
149 |
150 | for label in m.types:
151 | if label in self.label2idx:
152 | labelidx = self.label2idx[label]
153 | labels_batch[batch_el][labelidx] = 1.0
154 | #labels
155 |
156 | cohFound = False # If no coherence mention is found, then add unk
157 | if self.coherence:
158 | cohidxs = [] # Indexes in the [B, NumCoh] matrix
159 | cohvals = [] # 1.0 to indicate presence
160 | for cohstr in m.coherence:
161 | if cohstr in self.cohG92idx:
162 | cohidx = self.cohG92idx[cohstr]
163 | cohidxs.append([batch_el, cohidx])
164 | cohvals.append(1.0)
165 | cohFound = True
166 | if cohFound:
167 | coh_indices.extend(cohidxs)
168 | coh_values.extend(cohvals)
169 | else:
170 | cohidx = self.cohG92idx[self.unk_word]
171 | coh_indices.append([batch_el, cohidx])
172 | coh_values.append(1.0)
173 |
174 | # cohFound = False # If no coherence mention found, then add unk
175 | # if self.coherence:
176 | # for cohstr in m.coherence:
177 | # if cohstr in self.cohG92idx:
178 | # cohidx = self.cohG92idx[cohstr]
179 | # coh_indices.append([batch_el, cohidx])
180 | # coh_values.append(1.0)
181 | # cohFound = True
182 | # if not cohFound:
183 | # cohidx = self.cohG92idx[self.unk_word]
184 | # coh_indices.append([batch_el, cohidx])
185 | # coh_values.append(1.0)
186 |
187 | # Left and Right context includes mention surface
188 | left_tokens = m.sent_tokens[0:m.end_token+1]
189 | right_tokens = m.sent_tokens[m.start_token:][::-1]
190 |
191 | # Strict left and right context
192 | if self.strict_context:
193 | left_tokens = m.sent_tokens[0:m.start_token]
194 | right_tokens = m.sent_tokens[m.end_token+1:][::-1]
195 | # Left and Right context includes mention surface
196 | else:
197 | left_tokens = m.sent_tokens[0:m.end_token+1]
198 | right_tokens = m.sent_tokens[m.start_token:][::-1]
199 |
200 | if not self.pretrain_wordembed:
201 | left_idxs = [self.convert_word2idx(word)
202 | for word in left_tokens]
203 | right_idxs = [self.convert_word2idx(word)
204 | for word in right_tokens]
205 | else:
206 | left_idxs = left_tokens
207 | right_idxs = right_tokens
208 |
209 | left_batch.append(left_idxs)
210 | right_batch.append(right_idxs)
211 |
212 | # if m.wid in self.knwid2idx:
213 | # truewid_descvec_batch.append(self.kwnwid2descvecs[m.wid])
214 | # else:
215 | # truewid_descvec_batch.append(
216 | # self.kwnwid2descvecs[self.unk_wid])
217 |
218 | # wids : [true_knwn_idx, cand1_idx, cand2_idx, ..., unk_idx]
219 | # wid_cprobs : [cwikis probs or 0.0 for unks]
220 | (wid_idxs, wid_cprobs) = self.make_candidates_cprobs(m)
221 | wid_idxs_batch.append(wid_idxs)
222 | wid_cprobs_batch.append(wid_cprobs)
223 |
224 | # self.print_test_batch(m, wid_idxs, wid_cprobs)
225 | # print(m.docid)
226 |
227 | #end batch making
228 | coherence_batch = (coh_indices, coh_values, coh_matshape)
229 |
230 | # return (left_batch, right_batch, truewid_descvec_batch, labels_batch,
231 | # coherence_batch, wid_idxs_batch, wid_cprobs_batch)
232 | return (left_batch, right_batch, labels_batch,
233 | coherence_batch, wid_idxs_batch, wid_cprobs_batch)
234 |
235 | def print_test_batch(self, mention, wid_idxs, wid_cprobs):
236 | print("Surface : {} WID : {} WT: {}".format(
237 | mention.surface, mention.wid, self.wid2WikiTitle[mention.wid]))
238 | print(mention.wid in self.knwid2idx)
239 | for (idx,cprob) in zip(wid_idxs, wid_cprobs):
240 | print("({} : {:0.5f})".format(
241 | self.wid2WikiTitle[self.idx2knwid[idx]], cprob), end=" ")
242 | print("\n")
243 |
244 | def make_candidates_cprobs(self, m):
245 | # First wid_idx is true entity
246 | #if self.useKnownEntitesOnly:
247 | if m.wid in self.knwid2idx:
248 | wid_idxs = [self.knwid2idx[m.wid]]
249 | else:
250 | wid_idxs = [self.knwid2idx[self.unk_wid]]
251 | # else:
252 | # ''' Todo: Set wids_idxs[0] in a way to incorporate all entities'''
253 | # wids_idxs = [0]
254 |
255 | # This prob will be updated when going over cwikis candidates
256 | wid_cprobs = [0.0]
257 |
258 | # Crosswikis to use based on Known / All entities
259 | # if self.useKnownEntitesOnly:
260 | cwiki_dict = self.crosswikis
261 | # else:
262 | # cwiki_dict = self.test_all_cwikis
263 |
264 | # Indexing dict to use
265 | # Todo: When changed to all entities, indexing will change
266 | wid2idx = self.knwid2idx
267 |
268 | # Fill num_cands now
269 | surface = utils._getLnrm(m.surface)
270 | if surface in cwiki_dict:
271 | candwids_cprobs = cwiki_dict[surface][0:self.num_cands-1]
272 | (candwids, candwid_cprobs) = candwids_cprobs
273 | for (c, p) in zip(candwids, candwid_cprobs):
274 | if c in wid2idx:
275 | if c == m.wid: # Update cprob for true if in known set
276 | wid_cprobs[0] = p
277 | else:
278 | wid_idxs.append(wid2idx[c])
279 | wid_cprobs.append(p)
280 | # All possible candidates added now. Pad with unks
281 | assert len(wid_idxs) == len(wid_cprobs)
282 | remain = self.num_cands - len(wid_idxs)
283 | wid_idxs.extend([0]*remain)
284 | wid_cprobs.extend([0.0]*remain)
285 |
286 | wid_idxs = wid_idxs[0:self.num_cands]
287 | wid_cprobs = wid_cprobs[0:self.num_cands]
288 |
289 | return (wid_idxs, wid_cprobs)
290 |
291 | def embed_batch(self, batch):
292 | ''' Input is a padded batch of left or right contexts containing words
293 | Dimensions should be [B, padded_length]
294 | Output:
295 | Embed the word idxs using pretrain word embedding
296 | '''
297 | output_batch = []
298 | for sent in batch:
299 | word_embeddings = [self.get_vector(word) for word in sent]
300 | output_batch.append(word_embeddings)
301 | return output_batch
302 |
303 | def embed_mentions_batch(self, mentions_batch):
304 | ''' Input is batch of mention tokens as a list of list of tokens.
305 | Output: For each mention, average word embeddings '''
306 | embedded_mentions_batch = []
307 | for m_tokens in mentions_batch:
308 | outvec = np.zeros(300, dtype=float)
309 | for word in m_tokens:
310 | outvec += self.get_vector(word)
311 | outvec = outvec / len(m_tokens)
312 | embedded_mentions_batch.append(outvec)
313 | return embedded_mentions_batch
314 |
315 | def pad_batch(self, batch):
316 | if not self.pretrain_wordembed:
317 | pad_unit = self.word2idx[self.unk_word]
318 | else:
319 | pad_unit = self.unk_word
320 |
321 | lengths = [len(i) for i in batch]
322 | max_length = max(lengths)
323 | for i in range(0, len(batch)):
324 | batch[i].extend([pad_unit]*(max_length - lengths[i]))
325 | return (batch, lengths)
326 |
327 | def _next_padded_batch(self):
328 | # (left_batch, right_batch, truewid_descvec_batch,
329 | # labels_batch, coherence_batch,
330 | # wid_idxs_batch, wid_cprobs_batch) = self._next_batch()
331 | (left_batch, right_batch,
332 | labels_batch, coherence_batch,
333 | wid_idxs_batch, wid_cprobs_batch) = self._next_batch()
334 |
335 | (left_batch, left_lengths) = self.pad_batch(left_batch)
336 | (right_batch, right_lengths) = self.pad_batch(right_batch)
337 |
338 | if self.pretrain_wordembed:
339 | left_batch = self.embed_batch(left_batch)
340 | right_batch = self.embed_batch(right_batch)
341 |
342 | return (left_batch, left_lengths, right_batch, right_lengths,
343 | labels_batch, coherence_batch,
344 | wid_idxs_batch, wid_cprobs_batch)
345 |
346 | def convert_word2idx(self, word):
347 | if word in self.word2idx:
348 | return self.word2idx[word]
349 | else:
350 | return self.word2idx[self.unk_word]
351 |
352 | def next_test_batch(self):
353 | return self._next_padded_batch()
354 |
355 |
356 | def debugWIDIdxsBatch(self, wid_idxs_batch):
357 | WikiTitles = []
358 | for widxs in wid_idxs_batch:
359 | wits = [self.wid2WikiTitle[self.idx2knwid[wididx]] for wididx in widxs]
360 | WikiTitles.append(wits)
361 |
362 | return WikiTitles
363 |
364 | def widIdx2WikiTitle(self, widIdx):
365 | wid = self.idx2knwid[widIdx]
366 | wikiTitle = self.wid2WikiTitle[wid]
367 | return wikiTitle
368 |
369 |
370 | if __name__ == '__main__':
371 | sttime = time.time()
372 | batch_size = 1
373 | num_batch = 1000
374 | configpath = "configs/config.ini"
375 | config = Config(configpath, verbose=False)
376 | vocabloader = VocabLoader(config)
377 | b = TestDataReader(config=config,
378 | vocabloader=vocabloader,
379 | test_mens_file=config.test_file,
380 | num_cands=30,
381 | batch_size=batch_size,
382 | strict_context=False,
383 | pretrain_wordembed=False,
384 | coherence=False)
385 |
386 | stime = time.time()
387 |
388 | i = 0
389 | kwn = 0
390 | total_instances = 0
391 | while b.epochs < 1:
392 | (left_batch, left_lengths,
393 | right_batch, right_lengths,
394 | labels_batch, coherence_batch,
395 | wid_idxs_batch, wid_cprobs_batch) = b.next_test_batch()
396 |
397 | print(b.debugWIDIdxsBatch(wid_idxs_batch))
398 | print(wid_cprobs_batch)
399 |
400 | if i % 100 == 0:
401 | etime = time.time()
402 | t=etime-stime
403 | print("{} done. Time taken : {} seconds".format(i, t))
404 | i += 1
405 | etime = time.time()
406 | t=etime-stime
407 | tt = etime - sttime
408 | print("Batching time (in secs) to make %d batches of size %d : %7.4f seconds" % (i, batch_size, t))
409 | print("Total time (in secs) to make %d batches of size %d : %7.4f seconds" % (i, batch_size, tt))
410 |
--------------------------------------------------------------------------------
/readers/textanno_test_reader.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | import numpy as np
4 | import readers.utils as utils
5 | from readers.Mention import Mention
6 | from readers.config import Config
7 | from readers.vocabloader import VocabLoader
8 | import ccg_nlpy
9 | from ccg_nlpy.core.text_annotation import TextAnnotation
10 |
11 | start_word = ""
12 | end_word = ""
13 |
14 | # Reader for Text Annotations
15 | class TextAnnoTestReader(object):
16 | def __init__(self, config, vocabloader,
17 | num_cands, batch_size, strict_context=True,
18 | pretrain_wordembed=True, coherence=True,
19 | nerviewname="NER_CONLL"):
20 | self.typeOfReader = "inference"
21 | self.start_word = start_word
22 | self.end_word = end_word
23 | self.unk_word = 'unk' # In tune with word2vec
24 | self.unk_wid = ""
25 | self.tr_sup = 'tr_sup'
26 | self.tr_unsup = 'tr_unsup'
27 | self.pretrain_wordembed = pretrain_wordembed
28 | self.coherence = coherence
29 | self.nerviewname = nerviewname
30 |
31 | # Word Vocab
32 | (self.word2idx, self.idx2word) = vocabloader.getGloveWordVocab()
33 | self.num_words = len(self.idx2word)
34 |
35 | # Label Vocab
36 | (self.label2idx, self.idx2label) = vocabloader.getLabelVocab()
37 | self.num_labels = len(self.idx2label)
38 |
39 | # Known WID Vocab
40 | (self.knwid2idx, self.idx2knwid) = vocabloader.getKnwnWidVocab()
41 | self.num_knwn_entities = len(self.idx2knwid)
42 |
43 | # Wid2Wikititle Map
44 | self.wid2WikiTitle = vocabloader.getWID2Wikititle()
45 |
46 | # Coherence String Vocab
47 | print("Loading Coherence Strings Dicts ... ")
48 | (self.cohG92idx, self.idx2cohG9) = utils.load(
49 | config.cohstringG9_vocab_pkl)
50 | self.num_cohstr = len(self.idx2cohG9)
51 |
52 | # Crosswikis
53 | print("Loading Crosswikis dict. (takes ~2 mins to load)")
54 | self.crosswikis = utils.load(config.crosswikis_pruned_pkl)
55 | print("Crosswikis loaded. Size: {}".format(len(self.crosswikis)))
56 |
57 | if self.pretrain_wordembed:
58 | stime = time.time()
59 | self.word2vec = vocabloader.loadGloveVectors()
60 | print("[#] Glove Vectors loaded!")
61 | ttime = (time.time() - stime)/float(60)
62 |
63 |
64 | # print("[#] Test Mentions File : {}".format(test_mens_file))
65 |
66 | # print("[#] Loading test file and preprocessing ... ")
67 | # with open(test_mens_file, 'r') as f:
68 | # tajsonstr = f.read()
69 | # ta = TextAnnotation(json_str=tajsonstr)
70 | #
71 | # (sentences_tokenized, modified_ner_cons_list) = self.processTestDoc(ta)
72 | #
73 | # self.mention_lines = self.convertSent2NerToMentionLines(
74 | # sentences_tokenized, modified_ner_cons_list)
75 | #
76 | # self.mentions = []
77 | # for line in self.mention_lines:
78 | # m = Mention(line)
79 | # self.mentions.append(m)
80 |
81 | self.men_idx = 0
82 | # self.num_mens = len(self.mentions)
83 | self.epochs = 0
84 | # print( "[#] Test Mentions : {}".format(self.num_mens))
85 |
86 | self.batch_size = batch_size
87 | print("[#] Batch Size: %d" % self.batch_size)
88 | self.num_cands = num_cands
89 | self.strict_context = strict_context
90 |
91 | print("\n[#]LOADING COMPLETE")
92 | #******************* END __init__ *********************************
93 |
94 | def new_test_file(self, test_mens_file):
95 | self.test_mens_file = test_mens_file
96 |
97 | with open(test_mens_file, 'r') as f:
98 | tajsonstr = f.read()
99 | ta = TextAnnotation(json_str=tajsonstr)
100 | self.textanno = ta
101 |
102 | (sentences_tokenized, modified_ner_cons_list) = self.processTestDoc(ta)
103 |
104 | self.mention_lines = self.convertSent2NerToMentionLines(
105 | sentences_tokenized, modified_ner_cons_list)
106 |
107 | self.mentions = []
108 | for line in self.mention_lines:
109 | m = Mention(line)
110 | self.mentions.append(m)
111 |
112 | self.men_idx = 0
113 | self.num_mens = len(self.mentions)
114 | self.epochs = 0
115 |
116 | def new_tajsonstr(self, tajsonstr):
117 | """ tajsonstr is a json str of a TA """
118 | ta = TextAnnotation(json_str=tajsonstr)
119 | self.new_ta(ta)
120 |
121 | def new_ta(self, ta):
122 | self.textanno = ta
123 |
124 | (sentences_tokenized, modified_ner_cons_list) = self.processTestDoc(ta)
125 |
126 | self.mention_lines = self.convertSent2NerToMentionLines(
127 | sentences_tokenized, modified_ner_cons_list)
128 |
129 | self.mentions = []
130 | for line in self.mention_lines:
131 | m = Mention(line)
132 | self.mentions.append(m)
133 |
134 | self.men_idx = 0
135 | self.num_mens = len(self.mentions)
136 | self.epochs = 0
137 |
138 |
139 |
140 | def get_vector(self, word):
141 | if word in self.word2vec:
142 | return self.word2vec[word]
143 | else:
144 | return self.word2vec['unk']
145 |
146 | def reset_test(self):
147 | self.men_idx = 0
148 | self.epochs = 0
149 |
150 | def processTestDoc(self, ccgdoc):
151 | doc_tokens = ccgdoc.get_tokens
152 | # sent_end_token_indices : contains index for the starting of the
153 | # next sentence.
154 | sent_end_token_indices = \
155 | ccgdoc.get_sentence_end_token_indices
156 | # List of tokenized sentences
157 | sentences_tokenized = []
158 | for i in range(0, len(sent_end_token_indices)):
159 | start = sent_end_token_indices[i-1] if i != 0 else 0
160 | end = sent_end_token_indices[i]
161 | sent_tokens = doc_tokens[start:end]
162 | sentences_tokenized.append(sent_tokens)
163 |
164 | # List of ner dicts from ccg pipeline
165 | ner_cons_list = []
166 | try:
167 | ner_cons_list = ccgdoc.get_view(self.nerviewname).cons_list
168 | except:
169 | print("NO NAMED ENTITIES IN THE DOC. EXITING")
170 |
171 | modified_ner_cons_list = []
172 |
173 | for orig_ner in ner_cons_list:
174 | ner = orig_ner.copy()
175 | # ner['end'] = ner['end'] + 1
176 | # ner['tokens'] = ' '.join(doc_tokens[ner['start']:ner['end']])
177 |
178 | found = False
179 | # idx = sentIdx, j = sentEndTokenIdx
180 | for idx, j in enumerate(sent_end_token_indices):
181 | sent_start_token = sent_end_token_indices[idx-1] \
182 | if idx != 0 else 0
183 | # ner['end'] is the idx of the token after ner
184 | if ner['end'] <= j:
185 | ner['start'] = ner['start'] - sent_start_token
186 | ner['end'] = ner['end'] - sent_start_token - 1
187 | ner['sent_idx'] = idx
188 |
189 | modified_ner_cons_list.append(ner)
190 |
191 | found = True
192 | if found:
193 | break
194 | return (sentences_tokenized, modified_ner_cons_list)
195 |
196 | def convertSent2NerToMentionLines(self, sentences_tokenized,
197 | modified_ner_cons_list):
198 | '''Convert NERs from document to list of mention strings'''
199 | mentions = []
200 | # Make Document Context String for whole document
201 | cohStr = ""
202 | # for sent_idx, s_nerDicts in sentidx2ners.items():
203 | # for s, ner in s_nerDicts:
204 | # cohStr += ner['tokens'].replace(' ', '_') + ' '
205 |
206 | for ner_men in modified_ner_cons_list:
207 | cohStr += ner_men['tokens'].replace(' ', '_') + ' '
208 |
209 | cohStr = cohStr.strip()
210 |
211 | for ner_men in modified_ner_cons_list:
212 | idx = ner_men['sent_idx']
213 | sentence = ' '.join(sentences_tokenized[idx])
214 |
215 | mention = "%s\t%s\t%s" % ("unk_mid", "unk_wid", "unkWT")
216 | mention = mention + '\t' + str(ner_men['start'])
217 | mention = mention + '\t' + str(ner_men['end'])
218 | mention = mention + '\t' + str(ner_men['tokens'])
219 | mention = mention + '\t' + sentence
220 | mention = mention + '\t' + "UNK_TYPES"
221 | mention = mention + '\t' + cohStr
222 | mentions.append(mention)
223 | return mentions
224 |
225 | def bracketMentionInSentence(self, s, nerDict):
226 | tokens = s.split(" ")
227 | start = nerDict['start']
228 | end = nerDict['end']
229 | tokens.insert(start, '[[')
230 | tokens.insert(end + 2, ']]')
231 | return ' '.join(tokens)
232 |
233 | def _read_mention(self):
234 | mention = self.mentions[self.men_idx]
235 | self.men_idx += 1
236 | if self.men_idx == self.num_mens:
237 | self.men_idx = 0
238 | self.epochs += 1
239 | return mention
240 |
241 | def _next_batch(self):
242 | ''' Data : wikititle \t mid \t wid \t start \t end \t tokens \t labels
243 | start and end are inclusive
244 | '''
245 | # Sentence = s1 ... m1 ... mN, ... sN.
246 | # Left Batch = s1 ... m1 ... mN
247 | # Right Batch = sN ... mN ... m1
248 | (left_batch, right_batch) = ([], [])
249 |
250 | coh_indices = []
251 | coh_values = []
252 | if self.coherence:
253 | coh_matshape = [self.batch_size, self.num_cohstr]
254 | else:
255 | coh_matshape = []
256 |
257 | # Candidate WID idxs and their cprobs
258 | # First element is always true wid
259 | (wid_idxs_batch, wid_cprobs_batch) = ([], [])
260 |
261 | while len(left_batch) < self.batch_size:
262 | batch_el = len(left_batch)
263 | m = self._read_mention()
264 |
265 | # for label in m.types:
266 | # if label in self.label2idx:
267 | # labelidx = self.label2idx[label]
268 | # labels_batch[batch_el][labelidx] = 1.0
269 |
270 | cohFound = False # If no coherence mention is found, add unk
271 | if self.coherence:
272 | cohidxs = [] # Indexes in the [B, NumCoh] matrix
273 | cohvals = [] # 1.0 to indicate presence
274 | for cohstr in m.coherence:
275 | if cohstr in self.cohG92idx:
276 | cohidx = self.cohG92idx[cohstr]
277 | cohidxs.append([batch_el, cohidx])
278 | cohvals.append(1.0)
279 | cohFound = True
280 | if cohFound:
281 | coh_indices.extend(cohidxs)
282 | coh_values.extend(cohvals)
283 | else:
284 | cohidx = self.cohG92idx[self.unk_word]
285 | coh_indices.append([batch_el, cohidx])
286 | coh_values.append(1.0)
287 |
288 | # Left and Right context includes mention surface
289 | left_tokens = m.sent_tokens[0:m.end_token+1]
290 | right_tokens = m.sent_tokens[m.start_token:][::-1]
291 |
292 | # Strict left and right context
293 | if self.strict_context:
294 | left_tokens = m.sent_tokens[0:m.start_token]
295 | right_tokens = m.sent_tokens[m.end_token+1:][::-1]
296 | # Left and Right context includes mention surface
297 | else:
298 | left_tokens = m.sent_tokens[0:m.end_token+1]
299 | right_tokens = m.sent_tokens[m.start_token:][::-1]
300 |
301 | if not self.pretrain_wordembed:
302 | left_idxs = [self.convert_word2idx(word)
303 | for word in left_tokens]
304 | right_idxs = [self.convert_word2idx(word)
305 | for word in right_tokens]
306 | else:
307 | left_idxs = left_tokens
308 | right_idxs = right_tokens
309 |
310 | left_batch.append(left_idxs)
311 | right_batch.append(right_idxs)
312 |
313 | # wids : [true_knwn_idx, cand1_idx, cand2_idx, ..., unk_idx]
314 | # wid_cprobs : [cwikis probs or 0.0 for unks]
315 | (wid_idxs, wid_cprobs) = self.make_candidates_cprobs(m)
316 | wid_idxs_batch.append(wid_idxs)
317 | wid_cprobs_batch.append(wid_cprobs)
318 |
319 | coherence_batch = (coh_indices, coh_values, coh_matshape)
320 |
321 | return (left_batch, right_batch,
322 | coherence_batch, wid_idxs_batch, wid_cprobs_batch)
323 |
324 | def print_test_batch(self, mention, wid_idxs, wid_cprobs):
325 | print("Surface : {} WID : {} WT: {}".format(
326 | mention.surface, mention.wid, self.wid2WikiTitle[mention.wid]))
327 | print(mention.wid in self.knwid2idx)
328 | for (idx,cprob) in zip(wid_idxs, wid_cprobs):
329 | print("({} : {:0.5f})".format(
330 | self.wid2WikiTitle[self.idx2knwid[idx]], cprob), end=" ")
331 | print("\n")
332 |
333 | def make_candidates_cprobs(self, m):
334 | # Fill num_cands now
335 | surface = utils._getLnrm(m.surface)
336 | wid_idxs = []
337 | wid_cprobs = []
338 |
339 | # print(surface)
340 | if surface in self.crosswikis:
341 | # Pruned crosswikis has only known wids and 30 cands at max
342 | candwids_cprobs = self.crosswikis[surface][0:self.num_cands-1]
343 | (wids, wid_cprobs) = candwids_cprobs
344 | wid_idxs = [self.knwid2idx[wid] for wid in wids]
345 |
346 | # All possible candidates added now. Pad with unks
347 |
348 | # assert len(wid_idxs) == len(wid_cprobs)
349 | remain = self.num_cands - len(wid_idxs)
350 | wid_idxs.extend([0]*remain)
351 | remain = self.num_cands - len(wid_cprobs)
352 | wid_cprobs.extend([0.0]*remain)
353 |
354 | return (wid_idxs, wid_cprobs)
355 |
356 | def embed_batch(self, batch):
357 | ''' Input is a padded batch of left or right contexts containing words
358 | Dimensions should be [B, padded_length]
359 | Output:
360 | Embed the word idxs using pretrain word embedding
361 | '''
362 | output_batch = []
363 | for sent in batch:
364 | word_embeddings = [self.get_vector(word) for word in sent]
365 | output_batch.append(word_embeddings)
366 | return output_batch
367 |
368 | def embed_mentions_batch(self, mentions_batch):
369 | ''' Input is batch of mention tokens as a list of list of tokens.
370 | Output: For each mention, average word embeddings '''
371 | embedded_mentions_batch = []
372 | for m_tokens in mentions_batch:
373 | outvec = np.zeros(300, dtype=float)
374 | for word in m_tokens:
375 | outvec += self.get_vector(word)
376 | outvec = outvec / len(m_tokens)
377 | embedded_mentions_batch.append(outvec)
378 | return embedded_mentions_batch
379 |
380 | def pad_batch(self, batch):
381 | if not self.pretrain_wordembed:
382 | pad_unit = self.word2idx[self.unk_word]
383 | else:
384 | pad_unit = self.unk_word
385 |
386 | lengths = [len(i) for i in batch]
387 | max_length = max(lengths)
388 | for i in range(0, len(batch)):
389 | batch[i].extend([pad_unit]*(max_length - lengths[i]))
390 | return (batch, lengths)
391 |
392 | def _next_padded_batch(self):
393 | (left_batch, right_batch,
394 | coherence_batch,
395 | wid_idxs_batch, wid_cprobs_batch) = self._next_batch()
396 |
397 | (left_batch, left_lengths) = self.pad_batch(left_batch)
398 | (right_batch, right_lengths) = self.pad_batch(right_batch)
399 |
400 | if self.pretrain_wordembed:
401 | left_batch = self.embed_batch(left_batch)
402 | right_batch = self.embed_batch(right_batch)
403 |
404 | return (left_batch, left_lengths, right_batch, right_lengths,
405 | coherence_batch, wid_idxs_batch, wid_cprobs_batch)
406 |
407 | def convert_word2idx(self, word):
408 | if word in self.word2idx:
409 | return self.word2idx[word]
410 | else:
411 | return self.word2idx[self.unk_word]
412 |
413 | def next_test_batch(self):
414 | return self._next_padded_batch()
415 |
416 | def widIdx2WikiTitle(self, widIdx):
417 | wid = self.idx2knwid[widIdx]
418 | wikiTitle = self.wid2WikiTitle[wid]
419 | return wikiTitle
420 |
421 | if __name__ == '__main__':
422 | sttime = time.time()
423 | batch_size = 2
424 | num_batch = 1000
425 | configpath = "configs/all_mentions_config.ini"
426 | config = Config(configpath, verbose=False)
427 | vocabloader = VocabLoader(config)
428 | b = TextAnnoTestReader(config=config,
429 | vocabloader=vocabloader,
430 | num_cands=30,
431 | batch_size=batch_size,
432 | strict_context=False,
433 | pretrain_wordembed=True,
434 | coherence=True)
435 |
436 | stime = time.time()
437 |
438 | i = 0
439 | total_instances = 0
440 | while b.epochs < 1:
441 | (left_batch, left_lengths, right_batch, right_lengths,
442 | coherence_batch, wid_idxs_batch,
443 | wid_cprobs_batch) = b.next_test_batch()
444 | if i % 100 == 0:
445 | etime = time.time()
446 | t=etime-stime
447 | print("{} done. Time taken : {} seconds".format(i, t))
448 | i += 1
449 | etime = time.time()
450 | t=etime-stime
451 | tt = etime - sttime
452 | print("Total Instances : {}".format(total_instances))
453 | print("Batching time (in secs) to make %d batches of size %d : %7.4f seconds" % (i, batch_size, t))
454 | print("Total time (in secs) to make %d batches of size %d : %7.4f seconds" % (i, batch_size, tt))
455 |
--------------------------------------------------------------------------------
/readers/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import gc
4 | import sys
5 | import math
6 | import time
7 | import pickle
8 | import random
9 | import unicodedata
10 | import collections
11 | import numpy as np
12 |
13 | from readers.Mention import Mention
14 |
15 | def save(fname, obj):
16 | with open(fname, 'wb') as f:
17 | pickle.dump(obj, f)
18 |
19 | def load(fname):
20 | with open(fname, 'rb') as f:
21 | return pickle.load(f)
22 |
23 | def _getLnrm(arg):
24 | """Normalizes the given arg by stripping it of diacritics, lowercasing, and
25 | removing all non-alphanumeric characters.
26 | """
27 | arg = ''.join(
28 | [c for c in unicodedata.normalize('NFD', arg) if unicodedata.category(c) != 'Mn'])
29 | arg = arg.lower()
30 | arg = ''.join(
31 | [c for c in arg if c in set('abcdefghijklmnopqrstuvwxyz0123456789')])
32 | return arg
33 |
34 | def load_crosswikis(crosswikis_pkl):
35 | stime = time.time()
36 | print("[#] Loading normalized crosswikis dictionary ... ")
37 | crosswikis_dict = load(crosswikis_pkl)
38 | ttime = (time.time() - stime)/60.0
39 | print(" [#] Crosswikis dictionary loaded!. Time: {0:2.4f} mins. Size : {1}".format(
40 | ttime, len(crosswikis_dict)))
41 | return crosswikis_dict
42 | #end
43 |
44 | def load_widSet(
45 | widWikititle_file="/save/ngupta19/freebase/types_xiao/wid.WikiTitle"):
46 | print("Loading WIDs in the KB ... ")
47 | wids = set()
48 | with open(widWikititle_file, 'r') as f:
49 | text = f.read()
50 | lines = text.strip().split("\n")
51 | for line in lines:
52 | wids.add(line.split("\t")[0].strip())
53 |
54 | print("Loaded all WIDs : {}".format(len(wids)))
55 | return wids
56 |
57 | def make_mentions_from_file(mens_file, verbose=False):
58 | stime = time.time()
59 | with open(mens_file, 'r') as f:
60 | mention_lines = f.read().strip().split("\n")
61 | mentions = []
62 | for line in mention_lines:
63 | mentions.append(Mention(line))
64 | ttime = (time.time() - stime)
65 | if verbose:
66 | filename = mens_file.split("/")[-1]
67 | print(" ## Time in loading {} mens : {} secs".format(mens_file, ttime))
68 | return mentions
69 |
70 |
71 | def get_mention_files(mentions_dir):
72 | mention_files = []
73 | for (dirpath, dirnames, filenames) in os.walk(mentions_dir):
74 | mention_files.extend(filenames)
75 | break
76 | #endfor
77 | return mention_files
78 |
79 |
80 | def decrSortedDict(dict):
81 | # Returns a list of tuples (key, value) in decreasing order of the values
82 | return sorted(dict.items(), key=lambda kv: kv[1], reverse=True)
83 |
84 | if __name__=="__main__":
85 | measureCrossWikisEntityConverage("/save/ngupta19/crosswikis/crosswikis.normalized.pkl")
86 |
--------------------------------------------------------------------------------
/readers/vocabloader.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import gc
4 | import sys
5 | import time
6 | import math
7 | import pickle
8 | import random
9 | import pprint
10 | import unicodedata
11 | import configparser
12 | import collections
13 | import numpy as np
14 | import readers.utils as utils
15 | from readers.config import Config
16 |
17 | class VocabLoader(object):
18 | def __init__(self, config):
19 | self.initialize_all_dicts()
20 | self.config = config
21 |
22 | def initialize_all_dicts(self):
23 | (self.word2idx, self.idx2word) = (None, None)
24 | (self.label2idx, self.idx2label) = (None, None)
25 | (self.knwid2idx, self.idx2knwid) = (None, None)
26 | self.wid2Wikititle = None
27 | self.wid2TypeLabels = None
28 | (self.test_knwen_cwikis, self.test_allen_cwikis) = (None, None)
29 | self.cwikis_slice = None
30 | self.glove2vec = None
31 | (self.gword2idx, self.gidx2word) = (None, None)
32 | self.crosswikis = None
33 |
34 | def loadCrosswikis(self):
35 | if self.crosswikis == None:
36 | if not os.path.exists(self.config.crosswikis_pkl):
37 | print("Crosswikis pkl missing")
38 | sys.exit()
39 | self.crosswikis = utils.load(self.config.crosswikis_pkl)
40 | return self.crosswikis
41 |
42 | def getWordVocab(self):
43 | if self.word2idx == None or self.idx2word == None:
44 | if not os.path.exists(self.config.word_vocab_pkl):
45 | print("Word Vocab PKL missing")
46 | sys.exit()
47 | print("Loading Word Vocabulary")
48 | (self.word2idx, self.idx2word) = utils.load(self.config.word_vocab_pkl)
49 | return (self.word2idx, self.idx2word)
50 |
51 | def getLabelVocab(self):
52 | if self.label2idx == None or self.idx2label == None:
53 | if not os.path.exists(self.config.label_vocab_pkl):
54 | print("Label Vocab PKL missing")
55 | sys.exit()
56 | print("Loading Type Label Vocabulary")
57 | (self.label2idx, self.idx2label) = utils.load(self.config.label_vocab_pkl)
58 | return (self.label2idx, self.idx2label)
59 |
60 | def getKnwnWidVocab(self):
61 | if self.knwid2idx == None or self.idx2knwid == None:
62 | if not os.path.exists(self.config.kwnwid_vocab_pkl):
63 | print("Known Entities Vocab PKL missing")
64 | sys.exit()
65 | print("Loading Known Entity Vocabulary ... ")
66 | (self.knwid2idx, self.idx2knwid) = utils.load(self.config.kwnwid_vocab_pkl)
67 | return (self.knwid2idx, self.idx2knwid)
68 |
69 | def getTestKnwEnCwiki(self):
70 | if self.test_knwen_cwikis == None:
71 | if not os.path.exists(self.config.test_kwnen_cwikis_pkl):
72 | print("Test Known Entity CWikis Dict missing")
73 | sys.exit()
74 | print("Loading Test Data Known Entity CWIKI")
75 | self.test_knwen_cwikis = utils.load(self.config.test_kwnen_cwikis_pkl)
76 | return self.test_knwen_cwikis
77 |
78 | def getTestAllEnCwiki(self):
79 | if self.test_allen_cwikis == None:
80 | if not os.path.exists(self.config.test_allen_cwikis_pkl):
81 | print("Test All Entity CWikis Dict missing")
82 | sys.exit()
83 | print("Loading Test Data All Entity CWIKI")
84 | self.test_allen_cwikis = utils.load(self.config.test_allen_cwikis_pkl)
85 | return self.test_allen_cwikis
86 |
87 | def getCrosswikisSlice(self):
88 | if self.cwikis_slice == None:
89 | if not os.path.exists(self.config.crosswikis_slice):
90 | print("CWikis Slice Dict missing")
91 | sys.exit()
92 | print("Loading CWIKI Slice")
93 | self.cwikis_slice = utils.load(self.config.crosswikis_slice)
94 | return self.cwikis_slice
95 |
96 | def getWID2Wikititle(self):
97 | if self.wid2Wikititle == None:
98 | if not os.path.exists(self.config.widWiktitle_pkl):
99 | print("wid2Wikititle pkl missing")
100 | sys.exit()
101 | print("Loading wid2Wikititle")
102 | self.wid2Wikititle = utils.load(self.config.widWiktitle_pkl)
103 | return self.wid2Wikititle
104 |
105 | def getWID2TypeLabels(self):
106 | if self.wid2TypeLabels == None:
107 | if not os.path.exists(self.config.wid2typelabels_vocab_pkl):
108 | print("wid2TypeLabels pkl missing")
109 | sys.exit()
110 | print("Loading wid2TypeLabels")
111 | self.wid2TypeLabels = utils.load(self.config.wid2typelabels_vocab_pkl)
112 | return self.wid2TypeLabels
113 |
114 | def loadGloveVectors(self):
115 | if self.glove2vec == None:
116 | if not os.path.exists(self.config.glove_pkl):
117 | print("Glove_Vectors_PKL doesnot exist")
118 | sys.exit()
119 | print("Loading Glove Word Vectors")
120 | self.glove2vec = utils.load(self.config.glove_pkl)
121 | return self.glove2vec
122 |
123 | def getGloveWordVocab(self):
124 | if self.gword2idx == None or self.gidx2word == None:
125 | if not os.path.exists(self.config.glove_word_vocab_pkl):
126 | print("Glove Word Vocab PKL missing")
127 | sys.exit()
128 | print("Loading Glove Word Vocabulary")
129 | (self.gword2idx, self.gidx2word) = utils.load(self.config.glove_word_vocab_pkl)
130 | return (self.gword2idx, self.gidx2word)
131 |
132 | if __name__=='__main__':
133 | config = Config("configs/wcoh_config.ini")
134 | a = VocabLoader(config)
135 | a.loadWord2Vec()
136 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Requirements automatically generated by pigar.
2 | # https://github.com/Damnever/pigar
3 |
4 | # readers/inference_reader.py: 8
5 | ccg_nlpy == 0.11.0
6 |
7 | # evaluation/evaluate_inference.py: 3
8 | # evaluation/evaluate_types.py: 3
9 | # models/figer_model/coherence_model.py: 2
10 | # models/figer_model/coldStart.py: 4
11 | # models/figer_model/context_encoder.py: 3
12 | # models/figer_model/el_model.py: 3
13 | # models/figer_model/entity_posterior.py: 3
14 | # models/figer_model/joint_context.py: 3
15 | # models/figer_model/labeling_model.py: 3
16 | # models/figer_model/loss_optim.py: 2
17 | # models/figer_model/wiki_desc.py: 2
18 | # neuralel.py: 4
19 | # readers/inference_reader.py: 3
20 | # readers/utils.py: 11
21 | # readers/vocabloader.py: 13
22 | numpy == 1.11.3
23 |
24 | # models/base.py: 3
25 | # models/batch_normalizer.py: 1,3
26 | # models/figer_model/coherence_model.py: 3
27 | # models/figer_model/coldStart.py: 3
28 | # models/figer_model/context_encoder.py: 2
29 | # models/figer_model/el_model.py: 2
30 | # models/figer_model/entity_posterior.py: 2
31 | # models/figer_model/joint_context.py: 2
32 | # models/figer_model/labeling_model.py: 2
33 | # models/figer_model/loss_optim.py: 3
34 | # models/figer_model/wiki_desc.py: 3
35 | # neuralel.py: 5
36 | tensorflow == 0.12.1
37 |
--------------------------------------------------------------------------------