├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── embeddings.csv.zip
└── python
├── .gitignore
├── __init__.py
├── bin
└── rtc_eval.sh
├── data
└── valid_rels.txt
└── eukg
├── Config.py
├── __init__.py
├── data
├── DataGenerator.py
├── __init__.py
├── create_semnet_triples.py
├── create_test_set.py
├── create_triples.py
└── data_util.py
├── emb
├── EmbeddingModel.py
├── Smoothing.py
└── __init__.py
├── gan
├── Discriminator.py
├── Generator.py
├── Generator.pyc
├── __init__.py
├── __init__.pyc
├── train_gan.py
└── train_gan.pyc
├── save_embeddings.py
├── test
├── __init__.py
├── classification.py
├── nearest_neighbors.py
├── ppa.py
└── ranking_evals.py
├── tf_util
├── ModelSaver.py
├── Trainable.py
├── Trainer.py
└── __init__.py
├── threading_util.py
└── train.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | embeddings.csv.zip filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ####################
2 | # Custom Rules
3 | ####################
4 | output/
5 | out/
6 | input/
7 | tmp/
8 | temp/
9 | work/
10 | working/
11 | /.tgitconfig
12 | *~dev
13 | **/.idea/
14 | #
15 | ####################
16 | # Logging
17 | ####################
18 | # logs
19 | *.errlog
20 | *.log
21 | *.log.zip
22 | log*.txt
23 | log.txt
24 | log/
25 | logs/
26 | #
27 | ####################
28 | # Kirkness
29 | ####################
30 | # kirkness
31 | *.svm_annotations
32 | *.svm_model
33 | *.svm_model.maps
34 | #
35 | ####################
36 | # Java
37 | ####################
38 | *.class
39 | #
40 | # Mobile Tools for Java (J2ME)
41 | .mtj.tmp/
42 | #
43 | # Package Files #
44 | *.jar
45 | *.war
46 | *.ear
47 | #
48 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
49 | hs_err_pid*
50 | #
51 | ####################
52 | # Scala
53 | ####################
54 | *.class
55 | *.log
56 | #
57 | # sbt specific
58 | .cache
59 | .history
60 | .lib/
61 | dist/*
62 | target/
63 | lib_managed/
64 | src_managed/
65 | **/project/boot/
66 | **/project/plugins/project/
67 | #
68 | # Scala-IDE specific
69 | .scala_dependencies
70 | .worksheet
71 | #
72 | # ENSIME specific
73 | .ensime_cache/
74 | .ensime
75 | #
76 | ####################
77 | # SBT
78 | ####################
79 | # Simple Build Tool
80 | # http://www.scala-sbt.org/release/docs/Getting-Started/Directories.html#configuring-version-control
81 | #
82 | target/
83 | lib_managed/
84 | src_managed/
85 | project/boot/
86 | .history
87 | .cache
88 | #
89 | ####################
90 | # Python
91 | ####################
92 | # Byte-compiled / optimized / DLL files
93 | __pycache__/
94 | *.py[cod]
95 | *$py.class
96 | #
97 | # C extensions
98 | *.so
99 | #
100 | # Distribution / packaging
101 | .Python
102 | env/
103 | build/
104 | develop-eggs/
105 | dist/
106 | downloads/
107 | eggs/
108 | .eggs/
109 | #lib/
110 | #lib64/
111 | parts/
112 | sdist/
113 | var/
114 | *.egg-info/
115 | .installed.cfg
116 | *.egg
117 | #
118 | # PyInstaller
119 | # Usually these files are written by a python script from a template
120 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
121 | *.manifest
122 | *.spec
123 | #
124 | # Installer logs
125 | pip-log.txt
126 | pip-delete-this-directory.txt
127 | #
128 | # Unit test / coverage reports
129 | htmlcov/
130 | .tox/
131 | .coverage
132 | .coverage.*
133 | .cache
134 | nosetests.xml
135 | coverage.xml
136 | *,cover
137 | .hypothesis/
138 | #
139 | # Translations
140 | *.mo
141 | *.pot
142 | #
143 | # Django stuff:
144 | *.log
145 | local_settings.py
146 | #
147 | # Flask stuff:
148 | instance/
149 | .webassets-cache
150 | #
151 | # Scrapy stuff:
152 | .scrapy
153 | #
154 | # Sphinx documentation
155 | **/docs/_build/
156 | #
157 | # PyBuilder
158 | target/
159 | #
160 | # Jupyter Notebook
161 | .ipynb_checkpoints
162 | #
163 | # pyenv
164 | .python-version
165 | #
166 | # celery beat schedule file
167 | celerybeat-schedule
168 | #
169 | # dotenv
170 | .env
171 | #
172 | # virtualenv
173 | .venv/
174 | venv/
175 | ENV/
176 | #
177 | # Spyder project settings
178 | .spyderproject
179 | #
180 | # Rope project settings
181 | .ropeproject
182 | #
183 | ####################
184 | # Windows
185 | ####################
186 | # Windows image file caches
187 | Thumbs.db
188 | ehthumbs.db
189 | #
190 | # Folder config file
191 | Desktop.ini
192 | #
193 | # Recycle Bin used on file shares
194 | $RECYCLE.BIN/
195 | #
196 | # Windows Installer files
197 | *.cab
198 | *.msi
199 | *.msm
200 | *.msp
201 | #
202 | # Windows shortcuts
203 | *.lnk
204 | #
205 | ####################
206 | # OSX
207 | ####################
208 | .DS_Store
209 | .AppleDouble
210 | .LSOverride
211 | #
212 | # Icon must end with two \r
213 | Icon
214 | #
215 | #
216 | # Thumbnails
217 | ._*
218 | #
219 | # Files that might appear in the root of a volume
220 | .DocumentRevisions-V100
221 | .fseventsd
222 | .Spotlight-V100
223 | .TemporaryItems
224 | .Trashes
225 | .VolumeIcon.icns
226 | #
227 | # Directories potentially created on remote AFP share
228 | .AppleDB
229 | .AppleDesktop
230 | Network Trash Folder
231 | Temporary Items
232 | .apdisk
233 | #
234 | ####################
235 | # Linux
236 | ####################
237 | *~
238 | #
239 | # temporary files which can be created if a process still has a handle open of a deleted file
240 | .fuse_hidden*
241 | #
242 | # KDE directory preferences
243 | .directory
244 | #
245 | # Linux trash folder which might appear on any partition or disk
246 | .Trash-*
247 | #
248 | # .nfs files are created when an open file is removed but is still being accessed
249 | .nfs*
250 | #
251 | ####################
252 | # Sublime Text
253 | ####################
254 | # cache files for sublime text
255 | *.tmlanguage.cache
256 | *.tmPreferences.cache
257 | *.stTheme.cache
258 | #
259 | # workspace files are user-specific
260 | *.sublime-workspace
261 | #
262 | # project files should be checked into the repository, unless a significant
263 | # proportion of contributors will probably not be using SublimeText
264 | # *.sublime-project
265 | #
266 | # sftp configuration file
267 | sftp-config.json
268 | #
269 | # Package control specific files
270 | Package Control.last-run
271 | Package Control.ca-list
272 | Package Control.ca-bundle
273 | Package Control.system-ca-bundle
274 | Package Control.cache/
275 | Package Control.ca-certs/
276 | bh_unicode_properties.cache
277 | #
278 | # Sublime-github package stores a github token in this file
279 | # https://packagecontrol.io/packages/sublime-github
280 | GitHub.sublime-settings
281 | #
282 | ####################
283 | # Vim
284 | ####################
285 | # swap
286 | [._]*.s[a-w][a-z]
287 | [._]s[a-w][a-z]
288 | # session
289 | Session.vim
290 | # temporary
291 | .netrwhist
292 | *~
293 | # auto-generated tag files
294 | tags
295 | #
296 | ####################
297 | # Matlab
298 | ####################
299 | ##---------------------------------------------------
300 | ## Remove autosaves generated by the Matlab editor
301 | ## We have git for backups!
302 | ##---------------------------------------------------
303 | #
304 | # Windows default autosave extension
305 | *.asv
306 | #
307 | # OSX / *nix default autosave extension
308 | *.m~
309 | #
310 | # Compiled MEX binaries (all platforms)
311 | *.mex*
312 | #
313 | # Simulink Code Generation
314 | slprj/
315 | #
316 | # Session info
317 | octave-workspace
318 | #
319 |
320 | ####################
321 | # JetBrainzzzzz
322 | ####################
323 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
324 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
325 | #
326 | # User-specific stuff:
327 | .idea
328 | **/.idea/workspace.xml
329 | **/.idea/tasks.xml
330 | #
331 | # Sensitive or high-churn files:
332 | **/.idea/dataSources.ids
333 | **/.idea/dataSources.xml
334 | **/.idea/dataSources.local.xml
335 | **/.idea/sqlDataSources.xml
336 | **/.idea/dynamic.xml
337 | **/.idea/uiDesigner.xml
338 | #
339 | # Gradle:
340 | **/.idea/gradle.xml
341 | **/.idea/libraries
342 | # SBT
343 | **/.idea/libraries/
344 | #
345 | # Mongo Explorer plugin:
346 | **/.idea/mongoSettings.xml
347 | #
348 | ## File-based project format:
349 | *.iws
350 | *.iml
351 | #
352 | ## Plugin-specific files:
353 | #
354 | # IntelliJ
355 | /out/
356 | #
357 | # mpeltonen/sbt-idea plugin
358 | .idea_modules/
359 | #
360 | # JIRA plugin
361 | atlassian-ide-plugin.xml
362 | #
363 | # Crashlytics plugin (for Android Studio and IntelliJ)
364 | com_crashlytics_export_strings.xml
365 | crashlytics.properties
366 | crashlytics-build.properties
367 | fabric.properties
368 | #
369 | ####################
370 | # iPython Notebooks
371 | ####################
372 | # Temporary data
373 | .ipynb_checkpoints/
374 | #
375 | ####################
376 | # Eclipse
377 | ####################
378 | *.pydevproject
379 | .metadata
380 | .gradle
381 | tmp/
382 | *.tmp
383 | *.bak
384 | *.swp
385 | *~.nib
386 | local.properties
387 | .settings/
388 | .loadpath
389 | #
390 | # Eclipse Core
391 | .project
392 | #
393 | # External tool builders
394 | .externalToolBuilders/
395 | #
396 | # Locally stored "Eclipse launch configurations"
397 | *.launch
398 | #
399 | # CDT-specific
400 | .cproject
401 | #
402 | # JDT-specific (Eclipse Java Development Tools)
403 | .classpath
404 | #
405 | # Java annotation processor (APT)
406 | .factorypath
407 | #
408 | # PDT-specific
409 | .buildpath
410 | #
411 | # sbteclipse plugin
412 | .target
413 | #
414 | # TeXlipse plugin
415 | .texlipse
416 | #
417 | ####################
418 | # NetBeans
419 | ####################
420 | **/nbproject/private/
421 | build/
422 | nbbuild/
423 | dist/
424 | nbdist/
425 | nbactions.xml
426 | nb-configuration.xml
427 | .nb-gradle/
428 | #
429 | ####################
430 | # btsync
431 | ####################
432 | .sync/
433 | .sync
434 | #
435 | ####################
436 | # Maven
437 | ####################
438 | target/
439 | pom.xml.tag
440 | pom.xml.releaseBackup
441 | pom.xml.versionsBackup
442 | pom.xml.next
443 | release.properties
444 | dependency-reduced-pom.xml
445 | buildNumber.properties
446 | **/.mvn/timing.properties
447 | #
448 | ####################
449 | # Archives
450 | ####################
451 | # It's better to unpack these files and commit the raw source because
452 | # git has its own built in compression methods.
453 | *.7z
454 | *.jar
455 | *.rar
456 | *.zip
457 | *.gz
458 | *.bzip
459 | *.bz2
460 | *.xz
461 | *.lzma
462 | *.cab
463 | #
464 | #packing-only formats
465 | *.iso
466 | *.tar
467 | #
468 | #package management formats
469 | *.dmg
470 | *.xpi
471 | *.gem
472 | *.egg
473 | *.deb
474 | *.rpm
475 | *.msi
476 | *.msm
477 | *.msp
478 | #
479 | ####################
480 | # Dropbox
481 | ####################
482 | # Dropbox settings and caches
483 | .dropbox
484 | .dropbox.attr
485 | .dropbox.cache
486 | #
487 | ####################
488 | # Unison
489 | ####################
490 | # unison
491 | .unison*
492 | #
493 | ####################
494 | # SVN
495 | ####################
496 | .svn/
497 | #
498 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Code for the paper [Adversarial Learning of Knowledge Embeddings for the Unified Medical Language System](http://www.hlt.utdallas.edu/~ramon/papers/amia_cri_2019.pdf) to be presented at the AMIA Informatics Summit 2019.
2 |
3 | Requires Tensorflow version 1.9
4 |
5 | #### Data Preprocessing:
6 | 1. First download/extract the [UMLS](https://www.nlm.nih.gov/research/umls/licensedcontent/umlsknowledgesources.html). This project assumes the UMLS files are laid out as such:
7 | ```
8 | /
9 | META/
10 | MRCONSO.RRF
11 | MRSTY.RRF
12 | NET/
13 | SRSTRE1
14 | ```
15 | 2. Create Metathesaurus triples
16 |
17 | ```bash
18 | python -m eukg.data.create_triples
19 | ```
20 | This will create the Metathesaurus train/test triples in data/metathesaurus.
21 | 3. Create Semantic Network Triples
22 | ```bash
23 | python -m eukg.data.create_triples
24 | ```
25 |
26 | #### Training:
27 | To train the Metathesaurus Discriminator:
28 | ```bash
29 | python -m eukg.train --mode=disc --model=transd --run_name=transd-disc --no_semantic_network
30 | ```
31 | To train the both Metathesaurus and Semantic Network Discriminators:
32 | ```bash
33 | python -m eukg.train --mode=disc --model=transd --run_name=transd-sn-disc
34 | ```
35 | To train the Metathesaurus Generator:
36 | ```bash
37 | python -m eukg.train --mode=gen --model=distmult --run_name=dm-gen --no_semantic_network --learning_rate=1e-3
38 | ```
39 | To train the Metathesaurus and Semantic Network Generators:
40 | ```bash
41 | python -m eukg.train --mode=gen --model=distmult --run_name=dm-sn-gen --learning_rate=1e-3
42 | ```
43 | To train the full GAN model:
44 | ```bash
45 | python -m eukg.train --mode=gan --model=transd --run_name=gan --dis_run_name=transd-sn-disc --gen_run_name=dm-sn-gen
46 | ```
47 | Note that the GAN model requires a pretrained discriminator and generator
48 |
--------------------------------------------------------------------------------
/embeddings.csv.zip:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:40bf01ac241199bb26f3163fa0ecf86aebaba9c526a8fe363dbe1b8cae026da2
3 | size 823572063
4 |
--------------------------------------------------------------------------------
/python/.gitignore:
--------------------------------------------------------------------------------
1 | ####################
2 | # Custom Rules
3 | ####################
4 | output/
5 | out/
6 | input/
7 | tmp/
8 | temp/
9 | work/
10 | working/
11 | /.tgitconfig
12 | *~dev
13 | **/.idea/
14 | #
15 | ####################
16 | # Logging
17 | ####################
18 | # logs
19 | *.errlog
20 | *.log
21 | *.log.zip
22 | log*.txt
23 | log.txt
24 | log/
25 | logs/
26 | #
27 | ####################
28 | # Kirkness
29 | ####################
30 | # kirkness
31 | *.svm_annotations
32 | *.svm_model
33 | *.svm_model.maps
34 | #
35 | ####################
36 | # Java
37 | ####################
38 | *.class
39 | #
40 | # Mobile Tools for Java (J2ME)
41 | .mtj.tmp/
42 | #
43 | # Package Files #
44 | *.jar
45 | *.war
46 | *.ear
47 | #
48 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
49 | hs_err_pid*
50 | #
51 | ####################
52 | # Scala
53 | ####################
54 | *.class
55 | *.log
56 | #
57 | # sbt specific
58 | .cache
59 | .history
60 | .lib/
61 | dist/*
62 | target/
63 | lib_managed/
64 | src_managed/
65 | **/project/boot/
66 | **/project/plugins/project/
67 | #
68 | # Scala-IDE specific
69 | .scala_dependencies
70 | .worksheet
71 | #
72 | # ENSIME specific
73 | .ensime_cache/
74 | .ensime
75 | #
76 | ####################
77 | # SBT
78 | ####################
79 | # Simple Build Tool
80 | # http://www.scala-sbt.org/release/docs/Getting-Started/Directories.html#configuring-version-control
81 | #
82 | target/
83 | lib_managed/
84 | src_managed/
85 | project/boot/
86 | .history
87 | .cache
88 | #
89 | ####################
90 | # Python
91 | ####################
92 | # Byte-compiled / optimized / DLL files
93 | __pycache__/
94 | *.py[cod]
95 | *$py.class
96 | #
97 | # C extensions
98 | *.so
99 | #
100 | # Distribution / packaging
101 | .Python
102 | env/
103 | build/
104 | develop-eggs/
105 | dist/
106 | downloads/
107 | eggs/
108 | .eggs/
109 | #lib/
110 | #lib64/
111 | parts/
112 | sdist/
113 | var/
114 | *.egg-info/
115 | .installed.cfg
116 | *.egg
117 | #
118 | # PyInstaller
119 | # Usually these files are written by a python script from a template
120 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
121 | *.manifest
122 | *.spec
123 | #
124 | # Installer logs
125 | pip-log.txt
126 | pip-delete-this-directory.txt
127 | #
128 | # Unit test / coverage reports
129 | htmlcov/
130 | .tox/
131 | .coverage
132 | .coverage.*
133 | .cache
134 | nosetests.xml
135 | coverage.xml
136 | *,cover
137 | .hypothesis/
138 | #
139 | # Translations
140 | *.mo
141 | *.pot
142 | #
143 | # Django stuff:
144 | *.log
145 | local_settings.py
146 | #
147 | # Flask stuff:
148 | instance/
149 | .webassets-cache
150 | #
151 | # Scrapy stuff:
152 | .scrapy
153 | #
154 | # Sphinx documentation
155 | **/docs/_build/
156 | #
157 | # PyBuilder
158 | target/
159 | #
160 | # Jupyter Notebook
161 | .ipynb_checkpoints
162 | #
163 | # pyenv
164 | .python-version
165 | #
166 | # celery beat schedule file
167 | celerybeat-schedule
168 | #
169 | # dotenv
170 | .env
171 | #
172 | # virtualenv
173 | .venv/
174 | venv/
175 | ENV/
176 | #
177 | # Spyder project settings
178 | .spyderproject
179 | #
180 | # Rope project settings
181 | .ropeproject
182 | #
183 | ####################
184 | # Windows
185 | ####################
186 | # Windows image file caches
187 | Thumbs.db
188 | ehthumbs.db
189 | #
190 | # Folder config file
191 | Desktop.ini
192 | #
193 | # Recycle Bin used on file shares
194 | $RECYCLE.BIN/
195 | #
196 | # Windows Installer files
197 | *.cab
198 | *.msi
199 | *.msm
200 | *.msp
201 | #
202 | # Windows shortcuts
203 | *.lnk
204 | #
205 | ####################
206 | # OSX
207 | ####################
208 | .DS_Store
209 | .AppleDouble
210 | .LSOverride
211 | #
212 | # Icon must end with two \r
213 | Icon
214 | #
215 | #
216 | # Thumbnails
217 | ._*
218 | #
219 | # Files that might appear in the root of a volume
220 | .DocumentRevisions-V100
221 | .fseventsd
222 | .Spotlight-V100
223 | .TemporaryItems
224 | .Trashes
225 | .VolumeIcon.icns
226 | #
227 | # Directories potentially created on remote AFP share
228 | .AppleDB
229 | .AppleDesktop
230 | Network Trash Folder
231 | Temporary Items
232 | .apdisk
233 | #
234 | ####################
235 | # Linux
236 | ####################
237 | *~
238 | #
239 | # temporary files which can be created if a process still has a handle open of a deleted file
240 | .fuse_hidden*
241 | #
242 | # KDE directory preferences
243 | .directory
244 | #
245 | # Linux trash folder which might appear on any partition or disk
246 | .Trash-*
247 | #
248 | # .nfs files are created when an open file is removed but is still being accessed
249 | .nfs*
250 | #
251 | ####################
252 | # Sublime Text
253 | ####################
254 | # cache files for sublime text
255 | *.tmlanguage.cache
256 | *.tmPreferences.cache
257 | *.stTheme.cache
258 | #
259 | # workspace files are user-specific
260 | *.sublime-workspace
261 | #
262 | # project files should be checked into the repository, unless a significant
263 | # proportion of contributors will probably not be using SublimeText
264 | # *.sublime-project
265 | #
266 | # sftp configuration file
267 | sftp-config.json
268 | #
269 | # Package control specific files
270 | Package Control.last-run
271 | Package Control.ca-list
272 | Package Control.ca-bundle
273 | Package Control.system-ca-bundle
274 | Package Control.cache/
275 | Package Control.ca-certs/
276 | bh_unicode_properties.cache
277 | #
278 | # Sublime-github package stores a github token in this file
279 | # https://packagecontrol.io/packages/sublime-github
280 | GitHub.sublime-settings
281 | #
282 | ####################
283 | # Vim
284 | ####################
285 | # swap
286 | [._]*.s[a-w][a-z]
287 | [._]s[a-w][a-z]
288 | # session
289 | Session.vim
290 | # temporary
291 | .netrwhist
292 | *~
293 | # auto-generated tag files
294 | tags
295 | #
296 | ####################
297 | # Matlab
298 | ####################
299 | ##---------------------------------------------------
300 | ## Remove autosaves generated by the Matlab editor
301 | ## We have git for backups!
302 | ##---------------------------------------------------
303 | #
304 | # Windows default autosave extension
305 | *.asv
306 | #
307 | # OSX / *nix default autosave extension
308 | *.m~
309 | #
310 | # Compiled MEX binaries (all platforms)
311 | *.mex*
312 | #
313 | # Simulink Code Generation
314 | slprj/
315 | #
316 | # Session info
317 | octave-workspace
318 | #
319 |
320 | ####################
321 | # JetBrainzzzzz
322 | ####################
323 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
324 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
325 | #
326 | # User-specific stuff:
327 | .idea
328 | **/.idea/workspace.xml
329 | **/.idea/tasks.xml
330 | #
331 | # Sensitive or high-churn files:
332 | **/.idea/dataSources.ids
333 | **/.idea/dataSources.xml
334 | **/.idea/dataSources.local.xml
335 | **/.idea/sqlDataSources.xml
336 | **/.idea/dynamic.xml
337 | **/.idea/uiDesigner.xml
338 | #
339 | # Gradle:
340 | **/.idea/gradle.xml
341 | **/.idea/libraries
342 | # SBT
343 | **/.idea/libraries/
344 | #
345 | # Mongo Explorer plugin:
346 | **/.idea/mongoSettings.xml
347 | #
348 | ## File-based project format:
349 | *.iws
350 | *.iml
351 | #
352 | ## Plugin-specific files:
353 | #
354 | # IntelliJ
355 | /out/
356 | #
357 | # mpeltonen/sbt-idea plugin
358 | .idea_modules/
359 | #
360 | # JIRA plugin
361 | atlassian-ide-plugin.xml
362 | #
363 | # Crashlytics plugin (for Android Studio and IntelliJ)
364 | com_crashlytics_export_strings.xml
365 | crashlytics.properties
366 | crashlytics-build.properties
367 | fabric.properties
368 | #
369 | ####################
370 | # iPython Notebooks
371 | ####################
372 | # Temporary data
373 | .ipynb_checkpoints/
374 | #
375 | ####################
376 | # Eclipse
377 | ####################
378 | *.pydevproject
379 | .metadata
380 | .gradle
381 | tmp/
382 | *.tmp
383 | *.bak
384 | *.swp
385 | *~.nib
386 | local.properties
387 | .settings/
388 | .loadpath
389 | #
390 | # Eclipse Core
391 | .project
392 | #
393 | # External tool builders
394 | .externalToolBuilders/
395 | #
396 | # Locally stored "Eclipse launch configurations"
397 | *.launch
398 | #
399 | # CDT-specific
400 | .cproject
401 | #
402 | # JDT-specific (Eclipse Java Development Tools)
403 | .classpath
404 | #
405 | # Java annotation processor (APT)
406 | .factorypath
407 | #
408 | # PDT-specific
409 | .buildpath
410 | #
411 | # sbteclipse plugin
412 | .target
413 | #
414 | # TeXlipse plugin
415 | .texlipse
416 | #
417 | ####################
418 | # NetBeans
419 | ####################
420 | **/nbproject/private/
421 | build/
422 | nbbuild/
423 | dist/
424 | nbdist/
425 | nbactions.xml
426 | nb-configuration.xml
427 | .nb-gradle/
428 | #
429 | ####################
430 | # btsync
431 | ####################
432 | .sync/
433 | .sync
434 | #
435 | ####################
436 | # Maven
437 | ####################
438 | target/
439 | pom.xml.tag
440 | pom.xml.releaseBackup
441 | pom.xml.versionsBackup
442 | pom.xml.next
443 | release.properties
444 | dependency-reduced-pom.xml
445 | buildNumber.properties
446 | **/.mvn/timing.properties
447 | #
448 | ####################
449 | # Archives
450 | ####################
451 | # It's better to unpack these files and commit the raw source because
452 | # git has its own built in compression methods.
453 | *.7z
454 | *.jar
455 | *.rar
456 | *.zip
457 | *.gz
458 | *.bzip
459 | *.bz2
460 | *.xz
461 | *.lzma
462 | *.cab
463 | #
464 | #packing-only formats
465 | *.iso
466 | *.tar
467 | #
468 | #package management formats
469 | *.dmg
470 | *.xpi
471 | *.gem
472 | *.egg
473 | *.deb
474 | *.rpm
475 | *.msi
476 | *.msm
477 | *.msp
478 | #
479 | ####################
480 | # Dropbox
481 | ####################
482 | # Dropbox settings and caches
483 | .dropbox
484 | .dropbox.attr
485 | .dropbox.cache
486 | #
487 | ####################
488 | # Unison
489 | ####################
490 | # unison
491 | .unison*
492 | #
493 | ####################
494 | # SVN
495 | ####################
496 | .svn/
497 | #
498 |
--------------------------------------------------------------------------------
/python/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '/home/rmm120030/code/python')
3 |
--------------------------------------------------------------------------------
/python/bin/rtc_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # transe
3 | pytf -m eukg.test.classification --mode=disc --model=transe --run_name=transe-disc
4 | # transd
5 | pytf -m eukg.test.classification --mode=disc --model=transd --run_name=transd-disc_e0init --embedding_size=100
6 | # transe-sn
7 | pytf -m eukg.test.classification --mode=disc --model=transe --run_name=transe-sn-disc
8 | # transd-sn
9 | pytf -m eukg.test.classification --mode=disc --model=transd --run_name=transd-sn-disc_e0init --embedding_size=100
10 | # transe-sn gan
11 | pytf -m eukg.test.classification --mode=gan --model=transe --run_name=sn-gan --dis_run_name=transe-sn-disc
12 | # transd-sn gan
13 | pytf -m eukg.test.classification --mode=gan --model=transd --run_name=gan --dis_run_name=transd-sn-disc_e0init --embedding_size=100
14 | ################ SN ##################
15 | # transe-sn
16 | pytf -m eukg.test.classification --mode=disc --model=transe --run_name=transe-sn-disc --sn_eval --eval_mode=sn
17 | # transd-sn
18 | pytf -m eukg.test.classification --mode=disc --model=transd --run_name=transd-sn-disc_e0init --embedding_size=100 --sn_eval --eval_mode=sn
19 | # transe-sn gan
20 | pytf -m eukg.test.classification --mode=gan --model=transe --run_name=sn-gan --dis_run_name=transe-sn-disc --sn_eval --eval_mode=sn
21 | # transd-sn gan
22 | pytf -m eukg.test.classification --mode=gan --model=transd --run_name=gan --dis_run_name=transd-sn-disc_e0init --embedding_size=100 --sn_eval --eval_mode=sn
23 |
--------------------------------------------------------------------------------
/python/data/valid_rels.txt:
--------------------------------------------------------------------------------
1 | has_tradename
2 | has_ingredients
3 | organism_has_gene
4 | has_presentation_strength_numerator_value
5 | replaces
6 | chemical_or_drug_affects_gene_product
7 | has_supersystem
8 | disease_excludes_normal_tissue_origin
9 | has_specialty
10 | has_branch
11 | possibly_equivalent_to
12 | sib_in_isa
13 | isa
14 | procedure_site_of
15 | has_time_aspect
16 | component_of
17 | regimen_has_accepted_use_for_disease
18 | add_on_code_for
19 | has_basic_dose_form
20 | has_quotient
21 | partially_excised_anatomy_has_procedure
22 | disease_has_cytogenetic_abnormality
23 | direct_device_of
24 | disease_excludes_normal_cell_origin
25 | has_at_risk_population
26 | has_subject_of_information
27 | is_location_of_anatomic_structure
28 | primary_mapped_to
29 | may_be_finding_of_disease
30 | subject_relationship_context_of
31 | form_of
32 | expected_outcome_of
33 | actual_outcome_of
34 | process_includes_biological_process
35 | has_target
36 | is_metastatic_anatomic_site_of_disease
37 | has_alternative
38 | distal_to
39 | challenge_of
40 | state_of_matter_of
41 | chemical_or_drug_initiates_biological_process
42 | has_contraindicated_physiologic_effect
43 | has_method
44 | drains_into
45 | allele_in_chromosomal_location
46 | biomarker_type_includes_gene
47 | anterolateral_to
48 | clinical_course_of
49 | gene_product_is_biomarker_of
50 | approach_of_possibly_included
51 | posteroinferior_to
52 | during
53 | disease_has_associated_gene
54 | attaches_to
55 | is_chemical_classification_of_gene_product
56 | has_nerve_supply
57 | has_precise_ingredient
58 | has_access
59 | has_pathological_process
60 | has_measurement_method
61 | has_exam
62 | has_laterality
63 | refers_to
64 | was_a
65 | completely_excised_anatomy_has_procedure
66 | reformulated_to
67 | disease_may_have_associated_disease
68 | is_physiologic_effect_of_chemical_or_drug
69 | enzyme_metabolizes_chemical_or_drug
70 | exhibited_by
71 | has_dose_form_intended_site
72 | has_suffix
73 | has_context_binding
74 | has_modality_subtype
75 | has_interpretation
76 | has_specimen
77 | has_possibly_included_pathology
78 | has_free_acid_or_base_form
79 | connected_to
80 | uses_energy
81 | has_arterial_supply
82 | has_count
83 | classifies
84 | constitutes
85 | positively_regulates
86 | has_lymphatic_drainage
87 | has_active_ingredient
88 | has_finding_informer
89 | has_extent
90 | has_possibly_included_procedure_site
91 | may_be_normal_cell_origin_of_disease
92 | surrounds
93 | has_recipient_category
94 | has_access_instrument
95 | icd_asterisk
96 | has_quantified_form
97 | has_subject
98 | chemical_or_drug_is_product_of_biological_process
99 | disease_is_marked_by_gene
100 | has_location
101 | has_excluded_specimen
102 | has_direct_site
103 | inferolateral_to
104 | fuses_with
105 | superomedial_to
106 | develops_into
107 | endogenous_product_related_to
108 | adjacent_to
109 | has_clinician_form
110 | has_intent
111 | chemical_or_drug_affects_abnormal_cell
112 | process_involves_gene
113 | has_focus
114 | has_object_guidance
115 | has_doseformgroup
116 | has_common_name
117 | has_component
118 | has_entire_anatomy_structure
119 | has_diagnostic_criteria
120 | has_sign_or_symptom
121 | has_maneuver_type
122 | gene_product_has_biochemical_function
123 | has_segmental_supply
124 | pathway_has_gene_element
125 | has_associated_etiologic_finding
126 | procedure_has_target_anatomy
127 | regulates
128 | sib_in_tributary_of
129 | has_secondary_segmental_supply
130 | has_primary_segmental_supply
131 | do_not_code_with
132 | has_alias
133 | has_finding_context
134 | has_associated_morphology
135 | has_expanded_form
136 | exhibits
137 | has_excluded_method
138 | has_indirect_procedure_site
139 | print_name_of
140 | has_property
141 | has_route_of_administration
142 | has_manifestation
143 | biological_process_has_associated_location
144 | complex_has_physical_part
145 | has_scale_type
146 | may_be_cytogenetic_abnormality_of_disease
147 | has_dose_form_transformation
148 | disease_is_stage
149 | has_locale
150 | has_british_form
151 | disease_mapped_to_chromosome
152 | pathogenesis_of_disease_involves_gene
153 | disease_excludes_abnormal_cell
154 | has_definitional_manifestation
155 | has_causative_agent
156 | has_origin
157 | has_allelic_variant
158 | has_fragments_for_synonyms
159 | has_regional_part
160 | disease_excludes_primary_anatomic_site
161 | continuous_distally_with
162 | precondition_of
163 | has_development_type
164 | cell_type_is_associated_with_eo_disease
165 | has_possibly_included_component
166 | has_possibly_included_surgical_extent
167 | forms
168 | temporally_related_to
169 | happens_during
170 | has_related_developmental_entity
171 | has_technique
172 | has_possibly_included_patient_type
173 | associated_genetic_condition
174 | chemical_or_drug_plays_role_in_biological_process
175 | may_qualify
176 | gene_product_has_abnormality
177 | has_possibly_included_associated_finding
178 | has_ingredient
179 | eo_disease_has_property_or_attribute
180 | has_presentation_strength_numerator_unit
181 | has_presentation_strength_denominator_unit
182 | has_presentation_strength_denominator_value
183 | has_basis_of_strength_substance
184 | has_measured_component
185 | has_presence_guidance
186 | has_temporal_context
187 | disease_mapped_to_gene
188 | has_surgical_approach
189 | biomarker_type_includes_gene_product
190 | has_part_anatomy_structure
191 | has_specimen_source_identity
192 | has_possibly_included_procedure_device
193 | tissue_is_expression_site_of_gene_product
194 | has_conceptual_part
195 | has_multi_level_category
196 | has_episodicity
197 | derives
198 | has_system
199 | has_indirect_device
200 | lateral_to
201 | has_time_modifier
202 | has_excluded_procedure_device
203 | has_pharmacokinetics
204 | activity_of_allele
205 | homonym_for
206 | manufactures
207 | has_lateral_location_presence
208 | has_given_pharmaceutical_substance
209 | has_indirect_morphology
210 | has_scale
211 | mth_has_expanded_form
212 | related_part
213 | disease_has_finding
214 | has_snomed_parent
215 | uses
216 | see
217 | posterior_to
218 | continuous_with
219 | use
220 | has_imaging_focus
221 | genetic_biomarker_related_to
222 | characterizes
223 | allele_plays_altered_role_in_process
224 | has_panel_element
225 | has_venous_drainage
226 | chemotherapy_regimen_has_component
227 | includes
228 | has_consumer_friendly_form
229 | excised_anatomy_has_procedure
230 | has_approach
231 | has_lateral_anatomic_location
232 | has_physiologic_effect
233 | sib_in_part_of
234 | sib_in_branch_of
235 | has_dose_form
236 | has_instrumentation
237 | chromosome_involved_in_cytogenetic_abnormality
238 | has_active_moiety
239 | has_revision_status
240 | has_direct_substance
241 | disease_has_primary_anatomic_site
242 | has_procedure_device
243 | has_member
244 | temporally_follows
245 | disease_has_normal_cell_origin
246 | biological_process_has_result_anatomy
247 | classifies_class_code
248 | biological_process_results_from_biological_process
249 | has_associated_function
250 | has_answer
251 | has_patient_type
252 | has_adjustment
253 | has_phenotype
254 | related_object
255 | gene_product_sequence_variation_encoded_by_gene_mutant
256 | consider
257 | has_surgical_extent
258 | icd_dagger
259 | gene_product_malfunction_associated_with_disease
260 | has_associated_condition
261 | has_timing_of
262 | mth_has_xml_form
263 | mth_has_plain_text_form
264 | superior_to
265 | has_snomed_synonym
266 | disease_excludes_molecular_abnormality
267 | eo_disease_has_associated_eo_anatomy
268 | part_of
269 | allele_has_abnormality
270 | has_developmental_stage
271 | has_property_type
272 | uses_possibly_included_substance
273 | related_to
274 | has_direct_procedure_site
275 | has_view_type
276 | is_grade_of_disease
277 | measures
278 | contains
279 | has_pathology
280 | same_as
281 | has_procedure_context
282 | moved_to
283 | has_action_guidance
284 | bounds
285 | has_divisor
286 | disease_has_associated_disease
287 | occurs_in
288 | has_modality_type
289 | disease_excludes_finding
290 | has_specimen_substance
291 | has_defining_characteristic
292 | supported_concept_relationship_in
293 | has_risk_factor
294 | receives_input_from
295 | disease_may_have_molecular_abnormality
296 | modifies
297 | gene_has_abnormality
298 | has_dose_form_administration_method
299 | has_supported_concept_property
300 | has_dose_form_release_characteristic
301 | has_chemical_structure
302 | receives_projection
303 | afferent_to
304 | has_multipart
305 | has_specimen_procedure
306 | has_inactive_ingredient
307 | has_excluded_procedure_site
308 | has_constitutional_part
309 | has_disposition
310 | articulates_with
311 | has_specimen_source_topography
312 | interprets
313 | has_excluded_pathology
314 | posterolateral_to
315 | has_possibly_included_associated_procedure
316 | has_onset
317 | has_aggregation_view
318 | disease_may_have_abnormal_cell
319 | biological_process_has_initiator_process
320 | has_lab_number
321 | has_course
322 | mth_british_form_of
323 | uses_device
324 | disease_has_normal_tissue_origin
325 | has_tributary
326 | disease_excludes_cytogenetic_abnormality
327 | chemical_or_drug_has_mechanism_of_action
328 | has_evaluation
329 | chemical_or_drug_affects_cell_type_or_tissue
330 | mapped_to
331 | has_approach_guidance
332 | concept_in_subset
333 | gene_product_plays_role_in_biological_process
334 | occurs_before
335 | gene_encodes_gene_product
336 | disease_has_abnormal_cell
337 | has_direct_morphology
338 | negatively_regulates
339 | gene_product_is_element_in_pathway
340 | molecular_abnormality_involves_gene
341 | has_severity
342 | has_finding_method
343 | has_communication_with_wound
344 | has_associated_finding
345 | gene_product_has_gene_product_variant
346 | has_specimen_source_morphology
347 | has_mechanism_of_action
348 | has_projection
349 | has_possibly_included_method
350 | has_imaged_location
351 | gene_product_has_organism_source
352 | eo_disease_maps_to_human_disease
353 | has_associated_procedure
354 | analyzes
355 | has_insertion
356 | has_contraindicated_mechanism_of_action
357 | disease_has_molecular_abnormality
358 | gene_product_has_structural_domain_or_motif
359 | gene_product_has_associated_anatomy
360 | has_inherent_attribute
361 | uses_substance
362 | has_physical_part_of_anatomic_structure
363 | has_finding_site
364 | has_parent
365 | has_procedure_morphology
366 | cause_of
367 | disease_has_associated_anatomic_site
368 | matures_into
369 | has_excluded_associated_procedure
370 | corresponds_to
371 | transforms_into
372 | has_continuation_branch
373 | posterosuperior_to
374 | site_of_metabolism
375 | has_cdrh_parent
376 | associated_with
377 | has_pharmaceutical_route
378 | has_inheritance_type
379 | direct_right_of
380 | has_related_factor
381 | uses_excluded_substance
382 | has_class
383 | gene_in_chromosomal_location
384 | other_mapped_to
385 | has_priority
386 | uses_access_device
387 | has_therapeutic_class
388 | has_nichd_parent
389 |
--------------------------------------------------------------------------------
/python/eukg/Config.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | # control options
4 | tf.flags.DEFINE_string("data_dir", "/users/rmm120030/working/umls-mke/new", "training data directory")
5 | tf.flags.DEFINE_string("model_dir", "/users/rmm120030/working/umls-mke/new/model", "model directory")
6 | tf.flags.DEFINE_string("model", "transd", "Model: [transe, transd, distmult]")
7 | tf.flags.DEFINE_string("mode", "disc", "Mode: [disc, gen, gan]")
8 | tf.flags.DEFINE_string("run_name", None, "Run name")
9 | tf.flags.DEFINE_string("summaries_dir", "/users/rmm120030/working/umls-mke/new/model", "model summary dir")
10 | tf.flags.DEFINE_integer("batch_size", 564, "batch size")
11 | tf.flags.DEFINE_bool("load", False, "Load model?")
12 | tf.flags.DEFINE_bool("load_embeddings", False, "Load embeddings?")
13 | tf.flags.DEFINE_string("embedding_file", None, "Embedding matrix npz")
14 | tf.flags.DEFINE_integer("num_epochs", 100, "Number of epochs?")
15 | tf.flags.DEFINE_float("val_proportion", 0.1, "Proportion of training data to hold out for validation")
16 | tf.flags.DEFINE_integer("progress_update_interval", 1000, "Number of batches between progress updates")
17 | tf.flags.DEFINE_integer("val_progress_update_interval", 1800, "Number of batches between progress updates")
18 | tf.flags.DEFINE_integer("save_interval", 1800, "Number of seconds between model saves")
19 | tf.flags.DEFINE_integer("batches_per_epoch", 1, "Number of batches per training epoch")
20 | tf.flags.DEFINE_integer("max_batches_per_epoch", None, "Maximum number of batches per training epoch")
21 | tf.flags.DEFINE_string("embedding_device", "gpu", "Device to do embedding lookups on [gpu, cpu]")
22 | tf.flags.DEFINE_string("optimizer", "adam", "Optimizer [adam, sgd]")
23 | tf.flags.DEFINE_string("save_strategy", "epoch", "Save every epoch or saved every"
24 | " flags.save_interval seconds [epoch, timed]")
25 |
26 | # eval control options
27 | tf.flags.DEFINE_string("eval_mode", "save", "Evaluation mode: [save, calc]")
28 | tf.flags.DEFINE_string("eval_dir", "eval", "directory for evaluation outputs")
29 | tf.flags.DEFINE_integer("shard", 1, "Shard number for distributed eval.")
30 | tf.flags.DEFINE_integer("num_shards", 1, "Total number of shards for distributed eval.")
31 | tf.flags.DEFINE_bool("save_ranks", True, "Save ranks? (turn off while debugging)")
32 |
33 | # gan control options
34 | tf.flags.DEFINE_string("dis_run_name", None, "Run name for the discriminator model")
35 | tf.flags.DEFINE_string("gen_run_name", "dm-gen", "Run name for the generator model")
36 | tf.flags.DEFINE_string("sn_gen_run_name", "dm-sn-gen", "Run name for the semnet generator model")
37 |
38 | # model params
39 | tf.flags.DEFINE_float("learning_rate", 1e-5, "Starting learning rate")
40 | tf.flags.DEFINE_float("decay_rate", 0.96, "LR decay rate")
41 | tf.flags.DEFINE_float("momentum", 0.9, "Momentum")
42 | tf.flags.DEFINE_float("gamma", 0.1, "Margin parameter for loss")
43 | tf.flags.DEFINE_integer("vocab_size", 3210965, "Number of unique concepts+relations")
44 | tf.flags.DEFINE_integer("embedding_size", 50, "embedding size")
45 | tf.flags.DEFINE_integer("energy_norm_ord", 1,
46 | "Order of the normalization function used to quantify difference between h+r and t")
47 | tf.flags.DEFINE_integer("max_concepts_per_type", 1000, "Maximum number of concepts to average for semtype loss")
48 | tf.flags.DEFINE_integer("num_generator_samples", 100, "Number of negative samples for each generator example")
49 | tf.flags.DEFINE_string("p_init", "zeros",
50 | "Projection vectors initializer: [zeros, xavier, uniform]. Uniform is in [-.1,.1]")
51 |
52 | # semnet params
53 | tf.flags.DEFINE_bool("no_semantic_network", False, "Do not add semantic network loss to the graph?")
54 | tf.flags.DEFINE_float("semnet_alignment_param", 0.5, "Parameter to control semantic network loss")
55 | tf.flags.DEFINE_float("semnet_energy_param", 0.5, "Parameter to control semantic network loss")
56 | tf.flags.DEFINE_bool("sn_eval", False, "Train this model with subset of sn to evaluate the SN embeddings?")
57 |
58 | # distmult params
59 | tf.flags.DEFINE_float("regularization_parameter", 1e-4, "Regularization term weight")
60 | tf.flags.DEFINE_string("energy_activation", 'sigmoid',
61 | "Energy activation function [None, tanh, relu, sigmoid]")
62 |
63 |
64 | flags = tf.flags.FLAGS
65 |
--------------------------------------------------------------------------------
/python/eukg/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '/home/rmm120030/code/python')
3 |
--------------------------------------------------------------------------------
/python/eukg/data/DataGenerator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import math
4 | from itertools import izip
5 | from collections import defaultdict
6 | from tqdm import tqdm
7 | import os
8 | import ujson
9 |
10 | import data_util
11 |
12 |
13 | class DataGenerator:
14 | def __init__(self, data, train_idx, val_idx, config, type2cuis=None, test_mode=False):
15 | self.data = data
16 | self.train_idx = train_idx
17 | self.val_idx = val_idx
18 | self.config = config
19 | self._sampler = None
20 | self._sn_sampler = None
21 | if not config.no_semantic_network:
22 | assert type2cuis
23 | self._sn_sampler = NegativeSampler(data['sn_subj'], data['sn_rel'], data['sn_obj'], 'semnet')
24 | # if we wish to train this model for sn eval, only use
25 | if config.sn_eval:
26 | with np.load(os.path.join(config.data_dir, 'semnet', 'train.npz')) as sn_npz:
27 | for key, val in sn_npz.iteritems():
28 | self.data['sn_' + key] = val
29 |
30 | self.type2cuis = type2cuis
31 | self.test_mode = test_mode
32 |
33 | @property
34 | def sn_sampler(self):
35 | # if self._sn_sampler is None:
36 | # self._sn_sampler = self.init_sn_sampler()
37 | return self._sn_sampler
38 |
39 | @property
40 | def sampler(self):
41 | if self._sampler is None:
42 | self._sampler = self.init_sampler()
43 | return self._sampler
44 |
45 | # must include test data in negative sampler
46 | def init_sampler(self):
47 | if self.test_mode:
48 | _, test_data, _, _ = data_util.load_metathesaurus_data(self.config.data_dir, 0.)
49 | else:
50 | test_data = data_util.load_metathesaurus_test_data(self.config.data_dir)
51 | valid_triples = set()
52 | for s, r, o in zip(self.data['subj'], self.data['rel'], self.data['obj']):
53 | valid_triples.add((s, r, o))
54 | for s, r, o in zip(test_data['subj'], test_data['rel'], test_data['obj']):
55 | valid_triples.add((s, r, o))
56 |
57 | return NegativeSampler(valid_triples=valid_triples, name='mt')
58 |
59 | def generate_mt(self, is_training):
60 | idxs = self.train_idx if is_training else self.val_idx
61 | batch_size = self.config.batch_size
62 | subj, rel, obj = self.data['subj'], self.data['rel'], self.data['obj']
63 | nsubj, nobj = self.sampler.sample(subj, rel, obj)
64 | num_batches = int(math.floor(float(len(idxs)) / batch_size))
65 | print('\n\ngenerating %d batches' % num_batches)
66 | for b in xrange(num_batches):
67 | idx = idxs[b * batch_size: (b + 1) * batch_size]
68 | yield rel[idx], subj[idx], obj[idx], nsubj[idx], nobj[idx]
69 |
70 | def generate_mt_gen_mode(self, is_training):
71 | idxs = self.train_idx if is_training else self.val_idx
72 | batch_size = self.config.batch_size
73 | subj, rel, obj = self.data['subj'], self.data['rel'], self.data['obj']
74 | num_batches = int(math.floor(float(len(idxs)) / batch_size))
75 | print('\n\ngenerating %d batches in generation mode' % num_batches)
76 | for b in xrange(num_batches):
77 | idx = idxs[b * batch_size: (b + 1) * batch_size]
78 | sampl_subj, sampl_obj = self.sampler.sample_for_generator(subj[idx], rel[idx], obj[idx],
79 | self.config.num_generator_samples)
80 | yield rel[idx], subj[idx], obj[idx], sampl_subj, sampl_obj
81 |
82 | def generate_sn(self, is_training):
83 | print('\n\ngenerating SN data')
84 | if is_training:
85 | idxs = np.random.permutation(np.arange(len(self.data['sn_rel'])))
86 | data = self.data
87 | subj, rel, obj = data['sn_subj'], data['sn_rel'], data['sn_obj']
88 | nsubj, nobj = self.sn_sampler.sample(subj, rel, obj)
89 | sn_offset = 0
90 | batch_size = self.config.batch_size / 4
91 | while True:
92 | idx, sn_offset = get_next_k_idxs(idxs, batch_size, sn_offset)
93 | types = np.unique([subj[idx], obj[idx], nsubj[idx], nobj[idx]])
94 | concepts = np.zeros([len(types), self.config.max_concepts_per_type], dtype=np.int32)
95 | concept_lens = np.zeros([len(types)], dtype=np.int32)
96 | for i, tid in enumerate(types):
97 | concepts_of_type_t = self.type2cuis[tid] if tid in self.type2cuis else []
98 | random.shuffle(concepts_of_type_t)
99 | concepts_of_type_t = concepts_of_type_t[:self.config.max_concepts_per_type]
100 | concept_lens[i] = len(concepts_of_type_t)
101 | concepts[i, :len(concepts_of_type_t)] = concepts_of_type_t
102 |
103 | yield rel[idx], subj[idx], obj[idx], nsubj[idx], nobj[idx], \
104 | concepts, concept_lens, types
105 | else:
106 | while True:
107 | yield [0], [0], [0], [0], [0], np.zeros([1, 1000], dtype=np.int32), [1], [0]
108 |
109 | def generate_sn_gen_mode(self, is_training):
110 | print('\n\ngenerating SN data in generation mode')
111 | num_samples = 10
112 | if is_training:
113 | idxs = np.random.permutation(np.arange(len(self.data['sn_rel'])))
114 | subj, rel, obj = self.data['sn_subj'], self.data['sn_rel'], self.data['sn_obj']
115 | sn_offset = 0
116 | batch_size = self.config.batch_size / 4
117 | while True:
118 | idx, sn_offset = get_next_k_idxs(idxs, batch_size, sn_offset)
119 | subj_ = subj[idx]
120 | rel_ = rel[idx]
121 | obj_ = obj[idx]
122 | subj_samples, obj_samples = self.sn_sampler.sample_for_generator(subj_, rel_, obj_, num_samples)
123 |
124 | types = np.unique(np.concatenate([subj_, obj_, subj_samples.flatten(), obj_samples.flatten()]))
125 | concepts = np.zeros([len(types), self.config.max_concepts_per_type], dtype=np.int32)
126 | concept_lens = np.zeros([len(types)], dtype=np.int32)
127 | for i, tid in enumerate(types):
128 | concepts_of_type_t = self.type2cuis[tid] if tid in self.type2cuis else []
129 | random.shuffle(concepts_of_type_t)
130 | concepts_of_type_t = concepts_of_type_t[:self.config.max_concepts_per_type]
131 | concept_lens[i] = len(concepts_of_type_t)
132 | concepts[i, :len(concepts_of_type_t)] = concepts_of_type_t
133 |
134 | yield rel_, subj_, obj_, subj_samples, obj_samples, \
135 | concepts, concept_lens, types
136 | else:
137 | while True:
138 | yield [0], [0], [0], [0, 0], [0, 0], np.zeros([1, 1000], dtype=np.int32), [1], [0]
139 |
140 | def num_train_batches(self):
141 | return int(math.floor(float(len(self.train_idx)) / self.config.batch_size))
142 |
143 | def num_val_batches(self):
144 | return int(math.floor(float(len(self.val_idx)) / self.config.batch_size))
145 |
146 |
147 | class NegativeSampler:
148 | def __init__(self, subj=None, rel=None, obj=None, name=None, cachedir="/home/rmm120030/working/umls-mke/.cache",
149 | valid_triples=None):
150 | # cachedir = os.path.join(cachedir, name)
151 | # if os.path.exists(cachedir):
152 | # start = time.time()
153 | # print('loading negative sampler maps from %s' % cachedir)
154 | # self.sr2o = load_dict(os.path.join(cachedir, 'sr2o'))
155 | # self.or2s = load_dict(os.path.join(cachedir, 'or2s'))
156 | # self.concepts = ujson.load(open(os.path.join(cachedir, 'concepts.json')))
157 | # print('done! Took %.2f seconds' % (time.time() - start))
158 | # else:
159 | self.sr2o = defaultdict(set)
160 | self.or2s = defaultdict(set)
161 | concepts = set()
162 | triples = zip(subj, rel, obj) if valid_triples is None else valid_triples
163 | for s, r, o in tqdm(triples, desc='building triple maps', total=len(triples)):
164 | # s, r, o = int(s), int(r), int(o)
165 | self.sr2o[(s, r)].add(o)
166 | self.or2s[(o, r)].add(s)
167 | concepts.update([s, o])
168 | self.concepts = list(concepts)
169 |
170 | # print('\n\ncaching negative sampler maps to %s' % cachedir)
171 | # os.makedirs(cachedir)
172 | # save_dict(self.sr2o, os.path.join(cachedir, 'sr2o'))
173 | # save_dict(self.or2s, os.path.join(cachedir, 'or2s'))
174 | # ujson.dump(self.concepts, open(os.path.join(cachedir, 'concepts.json'), 'w+'))
175 | # print('done!')
176 |
177 | def _neg_sample(self, s_, r_, o_, replace_s):
178 | while True:
179 | c = random.choice(self.concepts)
180 | if replace_s and c not in self.or2s[(o_, r_)]:
181 | return c, o_
182 | elif not replace_s and c not in self.sr2o[(s_, r_)]:
183 | return s_, c
184 |
185 | def sample(self, subj, rel, obj):
186 | neg_subj = []
187 | neg_obj = []
188 | print("\n")
189 | for s, r, o in tqdm(zip(subj, rel, obj), desc='negative sampling', total=len(subj)):
190 | ns, no = self._neg_sample(s, r, o, random.random() > 0.5)
191 | neg_subj.append(ns)
192 | neg_obj.append(no)
193 |
194 | return np.asarray(neg_subj, dtype=np.int32), np.asarray(neg_obj, dtype=np.int32)
195 |
196 | def _sample_k(self, subj, rel, obj, k):
197 | neg_subj = []
198 | neg_obj = []
199 | for i in xrange(k):
200 | ns, no = self._neg_sample(subj, rel, obj, random.random() > 0.5)
201 | neg_subj.append(ns)
202 | neg_obj.append(no)
203 |
204 | return neg_subj, neg_obj
205 |
206 | def sample_for_generator(self, subj_array, rel_array, obj_array, k):
207 | subj_samples = []
208 | obj_samples = []
209 | for s, r, o in zip(subj_array, rel_array, obj_array):
210 | ns, no = self._sample_k(s, r, o, k)
211 | subj_samples.append(ns)
212 | obj_samples.append(no)
213 |
214 | return np.asarray(subj_samples, dtype=np.int32), np.asarray(obj_samples, dtype=np.int32)
215 |
216 | def invalid_concepts(self, subj, rel, obj, replace_subj):
217 | if replace_subj:
218 | return [c for c in self.concepts if c not in self.or2s[(obj, rel)]]
219 | else:
220 | return [c for c in self.concepts if c not in self.sr2o[(subj, rel)]]
221 |
222 |
223 | def get_next_k_idxs(all_idxs, k, offset):
224 | if offset + k < len(all_idxs):
225 | idx = all_idxs[offset: offset + k]
226 | offset += k
227 | else:
228 | random.shuffle(all_idxs)
229 | offset = k
230 | idx = all_idxs[:offset]
231 | return idx, offset
232 |
233 |
234 | def wrap_generators(mt_gen, sn_gen, is_training):
235 | if is_training:
236 | for mt_batch, sn_batch in izip(mt_gen(True), sn_gen(True)):
237 | yield mt_batch + sn_batch
238 | else:
239 | for b in mt_gen(True):
240 | yield b
241 |
242 |
243 | def save_dict(d, savepath):
244 | keys, values = [], []
245 | for k, v in d.iteritems():
246 | keys.append(list(k))
247 | values.append(list(v))
248 |
249 | ujson.dump(keys, open(savepath + '_keys.json', 'w+'))
250 | ujson.dump(values, open(savepath + '_values.json', 'w+'))
251 |
252 |
253 | def load_dict(savepath):
254 | keys = ujson.load(open(savepath + '_keys.json'))
255 | values = ujson.load(open(savepath + '_values.json'))
256 | return {tuple(k): set(v) for k, v in izip(keys, values)}
257 |
--------------------------------------------------------------------------------
/python/eukg/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-mal/umls-embeddings/1ac29b15f5bedf8edbd2c90d0ccdae094be03518/python/eukg/data/__init__.py
--------------------------------------------------------------------------------
/python/eukg/data/create_semnet_triples.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | import json
4 | import os
5 | from tqdm import tqdm
6 | from collections import defaultdict
7 | import numpy as np
8 |
9 | from create_test_set import split
10 |
11 |
12 | def process_mapping(umls_dir, data_dir):
13 | print('Creating mapping of semtypes to cuis and vice versa...')
14 | name2id = json.load(open(os.path.join(data_dir, 'name2id.json')))
15 | cui2semtypes = defaultdict(list)
16 | semtype2cuis = defaultdict(list)
17 | with open(os.path.join(umls_dir, 'META', 'MRSTY.RRF'), 'r') as f:
18 | reader = csv.reader(f, delimiter='|')
19 | for row in tqdm(reader, desc="reading", total=3395307):
20 | cui = row[0].strip()
21 | if cui in name2id:
22 | cid = name2id[cui]
23 | tui = row[1]
24 | if tui in name2id:
25 | tid = name2id[tui]
26 | else:
27 | tid = len(name2id)
28 | name2id[tui] = tid
29 | cui2semtypes[cid].append(tid)
30 | semtype2cuis[tid].append(cid)
31 | print('Processed type mappings for %d semantic types' % len(semtype2cuis))
32 | c2s_lens = sorted([len(l) for l in cui2semtypes.values()])
33 | print('Maximum # semtypes for a cui: %d' % max(c2s_lens))
34 | print('Average # semtypes for a cui: %.2f' % (float(sum(c2s_lens)) / len(c2s_lens)))
35 | print('Median # semtypes for a cui: %.2f' % c2s_lens[len(c2s_lens)/2])
36 |
37 | s2c_lens = sorted([len(l) for l in semtype2cuis.values()])
38 | print('Maximum # cuis for a semtype: %d' % max(s2c_lens))
39 | print('Average # cuis for a semtype: %.2f' % (float(sum(s2c_lens)) / len(s2c_lens)))
40 | print('Median # cuis for a semtype: %.2f' % s2c_lens[len(s2c_lens)/2])
41 | print('%% under 1k: %.4f' % (float(len([l for l in s2c_lens if l < 1000])) / len(s2c_lens)))
42 | print('%% under 2k: %.4f' % (float(len([l for l in s2c_lens if l < 2000])) / len(s2c_lens)))
43 |
44 | # json.dump(name2id, open('/home/rmm120030/working/umls-mke/data/name2id.json', 'w+'))
45 | json.dump(cui2semtypes, open(os.path.join(data_dir, 'semnet', 'cui2semtpyes.json'), 'w+'))
46 | json.dump(semtype2cuis, open(os.path.join(data_dir, 'semnet', 'semtype2cuis.json'), 'w+'))
47 |
48 |
49 | def semnet_triples(umls_dir, data_dir):
50 | print('Creating semantic network triples...')
51 | name2id = json.load(open(os.path.join(data_dir, 'name2id.json')))
52 |
53 | total_relations = 0
54 | new_relations = 0
55 | tui2id = {}
56 | # relations which have a metathesaurus analog are mapped to the MT embedding
57 | with open(os.path.join(umls_dir, 'NET', 'SRDEF')) as f:
58 | reader = csv.reader(f, delimiter='|')
59 | for row in reader:
60 | tui = row[1]
61 | if row[0] == 'RL':
62 | total_relations += 1
63 | rel = row[2]
64 | if rel in name2id:
65 | print('reusing relation embedding for %s' % rel)
66 | name2id[tui] = name2id[rel]
67 | elif tui not in name2id:
68 | new_relations += 1
69 | name2id[tui] = len(name2id)
70 | else:
71 | if tui not in name2id:
72 | name2id[tui] = len(name2id)
73 | tui2id[tui] = name2id[tui]
74 |
75 | print('Created %d of %d new relations' % (new_relations, total_relations))
76 | print('%d total embeddings' % len(name2id))
77 | json.dump(name2id, open(os.path.join(data_dir, 'name2id.json'), 'w+'))
78 | json.dump(tui2id, open(os.path.join(data_dir, 'semnet', 'tui2id.json'), 'w+'))
79 |
80 | subj, rel, obj = [], [], []
81 | with open(os.path.join(umls_dir, 'NET', 'SRSTRE1'), 'r') as f:
82 | reader = csv.reader(f, delimiter='|')
83 | for row in reader:
84 | subj.append(name2id[row[0]])
85 | rel.append(name2id[row[1]])
86 | obj.append(name2id[row[2]])
87 |
88 | print('Saving the %d triples of the semantic network graph' % len(rel))
89 | split(np.asarray(subj, dtype=np.int32),
90 | np.asarray(rel, dtype=np.int32),
91 | np.asarray(obj, dtype=np.int32),
92 | os.path.join(data_dir, 'semnet'),
93 | 'semnet',
94 | 600)
95 |
96 |
97 | if __name__ == "__main__":
98 | parser = argparse.ArgumentParser(description='Extract relation triples into a compressed numpy file from MRCONSO.RRF')
99 | parser.add_argument('umls_dir', help='UMLS directory')
100 | parser.add_argument('--output', default='data', help='the compressed numpy file to be created')
101 |
102 | args = parser.parse_args()
103 | semnet_triples(args.umls_dir, args.output)
104 | process_mapping(args.umls_dir, args.output)
105 |
--------------------------------------------------------------------------------
/python/eukg/data/create_test_set.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys
3 | import os
4 |
5 |
6 | def split(subj, rel, obj, data_dir, graph_name='metathesaurus', num_test=100000):
7 | valid_triples = set()
8 | for s, r, o in zip(subj, rel, obj):
9 | valid_triples.add((s, r, o))
10 | subj = np.asarray([s for s, _, _ in valid_triples])
11 | rel = np.asarray([r for _, r, _ in valid_triples])
12 | obj = np.asarray([o for _, _, o in valid_triples])
13 |
14 | perm = np.random.permutation(np.arange(len(rel)))
15 | test_idx = perm[:num_test]
16 | train_idx = perm[num_test:]
17 | print('created train/test splits (%d/%d)' % (len(train_idx), len(test_idx)))
18 |
19 | np.savez_compressed(os.path.join(data_dir, graph_name, 'train.npz'),
20 | subj=subj[train_idx],
21 | rel=rel[train_idx],
22 | obj=obj[train_idx])
23 | print('saved train set')
24 |
25 | np.savez_compressed(os.path.join(data_dir, graph_name, 'test.npz'),
26 | subj=subj[test_idx],
27 | rel=rel[test_idx],
28 | obj=obj[test_idx])
29 | print('saved test set')
30 |
31 |
32 | def from_train_file(data_dir, graph_name='metathesaurus', num_test=100000):
33 | npz = np.load(os.path.join(data_dir, graph_name, 'triples.npz'))
34 | data = dict(npz.iteritems())
35 | npz.close()
36 | print('read all data')
37 | split(data['subj'], data['rel'], data['obj'], data_dir, graph_name, num_test)
38 |
39 |
40 | def from_train_test_files():
41 | data_dir = sys.argv[1]
42 | npz = np.load(os.path.join(data_dir, 'metathesaurus', 'train.npz'))
43 | data = dict(npz.iteritems())
44 | npz.close()
45 | npz = np.load(os.path.join(data_dir, 'metathesaurus', 'test.npz'))
46 | subj = np.concatenate((data['subj'], npz['subj']))
47 | rel = np.concatenate((data['rel'], npz['rel']))
48 | obj = np.concatenate((data['obj'], npz['obj']))
49 | npz.close()
50 | print('read all data')
51 | valid_triples = set()
52 | for s, r, o in zip(subj, rel, obj):
53 | valid_triples.add((s, r, o))
54 | subj = np.asarray([s for s, _, _ in valid_triples])
55 | rel = np.asarray([r for _, r, _ in valid_triples])
56 | obj = np.asarray([o for _, _, o in valid_triples])
57 |
58 | perm = np.random.permutation(np.arange(len(rel)))
59 | num_test = 100000
60 | test_idx = perm[:num_test]
61 | train_idx = perm[num_test:]
62 | print('created train/test splits (%d/%d)' % (len(train_idx), len(test_idx)))
63 |
64 | np.savez_compressed(os.path.join(data_dir, 'metathesaurus', 'train.npz'),
65 | subj=subj[train_idx],
66 | rel=rel[train_idx],
67 | obj=obj[train_idx])
68 | print('saved train set')
69 |
70 | np.savez_compressed(os.path.join(data_dir, 'metathesaurus', 'test.npz'),
71 | subj=subj[test_idx],
72 | rel=rel[test_idx],
73 | obj=obj[test_idx])
74 | print('saved test set')
75 |
76 |
77 | if __name__ == "__main__":
78 | if len(sys.argv) > 2:
79 | from_train_file(sys.argv[1], sys.argv[2], int(sys.argv[3]))
80 | else:
81 | from_train_file(sys.argv[1])
82 |
--------------------------------------------------------------------------------
/python/eukg/data/create_triples.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import csv
4 | import json
5 | from tqdm import tqdm
6 | import argparse
7 |
8 | from create_test_set import split
9 |
10 |
11 | def metathesaurus_triples(rrf_file, output_dir, valid_relations):
12 | triples = set()
13 | conc2id = {}
14 |
15 | def add_concept(conc):
16 | if conc in conc2id:
17 | cid = conc2id[conc]
18 | else:
19 | cid = len(conc2id)
20 | conc2id[conc] = cid
21 | return cid
22 |
23 | with open(rrf_file, 'r') as f:
24 | reader = csv.reader(f, delimiter='|')
25 | for row in tqdm(reader, desc="reading", total=37207861):
26 | if row[7] in valid_relations:
27 | sid = add_concept(row[0])
28 | rid = add_concept(row[7])
29 | oid = add_concept(row[4])
30 | triples.add((sid, rid, oid))
31 |
32 | subjs, rels, objs = zip(*triples)
33 | snp = np.asarray(subjs, dtype=np.int32)
34 | rnp = np.asarray(rels, dtype=np.int32)
35 | onp = np.asarray(objs, dtype=np.int32)
36 |
37 | id2conc = {v: k for k, v in conc2id.iteritems()}
38 | concepts = [id2conc[i] for i in np.unique(np.concatenate((subjs, objs)))]
39 | relations = [id2conc[i] for i in set(rels)]
40 |
41 | print("Saving %d unique triples to %s. %d concepts spanning %d relations" % (rnp.shape[0], output_dir, len(concepts),
42 | len(relations)))
43 |
44 | split(snp, rnp, onp, output_dir)
45 |
46 | json.dump(conc2id, open(os.path.join(output_dir, 'name2id.json'), 'w+'))
47 | json.dump(concepts, open(os.path.join(output_dir, 'concept_vocab.json'), 'w+'))
48 | json.dump(relations, open(os.path.join(output_dir, 'relation_vocab.json'), 'w+'))
49 |
50 |
51 | def metathesaurus_triples_trimmed(rrf_file, output_dir, valid_concepts, valid_relations, important_concepts=None):
52 | triples = set()
53 | conc2id = {}
54 |
55 | def add_concept(conc):
56 | if conc in conc2id:
57 | cid = conc2id[conc]
58 | else:
59 | cid = len(conc2id)
60 | conc2id[conc] = cid
61 | return cid
62 |
63 | with open(rrf_file, 'r') as f:
64 | reader = csv.reader(f, delimiter='|')
65 | for row in tqdm(reader, desc="reading", total=37207861):
66 | if (row[0] in valid_concepts and row[4] in valid_concepts and row[7] in valid_relations) or \
67 | (important_concepts is not None and row[7] != '' and (row[0] in important_concepts or row[4] in important_concepts)):
68 | sid = add_concept(row[0])
69 | rid = add_concept(row[7])
70 | oid = add_concept(row[4])
71 | triples.add((sid, rid, oid))
72 |
73 | subjs, rels, objs = zip(*triples)
74 | snp = np.asarray(subjs, dtype=np.int32)
75 | rnp = np.asarray(rels, dtype=np.int32)
76 | onp = np.asarray(objs, dtype=np.int32)
77 |
78 | id2conc = {v: k for k, v in conc2id.iteritems()}
79 | concepts = [id2conc[i] for i in np.unique(np.concatenate((subjs, objs)))]
80 | relations = [id2conc[i] for i in rels]
81 |
82 | print("Saving %d unique triples to %s" % (rnp.shape[0], output_dir))
83 |
84 | np.savez_compressed(os.path.join(output_dir, 'triples'),
85 | subj=snp,
86 | rel=rnp,
87 | obj=onp)
88 | json.dump(conc2id, open(os.path.join(output_dir, 'name2id.json'), 'w+'))
89 | json.dump(concepts, open(os.path.join(concepts, 'concept_vocab.json'), 'w+'))
90 | json.dump(relations, open(os.path.join(relations, 'relation_vocab.json'), 'w+'))
91 |
92 |
93 | def main():
94 | parser = argparse.ArgumentParser(description='Extract relation triples into a compressed numpy file from MRCONSO.RRF')
95 | parser.add_argument('umls_dir', help='UMLS MRCONSO.RRF file containing metathesaurus relations')
96 | parser.add_argument('--output', default='data', help='the compressed numpy file to be created')
97 | parser.add_argument('--valid_relations', default='data/valid_rels.txt',
98 | help='plaintext list of relations we want to extract triples for, one per line.')
99 |
100 | args = parser.parse_args()
101 |
102 | valid_relations = set([rel.strip() for rel in open(args.valid_relations)])
103 |
104 | metathesaurus_triples(os.path.join(args.umls_dir, 'META', 'MRCONSO.RRF'), args.output, valid_relations)
105 |
106 |
107 | if __name__ == "__main__":
108 | main()
109 |
--------------------------------------------------------------------------------
/python/eukg/data/data_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import defaultdict
3 | from tqdm import tqdm
4 | import sys
5 | import math
6 | import os
7 | import json
8 | import random
9 |
10 |
11 | def negative_sampling_from_file(np_file):
12 | npz = np.load(np_file)
13 | subj, rel, obj = npz['subj'], npz['rel'], npz['obj']
14 |
15 | sampler = NegativeSampler(subj, rel, obj)
16 | neg_subj, neg_obj = sampler.sample(subj, rel, obj)
17 |
18 | npz.close()
19 | np.savez_compressed(np_file,
20 | subj=subj,
21 | rel=rel,
22 | obj=obj,
23 | nsubj=neg_subj,
24 | nobj=neg_obj)
25 |
26 |
27 | def load_metathesaurus_data(data_dir, val_proportion):
28 | random.seed(1337)
29 |
30 | cui2id = json.load(open(os.path.join(data_dir, 'name2id.json')))
31 | npz = np.load(os.path.join(data_dir, 'metathesaurus', 'train.npz'))
32 | data = dict(npz.iteritems())
33 | npz.close()
34 |
35 | perm = np.random.permutation(np.arange(len(data['rel'])))
36 | num_val = int(math.ceil(len(perm) * val_proportion))
37 | val_idx = perm[:num_val]
38 | train_idx = perm[num_val:]
39 |
40 | return cui2id, data, train_idx, val_idx
41 |
42 |
43 | def load_semantic_network_data(data_dir, data_map):
44 | type2cuis = json.load(open(os.path.join(data_dir, 'semnet', 'semtype2cuis.json')))
45 | npz = np.load(os.path.join(data_dir, 'semnet', 'triples.npz'))
46 | for key, val in npz.iteritems():
47 | data_map['sn_' + key] = val
48 | npz.close()
49 |
50 | return type2cuis
51 |
52 |
53 | def load_metathesaurus_test_data(data_dir):
54 | npz = np.load(os.path.join(data_dir, 'metathesaurus', 'test.npz'))
55 | data = dict(npz.iteritems())
56 | npz.close()
57 |
58 | return data
59 |
60 |
61 | def save_config(outdir, config):
62 | print('saving config to %s' % outdir)
63 | with open('%s/config.json' % outdir, 'w+') as f:
64 | json.dump(config.flag_values_dict(), f)
65 |
66 |
67 | def main():
68 | negative_sampling_from_file(sys.argv[1])
69 |
70 |
71 | class NegativeSampler:
72 | def __init__(self, subj, rel, obj, name, cachedir="/tmp"):
73 | cachedir = os.path.join(cachedir, name)
74 | # if os.path.exists(cachedir):
75 | # print('loading negative sampler maps from %s' % cachedir)
76 | # self.sr2o = defaultdict(list, json.load(open(os.path.join(cachedir, 'sr2o.json'))))
77 | # self.or2s = defaultdict(list, json.load(open(os.path.join(cachedir, 'or2s.json'))))
78 | # self.concepts = json.load(open(os.path.join(cachedir, 'concepts.json')))
79 | # else:
80 | self.sr2o = defaultdict(set)
81 | self.or2s = defaultdict(set)
82 | concepts = set()
83 | print("\n")
84 | for s, r, o in tqdm(zip(subj, rel, obj), desc='building triple maps', total=len(subj)):
85 | self.sr2o[(s, r)].add(o)
86 | self.or2s[(o, r)].add(s)
87 | concepts.update([s, o])
88 | self.concepts = list(concepts)
89 |
90 | # os.makedirs(cachedir)
91 | # json.dump({k: list(v) for k, v in self.sr2o.iteritems()}, open(os.path.join(cachedir, 'sr2o.json'), 'w+'))
92 | # json.dump({k: list(v) for k, v in self.or2s.iteritems()}, open(os.path.join(cachedir, 'or2s.json'), 'w+'))
93 | # json.dump(self.concepts, open(os.path.join(cachedir, 'concepts.json'), 'w+'))
94 |
95 | def _neg_sample(self, s_, r_, o_, replace_s):
96 | while True:
97 | c = random.choice(self.concepts)
98 | if replace_s and c not in self.or2s[(o_, r_)]:
99 | return c, o_
100 | elif not replace_s and c not in self.sr2o[(s_, r_)]:
101 | return s_, c
102 |
103 | def sample(self, subj, rel, obj):
104 | neg_subj = []
105 | neg_obj = []
106 | print("\n")
107 | for s, r, o in tqdm(zip(subj, rel, obj), desc='negative sampling', total=len(subj)):
108 | ns, no = self._neg_sample(s, r, o, random.random > 0.5)
109 | neg_subj.append(ns)
110 | neg_obj.append(no)
111 |
112 | return np.asarray(neg_subj, dtype=np.int32), np.asarray(neg_obj, dtype=np.int32)
113 |
114 | def _sample_k(self, subj, rel, obj, k):
115 | neg_subj = []
116 | neg_obj = []
117 | for i in xrange(k):
118 | ns, no = self._neg_sample(subj, rel, obj, random.random > 0.5)
119 | neg_subj.append(ns)
120 | neg_obj.append(no)
121 |
122 | return neg_subj, neg_obj
123 |
124 | def sample_for_generator(self, subj_array, rel_array, obj_array, k):
125 | subj_samples = []
126 | obj_samples = []
127 | for s, r, o in zip(subj_array, rel_array, obj_array):
128 | ns, no = self._sample_k(s, r, o, k)
129 | subj_samples.append(ns)
130 | obj_samples.append(no)
131 |
132 | return np.asarray(subj_samples), np.asarray(obj_samples)
133 |
134 |
135 | if __name__ == "__main__":
136 | main()
137 |
--------------------------------------------------------------------------------
/python/eukg/emb/EmbeddingModel.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tensorflow.contrib import layers
4 |
5 |
6 | class BaseModel:
7 | def __init__(self, config):
8 | self.embedding_size = config.embedding_size
9 | self.embedding_device = config.embedding_device
10 | self.vocab_size = config.vocab_size
11 |
12 | self.embeddings = None
13 | self.ids_to_update = tf.placeholder(dtype=tf.int32, shape=[None], name='used_vectors')
14 | self.norm_op = None
15 |
16 | def energy(self, head, rel, tail, norm_ord='euclidean'):
17 | """
18 | Calculates the energies of the batch of triples corresponding to head, rel and tail.
19 | Energy should be in (-inf, 0]
20 | :param head: [batch_size] vector of head entity ids
21 | :param rel: [batch_size] vector of relation ids
22 | :param tail: [batch_size] vector of tail entity ids
23 | :param norm_ord: order of the normalization function
24 | :return: [batch_size] vector of energies of the passed triples
25 | """
26 | raise NotImplementedError("subclass should implement")
27 |
28 | def embedding_lookup(self, ids):
29 | """
30 | returns embedding vectors or tuple of embedding vectors for the passed ids
31 | :param ids: ids of embedding vectors in an embedding matrix
32 | :return: embedding vectors or tuple of embedding vectors for the passed ids
33 | """
34 | raise NotImplementedError("subclass should implement")
35 |
36 | def normalize_parameters(self):
37 | """
38 | Returns the op that enforces the constraint that each embedding vector have a norm <= 1
39 | :return: the op that enforces the constraint that each embedding vector have a norm <= 1
40 | """
41 | raise NotImplementedError("subclass should implement")
42 |
43 | # noinspection PyMethodMayBeStatic
44 | def normalize(self, rows, mat, ids):
45 | """
46 | Normalizes the rows of the matrix mat corresponding to ids s.t. |row|_2 <= 1 for each row.
47 | :param rows: rows from the matrix mat, embedding vectors
48 | :param mat: matrix of embedding vectors
49 | :param ids: ids of rows in mat
50 | :return: the scatter update op that updates only the rows of mat corresponding to ids
51 | """
52 | norm = tf.norm(rows)
53 | scaling = 1. / tf.maximum(norm, 1.)
54 | scaled = scaling * rows
55 | return tf.scatter_update(mat, ids, scaled)
56 |
57 |
58 | class TransE(BaseModel):
59 | def __init__(self, config, embeddings_dict=None):
60 | BaseModel.__init__(self, config)
61 | with tf.device("/%s:0" % self.embedding_device):
62 | if embeddings_dict is None:
63 | self.embeddings = tf.get_variable("embeddings",
64 | shape=[self.vocab_size, self.embedding_size],
65 | dtype=tf.float32,
66 | initializer=layers.xavier_initializer())
67 | else:
68 | self.embeddings = tf.Variable(embeddings_dict['embs'], name="embeddings")
69 |
70 | def energy(self, head, rel, tail, norm_ord='euclidean'):
71 | h = self.embedding_lookup(head)
72 | r = self.embedding_lookup(rel)
73 | t = self.embedding_lookup(tail)
74 |
75 | return tf.norm(h + r - t,
76 | ord=norm_ord,
77 | axis=-1,
78 | keepdims=False,
79 | name='energy')
80 |
81 | def embedding_lookup(self, ids):
82 | with tf.device("/%s:0" % self.embedding_device):
83 | return tf.nn.embedding_lookup(self.embeddings, ids)
84 |
85 | def normalize_parameters(self):
86 | """
87 | Enforces the contraint that the embedding vectors corresponding to ids_to_update <= 1.0
88 | """
89 | params1 = self.embedding_lookup(self.ids_to_update)
90 | self.norm_op = self.normalize(params1, self.embeddings, self.ids_to_update)
91 |
92 | return self.norm_op
93 |
94 |
95 | class TransD(BaseModel):
96 | def __init__(self, config, embeddings_dict=None):
97 | BaseModel.__init__(self, config)
98 | with tf.device("/%s:0" % self.embedding_device):
99 | if embeddings_dict is None:
100 | print('Initializing embeddings.')
101 | self.embeddings = tf.get_variable("embeddings",
102 | shape=[self.vocab_size, self.embedding_size], # 364373
103 | dtype=tf.float32,
104 | initializer=layers.xavier_initializer())
105 | else:
106 | print('Loading embeddings.')
107 | self.embeddings = tf.Variable(embeddings_dict['embs'], name="embeddings")
108 | with tf.device("/%s:%d" % (self.embedding_device, 0 if self.embedding_device == 'cpu' else 0)):
109 | if embeddings_dict is None or 'p_embs' not in embeddings_dict:
110 | print('Initializing projection embeddings.')
111 | if config.p_init == 'zeros':
112 | p_init = tf.initializers.zeros()
113 | elif config.p_init == 'xavier':
114 | p_init = layers.xavier_initializer()
115 | elif config.p_init == 'uniform':
116 | p_init = tf.initializers.random_uniform(minval=-0.1, maxval=0.1, dtype=tf.float32)
117 | else:
118 | raise Exception('unrecognized p initializer: %s' % config.p_init)
119 |
120 | # projection embeddings initialized to zeros
121 | self.p_embeddings = tf.get_variable("p_embeddings",
122 | shape=[self.vocab_size, self.embedding_size],
123 | dtype=tf.float32,
124 | initializer=p_init)
125 | else:
126 | print('Loading projection embeddings.')
127 | self.p_embeddings = tf.Variable(embeddings_dict['p_embs'], name="p_embeddings")
128 |
129 | def energy(self, head, rel, tail, norm_ord='euclidean'):
130 | """
131 | Computes the TransD energy of a relation triple
132 | :param head: head concept embedding ids [batch_size]
133 | :param rel: relation embedding ids [batch_size]
134 | :param tail: tail concept embedding ids [batch_size]
135 | :param norm_ord: norm order ['euclidean', 'fro', 'inf', 1, 2, 3, etc.]
136 | :return: [batch_size] vector of energies
137 | """
138 | # x & x_proj both [batch_size, embedding_size]
139 | h, h_proj = self.embedding_lookup(head)
140 | r, r_proj = self.embedding_lookup(rel)
141 | t, t_proj = self.embedding_lookup(tail)
142 |
143 | # [batch_size]
144 | return tf.norm(self.project(h, h_proj, r_proj) + r - self.project(t, t_proj, r_proj),
145 | ord=norm_ord,
146 | axis=1,
147 | keepdims=False,
148 | name="energy")
149 |
150 | # noinspection PyMethodMayBeStatic
151 | def project(self, c, c_proj, r_proj):
152 | """
153 | Computes the projected concept embedding for relation r according to TransD:
154 | (c_proj^T*c)*r_proj + c
155 | :param c: concept embeddings [batch_size, embedding_size]
156 | :param c_proj: concept projection embeddings [batch_size, embedding_size]
157 | :param r_proj: relation projection embeddings [batch_size, embedding_size]
158 | :return: projected concept embedding [batch_size, embedding_size]
159 | """
160 | return c + tf.reduce_sum(c * c_proj, axis=-1, keepdims=True) * r_proj
161 |
162 | def embedding_lookup(self, ids):
163 | with tf.device("/%s:0" % self.embedding_device):
164 | params1 = tf.nn.embedding_lookup(self.embeddings, ids)
165 | with tf.device("/%s:%d" % (self.embedding_device, 1 if self.embedding_device == 'cpu' else 0)):
166 | params2 = tf.nn.embedding_lookup(self.p_embeddings, ids)
167 | return params1, params2
168 |
169 | def normalize_parameters(self):
170 | """
171 | Normalizes the vectors of embeddings corresponding to the passed ids
172 | :return: the normalization op
173 | """
174 | with tf.device("/%s:0" % self.embedding_device):
175 | params1 = tf.nn.embedding_lookup(self.embeddings, self.ids_to_update)
176 | with tf.device("/%s:%d" % (self.embedding_device, 1 if self.embedding_device == 'cpu' else 0)):
177 | params2 = tf.nn.embedding_lookup(self.p_embeddings, self.ids_to_update)
178 |
179 | n1 = self.normalize(params1, self.embeddings, self.ids_to_update)
180 | n2 = self.normalize(params2, self.p_embeddings, self.ids_to_update)
181 | self.norm_op = n1, n2
182 |
183 | return self.norm_op
184 |
185 |
186 | class DistMult(TransE):
187 | def __init__(self, config, embeddings_dict=None):
188 | BaseModel.__init__(self, config)
189 | if config.energy_activation == 'relu':
190 | self.energy_activation = tf.nn.relu
191 | elif config.energy_activation == 'tanh':
192 | self.energy_activation = tf.nn.tanh
193 | elif config.energy_activation == 'sigmoid':
194 | self.energy_activation = tf.nn.sigmoid
195 | elif config.energy_activation is None:
196 | self.energy_activation = lambda x: x
197 | else:
198 | raise Exception('Unrecognized activation: %s' % config.energy_activation)
199 |
200 | with tf.device("/%s:0" % self.embedding_device):
201 | if embeddings_dict is None:
202 | self.embeddings = tf.get_variable("embeddings",
203 | shape=[self.vocab_size, self.embedding_size],
204 | dtype=tf.float32,
205 | initializer=tf.initializers.random_uniform(minval=-0.5,
206 | maxval=0.5,
207 | dtype=tf.float32))
208 | else:
209 | self.embeddings = tf.Variable(embeddings_dict['embs'], name="embeddings")
210 |
211 | def energy(self, head, rel, tail, norm_ord='euclidean'):
212 | h = self.embedding_lookup(head)
213 | r = self.embedding_lookup(rel)
214 | t = self.embedding_lookup(tail)
215 |
216 | return self.energy_activation(tf.reduce_sum(h * r * t,
217 | axis=-1,
218 | keepdims=False))
219 |
220 | def normalize_parameters(self):
221 | return tf.no_op()
222 |
223 | def regularization(self, parameters):
224 | reg_term = 0
225 | for p in parameters:
226 | reg_term += tf.reduce_sum(tf.norm(self.embedding_lookup(p)))
227 | return reg_term
228 |
--------------------------------------------------------------------------------
/python/eukg/emb/Smoothing.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | # from .BaseModel import BaseModel
4 |
5 |
6 | def add_semantic_network_loss(model):
7 | """
8 | Adds the semantic network loss to the graph
9 | :param model:
10 | :type model: BaseModel.BaseModel
11 | :return: the semantic network loss - 1D real valued tensor
12 | """
13 | print('Adding semantic network to graph')
14 | # dataset3: sr, ss, so, sns, sno, t, conc, counts
15 | model.smoothing_placeholders['sn_relations'] = rel = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_relations')
16 | model.smoothing_placeholders['sn_pos_subj'] = psubj = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_pos_subj')
17 | model.smoothing_placeholders['sn_pos_obj'] = pobj = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_pos_obj')
18 | model.smoothing_placeholders['sn_neg_subj'] = nsubj = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_neg_subj')
19 | model.smoothing_placeholders['sn_neg_obj'] = nobj = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_neg_obj')
20 | model.smoothing_placeholders['sn_types'] = types = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_types')
21 | model.smoothing_placeholders['sn_concepts'] = concepts = tf.placeholder(dtype=tf.int32,
22 | shape=[None, model.max_concepts_per_type],
23 | name='sn_concepts')
24 | model.smoothing_placeholders['sn_conc_counts'] = counts = tf.placeholder(dtype=tf.int32,
25 | shape=[None],
26 | name='sn_conc_counts')
27 |
28 | with tf.variable_scope("energy"):
29 | model.sn_pos_energy = pos_energy = model.embedding_model.energy(psubj, rel, pobj, model.energy_norm)
30 | with tf.variable_scope("energy", reuse=True):
31 | model.sn_neg_energy = neg_energy = model.embedding_model.energy(nsubj, rel, nobj, model.energy_norm)
32 | energy_loss = tf.reduce_mean(tf.nn.relu(model.gamma - neg_energy + pos_energy))
33 | model.sn_reward = -tf.reduce_mean(neg_energy, name='sn_reward')
34 |
35 | # [batch_size, embedding_size]
36 | type_embeddings = model.embedding_model.embedding_lookup(types)
37 | # [batch_size, max_concepts_per_type, embedding_size]
38 | concepts_embeddings = model.embedding_model.embedding_lookup(concepts)
39 |
40 | def calc_alignment_loss(_type_embeddings, _concept_embeddings):
41 | mask = tf.expand_dims(tf.sequence_mask(counts, maxlen=model.max_concepts_per_type, dtype=tf.float32), axis=-1)
42 | sum_ = tf.reduce_sum(mask * _concept_embeddings, axis=1, keepdims=False)
43 | float_counts = tf.to_float(tf.expand_dims(counts, axis=-1))
44 | avg_conc_embeddings = sum_ / tf.maximum(float_counts, tf.ones_like(float_counts))
45 | return tf.reduce_mean(tf.abs(_type_embeddings - avg_conc_embeddings))
46 |
47 | if isinstance(type_embeddings, tuple):
48 | alignment_loss = 0
49 | for type_embeddings_i, concepts_embeddings_i in zip(type_embeddings, concepts_embeddings):
50 | alignment_loss += calc_alignment_loss(type_embeddings_i, concepts_embeddings_i)
51 | else:
52 | alignment_loss = calc_alignment_loss(type_embeddings, concepts_embeddings)
53 |
54 | # summary
55 | tf.summary.scalar('sn_energy_loss', energy_loss)
56 | tf.summary.scalar('sn_alignment_loss', alignment_loss)
57 | # tf.summary.scalar('sn_accuracy', accuracy)
58 | avg_pos_energy = tf.reduce_mean(pos_energy)
59 | tf.summary.scalar('sn_pos_energy', avg_pos_energy)
60 | avg_neg_energy = tf.reduce_mean(neg_energy)
61 | tf.summary.scalar('sn_neg_energy', avg_neg_energy)
62 |
63 | # is this loss?
64 | loss = model.semnet_energy_param * energy_loss + model.semnet_alignment_param * alignment_loss
65 | return loss
66 |
67 |
68 | def add_gen_semantic_network(model):
69 | """
70 | Adds the semantic network loss to the graph
71 | :param model:
72 | :type model: ..gan.Generator.Generator
73 | :return: the semantic network loss - 1D real valued tensor
74 | """
75 | print('Adding semantic network to graph')
76 | # dataset4: sr, ss, so, sns, sno, t, conc, counts
77 | model.smoothing_placeholders['sn_relations'] = rel = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_relations')
78 | model.smoothing_placeholders['sn_pos_subj'] = psubj = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_pos_subj')
79 | model.smoothing_placeholders['sn_pos_obj'] = pobj = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_pos_obj')
80 | model.smoothing_placeholders['sn_neg_subj'] = nsubj = tf.placeholder(dtype=tf.int32, shape=[None, None],
81 | name='sn_neg_subj')
82 | model.smoothing_placeholders['sn_neg_obj'] = nobj = tf.placeholder(dtype=tf.int32, shape=[None, None],
83 | name='sn_neg_obj')
84 | model.smoothing_placeholders['sn_concepts'] = concepts = tf.placeholder(dtype=tf.int32,
85 | shape=[None, model.max_concepts_per_type],
86 | name='sn_concepts')
87 | model.smoothing_placeholders['sn_conc_counts'] = counts = tf.placeholder(dtype=tf.int32,
88 | shape=[None],
89 | name='sn_conc_counts')
90 | model.smoothing_placeholders['sn_types'] = types = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_types')
91 |
92 | true_energy = model.embedding_model.energy(psubj, rel, pobj)
93 | sampl_energy = model.embedding_model.energy(nsubj, tf.expand_dims(rel, axis=-1), nobj, model.energy_norm)
94 | if model.gan_mode:
95 | model.sampl_distributions = tf.nn.softmax(-sampl_energy, axis=-1)
96 | model.type_probabilities = tf.gather_nd(model.sampl_distributions, model.gan_loss_sample, name='sampl_probs')
97 | else:
98 | sm_numerator = tf.exp(-true_energy)
99 | exp_sampl_energies = tf.exp(-sampl_energy)
100 | sm_denominator = tf.reduce_sum(exp_sampl_energies, axis=-1) + sm_numerator
101 | model.type_probabilities = sm_numerator / sm_denominator
102 | energy_loss = -tf.reduce_mean(tf.log(model.type_probabilities))
103 |
104 | # [batch_size, embedding_size]
105 | type_embeddings = model.embedding_model.embedding_lookup(types)
106 | # [batch_size, max_concepts_per_type, embedding_size]
107 | concepts_embeddings = model.embedding_model.embedding_lookup(concepts)
108 |
109 | def calc_alignment_loss(_type_embeddings, _concept_embeddings):
110 | mask = tf.expand_dims(tf.sequence_mask(counts, maxlen=model.max_concepts_per_type, dtype=tf.float32), axis=-1)
111 | sum_ = tf.reduce_sum(mask * _concept_embeddings, axis=1, keepdims=False)
112 | float_counts = tf.to_float(tf.expand_dims(counts, axis=-1))
113 | avg_conc_embeddings = sum_ / tf.maximum(float_counts, tf.ones_like(float_counts))
114 | return tf.reduce_mean(tf.abs(_type_embeddings - avg_conc_embeddings))
115 |
116 | if isinstance(type_embeddings, tuple):
117 | alignment_loss = 0
118 | for type_embeddings_i, concepts_embeddings_i in zip(type_embeddings, concepts_embeddings):
119 | alignment_loss += calc_alignment_loss(type_embeddings_i, concepts_embeddings_i)
120 | else:
121 | alignment_loss = calc_alignment_loss(type_embeddings, concepts_embeddings)
122 |
123 | return energy_loss, alignment_loss
124 |
--------------------------------------------------------------------------------
/python/eukg/emb/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('/home/rmm120030/code/python/tf_util')
3 |
--------------------------------------------------------------------------------
/python/eukg/gan/Discriminator.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | from ..data import DataGenerator
5 | from ..emb import Smoothing
6 |
7 | from tf_util import Trainable
8 |
9 |
10 | class BaseModel(Trainable):
11 | def __init__(self, config, embedding_model, data_generator=None):
12 | """
13 | :param config: config map
14 | :param embedding_model: KG Embedding Model
15 | :type embedding_model: EmbeddingModel.BaseModel
16 | :param data_generator: data generator
17 | :type data_generator: DataGenerator.DataGenerator
18 | """
19 | Trainable.__init__(self)
20 | self.model = config.model
21 | self.embedding_model = embedding_model
22 | self.data_generator = data_generator
23 |
24 | # class variable declarations
25 | self.batch_size = config.batch_size
26 | # self.embedding_size = config.embedding_size
27 | self.vocab_size = config.vocab_size
28 | self.gamma = config.gamma
29 | self.embedding_device = config.embedding_device
30 | self.max_concepts_per_type = config.max_concepts_per_type
31 | # self.is_training = is_training
32 | self.energy_norm = config.energy_norm_ord
33 | self.use_semantic_network = not config.no_semantic_network
34 | self.semnet_alignment_param = config.semnet_alignment_param
35 | self.semnet_energy_param = config.semnet_energy_param
36 | self.regulatization_parameter = config.regularization_parameter
37 |
38 | # optimization
39 | self.learning_rate = tf.train.exponential_decay(config.learning_rate,
40 | tf.train.get_or_create_global_step(),
41 | config.batches_per_epoch,
42 | config.decay_rate,
43 | staircase=True)
44 | if config.optimizer == "adam":
45 | self.optimizer = lambda: tf.train.AdamOptimizer(self.learning_rate)
46 | else:
47 | self.optimizer = lambda: tf.train.MomentumOptimizer(self.learning_rate,
48 | config.momentum,
49 | use_nesterov=True)
50 |
51 | # input placeholders
52 | self.relations = tf.placeholder(dtype=tf.int32, shape=[None], name='relations')
53 | self.pos_subj = tf.placeholder(dtype=tf.int32, shape=[None], name='pos_subj')
54 | self.pos_obj = tf.placeholder(dtype=tf.int32, shape=[None], name='pos_obj')
55 | self.neg_subj = tf.placeholder(dtype=tf.int32, shape=[None], name='neg_subj')
56 | self.neg_obj = tf.placeholder(dtype=tf.int32, shape=[None], name='neg_obj')
57 | self.labels = tf.ones_like(self.relations, dtype=tf.int32)
58 | self.smoothing_placeholders = {}
59 |
60 | self.pos_energy = None
61 | self.neg_energy = None
62 | self.predictions = None
63 | self.reward = None
64 | self.sn_reward = None
65 | self.loss = None
66 | self.streaming_accuracy = None
67 | self.accuracy = None
68 | self.avg_pos_energy = None
69 | self.avg_neg_energy = None
70 | self.summary = None
71 | self.train_op = None
72 |
73 | # define reset op (for resetting counts for streaming metrics after each validation epoch)
74 | self.reset_streaming_metrics_op = tf.variables_initializer(tf.local_variables())
75 |
76 | # define norm op
77 | self.norm_op = self.embedding_model.normalize_parameters()
78 | self.ids_to_update = self.embedding_model.ids_to_update
79 |
80 | def build(self):
81 | # energies
82 | with tf.variable_scope("energy"):
83 | self.pos_energy = self.embedding_model.energy(self.pos_subj, self.relations, self.pos_obj, self.energy_norm)
84 | with tf.variable_scope("energy", reuse=True):
85 | self.neg_energy = self.embedding_model.energy(self.neg_subj, self.relations, self.neg_obj, self.energy_norm)
86 | self.predictions = tf.argmax(tf.stack([self.pos_energy, self.neg_energy], axis=1), axis=1, output_type=tf.int32)
87 | self.reward = tf.reduce_mean(self.neg_energy, name='reward')
88 |
89 | # loss
90 | self.loss = tf.reduce_mean(tf.nn.relu(self.gamma - self.neg_energy + self.pos_energy), name='loss')
91 |
92 | if self.model == 'distmult':
93 | reg = self.regulatization_parameter * self.embedding_model.regularization([self.pos_subj, self.pos_obj,
94 | self.neg_subj, self.neg_obj,
95 | self.relations])
96 | tf.summary.scalar('reg', reg)
97 | tf.summary.scalar('margin_loss', self.loss)
98 | self.loss += reg
99 |
100 | if self.use_semantic_network:
101 | semnet_loss = Smoothing.add_semantic_network_loss(self)
102 | self.loss += semnet_loss
103 | tf.summary.scalar('sn_loss', semnet_loss)
104 | # backprop
105 | self.train_op = self.optimizer().minimize(self.loss, tf.train.get_or_create_global_step())
106 |
107 | # summary
108 | tf.summary.scalar('loss', self.loss)
109 | _, self.streaming_accuracy = tf.metrics.accuracy(labels=self.labels, predictions=self.predictions)
110 | tf.summary.scalar('streaming_accuracy', self.streaming_accuracy)
111 | self.accuracy = tf.reduce_mean(tf.to_float(tf.equal(self.predictions, self.labels)))
112 | tf.summary.scalar('accuracy', self.accuracy)
113 | self.avg_pos_energy = tf.reduce_mean(self.pos_energy)
114 | tf.summary.scalar('pos_energy', self.avg_pos_energy)
115 | self.avg_neg_energy = tf.reduce_mean(self.neg_energy)
116 | tf.summary.scalar('neg_energy', self.avg_neg_energy)
117 | tf.summary.scalar('margin', self.avg_neg_energy - self.avg_pos_energy)
118 | self.summary = tf.summary.merge_all()
119 |
120 | def fetches(self, is_training, verbose=False):
121 | fetches = [self.summary, self.loss]
122 | if verbose:
123 | if is_training:
124 | fetches += [self.accuracy]
125 | else:
126 | fetches += [self.streaming_accuracy]
127 | fetches += [self.avg_pos_energy, self.avg_neg_energy]
128 | if is_training:
129 | fetches += [self.train_op]
130 | return fetches
131 |
132 | def prepare_feed_dict(self, batch, is_training, **kwargs):
133 | # return {}
134 | if self.use_semantic_network:
135 | if is_training:
136 | rel, psub, pobj, nsub, nobj, sn_rel, sn_psub, sn_pobj, sn_nsub, sn_nobj, conc, c_lens, types = batch
137 | return {self.relations: rel,
138 | self.pos_subj: psub,
139 | self.pos_obj: pobj,
140 | self.neg_subj: nsub,
141 | self.neg_obj: nobj,
142 | self.smoothing_placeholders['sn_relations']: sn_rel,
143 | self.smoothing_placeholders['sn_pos_subj']: sn_psub,
144 | self.smoothing_placeholders['sn_pos_obj']: sn_pobj,
145 | self.smoothing_placeholders['sn_neg_subj']: sn_nsub,
146 | self.smoothing_placeholders['sn_neg_obj']: sn_nobj,
147 | self.smoothing_placeholders['sn_concepts']: conc,
148 | self.smoothing_placeholders['sn_conc_counts']: c_lens,
149 | self.smoothing_placeholders['sn_types']: types}
150 | else:
151 | rel, psub, pobj, nsub, nobj = batch
152 | return {self.relations: rel,
153 | self.pos_subj: psub,
154 | self.pos_obj: pobj,
155 | self.neg_subj: nsub,
156 | self.neg_obj: nobj,
157 | self.smoothing_placeholders['sn_relations']: [0],
158 | self.smoothing_placeholders['sn_pos_subj']: [0],
159 | self.smoothing_placeholders['sn_pos_obj']: [0],
160 | self.smoothing_placeholders['sn_neg_subj']: [0],
161 | self.smoothing_placeholders['sn_neg_obj']: [0],
162 | self.smoothing_placeholders['sn_concepts']: np.zeros([1, 1000], dtype=np.int32),
163 | self.smoothing_placeholders['sn_conc_counts']: [1],
164 | self.smoothing_placeholders['sn_types']: [0]}
165 | else:
166 | rel, psub, pobj, nsub, nobj = batch
167 | return {self.relations: rel,
168 | self.pos_subj: psub,
169 | self.pos_obj: pobj,
170 | self.neg_subj: nsub,
171 | self.neg_obj: nobj}
172 |
173 | def progress_update(self, batch, fetched, **kwargs):
174 | print('Avg loss of last batch: %.4f' % np.average(fetched[1]))
175 | print('Accuracy of last batch: %.4f' % np.average(fetched[2]))
176 | print('Avg pos energy of last batch: %.4f' % np.average(fetched[3]))
177 | print('Avg neg energy of last batch: %.4f' % np.average(fetched[4]))
178 |
179 | def data_provider(self, config, is_training, **kwargs):
180 | if self.use_semantic_network:
181 | return DataGenerator.wrap_generators(self.data_generator.generate_mt,
182 | self.data_generator.generate_sn, is_training)
183 | else:
184 | return self.data_generator.generate_mt(is_training)
185 |
--------------------------------------------------------------------------------
/python/eukg/gan/Generator.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | from Discriminator import BaseModel
5 | from ..emb import Smoothing
6 | from ..data import DataGenerator
7 |
8 |
9 | class Generator(BaseModel):
10 | def __init__(self, config, embedding_model, data_generator=None):
11 | BaseModel.__init__(self, config, embedding_model, data_generator)
12 | self.num_samples = config.num_generator_samples
13 | self.gan_mode = False
14 |
15 | # dataset2: s, r, o, ns, no
16 | self.neg_subj = tf.placeholder(dtype=tf.int32, shape=[None, self.num_samples], name="neg_subj")
17 | self.neg_obj = tf.placeholder(dtype=tf.int32, shape=[None, self.num_samples], name="neg_obj")
18 | self.discounted_reward = tf.placeholder(dtype=tf.float32, shape=[], name="discounted_reward")
19 | self.gan_loss_sample = tf.placeholder(dtype=tf.int32, shape=[None, 2], name="gan_loss_sample")
20 |
21 | self.probabilities = None
22 | self.sampl_energies = None
23 | self.true_energies = None
24 | self.probability_distributions = None
25 |
26 | # semantic network vars
27 | self.type_probabilities = None
28 |
29 | def build(self):
30 | summary = []
31 | # [batch_size, num_samples]
32 | self.sampl_energies = self.embedding_model.energy(self.neg_subj,
33 | tf.expand_dims(self.relations, axis=1),
34 | self.neg_obj)
35 | # [batch_size]
36 | self.true_energies = self.embedding_model.energy(self.pos_subj,
37 | self.relations,
38 | self.pos_obj)
39 |
40 | # backprop
41 | optimizer = self.optimizer()
42 | # [batch_size]
43 | sm_numerator = tf.exp(-self.true_energies)
44 | # [batch_size]
45 | exp_sampl_nergies = tf.exp(-self.sampl_energies)
46 | sm_denominator = tf.reduce_sum(exp_sampl_nergies, axis=-1) + sm_numerator
47 | # [batch_size]
48 | self.probabilities = sm_numerator / sm_denominator
49 | self.loss = -tf.reduce_mean(tf.log(self.probabilities))
50 |
51 | # regularization for distmult
52 | if self.model == "distmult":
53 | reg = self.regulatization_parameter * self.embedding_model.regularization([self.pos_subj, self.pos_obj,
54 | self.neg_subj, self.neg_obj,
55 | self.relations])
56 | summary += [tf.summary.scalar('reg', reg),
57 | tf.summary.scalar('log_prob', self.loss)]
58 | self.loss += reg
59 |
60 | if self.use_semantic_network:
61 | sn_energy_loss, sn_alignment_loss = Smoothing.add_gen_semantic_network(self)
62 | self.loss += self.semnet_energy_param * sn_energy_loss + self.semnet_alignment_param * sn_alignment_loss
63 | summary += [tf.summary.scalar('sn_energy_loss', sn_energy_loss / self.batch_size),
64 | tf.summary.scalar('sn_alignment_loss', sn_alignment_loss / self.batch_size)]
65 |
66 | self.avg_pos_energy = tf.reduce_mean(self.true_energies)
67 | self.avg_neg_energy = tf.reduce_mean(self.sampl_energies)
68 | summary += [tf.summary.scalar('loss', self.loss),
69 | tf.summary.scalar('avg_prob', tf.reduce_mean(self.probabilities)),
70 | tf.summary.scalar('min_prob', tf.reduce_min(self.probabilities)),
71 | tf.summary.scalar('max_prob', tf.reduce_max(self.probabilities)),
72 | tf.summary.scalar('pos_energy', self.avg_pos_energy),
73 | tf.summary.scalar('neg_energy', self.avg_neg_energy),
74 | tf.summary.scalar('margin', self.avg_pos_energy - self.avg_neg_energy)]
75 | self.train_op = optimizer.minimize(self.loss, tf.train.get_or_create_global_step())
76 |
77 | # summary
78 | self.summary = tf.summary.merge(summary)
79 |
80 | def fetches(self, is_training, verbose=False):
81 | fetches = [self.summary, self.loss]
82 | if verbose:
83 | fetches += [self.probabilities, self.avg_pos_energy, self.avg_neg_energy]
84 | if is_training:
85 | fetches += [self.train_op]
86 | return fetches
87 |
88 | def prepare_feed_dict(self, batch, is_training, **kwargs):
89 | if self.use_semantic_network:
90 | if is_training:
91 | rel, psub, pobj, nsub, nobj, sn_rel, sn_psub, sn_pobj, sn_nsub, sn_nobj, conc, c_lens, types = batch
92 | return {self.relations: rel,
93 | self.pos_subj: psub,
94 | self.pos_obj: pobj,
95 | self.neg_subj: nsub,
96 | self.neg_obj: nobj,
97 | self.smoothing_placeholders['sn_relations']: sn_rel,
98 | self.smoothing_placeholders['sn_pos_subj']: sn_psub,
99 | self.smoothing_placeholders['sn_pos_obj']: sn_pobj,
100 | self.smoothing_placeholders['sn_neg_subj']: sn_nsub,
101 | self.smoothing_placeholders['sn_neg_obj']: sn_nobj,
102 | self.smoothing_placeholders['sn_concepts']: conc,
103 | self.smoothing_placeholders['sn_conc_counts']: c_lens,
104 | self.smoothing_placeholders['sn_types']: types}
105 | else:
106 | rel, psub, pobj, nsub, nobj = batch
107 | return {self.relations: rel,
108 | self.pos_subj: psub,
109 | self.pos_obj: pobj,
110 | self.neg_subj: nsub,
111 | self.neg_obj: nobj,
112 | self.smoothing_placeholders['sn_relations']: [0],
113 | self.smoothing_placeholders['sn_pos_subj']: [0],
114 | self.smoothing_placeholders['sn_pos_obj']: [0],
115 | self.smoothing_placeholders['sn_neg_subj']: [[0]],
116 | self.smoothing_placeholders['sn_neg_obj']: [[0]],
117 | self.smoothing_placeholders['sn_concepts']: np.zeros([1, 1000], dtype=np.int32),
118 | self.smoothing_placeholders['sn_conc_counts']: [1],
119 | self.smoothing_placeholders['sn_types']: [0]}
120 | else:
121 | rel, psub, pobj, nsub, nobj = batch
122 | return {self.relations: rel,
123 | self.pos_subj: psub,
124 | self.pos_obj: pobj,
125 | self.neg_subj: nsub,
126 | self.neg_obj: nobj}
127 |
128 | def progress_update(self, batch, fetched, **kwargs):
129 | print('Avg loss of last batch: %.4f' % np.average(fetched[1]))
130 | print('Avg probability of last batch: %.4f' % np.average(fetched[2]))
131 | print('Avg pos energy of last batch: %.4f' % np.average(fetched[3]))
132 | print('Avg neg energy of last batch: %.4f' % np.average(fetched[4]))
133 |
134 | def data_provider(self, config, is_training, **kwargs):
135 | if self.use_semantic_network:
136 | return DataGenerator.wrap_generators(self.data_generator.generate_mt_gen_mode,
137 | self.data_generator.generate_sn_gen_mode, is_training)
138 | else:
139 | return self.data_generator.generate_mt_gen_mode(is_training)
140 |
141 |
142 | class GanGenerator(Generator):
143 | def __init__(self, config, embedding_model, data_generator=None):
144 | Generator.__init__(self, config, embedding_model, data_generator)
145 | self.gan_mode = True
146 | self.sampl_distributions = None
147 |
148 | def build(self):
149 | # [batch_size, num_samples]
150 | self.sampl_energies = self.embedding_model.energy(self.neg_subj,
151 | tf.expand_dims(self.relations, axis=1),
152 | self.neg_obj)
153 | # [batch_size]
154 | self.true_energies = self.embedding_model.energy(self.pos_subj,
155 | self.relations,
156 | self.pos_obj)
157 |
158 | optimizer = self.optimizer()
159 | if self.use_semantic_network:
160 | # this method also adds values for self.sampl_distributions and self.type_probabilities
161 | loss, _ = Smoothing.add_gen_semantic_network(self)
162 | grads_and_vars = optimizer.compute_gradients(loss)
163 | vars_with_grad = [v for g, v in grads_and_vars if g is not None]
164 | if not vars_with_grad:
165 | raise ValueError(
166 | "No gradients provided for any variable, check your graph for ops"
167 | " that do not support gradients, between variables %s and loss %s." %
168 | ([str(v) for _, v in grads_and_vars], loss))
169 | discounted_grads_and_vars = [(self.discounted_reward * g, v) for g, v in grads_and_vars if g is not None]
170 | self.train_op = optimizer.apply_gradients(discounted_grads_and_vars,
171 | global_step=tf.train.get_or_create_global_step())
172 | summary = [tf.summary.scalar('avg_st_prob', tf.reduce_mean(self.type_probabilities)),
173 | tf.summary.scalar('sn_loss', loss / self.batch_size),
174 | tf.summary.scalar('reward', self.discounted_reward)]
175 | else:
176 | # [batch_size, num_samples] - this is for sampling during GAN training
177 | self.probability_distributions = tf.nn.softmax(self.sampl_energies, axis=-1)
178 | self.probabilities = tf.gather_nd(self.probability_distributions, self.gan_loss_sample, name='sampl_probs')
179 | loss = -tf.reduce_sum(tf.log(self.probabilities))
180 | summary = [tf.summary.scalar('avg_sampled_prob', tf.reduce_mean(self.probabilities))]
181 |
182 | # if training as part of a GAN, gradients should be scaled by discounted_reward
183 | grads_and_vars = optimizer.compute_gradients(loss)
184 | vars_with_grad = [v for g, v in grads_and_vars if g is not None]
185 | if not vars_with_grad:
186 | raise ValueError(
187 | "No gradients provided for any variable, check your graph for ops"
188 | " that do not support gradients, between variables %s and loss %s." %
189 | ([str(v) for _, v in grads_and_vars], loss))
190 | discounted_grads_and_vars = [(self.discounted_reward * g, v) for g, v in grads_and_vars if g is not None]
191 | self.train_op = optimizer.apply_gradients(discounted_grads_and_vars,
192 | global_step=tf.train.get_or_create_global_step())
193 |
194 | # reporting loss
195 | self.loss = loss / self.batch_size
196 | # summary
197 | self.summary = tf.summary.merge(summary)
198 |
--------------------------------------------------------------------------------
/python/eukg/gan/Generator.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-mal/umls-embeddings/1ac29b15f5bedf8edbd2c90d0ccdae094be03518/python/eukg/gan/Generator.pyc
--------------------------------------------------------------------------------
/python/eukg/gan/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('/home/rmm120030/code/python/tf_util')
--------------------------------------------------------------------------------
/python/eukg/gan/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-mal/umls-embeddings/1ac29b15f5bedf8edbd2c90d0ccdae094be03518/python/eukg/gan/__init__.pyc
--------------------------------------------------------------------------------
/python/eukg/gan/train_gan.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | import math
4 | import random
5 | from tqdm import tqdm
6 | import numpy as np
7 | from itertools import izip
8 |
9 | from .. import Config
10 | from ..data import data_util, DataGenerator
11 | from ..emb import EmbeddingModel
12 | from Generator import GanGenerator
13 | import Discriminator
14 |
15 |
16 | # noinspection PyUnboundLocalVariable
17 | def train():
18 | config = Config.flags
19 |
20 | use_semnet = not config.no_semantic_network
21 |
22 | # init model dir
23 | gan_model_dir = os.path.join(config.model_dir, config.model, config.run_name)
24 | if not os.path.exists(gan_model_dir):
25 | os.makedirs(gan_model_dir)
26 |
27 | # init summaries dir
28 | config.summaries_dir = os.path.join(config.summaries_dir, config.run_name)
29 | if not os.path.exists(config.summaries_dir):
30 | os.makedirs(config.summaries_dir)
31 |
32 | # save the config
33 | data_util.save_config(gan_model_dir, config)
34 |
35 | # load data
36 | cui2id, data, train_idx, val_idx = data_util.load_metathesaurus_data(config.data_dir, config.val_proportion)
37 | config.val_progress_update_interval = int(math.floor(float(len(val_idx)) / config.batch_size))
38 | config.batches_per_epoch = int(math.floor(float(len(train_idx)) / config.batch_size))
39 | if not config.no_semantic_network:
40 | type2cuis = data_util.load_semantic_network_data(config.data_dir, data)
41 | else:
42 | type2cuis = None
43 | data_generator = DataGenerator.DataGenerator(data, train_idx, val_idx, config, type2cuis)
44 |
45 | with tf.Graph().as_default(), tf.Session() as session:
46 | # init models
47 | with tf.variable_scope(config.dis_run_name):
48 | discriminator = init_model(config, 'disc')
49 | with tf.variable_scope(config.gen_run_name):
50 | config.no_semantic_network = True
51 | config.learning_rate = 1e-1
52 | generator = init_model(config, 'gen')
53 | if use_semnet:
54 | with tf.variable_scope(config.sn_gen_run_name):
55 | config.no_semantic_network = False
56 | sn_generator = init_model(config, 'sn_gen')
57 |
58 | tf.global_variables_initializer().run()
59 | tf.local_variables_initializer().run()
60 |
61 | # init saver
62 | dis_saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=config.dis_run_name),
63 | max_to_keep=10)
64 | gen_saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=config.gen_run_name),
65 | max_to_keep=10)
66 |
67 | # load models
68 | dis_ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, config.dis_run_name))
69 | dis_saver.restore(session, dis_ckpt)
70 | gen_ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, "distmult", config.gen_run_name))
71 | gen_saver.restore(session, gen_ckpt)
72 | if use_semnet:
73 | sn_gen_saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
74 | scope=config.sn_gen_run_name),
75 | max_to_keep=10)
76 | sn_gen_ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, "distmult", config.sn_gen_run_name))
77 | sn_gen_saver.restore(session, sn_gen_ckpt)
78 |
79 | # finalize graph
80 | tf.get_default_graph().finalize()
81 |
82 | # define streaming_accuracy reset per epoch
83 | print('local variables that will be reinitialized every epoch: %s' % tf.local_variables())
84 | reset_local_vars = lambda: session.run(discriminator.reset_streaming_metrics_op)
85 |
86 | # init summary directories and summary writers
87 | if not os.path.exists(os.path.join(config.summaries_dir, 'train')):
88 | os.makedirs(os.path.join(config.summaries_dir, 'train'))
89 | train_summary_writer = tf.summary.FileWriter(os.path.join(config.summaries_dir, 'train'))
90 | if not os.path.exists(os.path.join(config.summaries_dir, 'val')):
91 | os.makedirs(os.path.join(config.summaries_dir, 'val'))
92 | val_summary_writer = tf.summary.FileWriter(os.path.join(config.summaries_dir, 'val'))
93 |
94 | # config_map = config.flag_values_dict()
95 | # config_map['data'] = data
96 | # config_map['train_idx'] = train_idx
97 | # config_map['val_idx'] = val_idx
98 |
99 | global_step = 0
100 | for ep in xrange(config.num_epochs):
101 | print('----------------------------')
102 | print('Begin Train Epoch %d' % ep)
103 | if use_semnet:
104 | global_step = train_epoch_sn(session, discriminator, generator, sn_generator, config, data_generator,
105 | train_summary_writer, global_step)
106 | sn_gen_saver.save(session, os.path.join(gan_model_dir, 'sn_generator', config.model), global_step=global_step)
107 | else:
108 | global_step = train_epoch(session, discriminator, generator, config, data_generator, train_summary_writer,
109 | global_step)
110 | print("Saving models to %s at step %d" % (gan_model_dir, global_step))
111 | dis_saver.save(session, os.path.join(gan_model_dir, 'discriminator', config.model), global_step=global_step)
112 | gen_saver.save(session, os.path.join(gan_model_dir, 'generator', config.model), global_step=global_step)
113 | reset_local_vars()
114 | print('----------------------------')
115 | print('Begin Validation Epoch %d' % ep)
116 | validation_epoch(session, discriminator, config, data_generator, val_summary_writer, global_step)
117 |
118 |
119 | def init_model(config, mode):
120 | print('Initializing %s model...' % mode)
121 |
122 | if mode == 'disc':
123 | if config.model == 'transe':
124 | em = EmbeddingModel.TransE(config)
125 | elif config.model == 'transd':
126 | # config.embedding_size = config.embedding_size / 2
127 | em = EmbeddingModel.TransD(config)
128 | # config.embedding_size = config.embedding_size * 2
129 | else:
130 | raise ValueError('Unrecognized model type: %s' % config.model)
131 | model = Discriminator.BaseModel(config, em)
132 | elif mode == 'gen':
133 | em = EmbeddingModel.DistMult(config)
134 | model = GanGenerator(config, em)
135 | elif mode == 'sn_gen':
136 | em = EmbeddingModel.DistMult(config)
137 | model = GanGenerator(config, em)
138 | else:
139 | raise ValueError('Unrecognized mode: %s' % config.mode)
140 |
141 | model.build()
142 | return model
143 |
144 |
145 | def find_unique(tensor_list):
146 | if max([len(t.shape) for t in tensor_list[:10]]) == 1:
147 | return np.unique(np.concatenate(tensor_list[:10]))
148 | else:
149 | return np.unique(np.concatenate([t.flatten() for t in tensor_list[:10]]))
150 |
151 |
152 | def sample_corrupted_triples(sampl_sub, sampl_obj, probability_distributions, idx_np):
153 | nsub = []
154 | nobj = []
155 | sampl_idx = []
156 | for i, dist in enumerate(probability_distributions):
157 | [j] = np.random.choice(idx_np, [1], p=dist)
158 | nsub.append(sampl_sub[i, j])
159 | nobj.append(sampl_obj[i, j])
160 | sampl_idx.append([i, j])
161 | nsub = np.asarray(nsub)
162 | nobj = np.asarray(nobj)
163 | return nsub, nobj, sampl_idx
164 |
165 |
166 | def train_epoch(session, discriminator, generator, config, data_generator, summary_writer, global_step):
167 | baseline = 0.
168 | console_update_interval = config.progress_update_interval
169 | pbar = tqdm(total=console_update_interval)
170 | idx_np = np.arange(config.num_generator_samples)
171 | for b, batch in enumerate(data_generator.generate_mt_gen_mode(True)):
172 | verbose_batch = b > 0 and b % console_update_interval == 0
173 |
174 | # generation
175 | gen_feed_dict = generator.prepare_feed_dict(batch, True)
176 | probability_distributions = session.run(generator.probability_distributions, gen_feed_dict)
177 | rel, psub, pobj, sampl_sub, sampl_obj = batch
178 | nsub = []
179 | nobj = []
180 | sampl_idx = []
181 | for i, dist in enumerate(probability_distributions):
182 | [j] = np.random.choice(idx_np, [1], p=dist)
183 | nsub.append(sampl_sub[i, j])
184 | nobj.append(sampl_obj[i, j])
185 | sampl_idx.append([i, j])
186 | nsub = np.asarray(nsub)
187 | nobj = np.asarray(nobj)
188 |
189 | # discrimination
190 | dis_fetched = session.run(discriminator.fetches(True, verbose_batch) + [discriminator.reward],
191 | {discriminator.relations: rel,
192 | discriminator.pos_subj: psub,
193 | discriminator.pos_obj: pobj,
194 | discriminator.neg_subj: nsub,
195 | discriminator.neg_obj: nobj})
196 |
197 | # generation reward
198 | discounted_reward = dis_fetched[-1] - baseline
199 | baseline = dis_fetched[-1]
200 | gen_feed_dict[generator.discounted_reward] = discounted_reward
201 | gen_feed_dict[generator.gan_loss_sample] = np.asarray(sampl_idx)
202 | gen_fetched = session.run([generator.summary, generator.loss, generator.probabilities, generator.train_op],
203 | gen_feed_dict)
204 | # assert gloss == gen_fetched[1], \
205 | # "Forward pass for generation step does not match forward pass for generator learning! %f != %f" \
206 | # % (gloss, gen_fetched[1])
207 |
208 | # update tensorboard summary
209 | summary_writer.add_summary(dis_fetched[0], global_step)
210 | summary_writer.add_summary(gen_fetched[0], global_step)
211 | global_step += 1
212 |
213 | # perform normalization
214 | session.run([generator.norm_op, discriminator.norm_op],
215 | {generator.ids_to_update: find_unique(batch),
216 | discriminator.ids_to_update: find_unique([rel, psub, pobj, nsub, nobj])})
217 |
218 | # udpate progress bar
219 | pbar.set_description("Training Batch: %d. GLoss: %.4f. DLoss: %.4f. Reward: %.4f" %
220 | (b, gen_fetched[1], dis_fetched[1], discounted_reward))
221 | pbar.update()
222 |
223 | if verbose_batch:
224 | print('Discriminator:')
225 | discriminator.progress_update(batch, dis_fetched)
226 | print('Generator:')
227 | print('Avg probability of sampled negative examples from last batch: %.4f' % np.average(gen_fetched[2]))
228 | pbar.close()
229 | pbar = tqdm(total=console_update_interval)
230 | pbar.close()
231 |
232 | return global_step
233 |
234 |
235 | def train_epoch_sn(sess, discriminator, generator, sn_generator, config, data_generator, summary_writer, global_step):
236 | baseline = 0.
237 | sn_baseline = 0.
238 | pbar = None
239 | console_update_interval = config.progress_update_interval
240 | idx_np = np.arange(config.num_generator_samples)
241 | sn_idx_np = np.arange(10)
242 | for b, (mt_batch, sn_batch) in enumerate(izip(data_generator.generate_mt_gen_mode(True),
243 | data_generator.generate_sn_gen_mode(True))):
244 | verbose_batch = b > 0 and b % console_update_interval == 0
245 |
246 | # mt generation
247 | gen_feed_dict = generator.prepare_feed_dict(mt_batch, True)
248 | concept_distributions = sess.run(generator.probability_distributions, gen_feed_dict)
249 | rel, psub, pobj, sampl_sub, sampl_obj = mt_batch
250 | nsub, nobj, sampl_idx = sample_corrupted_triples(sampl_sub, sampl_obj, concept_distributions, idx_np)
251 |
252 | # sn generation
253 | sn_gen_feed_dict = {sn_generator.smoothing_placeholders['sn_relations']: sn_batch[0],
254 | sn_generator.smoothing_placeholders['sn_neg_subj']: sn_batch[3],
255 | sn_generator.smoothing_placeholders['sn_neg_obj']: sn_batch[4]}
256 | type_distributions = sess.run(sn_generator.sampl_distributions, sn_gen_feed_dict)
257 | sn_nsub, sn_nobj, sn_sampl_idx = sample_corrupted_triples(sn_batch[3], sn_batch[4], type_distributions, sn_idx_np)
258 | types = np.unique(np.concatenate([sn_batch[1], sn_batch[2], sn_nsub, sn_nobj]))
259 | concepts = np.zeros([len(types), config.max_concepts_per_type], dtype=np.int32)
260 | concept_lens = np.zeros([len(types)], dtype=np.int32)
261 | for i, tid in enumerate(types):
262 | concepts_of_type_t = data_generator.type2cuis[tid] if tid in data_generator.type2cuis else []
263 | random.shuffle(concepts_of_type_t)
264 | concepts_of_type_t = concepts_of_type_t[:config.max_concepts_per_type]
265 | concept_lens[i] = len(concepts_of_type_t)
266 | concepts[i, :len(concepts_of_type_t)] = concepts_of_type_t
267 |
268 | # discrimination
269 | dis_fetched = sess.run(discriminator.fetches(True, verbose_batch) + [discriminator.sn_reward, discriminator.reward],
270 | {discriminator.relations: rel,
271 | discriminator.pos_subj: psub,
272 | discriminator.pos_obj: pobj,
273 | discriminator.neg_subj: nsub,
274 | discriminator.neg_obj: nobj,
275 | discriminator.smoothing_placeholders['sn_relations']: sn_batch[0],
276 | discriminator.smoothing_placeholders['sn_pos_subj']: sn_batch[1],
277 | discriminator.smoothing_placeholders['sn_pos_obj']: sn_batch[2],
278 | discriminator.smoothing_placeholders['sn_neg_subj']: sn_nsub,
279 | discriminator.smoothing_placeholders['sn_neg_obj']: sn_nobj,
280 | discriminator.smoothing_placeholders['sn_types']: types,
281 | discriminator.smoothing_placeholders['sn_concepts']: concepts,
282 | discriminator.smoothing_placeholders['sn_conc_counts']: concept_lens})
283 |
284 | # generation reward
285 | discounted_reward = dis_fetched[-1] - baseline
286 | baseline = dis_fetched[-1]
287 | gen_feed_dict[generator.discounted_reward] = discounted_reward
288 | gen_feed_dict[generator.gan_loss_sample] = np.asarray(sampl_idx)
289 | gen_fetched = sess.run([generator.summary, generator.loss, generator.probabilities, generator.train_op],
290 | gen_feed_dict)
291 |
292 | # sn generation reward
293 | sn_discounted_reward = dis_fetched[-2] - sn_baseline
294 | sn_baseline = dis_fetched[-2]
295 | sn_gen_feed_dict[sn_generator.discounted_reward] = sn_discounted_reward
296 | sn_gen_feed_dict[sn_generator.gan_loss_sample] = np.asarray(sn_sampl_idx)
297 | sn_gen_fetched = sess.run([sn_generator.summary, sn_generator.loss,
298 | sn_generator.type_probabilities, sn_generator.train_op],
299 | sn_gen_feed_dict)
300 |
301 | # update tensorboard summary
302 | summary_writer.add_summary(dis_fetched[0], global_step)
303 | summary_writer.add_summary(gen_fetched[0], global_step)
304 | summary_writer.add_summary(sn_gen_fetched[0], global_step)
305 | global_step += 1
306 |
307 | # perform normalization
308 | sess.run([generator.norm_op, discriminator.norm_op, sn_generator.norm_op],
309 | {generator.ids_to_update: find_unique(mt_batch + sn_batch),
310 | discriminator.ids_to_update: find_unique([rel, psub, pobj, nsub, nobj]),
311 | sn_generator.ids_to_update: sn_batch[7]})
312 |
313 | # udpate progress bar
314 | pbar = tqdm(total=console_update_interval) if pbar is None else pbar
315 | pbar.set_description("Training Batch: %d. GLoss: %.4f. SN_GLoss: %.4f. DLoss: %.4f." %
316 | (b, gen_fetched[1], sn_gen_fetched[1], dis_fetched[1]))
317 | pbar.update()
318 |
319 | if verbose_batch:
320 | print('Discriminator:')
321 | discriminator.progress_update(mt_batch, dis_fetched)
322 | print('Generator:')
323 | print('Avg probability of sampled negative examples from last batch: %.4f' % np.average(gen_fetched[2]))
324 | print('SN Generator:')
325 | print('Avg probability of sampled negative examples from last batch: %.4f' % np.average(sn_gen_fetched[2]))
326 | pbar.close()
327 | pbar = tqdm(total=console_update_interval)
328 | if pbar:
329 | pbar.close()
330 |
331 | return global_step
332 |
333 |
334 | def validation_epoch(session, model, config, data_generator, summary_writer, global_step):
335 | console_update_interval = config.val_progress_update_interval
336 | pbar = tqdm(total=console_update_interval)
337 | # validation epoch
338 | for b, batch in enumerate(data_generator.generate_mt(False)):
339 | verbose_batch = b > 0 and b % console_update_interval == 0
340 |
341 | fetched = session.run(model.fetches(False, verbose=verbose_batch), model.prepare_feed_dict(batch, False))
342 |
343 | # update tensorboard summary
344 | summary_writer.add_summary(fetched[0], global_step)
345 | global_step += 1
346 |
347 | # udpate progress bar
348 | pbar.set_description("Validation Batch: %d. Loss: %.4f" % (b, fetched[1]))
349 | pbar.update()
350 |
351 | if verbose_batch:
352 | model.progress_update(batch, fetched)
353 | pbar.close()
354 | pbar = tqdm(total=console_update_interval)
355 | pbar.close()
356 |
--------------------------------------------------------------------------------
/python/eukg/gan/train_gan.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-mal/umls-embeddings/1ac29b15f5bedf8edbd2c90d0ccdae094be03518/python/eukg/gan/train_gan.pyc
--------------------------------------------------------------------------------
/python/eukg/save_embeddings.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | import numpy as np
4 |
5 | import Config
6 | from train import init_model
7 |
8 |
9 | def main():
10 | config = Config.flags
11 |
12 | if config.mode == 'gan':
13 | scope = config.dis_run_name
14 | run_name = config.run_name + "/discriminator"
15 | config.mode = 'disc'
16 | else:
17 | scope = config.run_name
18 | run_name = config.run_name
19 |
20 | with tf.Graph().as_default(), tf.Session() as session:
21 | with tf.variable_scope(scope):
22 | model = init_model(config, None)
23 | saver = tf.train.Saver([var for var in tf.global_variables() if 'embeddings' in var.name])
24 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, run_name))
25 | print('Loading checkpoint: %s' % ckpt)
26 | saver.restore(session, ckpt)
27 |
28 | embeddings = session.run(model.embedding_model.embeddings)
29 | if config.model == 'transd':
30 | embeddings = np.concatenate((embeddings, session.run(model.embedding_model.p_embeddings)), axis=1)
31 |
32 | np.savez_compressed(config.embedding_file,
33 | embs=embeddings)
34 |
35 |
36 | if __name__ == "__main__":
37 | main()
38 |
--------------------------------------------------------------------------------
/python/eukg/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-mal/umls-embeddings/1ac29b15f5bedf8edbd2c90d0ccdae094be03518/python/eukg/test/__init__.py
--------------------------------------------------------------------------------
/python/eukg/test/classification.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import random
4 | import os
5 | from tqdm import tqdm
6 |
7 | from ..data import data_util, DataGenerator
8 | from .. import Config, train
9 | from sklearn.svm import LinearSVC
10 | from sklearn import metrics
11 |
12 | config = Config.flags
13 |
14 |
15 | def evaluate():
16 | random.seed(1337)
17 | np.random.seed(1337)
18 | config.no_semantic_network = True
19 | config.batch_size = 2000
20 |
21 | cui2id, train_data, _, _ = data_util.load_metathesaurus_data(config.data_dir, config.val_proportion)
22 | test_data = data_util.load_metathesaurus_test_data(config.data_dir)
23 | print('Loaded %d test triples from %s' % (len(test_data['rel']), config.data_dir))
24 | concept_ids = np.unique(np.concatenate([train_data['subj'], train_data['obj'], test_data['subj'], test_data['obj']]))
25 | print('%d total unique concepts' % len(concept_ids))
26 | val_idx = np.random.permutation(np.arange(len(train_data['rel'])))[:100000]
27 | val_data_generator = DataGenerator.DataGenerator(train_data,
28 | train_idx=val_idx,
29 | val_idx=[],
30 | config=config,
31 | test_mode=True)
32 |
33 | valid_triples = set()
34 | for s, r, o in zip(train_data['subj'], train_data['rel'], train_data['obj']):
35 | valid_triples.add((s, r, o))
36 | for s, r, o in zip(test_data['subj'], test_data['rel'], test_data['obj']):
37 | valid_triples.add((s, r, o))
38 | print('%d valid triples' % len(valid_triples))
39 |
40 | model_name = config.run_name
41 | if config.mode == 'gan':
42 | scope = config.dis_run_name
43 | model_name += '/discriminator'
44 | config.mode = 'disc'
45 | else:
46 | scope = config.run_name
47 |
48 | with tf.Graph().as_default(), tf.Session() as session:
49 | # init model
50 | with tf.variable_scope(scope):
51 | model = train.init_model(config, None)
52 |
53 | tf.global_variables_initializer().run()
54 | tf.local_variables_initializer().run()
55 |
56 | # init saver
57 | tf_saver = tf.train.Saver(max_to_keep=10)
58 |
59 | # load model
60 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, model_name))
61 | print('Loading checkpoint: %s' % ckpt)
62 | tf_saver.restore(session, ckpt)
63 | tf.get_default_graph().finalize()
64 |
65 | scores = []
66 | labels = []
67 | for r, s, o, ns, no in tqdm(val_data_generator.generate_mt(True), total=val_data_generator.num_train_batches()):
68 | pscores, nscores = session.run([model.pos_energy, model.neg_energy], {model.relations: r,
69 | model.pos_subj: s,
70 | model.pos_obj: o,
71 | model.neg_subj: ns,
72 | model.neg_obj: no})
73 | scores += pscores.tolist()
74 | labels += np.ones_like(pscores, dtype=np.int).tolist()
75 | scores += nscores.tolist()
76 | labels += np.zeros_like(nscores, dtype=np.int).tolist()
77 | print('Calculated scores. Training SVM.')
78 | svm = LinearSVC(dual=False)
79 | svm.fit(np.asarray(scores).reshape(-1, 1), labels)
80 | print('Done.')
81 |
82 | data_generator = DataGenerator.DataGenerator(test_data,
83 | train_idx=np.arange(len(test_data['rel'])),
84 | val_idx=[],
85 | config=config,
86 | test_mode=True)
87 | data_generator._sampler = val_data_generator.sampler
88 | scores, labels = [], []
89 | for r, s, o, ns, no in tqdm(data_generator.generate_mt(True), desc='classifying',
90 | total=data_generator.num_train_batches()):
91 | pscores, nscores = session.run([model.pos_energy, model.neg_energy], {model.relations: r,
92 | model.pos_subj: s,
93 | model.pos_obj: o,
94 | model.neg_subj: ns,
95 | model.neg_obj: no})
96 | scores += pscores.tolist()
97 | labels += np.ones_like(pscores, dtype=np.int).tolist()
98 | scores += nscores.tolist()
99 | labels += np.zeros_like(nscores, dtype=np.int).tolist()
100 | predictions = svm.predict(np.asarray(scores).reshape(-1, 1))
101 | print('pred: %s' % predictions.shape)
102 | print('lbl: %d' % len(labels))
103 | print('Relation Triple Classification Accuracy: %.4f' % metrics.accuracy_score(labels, predictions))
104 | print('Relation Triple Classification Precision: %.4f' % metrics.precision_score(labels, predictions))
105 | print('Relation Triple Classification Recall: %.4f' % metrics.recall_score(labels, predictions))
106 | print(metrics.classification_report(labels, predictions))
107 |
108 |
109 | def evaluate_sn():
110 | random.seed(1337)
111 | config.no_semantic_network = False
112 |
113 | data = {}
114 | cui2id, _, _, _ = data_util.load_metathesaurus_data(config.data_dir, 0.)
115 | _ = data_util.load_semantic_network_data(config.data_dir, data)
116 | subj, rel, obj = data['sn_subj'], data['sn_rel'], data['sn_obj']
117 | print('Loaded %d sn triples from %s' % (len(rel), config.data_dir))
118 |
119 | valid_triples = set()
120 | for trip in zip(subj, rel, obj):
121 | valid_triples.add(trip)
122 | print('%d valid triples' % len(valid_triples))
123 | idxs = np.random.permutation(np.arange(len(rel)))
124 | idxs = idxs[:600]
125 | subj, rel, obj = subj[idxs], rel[idxs], obj[idxs]
126 | sampler = DataGenerator.NegativeSampler(valid_triples=valid_triples, name='???')
127 | nsubj, nobj = sampler.sample(subj, rel, obj)
128 |
129 | model_name = config.run_name
130 | if config.mode == 'gan':
131 | scope = config.dis_run_name
132 | model_name += '/discriminator'
133 | config.mode = 'disc'
134 | else:
135 | scope = config.run_name
136 |
137 | with tf.Graph().as_default(), tf.Session() as session:
138 | # init model
139 | with tf.variable_scope(scope):
140 | model = train.init_model(config, None)
141 |
142 | tf.global_variables_initializer().run()
143 | tf.local_variables_initializer().run()
144 |
145 | # init saver
146 | tf_saver = tf.train.Saver(max_to_keep=10)
147 |
148 | # load model
149 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, model_name))
150 | print('Loading checkpoint: %s' % ckpt)
151 | tf_saver.restore(session, ckpt)
152 | tf.get_default_graph().finalize()
153 |
154 | feed_dict = {model.smoothing_placeholders['sn_relations']: rel,
155 | model.smoothing_placeholders['sn_pos_subj']: subj,
156 | model.smoothing_placeholders['sn_pos_obj']: obj,
157 | model.smoothing_placeholders['sn_neg_subj']: nsubj,
158 | model.smoothing_placeholders['sn_neg_obj']: nobj}
159 | pscores, nscores = session.run([model.sn_pos_energy, model.sn_neg_energy], feed_dict)
160 | scores = np.concatenate((pscores, nscores))
161 | labels = np.concatenate((np.ones_like(pscores, dtype=np.int), np.zeros_like(nscores, dtype=np.int)))
162 | print('Calculated scores. Training SVM.')
163 | svm = LinearSVC(dual=False)
164 | svm.fit(scores.reshape(-1, 1), labels)
165 | print('Done.')
166 |
167 | with np.load(os.path.join(config.data_dir, 'semnet', 'triples.npz')) as npz:
168 | subj = npz['subj']
169 | rel = npz['rel']
170 | obj = npz['obj']
171 | nsubj, nobj = sampler.sample(subj, rel, obj)
172 | feed_dict = {model.smoothing_placeholders['sn_relations']: rel,
173 | model.smoothing_placeholders['sn_pos_subj']: subj,
174 | model.smoothing_placeholders['sn_pos_obj']: obj,
175 | model.smoothing_placeholders['sn_neg_subj']: nsubj,
176 | model.smoothing_placeholders['sn_neg_obj']: nobj}
177 | pscores, nscores = session.run([model.sn_pos_energy, model.sn_neg_energy], feed_dict)
178 | predictions = svm.predict(pscores.reshape(-1, 1)).tolist()
179 | labels = np.ones_like(pscores, dtype=np.int).tolist()
180 | predictions += svm.predict(nscores.reshape(-1, 1)).tolist()
181 | labels += np.zeros_like(nscores, dtype=np.int).tolist()
182 | print('SN Relation Triple Classification Accuracy: %.4f' % metrics.accuracy_score(labels, predictions))
183 | print('SN Relation Triple Classification Precision: %.4f' % metrics.precision_score(labels, predictions))
184 | print('SN Relation Triple Classification Recall: %.4f' % metrics.recall_score(labels, predictions))
185 | print(metrics.classification_report(labels, predictions))
186 |
187 |
188 |
189 | if __name__ == "__main__":
190 | if config.eval_mode == 'sn':
191 | evaluate_sn()
192 | else:
193 | evaluate()
194 |
--------------------------------------------------------------------------------
/python/eukg/test/nearest_neighbors.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys
3 | import json
4 | import os
5 | import csv
6 | from tqdm import tqdm
7 | from collections import defaultdict
8 |
9 | csv.field_size_limit(sys.maxsize)
10 |
11 |
12 | def knn(embeddings, idx, k):
13 | k = k+1
14 | e = np.expand_dims(embeddings[idx], axis=0)
15 | distances = np.linalg.norm(e - embeddings, axis=1)
16 | print(distances.shape)
17 | idxs = list(range(embeddings.shape[0]))
18 | idxs.sort(key=lambda i_: distances[i_])
19 |
20 | return zip(idxs[:k], [distances[i] for i in idxs[:k]])
21 |
22 |
23 | def main():
24 | cui2names = defaultdict(list)
25 | with open('/home/rmm120030/working/umls-mke/umls/META/MRCONSO.RRF', 'r') as f:
26 | reader = csv.reader(f, delimiter='|')
27 | for row in tqdm(reader, desc="reading mrconso", total=8157818):
28 | cui2names[row[0]].append(row[14])
29 |
30 | cui2id = json.load(open(os.path.join('/home/rmm120030/working/umls-mke/data', 'name2id.json')))
31 | id2cui = {v: k for k, v in cui2id.iteritems()}
32 | with np.load(sys.argv[1]) as npz:
33 | embeddings = npz['embs']
34 | print('Loaded embedding matrix: %s' % str(embeddings.shape))
35 |
36 | while True:
37 | cui = raw_input('enter CUI (or \'exit\' to stop): ')
38 | if cui == "exit":
39 | exit()
40 | if cui in cui2id:
41 | neihbors = knn(embeddings, cui2id[cui], 10)
42 | for i, dist in neihbors:
43 | c = id2cui[i]
44 | print(' %.6f - %s - %s' % (dist, c, cui2names[c][:5]))
45 | else:
46 | print('No embedding for CUI %s' % cui)
47 |
48 |
49 | if __name__ == "__main__":
50 | main()
51 |
--------------------------------------------------------------------------------
/python/eukg/test/ppa.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import random
4 | import os
5 | import json
6 | from tqdm import tqdm
7 |
8 | from ..data import data_util, DataGenerator
9 | from .. import Config, train
10 |
11 | config = Config.flags
12 |
13 |
14 | def evaluate():
15 | random.seed(1337)
16 | config.no_semantic_network = True
17 |
18 | cui2id, train_data, _, _ = data_util.load_metathesaurus_data(config.data_dir, config.val_proportion)
19 | id2cui = {v: k for k, v in cui2id.iteritems()}
20 | test_data = data_util.load_metathesaurus_test_data(config.data_dir)
21 | print('Loaded %d test triples from %s' % (len(test_data['rel']), config.data_dir))
22 | concept_ids = np.unique(np.concatenate([train_data['subj'], train_data['obj'], test_data['subj'], test_data['obj']]))
23 | print('%d total unique concepts' % len(concept_ids))
24 | data_generator = DataGenerator.DataGenerator(test_data,
25 | train_idx=np.arange(len(test_data['rel'])),
26 | val_idx=[],
27 | config=config,
28 | test_mode=True)
29 |
30 | valid_triples = set()
31 | for s, r, o in zip(train_data['subj'], train_data['rel'], train_data['obj']):
32 | valid_triples.add((s, r, o))
33 | for s, r, o in zip(test_data['subj'], test_data['rel'], test_data['obj']):
34 | valid_triples.add((s, r, o))
35 | print('%d valid triples' % len(valid_triples))
36 |
37 | model_name = config.run_name
38 | if config.mode == 'gan':
39 | scope = config.dis_run_name
40 | model_name += '/discriminator'
41 | config.mode = 'disc'
42 | else:
43 | scope = config.run_name
44 |
45 | with tf.Graph().as_default(), tf.Session() as session:
46 | # init model
47 | with tf.variable_scope(scope):
48 | model = train.init_model(config, data_generator)
49 |
50 | tf.global_variables_initializer().run()
51 | tf.local_variables_initializer().run()
52 |
53 | # init saver
54 | tf_saver = tf.train.Saver(max_to_keep=10)
55 |
56 | # load model
57 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, model_name))
58 | print('Loading checkpoint: %s' % ckpt)
59 | tf_saver.restore(session, ckpt)
60 | tf.get_default_graph().finalize()
61 |
62 | def decode_triple(s_, r_, o_, score):
63 | return id2cui[s_], id2cui[r_], id2cui[o_], str(score)
64 |
65 | incorrect = [] # list of tuples of (triple, corrupted_triple)
66 | num_correct = 0
67 | total = 0
68 | for r, s, o, ns, no in tqdm(data_generator.generate_mt(True), total=data_generator.num_train_batches()):
69 | pscores, nscores = session.run([model.pos_energy, model.neg_energy], {model.relations: r,
70 | model.pos_subj: s,
71 | model.pos_obj: o,
72 | model.neg_subj: ns,
73 | model.neg_obj: no})
74 | total += len(r)
75 | for pscore, nscore, rel, subj, obj, nsubj, nobj in zip(pscores, nscores, r, s, o, ns, no):
76 | if pscore < nscore:
77 | num_correct += 1
78 | else:
79 | incorrect.append((decode_triple(subj, rel, obj, pscore), decode_triple(nsubj, rel, nobj, nscore)))
80 | ppa = float(num_correct)/total
81 | print('PPA: %.4f' % ppa)
82 | outdir = os.path.join(config.eval_dir, config.run_name)
83 | if not os.path.exists(outdir):
84 | os.makedirs(outdir)
85 | json.dump(incorrect, open(os.path.join(outdir, 'ppa_incorrect.json'), 'w+'))
86 | with open(os.path.join(outdir, 'ppa.txt'), 'w+') as f:
87 | f.write(str(ppa))
88 |
89 |
90 | def evaluate_sn():
91 | random.seed(1337)
92 | config.no_semantic_network = False
93 |
94 | data = {}
95 | cui2id, _, _, _ = data_util.load_metathesaurus_data(config.data_dir, 0.)
96 | id2cui = {v: k for k, v in cui2id.iteritems()}
97 | _ = data_util.load_semantic_network_data(config.data_dir, data)
98 | subj, rel, obj = data['sn_subj'], data['sn_rel'], data['sn_obj']
99 | print('Loaded %d sn triples from %s' % (len(rel), config.data_dir))
100 |
101 | valid_triples = set()
102 | for trip in zip(subj, rel, obj):
103 | valid_triples.add(trip)
104 | print('%d valid triples' % len(valid_triples))
105 | idxs = np.arange(len(rel))
106 | np.random.shuffle(idxs)
107 | idxs = idxs[:600]
108 | subj, rel, obj = subj[idxs], rel[idxs], obj[idxs]
109 | sampler = DataGenerator.NegativeSampler(valid_triples=valid_triples, name='???')
110 | nsubj, nobj = sampler.sample(subj, rel, obj)
111 |
112 | model_name = config.run_name
113 | if config.mode == 'gan':
114 | scope = config.dis_run_name
115 | model_name += '/discriminator'
116 | config.mode = 'disc'
117 | else:
118 | scope = config.run_name
119 |
120 | with tf.Graph().as_default(), tf.Session() as session:
121 | # init model
122 | with tf.variable_scope(scope):
123 | model = train.init_model(config, None)
124 |
125 | tf.global_variables_initializer().run()
126 | tf.local_variables_initializer().run()
127 |
128 | # init saver
129 | tf_saver = tf.train.Saver(max_to_keep=10)
130 |
131 | # load model
132 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, model_name))
133 | print('Loading checkpoint: %s' % ckpt)
134 | tf_saver.restore(session, ckpt)
135 | tf.get_default_graph().finalize()
136 |
137 | def decode_triple(s_, r_, o_, score):
138 | return id2cui[s_], id2cui[r_], id2cui[o_], str(score)
139 |
140 | incorrect = [] # list of tuples of (triple, corrupted_triple)
141 | num_correct = 0
142 | total = len(rel)
143 | feed_dict = {model.smoothing_placeholders['sn_relations']: rel,
144 | model.smoothing_placeholders['sn_pos_subj']: subj,
145 | model.smoothing_placeholders['sn_pos_obj']: obj,
146 | model.smoothing_placeholders['sn_neg_subj']: nsubj,
147 | model.smoothing_placeholders['sn_neg_obj']: nobj}
148 | pscores, nscores = session.run([model.sn_pos_energy, model.sn_neg_energy], feed_dict)
149 | for pscore, nscore, r, s, o, ns, no in zip(pscores, nscores, rel, subj, obj, nsubj, nobj):
150 | if pscore < nscore:
151 | num_correct += 1
152 | else:
153 | incorrect.append((decode_triple(s, r, o, pscore), decode_triple(ns, r, no, nscore)))
154 | ppa = float(num_correct)/total
155 | print('PPA: %.4f' % ppa)
156 | outdir = os.path.join(config.eval_dir, config.run_name)
157 | if not os.path.exists(outdir):
158 | os.makedirs(outdir)
159 | json.dump(incorrect, open(os.path.join(outdir, 'sn_ppa_incorrect.json'), 'w+'))
160 | with open(os.path.join(outdir, 'sn_ppa.txt'), 'w+') as f:
161 | f.write(str(ppa))
162 |
163 |
164 | if __name__ == "__main__":
165 | if config.eval_mode == 'sn':
166 | evaluate_sn()
167 | else:
168 | evaluate()
169 |
--------------------------------------------------------------------------------
/python/eukg/test/ranking_evals.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import random
4 | import math
5 | import os
6 | import json
7 | from tqdm import tqdm
8 | from collections import defaultdict
9 |
10 | from ..data import data_util, DataGenerator
11 | from .. import Config, train
12 | from ..threading_util import synchronized, parallel_stream
13 |
14 |
15 | config = Config.flags
16 |
17 |
18 | def save_ranks():
19 | random.seed(1337)
20 | config.batch_size = 4096
21 |
22 | cui2id, train_data, _, _ = data_util.load_metathesaurus_data(config.data_dir, config.val_proportion)
23 | id2cui = {v: k for k, v in cui2id.iteritems()}
24 | test_data = data_util.load_metathesaurus_test_data(config.data_dir)
25 | print('Loaded %d test triples from %s' % (len(test_data['rel']), config.data_dir))
26 |
27 | valid_triples = set()
28 | for s, r, o in zip(train_data['subj'], train_data['rel'], train_data['obj']):
29 | valid_triples.add((s, r, o))
30 | for s, r, o in zip(test_data['subj'], test_data['rel'], test_data['obj']):
31 | valid_triples.add((s, r, o))
32 | print('%d valid triples' % len(valid_triples))
33 |
34 | model_name = config.run_name
35 | if config.mode == 'gan':
36 | scope = config.dis_run_name
37 | model_name += '/discriminator'
38 | config.mode = 'disc'
39 | else:
40 | scope = config.run_name
41 |
42 | with tf.Graph().as_default(), tf.Session() as session:
43 | # init model
44 | with tf.variable_scope(scope):
45 | model = train.init_model(config, None)
46 |
47 | tf.global_variables_initializer().run()
48 | tf.local_variables_initializer().run()
49 |
50 | # init saver
51 | tf_saver = tf.train.Saver(max_to_keep=10)
52 |
53 | # load model
54 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, model_name))
55 | print('Loading checkpoint: %s' % ckpt)
56 | tf_saver.restore(session, ckpt)
57 | tf.get_default_graph().finalize()
58 |
59 | if not config.save_ranks:
60 | print('WARNING: ranks will not be saved! This run should only be for debugging purposes!')
61 |
62 | outdir = os.path.join(config.eval_dir, config.run_name)
63 | if not os.path.exists(outdir):
64 | os.makedirs(outdir)
65 |
66 | chunk_size = (1. / config.num_shards) * len(test_data['rel'])
67 | print('Chunk size: %f' % chunk_size)
68 | start = int((config.shard - 1) * chunk_size)
69 | end = (len(test_data['rel']) if config.shard == int(config.num_shards) else int(config.shard * chunk_size)) \
70 | if config.save_ranks else start + 1000
71 | print('Processing data from idx %d to %d' % (start, end))
72 |
73 | sampler = DataGenerator.NegativeSampler(valid_triples=valid_triples, name='gan')
74 | global ranks
75 | global first
76 | ranks = []
77 | first = True
78 | open_file = lambda: open(os.path.join(outdir, 'ranks_%d.json' % config.shard), 'w+') \
79 | if config.save_ranks else open("/dev/null")
80 | with open_file() as f, tqdm(total=(end-start)) as pbar:
81 | if config.save_ranks:
82 | f.write('[')
83 |
84 | # define thread function
85 | def calculate(triple):
86 | s, r, o = triple
87 | # subj
88 | invalid_concepts = [s] + synchronized(lambda: sampler.invalid_concepts(s, r, o, True))
89 | subj_scores, subj_ranking = calculate_scores(s, r, o, True, invalid_concepts, session, model, config.batch_size)
90 | srank = subj_ranking[s] if s in subj_ranking else -1
91 |
92 | # obj
93 | invalid_concepts = [o] + synchronized(lambda: sampler.invalid_concepts(s, r, o, False))
94 | obj_scores, obj_ranking = calculate_scores(s, r, o, False, invalid_concepts, session, model, config.batch_size)
95 | orank = obj_ranking[o] if o in obj_ranking else -1
96 | json_string = json.dumps([str(srank),
97 | str(orank),
98 | str(obj_scores[o]),
99 | id2cui[s], id2cui[r], id2cui[o]])
100 | return srank, orank, json_string
101 |
102 | # define save function
103 | def save(future):
104 | global ranks
105 | global first
106 |
107 | srank, orank, json_string = future.result()
108 | ranks += [srank, orank]
109 | if config.save_ranks:
110 | if not first:
111 | f.write(',')
112 | first = False
113 | f.write(json_string)
114 | npr = np.asarray(ranks, dtype=np.float)
115 | if config.save_ranks:
116 | pbar.set_description('Srank: %5d. Orank: %5d.' % (srank, orank))
117 | else:
118 | pbar.set_description('Srank: %5d. Orank: %5d. MRR: %.4f. H@10: %.4f' %
119 | (srank, orank, mrr(npr), hits_at_10(npr)))
120 | pbar.update()
121 |
122 | parallel_stream(zip(test_data['subj'][start:end], test_data['rel'][start:end], test_data['obj'][start:end]),
123 | parallelizable_fn=calculate,
124 | future_consumer=save)
125 |
126 | # finish json
127 | if config.save_ranks:
128 | f.write(']')
129 |
130 | ranks_np = np.asarray(ranks, dtype=np.float)
131 | print('Mean Reciprocal Rank of shard %d: %.4f' % (config.shard, mrr(ranks_np)))
132 | print('Mean Rank of shard %d: %.2f' % (config.shard, mr(ranks_np)))
133 | print('Hits @ 10 of shard %d: %.4f' % (config.shard, hits_at_10(ranks_np)))
134 |
135 |
136 | def calculate_scores(subj, rel, obj, replace_subject, concept_ids, session, model, batch_size):
137 | num_batches = int(math.ceil(float(len(concept_ids))/batch_size))
138 | subjects = np.full(batch_size, subj, dtype=np.int32)
139 | relations = np.full(batch_size, rel, dtype=np.int32)
140 | objects = np.full(batch_size, obj, dtype=np.int32)
141 | scores = {}
142 |
143 | for b in xrange(num_batches):
144 | concepts = concept_ids[b*batch_size:(b+1)*batch_size]
145 | feed_dict = {model.pos_subj: subjects,
146 | model.relations: relations,
147 | model.pos_obj: objects}
148 |
149 | # pad concepts if necessary
150 | if len(concepts) < batch_size:
151 | concepts = np.pad(concepts, (0, batch_size - len(concepts)), mode='constant', constant_values=0)
152 |
153 | # replace subj/obj in feed dict
154 | if replace_subject:
155 | feed_dict[model.pos_subj] = concepts
156 | else:
157 | feed_dict[model.pos_obj] = concepts
158 |
159 | # calculate energies
160 | energies = session.run(model.pos_energy, feed_dict)
161 |
162 | # store scores
163 | for i, cid in enumerate(concept_ids[b*batch_size:(b+1)*batch_size]):
164 | scores[cid] = energies[i]
165 |
166 | ranking = sorted(scores.keys(), key=lambda k: scores[k])
167 |
168 | rank_map = {}
169 | prev_rank = 0
170 | prev_score = -1
171 | total = 1
172 | for c in ranking:
173 | # if c has a lower score than prev
174 | if scores[c] > prev_score:
175 | # increment the rank
176 | prev_rank = total
177 | # update score
178 | prev_score = scores[c]
179 | total += 1
180 | rank_map[c] = prev_rank
181 |
182 | return scores, rank_map
183 |
184 |
185 | def save_ranks_sn():
186 | random.seed(1337)
187 | config.batch_size = 4096
188 | assert not config.no_semantic_network
189 |
190 | data = {}
191 | cui2id, _, _, _ = data_util.load_metathesaurus_data(config.data_dir, 0.)
192 | id2cui = {v: k for k, v in cui2id.iteritems()}
193 | _ = data_util.load_semantic_network_data(config.data_dir, data)
194 | subj, rel, obj = data['sn_subj'], data['sn_rel'], data['sn_obj']
195 | n = len(rel)
196 | print('Loaded %d sn triples from %s' % (n, config.data_dir))
197 |
198 | valid_triples = set()
199 | for trip in zip(subj, rel, obj):
200 | valid_triples.add(trip)
201 | print('%d valid triples' % len(valid_triples))
202 |
203 | model_name = config.run_name
204 | if config.mode == 'gan':
205 | scope = config.dis_run_name
206 | model_name += '/discriminator'
207 | config.mode = 'disc'
208 | else:
209 | scope = config.run_name
210 |
211 | with tf.Graph().as_default(), tf.Session() as session:
212 | # init model
213 | with tf.variable_scope(scope):
214 | model = train.init_model(config, None)
215 |
216 | tf.global_variables_initializer().run()
217 | tf.local_variables_initializer().run()
218 |
219 | # init saver
220 | tf_saver = tf.train.Saver(max_to_keep=10)
221 |
222 | # load model
223 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, model_name))
224 | print('Loading checkpoint: %s' % ckpt)
225 | tf_saver.restore(session, ckpt)
226 | tf.get_default_graph().finalize()
227 |
228 | if not config.save_ranks:
229 | print('WARNING: ranks will not be saved! This run should only be for debugging purposes!')
230 |
231 | outdir = os.path.join(config.eval_dir, config.run_name)
232 | if not os.path.exists(outdir):
233 | os.makedirs(outdir)
234 |
235 | sampler = DataGenerator.NegativeSampler(valid_triples=valid_triples, name='gan')
236 | global ranks
237 | global first
238 | ranks = []
239 | first = True
240 | open_file = lambda: open(os.path.join(outdir, 'sn_ranks.json'), 'w+') \
241 | if config.save_ranks else open("/dev/null")
242 | with open_file() as f, tqdm(total=600) as pbar:
243 | if config.save_ranks:
244 | f.write('[')
245 |
246 | # define thread function
247 | def calculate(triple):
248 | s, r, o = triple
249 | # subj
250 | invalid_concepts = [s] + synchronized(lambda: sampler.invalid_concepts(s, r, o, True))
251 | subj_scores, subj_ranking = calculate_scores_sn(s, r, o, True, invalid_concepts, session, model, config.batch_size)
252 | srank = subj_ranking[s] if s in subj_ranking else -1
253 |
254 | # obj
255 | invalid_concepts = [o] + synchronized(lambda: sampler.invalid_concepts(s, r, o, False))
256 | obj_scores, obj_ranking = calculate_scores_sn(s, r, o, False, invalid_concepts, session, model, config.batch_size)
257 | orank = obj_ranking[o] if o in obj_ranking else -1
258 | json_string = json.dumps([str(srank),
259 | str(orank),
260 | str(obj_scores[o]),
261 | id2cui[s], id2cui[r], id2cui[o]])
262 | return srank, orank, json_string
263 |
264 | # define save function
265 | def save(future):
266 | global ranks
267 | global first
268 |
269 | srank, orank, json_string = future.result()
270 | ranks += [srank, orank]
271 | if config.save_ranks:
272 | if not first:
273 | f.write(',')
274 | first = False
275 | f.write(json_string)
276 | npr = np.asarray(ranks, dtype=np.float)
277 | pbar.set_description('Srank: %5d. Orank: %5d. MRR: %.4f. H@10: %.4f' %
278 | (srank, orank, mrr(npr), hits_at_10(npr)))
279 | pbar.update()
280 |
281 | with np.load(os.path.join(config.data_dir, 'semnet', 'test.npz')) as sn_npz:
282 | subj, rel, obj = sn_npz['subj'], sn_npz['rel'], sn_npz['obj']
283 | parallel_stream(zip(subj, rel, obj),
284 | parallelizable_fn=calculate,
285 | future_consumer=save)
286 |
287 | # finish json
288 | if config.save_ranks:
289 | f.write(']')
290 |
291 | ranks_np = np.asarray(ranks, dtype=np.float)
292 | print('Mean Reciprocal Rank: %.4f' % (mrr(ranks_np)))
293 | print('Mean Rank: %.2f' % (mr(ranks_np)))
294 | print('Hits @ 10: %.4f' % (hits_at_10(ranks_np)))
295 |
296 |
297 | def calculate_scores_sn(subj, rel, obj, replace_subject, concept_ids, session, model, batch_size):
298 | num_batches = int(math.ceil(float(len(concept_ids))/batch_size))
299 | subjects = np.full(batch_size, subj, dtype=np.int32)
300 | relations = np.full(batch_size, rel, dtype=np.int32)
301 | objects = np.full(batch_size, obj, dtype=np.int32)
302 | scores = {}
303 |
304 | for b in xrange(num_batches):
305 | concepts = concept_ids[b*batch_size:(b+1)*batch_size]
306 | feed_dict = {model.smoothing_placeholders['sn_pos_subj']: subjects,
307 | model.smoothing_placeholders['sn_relations']: relations,
308 | model.smoothing_placeholders['sn_pos_obj']: objects}
309 |
310 | # pad concepts if necessary
311 | if len(concepts) < batch_size:
312 | concepts = np.pad(concepts, (0, batch_size - len(concepts)), mode='constant', constant_values=0)
313 |
314 | # replace subj/obj in feed dict
315 | if replace_subject:
316 | feed_dict[model.smoothing_placeholders['sn_pos_subj']] = concepts
317 | else:
318 | feed_dict[model.smoothing_placeholders['sn_pos_obj']] = concepts
319 |
320 | # calculate energies
321 | energies = session.run(model.sn_pos_energy, feed_dict)
322 |
323 | # store scores
324 | for i, cid in enumerate(concept_ids[b*batch_size:(b+1)*batch_size]):
325 | scores[cid] = energies[i]
326 |
327 | ranking = sorted(scores.keys(), key=lambda k: scores[k])
328 |
329 | rank_map = {}
330 | prev_rank = 0
331 | prev_score = -1
332 | total = 1
333 | for c in ranking:
334 | # if c has a lower score than prev
335 | if scores[c] > prev_score:
336 | # increment the rank
337 | prev_rank = total
338 | # update score
339 | prev_score = scores[c]
340 | total += 1
341 | rank_map[c] = prev_rank
342 |
343 | return scores, rank_map
344 |
345 |
346 | def mrr(ranks_np):
347 | return float(np.mean(1. / ranks_np))
348 |
349 |
350 | def mr(ranks_np):
351 | return float(np.mean(ranks_np))
352 |
353 |
354 | def hits_at_10(ranks_np):
355 | return float(len(ranks_np[ranks_np <= 10])) / len(ranks_np)
356 |
357 |
358 | def calculate_ranking_evals():
359 | outdir = os.path.join(config.eval_dir, config.run_name)
360 |
361 | ppa = float(str(next(open(os.path.join(outdir, 'ppa.txt')))).strip())
362 | print('PPA: %.4f' % ppa)
363 |
364 | ranks = []
365 | for fname in os.listdir(outdir):
366 | full_path = os.path.join(outdir, fname)
367 | if os.path.isfile(full_path) and fname.startswith('ranks_'):
368 | for fields in json.load(open(full_path)):
369 | ranks += [fields[0], fields[1]]
370 | ranks_np = np.asarray(ranks, dtype=np.float)
371 | mrr_ = mrr(ranks_np)
372 | mr_ = mr(ranks_np)
373 | hat10 = hits_at_10(ranks_np)
374 | print('MRR: %.4f' % mrr_)
375 | print('MR: %.2f' % mr_)
376 | print('H@10: %.4f' % hat10)
377 |
378 | with open(os.path.join(outdir, 'ranking_evals.tsv'), 'w+') as f:
379 | f.write('ppa\t%f\n' % ppa)
380 | f.write('mrr\t%f\n' % mrr_)
381 | f.write('mr\t%f\n' % mr_)
382 | f.write('h@10\t%f' % hat10)
383 |
384 |
385 | def calculate_ranking_evals_per_rel():
386 | outdir = os.path.join(config.eval_dir, config.run_name)
387 |
388 | ranks = defaultdict(list)
389 | for fname in os.listdir(outdir):
390 | full_path = os.path.join(outdir, fname)
391 | if os.path.isfile(full_path) and fname.startswith('ranks_'):
392 | for fields in json.load(open(full_path)):
393 | ranks[fields[4]] += [fields[0], fields[1]]
394 | print('Gathered %d rankings' % len(ranks))
395 |
396 | for rel, rl in ranks.iteritems():
397 | ranks_np = np.asarray(rl, dtype=np.float)
398 | ranks[rel] = [mrr(ranks_np), mr(ranks_np), hits_at_10(ranks_np), len(rl)]
399 |
400 | relations = ranks.keys()
401 | relations.sort(key=lambda x: ranks[x][1])
402 |
403 | with open(os.path.join(outdir, 'ranking_evals_per_rel.tsv'), 'w+') as f:
404 | for rel in relations:
405 | [mrr_, mr_, hat10, count] = ranks[rel]
406 | print('----------%s(%d)----------' % (rel, count))
407 | print('MRR: %.4f' % mrr_)
408 | print('MR: %.2f' % mr_)
409 | print('H@10: %.4f' % hat10)
410 |
411 | f.write('%s (%d)\n' % (rel, count))
412 | f.write('mrr\t%f\n' % mrr_)
413 | f.write('mr\t%f\n' % mr_)
414 | f.write('h@10\t%f\n\n' % hat10)
415 |
416 |
417 | def split_ranking_evals():
418 | outdir = os.path.join(config.eval_dir, config.run_name)
419 |
420 | ppa = float(str(next(open(os.path.join(outdir, 'ppa.txt')))).strip())
421 | print('PPA: %.4f' % ppa)
422 |
423 | s_ranks = []
424 | o_ranks = []
425 | for fname in os.listdir(outdir):
426 | full_path = os.path.join(outdir, fname)
427 | if os.path.isfile(full_path) and fname.startswith('ranks_'):
428 | for fields in json.load(open(full_path)):
429 | s_ranks += [fields[0]]
430 | o_ranks += [fields[1]]
431 |
432 | def report(ranks):
433 | ranks_np = np.asarray(ranks, dtype=np.float)
434 | mrr_ = mrr(ranks_np)
435 | mr_ = mr(ranks_np)
436 | hat10 = hits_at_10(ranks_np)
437 | print('MRR: %.4f' % mrr_)
438 | print('MR: %.2f' % mr_)
439 | print('H@10: %.4f' % hat10)
440 | report(s_ranks)
441 | report(o_ranks)
442 |
443 |
444 | def fix_json():
445 | config = Config.flags
446 | outdir = os.path.join(config.eval_dir, config.run_name)
447 |
448 | ppa = float(str(next(open(os.path.join(outdir, 'ppa.txt')))).strip())
449 | print('PPA: %.4f' % ppa)
450 |
451 | ranks = []
452 | for fname in os.listdir(outdir):
453 | full_path = os.path.join(outdir, fname)
454 | if os.path.isfile(full_path) and fname.startswith('ranks_'):
455 | json_string = '[' + str(next(open(full_path))).strip()
456 | with open(full_path, 'w+') as f:
457 | f.write(json_string)
458 | for fields in json.load(open(full_path)):
459 | ranks += [fields[0], fields[2]]
460 | ranks_np = np.asarray(ranks, dtype=np.float)
461 | mrr_ = mrr(ranks_np)
462 | mr_ = mr(ranks_np)
463 | hat10 = hits_at_10(ranks_np)
464 | print('MRR: %.4f' % mrr_)
465 | print('MR: %.2f' % mr_)
466 | print('H@10: %.4f' % hat10)
467 |
468 | with open(os.path.join(outdir, 'ranking_evals.tsv'), 'w+') as f:
469 | f.write('ppa\t%f\n' % ppa)
470 | f.write('mrr\t%f\n' % mrr_)
471 | f.write('mr\t%f\n' % mr_)
472 | f.write('h@10\t%f' % hat10)
473 |
474 |
475 | if __name__ == "__main__":
476 | if config.eval_mode == "save":
477 | save_ranks()
478 | elif config.eval_mode == "save-sn":
479 | save_ranks_sn()
480 | elif config.eval_mode == "calc":
481 | calculate_ranking_evals()
482 | elif config.eval_mode == "calc-rel":
483 | calculate_ranking_evals_per_rel()
484 | else:
485 | raise Exception('Unrecognized eval_mode: %s' % config.eval_mode)
486 |
--------------------------------------------------------------------------------
/python/eukg/tf_util/ModelSaver.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class Policy(Enum):
5 | EPOCH = 1
6 | TIMED = 2
7 |
8 |
9 | class ModelSaver:
10 | def __init__(self, tf_saver, session, model_file, policy):
11 | self.tf_saver = tf_saver
12 | self.session = session
13 | self.model_file = model_file
14 | self.policy = policy
15 |
16 | def save(self, global_step, policy, **kwargs):
17 | if self.save_condition(policy, **kwargs):
18 | print("Saving model to %s at step %d" % (self.model_file, global_step))
19 | self.tf_saver.save(self.session, self.model_file, global_step=global_step)
20 |
21 | def save_condition(self, policy, **kwargs):
22 | return policy == self.policy
23 |
24 |
25 | class TimedSaver(ModelSaver):
26 | def __init__(self, tf_saver, session, model_file, seconds_per_save):
27 | ModelSaver.__init__(self, tf_saver, session, model_file, Policy.TIMED)
28 | self.seconds_per_save = seconds_per_save
29 |
30 | def save_condition(self, policy, **kwargs):
31 | return policy == self.policy and kwargs['seconds_since_last_save'] > self.seconds_per_save
32 |
33 |
34 | class EpochSaver(ModelSaver):
35 | def __init__(self, tf_saver, session, model_file, save_every_x_epochs=1):
36 | ModelSaver.__init__(self, tf_saver, session, model_file, Policy.EPOCH)
37 | self.save_every_x_epochs = save_every_x_epochs
38 |
39 | def save_condition(self, policy, **kwargs):
40 | return 'epoch' in kwargs and kwargs['epoch'] % self.save_every_x_epochs == 0
41 |
--------------------------------------------------------------------------------
/python/eukg/tf_util/Trainable.py:
--------------------------------------------------------------------------------
1 | class Trainable:
2 | def __init__(self):
3 | pass
4 |
5 | def fetches(self, is_training, verbose=False):
6 | """
7 | Returns a list of fetches to be passed to session.run()
8 | :param is_training: flag indicating if the model is training/testing
9 | :param verbose: flag indicating if a more verbose set of variables should be fetched (usually for debugging or
10 | progress updates)
11 | :return: a list of fetches to be passed to session.run()
12 | """
13 | raise NotImplementedError("to be implemented by subclass")
14 |
15 | def prepare_feed_dict(self, batch, is_training, **kwargs):
16 | """
17 | Turns a list of tensors into a dict of model_parameter: tensor
18 | :param batch: list of data tensors to be passed to the model
19 | :param is_training: flag indicating if the model is in training or testing mode
20 | :param kwargs: optional other params
21 | :return: the feed dict to be passed to session.run()
22 | """
23 | raise NotImplementedError("to be implemented by subclass")
24 |
25 | def progress_update(self, batch, fetched, **kwargs):
26 | """
27 | Prepares a progress update
28 | :param batch: batch data passed to prepare_feed_dict()
29 | :param fetched: tensors returned by session.run()
30 | :param kwargs: optional other params
31 | :return: String progress update
32 | """
33 | raise NotImplementedError("to be implemented by subclass")
34 |
35 | def data_provider(self, config, is_training, **kwargs):
36 | """
37 | Provides access to a data generator that generates batches of data
38 | :param config: dict of config flags to values (usually tf.flags.FLAGS)
39 | :param is_training: flag indicating if the model is in training or testing mode
40 | :param kwargs: optional other params
41 | :return: A generator that generates batches of data
42 | """
43 | raise NotImplementedError("to be implemented by subclass")
44 |
--------------------------------------------------------------------------------
/python/eukg/tf_util/Trainer.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | from tqdm import tqdm
4 | import time
5 |
6 | from Trainable import Trainable
7 | from ModelSaver import ModelSaver, Policy
8 |
9 |
10 | def train(config, session, model, saver,
11 | train_post_step=None,
12 | train_post_epoch=None,
13 | val_post_step=None,
14 | val_post_epoch=None,
15 | global_step=0,
16 | max_batches_per_epoch=None):
17 | """
18 | Trains and validates
19 | :param config: config
20 | :type config: dict
21 | :param session: tf session
22 | :param model: Trainable
23 | :type model: Trainable
24 | :param train_post_step: functions to execute after each train step
25 | :param saver: ModelSaver
26 | :type saver: ModelSaver
27 | :param train_post_step: functions to execute after each train step
28 | :param train_post_epoch: functions to execute after each train epoch
29 | :param val_post_step: functions to execute after each validation step
30 | :param val_post_epoch: functions to execute after each validation epoch
31 | :param global_step: optional argument to pass a non-zero global step
32 | :param max_batches_per_epoch: optional argument to limit a training epoch to some number of batches
33 | """
34 | # init summary directories and summary writers
35 | if not os.path.exists(os.path.join(config['summaries_dir'], 'train')):
36 | os.makedirs(os.path.join(config['summaries_dir'], 'train'))
37 | train_summary_writer = tf.summary.FileWriter(os.path.join(config['summaries_dir'], 'train'))
38 | if not os.path.exists(os.path.join(config['summaries_dir'], 'val')):
39 | os.makedirs(os.path.join(config['summaries_dir'], 'val'))
40 | val_summary_writer = tf.summary.FileWriter(os.path.join(config['summaries_dir'], 'val'))
41 |
42 | # determine if data_provider should be bounded
43 | if max_batches_per_epoch is None:
44 | print('Each training epoch will pass through the full dataset.')
45 | train_data_provider = lambda: model.data_provider(config, True)
46 | else:
47 | print('Each training epoch will process at most %d batches.' % max_batches_per_epoch)
48 | global it
49 | it = iter(model.data_provider(config, True))
50 |
51 | def bounded_train_data_provider():
52 | global it
53 | for b in xrange(max_batches_per_epoch):
54 | try:
55 | yield it.next()
56 | except StopIteration:
57 | print('WARNING: reached the end of training data. Looping over it again.')
58 | it = iter(model.data_provider(config, True))
59 | yield it.next()
60 | train_data_provider = bounded_train_data_provider
61 |
62 | # train
63 | for ep in xrange(config['num_epochs']):
64 | print('\nBegin training epoch %d' % ep)
65 | global_step = train_epoch(config, session, model, train_summary_writer, train_post_step, global_step,
66 | train_data_provider, saver)
67 | saver.save(global_step, Policy.EPOCH, epoch=ep)
68 | if train_post_epoch:
69 | for post_epoch in train_post_epoch:
70 | post_epoch()
71 |
72 | print('\nDone epoch %d. Begin validation' % ep)
73 | validate(config, session, model, val_summary_writer, val_post_step, global_step)
74 | if val_post_epoch:
75 | for post_epoch in val_post_epoch:
76 | post_epoch()
77 |
78 |
79 | def train_epoch(config, session, model, summary_writer, post_step, global_step, batch_generator, saver):
80 | console_update_interval = config['progress_update_interval']
81 | pbar = tqdm(total=console_update_interval)
82 | start = time.time()
83 | for b, batch in enumerate(batch_generator()):
84 | verbose_batch = b > 0 and b % console_update_interval == 0
85 |
86 | # training batch
87 | fetched = session.run(model.fetches(True, verbose=verbose_batch), model.prepare_feed_dict(batch, True))
88 |
89 | # update tensorboard summary
90 | summary_writer.add_summary(fetched[0], global_step)
91 | global_step += 1
92 |
93 | # perform post steps
94 | if post_step is not None:
95 | for step in post_step:
96 | step(global_step, batch)
97 |
98 | # udpate progress bar
99 | pbar.set_description("Training Batch: %d. Loss: %.4f" % (b, fetched[1]))
100 | pbar.update()
101 |
102 | if verbose_batch:
103 | pbar.close()
104 | model.progress_update(batch, fetched)
105 | pbar = tqdm(total=console_update_interval)
106 |
107 | saver.save(global_step, Policy.TIMED, seconds_since_last_save=(time.time() - start))
108 | pbar.close()
109 |
110 | return global_step
111 |
112 |
113 | def validate(config, session, model, summary_writer, post_step, global_step):
114 | console_update_interval = config['val_progress_update_interval']
115 | pbar = tqdm(total=console_update_interval)
116 | # validation epoch
117 | for b, batch in enumerate(model.data_provider(config, False)):
118 | verbose_batch = b > 0 and b % console_update_interval == 0
119 |
120 | fetched = session.run(model.fetches(False, verbose=verbose_batch), model.prepare_feed_dict(batch, False))
121 |
122 | # update tensorboard summary
123 | summary_writer.add_summary(fetched[0], global_step)
124 | global_step += 1
125 |
126 | # perform post steps
127 | if post_step is not None:
128 | for step in post_step:
129 | step(b, batch)
130 |
131 | # udpate progress bar
132 | pbar.set_description("Validation Batch: %d. Loss: %.4f" % (b, fetched[1]))
133 | pbar.update()
134 |
135 | if verbose_batch:
136 | pbar.close()
137 | model.progress_update(batch, fetched)
138 | pbar = tqdm(total=console_update_interval)
139 | pbar.close()
140 |
141 | return global_step
142 |
--------------------------------------------------------------------------------
/python/eukg/tf_util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-mal/umls-embeddings/1ac29b15f5bedf8edbd2c90d0ccdae094be03518/python/eukg/tf_util/__init__.py
--------------------------------------------------------------------------------
/python/eukg/threading_util.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import threading
3 | from concurrent.futures import ThreadPoolExecutor
4 |
5 |
6 | condition = threading.Condition()
7 | def synchronized(func):
8 | condition.acquire()
9 | r = func()
10 | condition.notify()
11 | condition.release()
12 | return r
13 |
14 |
15 | def parallel_stream(iterable, parallelizable_fn, future_consumer=None):
16 | """
17 | Executes parallelizable_fn on each element of iterable in parallel. When each thread finishes, its future is passed
18 | to future_consumer if provided.
19 | :param iterable: an iterable of objects to be processed by parallelizable_fn
20 | :param parallelizable_fn: a function that operates on elements of iterable
21 | :param future_consumer: optional consumer of objects returned by parallelizable_fn
22 | :return: void
23 | """
24 | if future_consumer is None:
25 | future_consumer = lambda f: f.result()
26 |
27 | num_threads = multiprocessing.cpu_count()
28 | executor = ThreadPoolExecutor(max_workers=num_threads)
29 | futures = []
30 | for tup in iterable:
31 | # all cpus are in use, wait until at least 1 thread finishes before spawning new threads
32 | if len(futures) >= num_threads:
33 | full_queue = True
34 | while full_queue:
35 | incomplete_futures = []
36 | for fut in futures:
37 | if fut.done():
38 | future_consumer(fut)
39 | full_queue = False
40 | else:
41 | incomplete_futures.append(fut)
42 | futures = incomplete_futures
43 | # spawn new thread
44 | futures += [executor.submit(parallelizable_fn, tup)]
45 |
46 | # ensure all threads have finished
47 | for fut in futures:
48 | future_consumer(fut)
--------------------------------------------------------------------------------
/python/eukg/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 | import numpy as np
4 | import math
5 |
6 | from tf_util import Trainer, ModelSaver
7 |
8 | from emb import EmbeddingModel
9 | from gan import Generator, train_gan, Discriminator
10 | import Config
11 | from data import data_util, DataGenerator
12 |
13 |
14 | def train():
15 | config = Config.flags
16 |
17 | if config.mode == 'gan':
18 | train_gan.train()
19 | exit()
20 |
21 | # init model dir
22 | config.model_dir = os.path.join(config.model_dir, config.model, config.run_name)
23 | if not os.path.exists(config.model_dir):
24 | os.makedirs(config.model_dir)
25 |
26 | # init summaries dir
27 | config.summaries_dir = os.path.join(config.summaries_dir, config.run_name)
28 | if not os.path.exists(config.summaries_dir):
29 | os.makedirs(config.summaries_dir)
30 |
31 | # save the config
32 | data_util.save_config(config.model_dir, config)
33 |
34 | # load data
35 | cui2id, data, train_idx, val_idx = data_util.load_metathesaurus_data(config.data_dir, config.val_proportion)
36 | config.val_progress_update_interval = int(math.floor(float(len(val_idx)) / config.batch_size))
37 | config.batches_per_epoch = int(math.floor(float(len(train_idx)) / config.batch_size))
38 | if not config.no_semantic_network:
39 | type2cuis = data_util.load_semantic_network_data(config.data_dir, data)
40 | else:
41 | type2cuis = None
42 | data_generator = DataGenerator.DataGenerator(data, train_idx, val_idx, config, type2cuis)
43 |
44 | # config map
45 | config_map = config.flag_values_dict()
46 | config_map['data'] = data
47 | config_map['train_idx'] = train_idx
48 | config_map['val_idx'] = val_idx
49 | if not config_map['no_semantic_network']:
50 | config_map['type2cuis'] = type2cuis
51 |
52 | with tf.Graph().as_default(), tf.Session() as session:
53 | # init model
54 | with tf.variable_scope(config.run_name):
55 | model = init_model(config, data_generator)
56 | # session.run(model.train_init_op)
57 |
58 | tf.global_variables_initializer().run()
59 | tf.local_variables_initializer().run()
60 |
61 | # init saver
62 | tf_saver = tf.train.Saver(max_to_keep=10)
63 | saver = init_saver(config, tf_saver, session)
64 |
65 | # load model
66 | global_step = 0
67 | if config.load:
68 | ckpt = tf.train.latest_checkpoint(config.model_dir)
69 | print('Loading checkpoint: %s' % ckpt)
70 | global_step = int(os.path.split(ckpt)[-1].split('-')[-1])
71 | tf_saver.restore(session, ckpt)
72 |
73 | # finalize graph
74 | tf.get_default_graph().finalize()
75 |
76 | # define normalization step
77 | def find_unique(tensor_list):
78 | if max([len(t.shape) for t in tensor_list[:10]]) == 1:
79 | return np.unique(np.concatenate(tensor_list[:10]))
80 | else:
81 | return np.unique(np.concatenate([t.flatten() for t in tensor_list[:10]]))
82 | normalize = lambda _, batch: session.run(model.norm_op,
83 | {model.ids_to_update: find_unique(batch)})
84 |
85 | # define streaming_accuracy reset per epoch
86 | print('local variables that will be reinitialized every epoch: %s' % tf.local_variables())
87 | reset_local_vars = lambda: session.run(model.reset_streaming_metrics_op)
88 |
89 | # train
90 | Trainer.train(config_map, session, model, saver,
91 | train_post_step=[normalize],
92 | train_post_epoch=[reset_local_vars],
93 | val_post_epoch=[reset_local_vars],
94 | global_step=global_step,
95 | max_batches_per_epoch=config_map['max_batches_per_epoch'])
96 |
97 |
98 | def init_model(config, data_generator):
99 | print('Initializing %s embedding model in %s mode...' % (config.model, config.mode))
100 | npz = np.load(config.embedding_file) if config.load_embeddings else None
101 |
102 | if config.model == 'transe':
103 | em = EmbeddingModel.TransE(config, embeddings_dict=npz)
104 | elif config.model == 'transd':
105 | config.embedding_size = config.embedding_size / 2
106 | em = EmbeddingModel.TransD(config, embeddings_dict=npz)
107 | elif config.model == 'distmult':
108 | em = EmbeddingModel.DistMult(config, embeddings_dict=npz)
109 | else:
110 | raise ValueError('Unrecognized model type: %s' % config.model)
111 |
112 | if config.mode == 'disc':
113 | model = Discriminator.BaseModel(config, em, data_generator)
114 | elif config.mode == 'gen':
115 | model = Generator.Generator(config, em, data_generator)
116 | else:
117 | raise ValueError('Unrecognized mode: %s' % config.mode)
118 |
119 | if npz:
120 | # noinspection PyUnresolvedReferences
121 | npz.close()
122 |
123 | model.build()
124 | print('Built model.')
125 | print('use semnet: %s' % model.use_semantic_network)
126 | return model
127 |
128 |
129 | def init_saver(config, tf_saver, session):
130 | model_file = os.path.join(config.model_dir, config.model)
131 | if config.save_strategy == 'timed':
132 | print('Models will be saved every %d seconds' % config.save_interval)
133 | return ModelSaver.TimedSaver(tf_saver, session, model_file, config.save_interval)
134 | elif config.save_strategy == 'epoch':
135 | print('Models will be saved every training epoch')
136 | return ModelSaver.EpochSaver(tf_saver, session, model_file)
137 | else:
138 | raise ValueError('Unrecognized save strategy: %s' % config.save_strategy)
139 |
140 |
141 | if __name__ == "__main__":
142 | train()
143 |
--------------------------------------------------------------------------------