├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── embeddings.csv.zip └── python ├── .gitignore ├── __init__.py ├── bin └── rtc_eval.sh ├── data └── valid_rels.txt └── eukg ├── Config.py ├── __init__.py ├── data ├── DataGenerator.py ├── __init__.py ├── create_semnet_triples.py ├── create_test_set.py ├── create_triples.py └── data_util.py ├── emb ├── EmbeddingModel.py ├── Smoothing.py └── __init__.py ├── gan ├── Discriminator.py ├── Generator.py ├── Generator.pyc ├── __init__.py ├── __init__.pyc ├── train_gan.py └── train_gan.pyc ├── save_embeddings.py ├── test ├── __init__.py ├── classification.py ├── nearest_neighbors.py ├── ppa.py └── ranking_evals.py ├── tf_util ├── ModelSaver.py ├── Trainable.py ├── Trainer.py └── __init__.py ├── threading_util.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | embeddings.csv.zip filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #################### 2 | # Custom Rules 3 | #################### 4 | output/ 5 | out/ 6 | input/ 7 | tmp/ 8 | temp/ 9 | work/ 10 | working/ 11 | /.tgitconfig 12 | *~dev 13 | **/.idea/ 14 | # 15 | #################### 16 | # Logging 17 | #################### 18 | # logs 19 | *.errlog 20 | *.log 21 | *.log.zip 22 | log*.txt 23 | log.txt 24 | log/ 25 | logs/ 26 | # 27 | #################### 28 | # Kirkness 29 | #################### 30 | # kirkness 31 | *.svm_annotations 32 | *.svm_model 33 | *.svm_model.maps 34 | # 35 | #################### 36 | # Java 37 | #################### 38 | *.class 39 | # 40 | # Mobile Tools for Java (J2ME) 41 | .mtj.tmp/ 42 | # 43 | # Package Files # 44 | *.jar 45 | *.war 46 | *.ear 47 | # 48 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 49 | hs_err_pid* 50 | # 51 | #################### 52 | # Scala 53 | #################### 54 | *.class 55 | *.log 56 | # 57 | # sbt specific 58 | .cache 59 | .history 60 | .lib/ 61 | dist/* 62 | target/ 63 | lib_managed/ 64 | src_managed/ 65 | **/project/boot/ 66 | **/project/plugins/project/ 67 | # 68 | # Scala-IDE specific 69 | .scala_dependencies 70 | .worksheet 71 | # 72 | # ENSIME specific 73 | .ensime_cache/ 74 | .ensime 75 | # 76 | #################### 77 | # SBT 78 | #################### 79 | # Simple Build Tool 80 | # http://www.scala-sbt.org/release/docs/Getting-Started/Directories.html#configuring-version-control 81 | # 82 | target/ 83 | lib_managed/ 84 | src_managed/ 85 | project/boot/ 86 | .history 87 | .cache 88 | # 89 | #################### 90 | # Python 91 | #################### 92 | # Byte-compiled / optimized / DLL files 93 | __pycache__/ 94 | *.py[cod] 95 | *$py.class 96 | # 97 | # C extensions 98 | *.so 99 | # 100 | # Distribution / packaging 101 | .Python 102 | env/ 103 | build/ 104 | develop-eggs/ 105 | dist/ 106 | downloads/ 107 | eggs/ 108 | .eggs/ 109 | #lib/ 110 | #lib64/ 111 | parts/ 112 | sdist/ 113 | var/ 114 | *.egg-info/ 115 | .installed.cfg 116 | *.egg 117 | # 118 | # PyInstaller 119 | # Usually these files are written by a python script from a template 120 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 121 | *.manifest 122 | *.spec 123 | # 124 | # Installer logs 125 | pip-log.txt 126 | pip-delete-this-directory.txt 127 | # 128 | # Unit test / coverage reports 129 | htmlcov/ 130 | .tox/ 131 | .coverage 132 | .coverage.* 133 | .cache 134 | nosetests.xml 135 | coverage.xml 136 | *,cover 137 | .hypothesis/ 138 | # 139 | # Translations 140 | *.mo 141 | *.pot 142 | # 143 | # Django stuff: 144 | *.log 145 | local_settings.py 146 | # 147 | # Flask stuff: 148 | instance/ 149 | .webassets-cache 150 | # 151 | # Scrapy stuff: 152 | .scrapy 153 | # 154 | # Sphinx documentation 155 | **/docs/_build/ 156 | # 157 | # PyBuilder 158 | target/ 159 | # 160 | # Jupyter Notebook 161 | .ipynb_checkpoints 162 | # 163 | # pyenv 164 | .python-version 165 | # 166 | # celery beat schedule file 167 | celerybeat-schedule 168 | # 169 | # dotenv 170 | .env 171 | # 172 | # virtualenv 173 | .venv/ 174 | venv/ 175 | ENV/ 176 | # 177 | # Spyder project settings 178 | .spyderproject 179 | # 180 | # Rope project settings 181 | .ropeproject 182 | # 183 | #################### 184 | # Windows 185 | #################### 186 | # Windows image file caches 187 | Thumbs.db 188 | ehthumbs.db 189 | # 190 | # Folder config file 191 | Desktop.ini 192 | # 193 | # Recycle Bin used on file shares 194 | $RECYCLE.BIN/ 195 | # 196 | # Windows Installer files 197 | *.cab 198 | *.msi 199 | *.msm 200 | *.msp 201 | # 202 | # Windows shortcuts 203 | *.lnk 204 | # 205 | #################### 206 | # OSX 207 | #################### 208 | .DS_Store 209 | .AppleDouble 210 | .LSOverride 211 | # 212 | # Icon must end with two \r 213 | Icon 214 | # 215 | # 216 | # Thumbnails 217 | ._* 218 | # 219 | # Files that might appear in the root of a volume 220 | .DocumentRevisions-V100 221 | .fseventsd 222 | .Spotlight-V100 223 | .TemporaryItems 224 | .Trashes 225 | .VolumeIcon.icns 226 | # 227 | # Directories potentially created on remote AFP share 228 | .AppleDB 229 | .AppleDesktop 230 | Network Trash Folder 231 | Temporary Items 232 | .apdisk 233 | # 234 | #################### 235 | # Linux 236 | #################### 237 | *~ 238 | # 239 | # temporary files which can be created if a process still has a handle open of a deleted file 240 | .fuse_hidden* 241 | # 242 | # KDE directory preferences 243 | .directory 244 | # 245 | # Linux trash folder which might appear on any partition or disk 246 | .Trash-* 247 | # 248 | # .nfs files are created when an open file is removed but is still being accessed 249 | .nfs* 250 | # 251 | #################### 252 | # Sublime Text 253 | #################### 254 | # cache files for sublime text 255 | *.tmlanguage.cache 256 | *.tmPreferences.cache 257 | *.stTheme.cache 258 | # 259 | # workspace files are user-specific 260 | *.sublime-workspace 261 | # 262 | # project files should be checked into the repository, unless a significant 263 | # proportion of contributors will probably not be using SublimeText 264 | # *.sublime-project 265 | # 266 | # sftp configuration file 267 | sftp-config.json 268 | # 269 | # Package control specific files 270 | Package Control.last-run 271 | Package Control.ca-list 272 | Package Control.ca-bundle 273 | Package Control.system-ca-bundle 274 | Package Control.cache/ 275 | Package Control.ca-certs/ 276 | bh_unicode_properties.cache 277 | # 278 | # Sublime-github package stores a github token in this file 279 | # https://packagecontrol.io/packages/sublime-github 280 | GitHub.sublime-settings 281 | # 282 | #################### 283 | # Vim 284 | #################### 285 | # swap 286 | [._]*.s[a-w][a-z] 287 | [._]s[a-w][a-z] 288 | # session 289 | Session.vim 290 | # temporary 291 | .netrwhist 292 | *~ 293 | # auto-generated tag files 294 | tags 295 | # 296 | #################### 297 | # Matlab 298 | #################### 299 | ##--------------------------------------------------- 300 | ## Remove autosaves generated by the Matlab editor 301 | ## We have git for backups! 302 | ##--------------------------------------------------- 303 | # 304 | # Windows default autosave extension 305 | *.asv 306 | # 307 | # OSX / *nix default autosave extension 308 | *.m~ 309 | # 310 | # Compiled MEX binaries (all platforms) 311 | *.mex* 312 | # 313 | # Simulink Code Generation 314 | slprj/ 315 | # 316 | # Session info 317 | octave-workspace 318 | # 319 | 320 | #################### 321 | # JetBrainzzzzz 322 | #################### 323 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 324 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 325 | # 326 | # User-specific stuff: 327 | .idea 328 | **/.idea/workspace.xml 329 | **/.idea/tasks.xml 330 | # 331 | # Sensitive or high-churn files: 332 | **/.idea/dataSources.ids 333 | **/.idea/dataSources.xml 334 | **/.idea/dataSources.local.xml 335 | **/.idea/sqlDataSources.xml 336 | **/.idea/dynamic.xml 337 | **/.idea/uiDesigner.xml 338 | # 339 | # Gradle: 340 | **/.idea/gradle.xml 341 | **/.idea/libraries 342 | # SBT 343 | **/.idea/libraries/ 344 | # 345 | # Mongo Explorer plugin: 346 | **/.idea/mongoSettings.xml 347 | # 348 | ## File-based project format: 349 | *.iws 350 | *.iml 351 | # 352 | ## Plugin-specific files: 353 | # 354 | # IntelliJ 355 | /out/ 356 | # 357 | # mpeltonen/sbt-idea plugin 358 | .idea_modules/ 359 | # 360 | # JIRA plugin 361 | atlassian-ide-plugin.xml 362 | # 363 | # Crashlytics plugin (for Android Studio and IntelliJ) 364 | com_crashlytics_export_strings.xml 365 | crashlytics.properties 366 | crashlytics-build.properties 367 | fabric.properties 368 | # 369 | #################### 370 | # iPython Notebooks 371 | #################### 372 | # Temporary data 373 | .ipynb_checkpoints/ 374 | # 375 | #################### 376 | # Eclipse 377 | #################### 378 | *.pydevproject 379 | .metadata 380 | .gradle 381 | tmp/ 382 | *.tmp 383 | *.bak 384 | *.swp 385 | *~.nib 386 | local.properties 387 | .settings/ 388 | .loadpath 389 | # 390 | # Eclipse Core 391 | .project 392 | # 393 | # External tool builders 394 | .externalToolBuilders/ 395 | # 396 | # Locally stored "Eclipse launch configurations" 397 | *.launch 398 | # 399 | # CDT-specific 400 | .cproject 401 | # 402 | # JDT-specific (Eclipse Java Development Tools) 403 | .classpath 404 | # 405 | # Java annotation processor (APT) 406 | .factorypath 407 | # 408 | # PDT-specific 409 | .buildpath 410 | # 411 | # sbteclipse plugin 412 | .target 413 | # 414 | # TeXlipse plugin 415 | .texlipse 416 | # 417 | #################### 418 | # NetBeans 419 | #################### 420 | **/nbproject/private/ 421 | build/ 422 | nbbuild/ 423 | dist/ 424 | nbdist/ 425 | nbactions.xml 426 | nb-configuration.xml 427 | .nb-gradle/ 428 | # 429 | #################### 430 | # btsync 431 | #################### 432 | .sync/ 433 | .sync 434 | # 435 | #################### 436 | # Maven 437 | #################### 438 | target/ 439 | pom.xml.tag 440 | pom.xml.releaseBackup 441 | pom.xml.versionsBackup 442 | pom.xml.next 443 | release.properties 444 | dependency-reduced-pom.xml 445 | buildNumber.properties 446 | **/.mvn/timing.properties 447 | # 448 | #################### 449 | # Archives 450 | #################### 451 | # It's better to unpack these files and commit the raw source because 452 | # git has its own built in compression methods. 453 | *.7z 454 | *.jar 455 | *.rar 456 | *.zip 457 | *.gz 458 | *.bzip 459 | *.bz2 460 | *.xz 461 | *.lzma 462 | *.cab 463 | # 464 | #packing-only formats 465 | *.iso 466 | *.tar 467 | # 468 | #package management formats 469 | *.dmg 470 | *.xpi 471 | *.gem 472 | *.egg 473 | *.deb 474 | *.rpm 475 | *.msi 476 | *.msm 477 | *.msp 478 | # 479 | #################### 480 | # Dropbox 481 | #################### 482 | # Dropbox settings and caches 483 | .dropbox 484 | .dropbox.attr 485 | .dropbox.cache 486 | # 487 | #################### 488 | # Unison 489 | #################### 490 | # unison 491 | .unison* 492 | # 493 | #################### 494 | # SVN 495 | #################### 496 | .svn/ 497 | # 498 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code for the paper [Adversarial Learning of Knowledge Embeddings for the Unified Medical Language System](http://www.hlt.utdallas.edu/~ramon/papers/amia_cri_2019.pdf) to be presented at the AMIA Informatics Summit 2019. 2 | 3 | Requires Tensorflow version 1.9 4 | 5 | #### Data Preprocessing: 6 | 1. First download/extract the [UMLS](https://www.nlm.nih.gov/research/umls/licensedcontent/umlsknowledgesources.html). This project assumes the UMLS files are laid out as such: 7 | ``` 8 | / 9 | META/ 10 | MRCONSO.RRF 11 | MRSTY.RRF 12 | NET/ 13 | SRSTRE1 14 | ``` 15 | 2. Create Metathesaurus triples 16 | 17 | ```bash 18 | python -m eukg.data.create_triples 19 | ``` 20 | This will create the Metathesaurus train/test triples in data/metathesaurus. 21 | 3. Create Semantic Network Triples 22 | ```bash 23 | python -m eukg.data.create_triples 24 | ``` 25 | 26 | #### Training: 27 | To train the Metathesaurus Discriminator: 28 | ```bash 29 | python -m eukg.train --mode=disc --model=transd --run_name=transd-disc --no_semantic_network 30 | ``` 31 | To train the both Metathesaurus and Semantic Network Discriminators: 32 | ```bash 33 | python -m eukg.train --mode=disc --model=transd --run_name=transd-sn-disc 34 | ``` 35 | To train the Metathesaurus Generator: 36 | ```bash 37 | python -m eukg.train --mode=gen --model=distmult --run_name=dm-gen --no_semantic_network --learning_rate=1e-3 38 | ``` 39 | To train the Metathesaurus and Semantic Network Generators: 40 | ```bash 41 | python -m eukg.train --mode=gen --model=distmult --run_name=dm-sn-gen --learning_rate=1e-3 42 | ``` 43 | To train the full GAN model: 44 | ```bash 45 | python -m eukg.train --mode=gan --model=transd --run_name=gan --dis_run_name=transd-sn-disc --gen_run_name=dm-sn-gen 46 | ``` 47 | Note that the GAN model requires a pretrained discriminator and generator 48 | -------------------------------------------------------------------------------- /embeddings.csv.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:40bf01ac241199bb26f3163fa0ecf86aebaba9c526a8fe363dbe1b8cae026da2 3 | size 823572063 4 | -------------------------------------------------------------------------------- /python/.gitignore: -------------------------------------------------------------------------------- 1 | #################### 2 | # Custom Rules 3 | #################### 4 | output/ 5 | out/ 6 | input/ 7 | tmp/ 8 | temp/ 9 | work/ 10 | working/ 11 | /.tgitconfig 12 | *~dev 13 | **/.idea/ 14 | # 15 | #################### 16 | # Logging 17 | #################### 18 | # logs 19 | *.errlog 20 | *.log 21 | *.log.zip 22 | log*.txt 23 | log.txt 24 | log/ 25 | logs/ 26 | # 27 | #################### 28 | # Kirkness 29 | #################### 30 | # kirkness 31 | *.svm_annotations 32 | *.svm_model 33 | *.svm_model.maps 34 | # 35 | #################### 36 | # Java 37 | #################### 38 | *.class 39 | # 40 | # Mobile Tools for Java (J2ME) 41 | .mtj.tmp/ 42 | # 43 | # Package Files # 44 | *.jar 45 | *.war 46 | *.ear 47 | # 48 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 49 | hs_err_pid* 50 | # 51 | #################### 52 | # Scala 53 | #################### 54 | *.class 55 | *.log 56 | # 57 | # sbt specific 58 | .cache 59 | .history 60 | .lib/ 61 | dist/* 62 | target/ 63 | lib_managed/ 64 | src_managed/ 65 | **/project/boot/ 66 | **/project/plugins/project/ 67 | # 68 | # Scala-IDE specific 69 | .scala_dependencies 70 | .worksheet 71 | # 72 | # ENSIME specific 73 | .ensime_cache/ 74 | .ensime 75 | # 76 | #################### 77 | # SBT 78 | #################### 79 | # Simple Build Tool 80 | # http://www.scala-sbt.org/release/docs/Getting-Started/Directories.html#configuring-version-control 81 | # 82 | target/ 83 | lib_managed/ 84 | src_managed/ 85 | project/boot/ 86 | .history 87 | .cache 88 | # 89 | #################### 90 | # Python 91 | #################### 92 | # Byte-compiled / optimized / DLL files 93 | __pycache__/ 94 | *.py[cod] 95 | *$py.class 96 | # 97 | # C extensions 98 | *.so 99 | # 100 | # Distribution / packaging 101 | .Python 102 | env/ 103 | build/ 104 | develop-eggs/ 105 | dist/ 106 | downloads/ 107 | eggs/ 108 | .eggs/ 109 | #lib/ 110 | #lib64/ 111 | parts/ 112 | sdist/ 113 | var/ 114 | *.egg-info/ 115 | .installed.cfg 116 | *.egg 117 | # 118 | # PyInstaller 119 | # Usually these files are written by a python script from a template 120 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 121 | *.manifest 122 | *.spec 123 | # 124 | # Installer logs 125 | pip-log.txt 126 | pip-delete-this-directory.txt 127 | # 128 | # Unit test / coverage reports 129 | htmlcov/ 130 | .tox/ 131 | .coverage 132 | .coverage.* 133 | .cache 134 | nosetests.xml 135 | coverage.xml 136 | *,cover 137 | .hypothesis/ 138 | # 139 | # Translations 140 | *.mo 141 | *.pot 142 | # 143 | # Django stuff: 144 | *.log 145 | local_settings.py 146 | # 147 | # Flask stuff: 148 | instance/ 149 | .webassets-cache 150 | # 151 | # Scrapy stuff: 152 | .scrapy 153 | # 154 | # Sphinx documentation 155 | **/docs/_build/ 156 | # 157 | # PyBuilder 158 | target/ 159 | # 160 | # Jupyter Notebook 161 | .ipynb_checkpoints 162 | # 163 | # pyenv 164 | .python-version 165 | # 166 | # celery beat schedule file 167 | celerybeat-schedule 168 | # 169 | # dotenv 170 | .env 171 | # 172 | # virtualenv 173 | .venv/ 174 | venv/ 175 | ENV/ 176 | # 177 | # Spyder project settings 178 | .spyderproject 179 | # 180 | # Rope project settings 181 | .ropeproject 182 | # 183 | #################### 184 | # Windows 185 | #################### 186 | # Windows image file caches 187 | Thumbs.db 188 | ehthumbs.db 189 | # 190 | # Folder config file 191 | Desktop.ini 192 | # 193 | # Recycle Bin used on file shares 194 | $RECYCLE.BIN/ 195 | # 196 | # Windows Installer files 197 | *.cab 198 | *.msi 199 | *.msm 200 | *.msp 201 | # 202 | # Windows shortcuts 203 | *.lnk 204 | # 205 | #################### 206 | # OSX 207 | #################### 208 | .DS_Store 209 | .AppleDouble 210 | .LSOverride 211 | # 212 | # Icon must end with two \r 213 | Icon 214 | # 215 | # 216 | # Thumbnails 217 | ._* 218 | # 219 | # Files that might appear in the root of a volume 220 | .DocumentRevisions-V100 221 | .fseventsd 222 | .Spotlight-V100 223 | .TemporaryItems 224 | .Trashes 225 | .VolumeIcon.icns 226 | # 227 | # Directories potentially created on remote AFP share 228 | .AppleDB 229 | .AppleDesktop 230 | Network Trash Folder 231 | Temporary Items 232 | .apdisk 233 | # 234 | #################### 235 | # Linux 236 | #################### 237 | *~ 238 | # 239 | # temporary files which can be created if a process still has a handle open of a deleted file 240 | .fuse_hidden* 241 | # 242 | # KDE directory preferences 243 | .directory 244 | # 245 | # Linux trash folder which might appear on any partition or disk 246 | .Trash-* 247 | # 248 | # .nfs files are created when an open file is removed but is still being accessed 249 | .nfs* 250 | # 251 | #################### 252 | # Sublime Text 253 | #################### 254 | # cache files for sublime text 255 | *.tmlanguage.cache 256 | *.tmPreferences.cache 257 | *.stTheme.cache 258 | # 259 | # workspace files are user-specific 260 | *.sublime-workspace 261 | # 262 | # project files should be checked into the repository, unless a significant 263 | # proportion of contributors will probably not be using SublimeText 264 | # *.sublime-project 265 | # 266 | # sftp configuration file 267 | sftp-config.json 268 | # 269 | # Package control specific files 270 | Package Control.last-run 271 | Package Control.ca-list 272 | Package Control.ca-bundle 273 | Package Control.system-ca-bundle 274 | Package Control.cache/ 275 | Package Control.ca-certs/ 276 | bh_unicode_properties.cache 277 | # 278 | # Sublime-github package stores a github token in this file 279 | # https://packagecontrol.io/packages/sublime-github 280 | GitHub.sublime-settings 281 | # 282 | #################### 283 | # Vim 284 | #################### 285 | # swap 286 | [._]*.s[a-w][a-z] 287 | [._]s[a-w][a-z] 288 | # session 289 | Session.vim 290 | # temporary 291 | .netrwhist 292 | *~ 293 | # auto-generated tag files 294 | tags 295 | # 296 | #################### 297 | # Matlab 298 | #################### 299 | ##--------------------------------------------------- 300 | ## Remove autosaves generated by the Matlab editor 301 | ## We have git for backups! 302 | ##--------------------------------------------------- 303 | # 304 | # Windows default autosave extension 305 | *.asv 306 | # 307 | # OSX / *nix default autosave extension 308 | *.m~ 309 | # 310 | # Compiled MEX binaries (all platforms) 311 | *.mex* 312 | # 313 | # Simulink Code Generation 314 | slprj/ 315 | # 316 | # Session info 317 | octave-workspace 318 | # 319 | 320 | #################### 321 | # JetBrainzzzzz 322 | #################### 323 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 324 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 325 | # 326 | # User-specific stuff: 327 | .idea 328 | **/.idea/workspace.xml 329 | **/.idea/tasks.xml 330 | # 331 | # Sensitive or high-churn files: 332 | **/.idea/dataSources.ids 333 | **/.idea/dataSources.xml 334 | **/.idea/dataSources.local.xml 335 | **/.idea/sqlDataSources.xml 336 | **/.idea/dynamic.xml 337 | **/.idea/uiDesigner.xml 338 | # 339 | # Gradle: 340 | **/.idea/gradle.xml 341 | **/.idea/libraries 342 | # SBT 343 | **/.idea/libraries/ 344 | # 345 | # Mongo Explorer plugin: 346 | **/.idea/mongoSettings.xml 347 | # 348 | ## File-based project format: 349 | *.iws 350 | *.iml 351 | # 352 | ## Plugin-specific files: 353 | # 354 | # IntelliJ 355 | /out/ 356 | # 357 | # mpeltonen/sbt-idea plugin 358 | .idea_modules/ 359 | # 360 | # JIRA plugin 361 | atlassian-ide-plugin.xml 362 | # 363 | # Crashlytics plugin (for Android Studio and IntelliJ) 364 | com_crashlytics_export_strings.xml 365 | crashlytics.properties 366 | crashlytics-build.properties 367 | fabric.properties 368 | # 369 | #################### 370 | # iPython Notebooks 371 | #################### 372 | # Temporary data 373 | .ipynb_checkpoints/ 374 | # 375 | #################### 376 | # Eclipse 377 | #################### 378 | *.pydevproject 379 | .metadata 380 | .gradle 381 | tmp/ 382 | *.tmp 383 | *.bak 384 | *.swp 385 | *~.nib 386 | local.properties 387 | .settings/ 388 | .loadpath 389 | # 390 | # Eclipse Core 391 | .project 392 | # 393 | # External tool builders 394 | .externalToolBuilders/ 395 | # 396 | # Locally stored "Eclipse launch configurations" 397 | *.launch 398 | # 399 | # CDT-specific 400 | .cproject 401 | # 402 | # JDT-specific (Eclipse Java Development Tools) 403 | .classpath 404 | # 405 | # Java annotation processor (APT) 406 | .factorypath 407 | # 408 | # PDT-specific 409 | .buildpath 410 | # 411 | # sbteclipse plugin 412 | .target 413 | # 414 | # TeXlipse plugin 415 | .texlipse 416 | # 417 | #################### 418 | # NetBeans 419 | #################### 420 | **/nbproject/private/ 421 | build/ 422 | nbbuild/ 423 | dist/ 424 | nbdist/ 425 | nbactions.xml 426 | nb-configuration.xml 427 | .nb-gradle/ 428 | # 429 | #################### 430 | # btsync 431 | #################### 432 | .sync/ 433 | .sync 434 | # 435 | #################### 436 | # Maven 437 | #################### 438 | target/ 439 | pom.xml.tag 440 | pom.xml.releaseBackup 441 | pom.xml.versionsBackup 442 | pom.xml.next 443 | release.properties 444 | dependency-reduced-pom.xml 445 | buildNumber.properties 446 | **/.mvn/timing.properties 447 | # 448 | #################### 449 | # Archives 450 | #################### 451 | # It's better to unpack these files and commit the raw source because 452 | # git has its own built in compression methods. 453 | *.7z 454 | *.jar 455 | *.rar 456 | *.zip 457 | *.gz 458 | *.bzip 459 | *.bz2 460 | *.xz 461 | *.lzma 462 | *.cab 463 | # 464 | #packing-only formats 465 | *.iso 466 | *.tar 467 | # 468 | #package management formats 469 | *.dmg 470 | *.xpi 471 | *.gem 472 | *.egg 473 | *.deb 474 | *.rpm 475 | *.msi 476 | *.msm 477 | *.msp 478 | # 479 | #################### 480 | # Dropbox 481 | #################### 482 | # Dropbox settings and caches 483 | .dropbox 484 | .dropbox.attr 485 | .dropbox.cache 486 | # 487 | #################### 488 | # Unison 489 | #################### 490 | # unison 491 | .unison* 492 | # 493 | #################### 494 | # SVN 495 | #################### 496 | .svn/ 497 | # 498 | -------------------------------------------------------------------------------- /python/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '/home/rmm120030/code/python') 3 | -------------------------------------------------------------------------------- /python/bin/rtc_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # transe 3 | pytf -m eukg.test.classification --mode=disc --model=transe --run_name=transe-disc 4 | # transd 5 | pytf -m eukg.test.classification --mode=disc --model=transd --run_name=transd-disc_e0init --embedding_size=100 6 | # transe-sn 7 | pytf -m eukg.test.classification --mode=disc --model=transe --run_name=transe-sn-disc 8 | # transd-sn 9 | pytf -m eukg.test.classification --mode=disc --model=transd --run_name=transd-sn-disc_e0init --embedding_size=100 10 | # transe-sn gan 11 | pytf -m eukg.test.classification --mode=gan --model=transe --run_name=sn-gan --dis_run_name=transe-sn-disc 12 | # transd-sn gan 13 | pytf -m eukg.test.classification --mode=gan --model=transd --run_name=gan --dis_run_name=transd-sn-disc_e0init --embedding_size=100 14 | ################ SN ################## 15 | # transe-sn 16 | pytf -m eukg.test.classification --mode=disc --model=transe --run_name=transe-sn-disc --sn_eval --eval_mode=sn 17 | # transd-sn 18 | pytf -m eukg.test.classification --mode=disc --model=transd --run_name=transd-sn-disc_e0init --embedding_size=100 --sn_eval --eval_mode=sn 19 | # transe-sn gan 20 | pytf -m eukg.test.classification --mode=gan --model=transe --run_name=sn-gan --dis_run_name=transe-sn-disc --sn_eval --eval_mode=sn 21 | # transd-sn gan 22 | pytf -m eukg.test.classification --mode=gan --model=transd --run_name=gan --dis_run_name=transd-sn-disc_e0init --embedding_size=100 --sn_eval --eval_mode=sn 23 | -------------------------------------------------------------------------------- /python/data/valid_rels.txt: -------------------------------------------------------------------------------- 1 | has_tradename 2 | has_ingredients 3 | organism_has_gene 4 | has_presentation_strength_numerator_value 5 | replaces 6 | chemical_or_drug_affects_gene_product 7 | has_supersystem 8 | disease_excludes_normal_tissue_origin 9 | has_specialty 10 | has_branch 11 | possibly_equivalent_to 12 | sib_in_isa 13 | isa 14 | procedure_site_of 15 | has_time_aspect 16 | component_of 17 | regimen_has_accepted_use_for_disease 18 | add_on_code_for 19 | has_basic_dose_form 20 | has_quotient 21 | partially_excised_anatomy_has_procedure 22 | disease_has_cytogenetic_abnormality 23 | direct_device_of 24 | disease_excludes_normal_cell_origin 25 | has_at_risk_population 26 | has_subject_of_information 27 | is_location_of_anatomic_structure 28 | primary_mapped_to 29 | may_be_finding_of_disease 30 | subject_relationship_context_of 31 | form_of 32 | expected_outcome_of 33 | actual_outcome_of 34 | process_includes_biological_process 35 | has_target 36 | is_metastatic_anatomic_site_of_disease 37 | has_alternative 38 | distal_to 39 | challenge_of 40 | state_of_matter_of 41 | chemical_or_drug_initiates_biological_process 42 | has_contraindicated_physiologic_effect 43 | has_method 44 | drains_into 45 | allele_in_chromosomal_location 46 | biomarker_type_includes_gene 47 | anterolateral_to 48 | clinical_course_of 49 | gene_product_is_biomarker_of 50 | approach_of_possibly_included 51 | posteroinferior_to 52 | during 53 | disease_has_associated_gene 54 | attaches_to 55 | is_chemical_classification_of_gene_product 56 | has_nerve_supply 57 | has_precise_ingredient 58 | has_access 59 | has_pathological_process 60 | has_measurement_method 61 | has_exam 62 | has_laterality 63 | refers_to 64 | was_a 65 | completely_excised_anatomy_has_procedure 66 | reformulated_to 67 | disease_may_have_associated_disease 68 | is_physiologic_effect_of_chemical_or_drug 69 | enzyme_metabolizes_chemical_or_drug 70 | exhibited_by 71 | has_dose_form_intended_site 72 | has_suffix 73 | has_context_binding 74 | has_modality_subtype 75 | has_interpretation 76 | has_specimen 77 | has_possibly_included_pathology 78 | has_free_acid_or_base_form 79 | connected_to 80 | uses_energy 81 | has_arterial_supply 82 | has_count 83 | classifies 84 | constitutes 85 | positively_regulates 86 | has_lymphatic_drainage 87 | has_active_ingredient 88 | has_finding_informer 89 | has_extent 90 | has_possibly_included_procedure_site 91 | may_be_normal_cell_origin_of_disease 92 | surrounds 93 | has_recipient_category 94 | has_access_instrument 95 | icd_asterisk 96 | has_quantified_form 97 | has_subject 98 | chemical_or_drug_is_product_of_biological_process 99 | disease_is_marked_by_gene 100 | has_location 101 | has_excluded_specimen 102 | has_direct_site 103 | inferolateral_to 104 | fuses_with 105 | superomedial_to 106 | develops_into 107 | endogenous_product_related_to 108 | adjacent_to 109 | has_clinician_form 110 | has_intent 111 | chemical_or_drug_affects_abnormal_cell 112 | process_involves_gene 113 | has_focus 114 | has_object_guidance 115 | has_doseformgroup 116 | has_common_name 117 | has_component 118 | has_entire_anatomy_structure 119 | has_diagnostic_criteria 120 | has_sign_or_symptom 121 | has_maneuver_type 122 | gene_product_has_biochemical_function 123 | has_segmental_supply 124 | pathway_has_gene_element 125 | has_associated_etiologic_finding 126 | procedure_has_target_anatomy 127 | regulates 128 | sib_in_tributary_of 129 | has_secondary_segmental_supply 130 | has_primary_segmental_supply 131 | do_not_code_with 132 | has_alias 133 | has_finding_context 134 | has_associated_morphology 135 | has_expanded_form 136 | exhibits 137 | has_excluded_method 138 | has_indirect_procedure_site 139 | print_name_of 140 | has_property 141 | has_route_of_administration 142 | has_manifestation 143 | biological_process_has_associated_location 144 | complex_has_physical_part 145 | has_scale_type 146 | may_be_cytogenetic_abnormality_of_disease 147 | has_dose_form_transformation 148 | disease_is_stage 149 | has_locale 150 | has_british_form 151 | disease_mapped_to_chromosome 152 | pathogenesis_of_disease_involves_gene 153 | disease_excludes_abnormal_cell 154 | has_definitional_manifestation 155 | has_causative_agent 156 | has_origin 157 | has_allelic_variant 158 | has_fragments_for_synonyms 159 | has_regional_part 160 | disease_excludes_primary_anatomic_site 161 | continuous_distally_with 162 | precondition_of 163 | has_development_type 164 | cell_type_is_associated_with_eo_disease 165 | has_possibly_included_component 166 | has_possibly_included_surgical_extent 167 | forms 168 | temporally_related_to 169 | happens_during 170 | has_related_developmental_entity 171 | has_technique 172 | has_possibly_included_patient_type 173 | associated_genetic_condition 174 | chemical_or_drug_plays_role_in_biological_process 175 | may_qualify 176 | gene_product_has_abnormality 177 | has_possibly_included_associated_finding 178 | has_ingredient 179 | eo_disease_has_property_or_attribute 180 | has_presentation_strength_numerator_unit 181 | has_presentation_strength_denominator_unit 182 | has_presentation_strength_denominator_value 183 | has_basis_of_strength_substance 184 | has_measured_component 185 | has_presence_guidance 186 | has_temporal_context 187 | disease_mapped_to_gene 188 | has_surgical_approach 189 | biomarker_type_includes_gene_product 190 | has_part_anatomy_structure 191 | has_specimen_source_identity 192 | has_possibly_included_procedure_device 193 | tissue_is_expression_site_of_gene_product 194 | has_conceptual_part 195 | has_multi_level_category 196 | has_episodicity 197 | derives 198 | has_system 199 | has_indirect_device 200 | lateral_to 201 | has_time_modifier 202 | has_excluded_procedure_device 203 | has_pharmacokinetics 204 | activity_of_allele 205 | homonym_for 206 | manufactures 207 | has_lateral_location_presence 208 | has_given_pharmaceutical_substance 209 | has_indirect_morphology 210 | has_scale 211 | mth_has_expanded_form 212 | related_part 213 | disease_has_finding 214 | has_snomed_parent 215 | uses 216 | see 217 | posterior_to 218 | continuous_with 219 | use 220 | has_imaging_focus 221 | genetic_biomarker_related_to 222 | characterizes 223 | allele_plays_altered_role_in_process 224 | has_panel_element 225 | has_venous_drainage 226 | chemotherapy_regimen_has_component 227 | includes 228 | has_consumer_friendly_form 229 | excised_anatomy_has_procedure 230 | has_approach 231 | has_lateral_anatomic_location 232 | has_physiologic_effect 233 | sib_in_part_of 234 | sib_in_branch_of 235 | has_dose_form 236 | has_instrumentation 237 | chromosome_involved_in_cytogenetic_abnormality 238 | has_active_moiety 239 | has_revision_status 240 | has_direct_substance 241 | disease_has_primary_anatomic_site 242 | has_procedure_device 243 | has_member 244 | temporally_follows 245 | disease_has_normal_cell_origin 246 | biological_process_has_result_anatomy 247 | classifies_class_code 248 | biological_process_results_from_biological_process 249 | has_associated_function 250 | has_answer 251 | has_patient_type 252 | has_adjustment 253 | has_phenotype 254 | related_object 255 | gene_product_sequence_variation_encoded_by_gene_mutant 256 | consider 257 | has_surgical_extent 258 | icd_dagger 259 | gene_product_malfunction_associated_with_disease 260 | has_associated_condition 261 | has_timing_of 262 | mth_has_xml_form 263 | mth_has_plain_text_form 264 | superior_to 265 | has_snomed_synonym 266 | disease_excludes_molecular_abnormality 267 | eo_disease_has_associated_eo_anatomy 268 | part_of 269 | allele_has_abnormality 270 | has_developmental_stage 271 | has_property_type 272 | uses_possibly_included_substance 273 | related_to 274 | has_direct_procedure_site 275 | has_view_type 276 | is_grade_of_disease 277 | measures 278 | contains 279 | has_pathology 280 | same_as 281 | has_procedure_context 282 | moved_to 283 | has_action_guidance 284 | bounds 285 | has_divisor 286 | disease_has_associated_disease 287 | occurs_in 288 | has_modality_type 289 | disease_excludes_finding 290 | has_specimen_substance 291 | has_defining_characteristic 292 | supported_concept_relationship_in 293 | has_risk_factor 294 | receives_input_from 295 | disease_may_have_molecular_abnormality 296 | modifies 297 | gene_has_abnormality 298 | has_dose_form_administration_method 299 | has_supported_concept_property 300 | has_dose_form_release_characteristic 301 | has_chemical_structure 302 | receives_projection 303 | afferent_to 304 | has_multipart 305 | has_specimen_procedure 306 | has_inactive_ingredient 307 | has_excluded_procedure_site 308 | has_constitutional_part 309 | has_disposition 310 | articulates_with 311 | has_specimen_source_topography 312 | interprets 313 | has_excluded_pathology 314 | posterolateral_to 315 | has_possibly_included_associated_procedure 316 | has_onset 317 | has_aggregation_view 318 | disease_may_have_abnormal_cell 319 | biological_process_has_initiator_process 320 | has_lab_number 321 | has_course 322 | mth_british_form_of 323 | uses_device 324 | disease_has_normal_tissue_origin 325 | has_tributary 326 | disease_excludes_cytogenetic_abnormality 327 | chemical_or_drug_has_mechanism_of_action 328 | has_evaluation 329 | chemical_or_drug_affects_cell_type_or_tissue 330 | mapped_to 331 | has_approach_guidance 332 | concept_in_subset 333 | gene_product_plays_role_in_biological_process 334 | occurs_before 335 | gene_encodes_gene_product 336 | disease_has_abnormal_cell 337 | has_direct_morphology 338 | negatively_regulates 339 | gene_product_is_element_in_pathway 340 | molecular_abnormality_involves_gene 341 | has_severity 342 | has_finding_method 343 | has_communication_with_wound 344 | has_associated_finding 345 | gene_product_has_gene_product_variant 346 | has_specimen_source_morphology 347 | has_mechanism_of_action 348 | has_projection 349 | has_possibly_included_method 350 | has_imaged_location 351 | gene_product_has_organism_source 352 | eo_disease_maps_to_human_disease 353 | has_associated_procedure 354 | analyzes 355 | has_insertion 356 | has_contraindicated_mechanism_of_action 357 | disease_has_molecular_abnormality 358 | gene_product_has_structural_domain_or_motif 359 | gene_product_has_associated_anatomy 360 | has_inherent_attribute 361 | uses_substance 362 | has_physical_part_of_anatomic_structure 363 | has_finding_site 364 | has_parent 365 | has_procedure_morphology 366 | cause_of 367 | disease_has_associated_anatomic_site 368 | matures_into 369 | has_excluded_associated_procedure 370 | corresponds_to 371 | transforms_into 372 | has_continuation_branch 373 | posterosuperior_to 374 | site_of_metabolism 375 | has_cdrh_parent 376 | associated_with 377 | has_pharmaceutical_route 378 | has_inheritance_type 379 | direct_right_of 380 | has_related_factor 381 | uses_excluded_substance 382 | has_class 383 | gene_in_chromosomal_location 384 | other_mapped_to 385 | has_priority 386 | uses_access_device 387 | has_therapeutic_class 388 | has_nichd_parent 389 | -------------------------------------------------------------------------------- /python/eukg/Config.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # control options 4 | tf.flags.DEFINE_string("data_dir", "/users/rmm120030/working/umls-mke/new", "training data directory") 5 | tf.flags.DEFINE_string("model_dir", "/users/rmm120030/working/umls-mke/new/model", "model directory") 6 | tf.flags.DEFINE_string("model", "transd", "Model: [transe, transd, distmult]") 7 | tf.flags.DEFINE_string("mode", "disc", "Mode: [disc, gen, gan]") 8 | tf.flags.DEFINE_string("run_name", None, "Run name") 9 | tf.flags.DEFINE_string("summaries_dir", "/users/rmm120030/working/umls-mke/new/model", "model summary dir") 10 | tf.flags.DEFINE_integer("batch_size", 564, "batch size") 11 | tf.flags.DEFINE_bool("load", False, "Load model?") 12 | tf.flags.DEFINE_bool("load_embeddings", False, "Load embeddings?") 13 | tf.flags.DEFINE_string("embedding_file", None, "Embedding matrix npz") 14 | tf.flags.DEFINE_integer("num_epochs", 100, "Number of epochs?") 15 | tf.flags.DEFINE_float("val_proportion", 0.1, "Proportion of training data to hold out for validation") 16 | tf.flags.DEFINE_integer("progress_update_interval", 1000, "Number of batches between progress updates") 17 | tf.flags.DEFINE_integer("val_progress_update_interval", 1800, "Number of batches between progress updates") 18 | tf.flags.DEFINE_integer("save_interval", 1800, "Number of seconds between model saves") 19 | tf.flags.DEFINE_integer("batches_per_epoch", 1, "Number of batches per training epoch") 20 | tf.flags.DEFINE_integer("max_batches_per_epoch", None, "Maximum number of batches per training epoch") 21 | tf.flags.DEFINE_string("embedding_device", "gpu", "Device to do embedding lookups on [gpu, cpu]") 22 | tf.flags.DEFINE_string("optimizer", "adam", "Optimizer [adam, sgd]") 23 | tf.flags.DEFINE_string("save_strategy", "epoch", "Save every epoch or saved every" 24 | " flags.save_interval seconds [epoch, timed]") 25 | 26 | # eval control options 27 | tf.flags.DEFINE_string("eval_mode", "save", "Evaluation mode: [save, calc]") 28 | tf.flags.DEFINE_string("eval_dir", "eval", "directory for evaluation outputs") 29 | tf.flags.DEFINE_integer("shard", 1, "Shard number for distributed eval.") 30 | tf.flags.DEFINE_integer("num_shards", 1, "Total number of shards for distributed eval.") 31 | tf.flags.DEFINE_bool("save_ranks", True, "Save ranks? (turn off while debugging)") 32 | 33 | # gan control options 34 | tf.flags.DEFINE_string("dis_run_name", None, "Run name for the discriminator model") 35 | tf.flags.DEFINE_string("gen_run_name", "dm-gen", "Run name for the generator model") 36 | tf.flags.DEFINE_string("sn_gen_run_name", "dm-sn-gen", "Run name for the semnet generator model") 37 | 38 | # model params 39 | tf.flags.DEFINE_float("learning_rate", 1e-5, "Starting learning rate") 40 | tf.flags.DEFINE_float("decay_rate", 0.96, "LR decay rate") 41 | tf.flags.DEFINE_float("momentum", 0.9, "Momentum") 42 | tf.flags.DEFINE_float("gamma", 0.1, "Margin parameter for loss") 43 | tf.flags.DEFINE_integer("vocab_size", 3210965, "Number of unique concepts+relations") 44 | tf.flags.DEFINE_integer("embedding_size", 50, "embedding size") 45 | tf.flags.DEFINE_integer("energy_norm_ord", 1, 46 | "Order of the normalization function used to quantify difference between h+r and t") 47 | tf.flags.DEFINE_integer("max_concepts_per_type", 1000, "Maximum number of concepts to average for semtype loss") 48 | tf.flags.DEFINE_integer("num_generator_samples", 100, "Number of negative samples for each generator example") 49 | tf.flags.DEFINE_string("p_init", "zeros", 50 | "Projection vectors initializer: [zeros, xavier, uniform]. Uniform is in [-.1,.1]") 51 | 52 | # semnet params 53 | tf.flags.DEFINE_bool("no_semantic_network", False, "Do not add semantic network loss to the graph?") 54 | tf.flags.DEFINE_float("semnet_alignment_param", 0.5, "Parameter to control semantic network loss") 55 | tf.flags.DEFINE_float("semnet_energy_param", 0.5, "Parameter to control semantic network loss") 56 | tf.flags.DEFINE_bool("sn_eval", False, "Train this model with subset of sn to evaluate the SN embeddings?") 57 | 58 | # distmult params 59 | tf.flags.DEFINE_float("regularization_parameter", 1e-4, "Regularization term weight") 60 | tf.flags.DEFINE_string("energy_activation", 'sigmoid', 61 | "Energy activation function [None, tanh, relu, sigmoid]") 62 | 63 | 64 | flags = tf.flags.FLAGS 65 | -------------------------------------------------------------------------------- /python/eukg/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '/home/rmm120030/code/python') 3 | -------------------------------------------------------------------------------- /python/eukg/data/DataGenerator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from itertools import izip 5 | from collections import defaultdict 6 | from tqdm import tqdm 7 | import os 8 | import ujson 9 | 10 | import data_util 11 | 12 | 13 | class DataGenerator: 14 | def __init__(self, data, train_idx, val_idx, config, type2cuis=None, test_mode=False): 15 | self.data = data 16 | self.train_idx = train_idx 17 | self.val_idx = val_idx 18 | self.config = config 19 | self._sampler = None 20 | self._sn_sampler = None 21 | if not config.no_semantic_network: 22 | assert type2cuis 23 | self._sn_sampler = NegativeSampler(data['sn_subj'], data['sn_rel'], data['sn_obj'], 'semnet') 24 | # if we wish to train this model for sn eval, only use 25 | if config.sn_eval: 26 | with np.load(os.path.join(config.data_dir, 'semnet', 'train.npz')) as sn_npz: 27 | for key, val in sn_npz.iteritems(): 28 | self.data['sn_' + key] = val 29 | 30 | self.type2cuis = type2cuis 31 | self.test_mode = test_mode 32 | 33 | @property 34 | def sn_sampler(self): 35 | # if self._sn_sampler is None: 36 | # self._sn_sampler = self.init_sn_sampler() 37 | return self._sn_sampler 38 | 39 | @property 40 | def sampler(self): 41 | if self._sampler is None: 42 | self._sampler = self.init_sampler() 43 | return self._sampler 44 | 45 | # must include test data in negative sampler 46 | def init_sampler(self): 47 | if self.test_mode: 48 | _, test_data, _, _ = data_util.load_metathesaurus_data(self.config.data_dir, 0.) 49 | else: 50 | test_data = data_util.load_metathesaurus_test_data(self.config.data_dir) 51 | valid_triples = set() 52 | for s, r, o in zip(self.data['subj'], self.data['rel'], self.data['obj']): 53 | valid_triples.add((s, r, o)) 54 | for s, r, o in zip(test_data['subj'], test_data['rel'], test_data['obj']): 55 | valid_triples.add((s, r, o)) 56 | 57 | return NegativeSampler(valid_triples=valid_triples, name='mt') 58 | 59 | def generate_mt(self, is_training): 60 | idxs = self.train_idx if is_training else self.val_idx 61 | batch_size = self.config.batch_size 62 | subj, rel, obj = self.data['subj'], self.data['rel'], self.data['obj'] 63 | nsubj, nobj = self.sampler.sample(subj, rel, obj) 64 | num_batches = int(math.floor(float(len(idxs)) / batch_size)) 65 | print('\n\ngenerating %d batches' % num_batches) 66 | for b in xrange(num_batches): 67 | idx = idxs[b * batch_size: (b + 1) * batch_size] 68 | yield rel[idx], subj[idx], obj[idx], nsubj[idx], nobj[idx] 69 | 70 | def generate_mt_gen_mode(self, is_training): 71 | idxs = self.train_idx if is_training else self.val_idx 72 | batch_size = self.config.batch_size 73 | subj, rel, obj = self.data['subj'], self.data['rel'], self.data['obj'] 74 | num_batches = int(math.floor(float(len(idxs)) / batch_size)) 75 | print('\n\ngenerating %d batches in generation mode' % num_batches) 76 | for b in xrange(num_batches): 77 | idx = idxs[b * batch_size: (b + 1) * batch_size] 78 | sampl_subj, sampl_obj = self.sampler.sample_for_generator(subj[idx], rel[idx], obj[idx], 79 | self.config.num_generator_samples) 80 | yield rel[idx], subj[idx], obj[idx], sampl_subj, sampl_obj 81 | 82 | def generate_sn(self, is_training): 83 | print('\n\ngenerating SN data') 84 | if is_training: 85 | idxs = np.random.permutation(np.arange(len(self.data['sn_rel']))) 86 | data = self.data 87 | subj, rel, obj = data['sn_subj'], data['sn_rel'], data['sn_obj'] 88 | nsubj, nobj = self.sn_sampler.sample(subj, rel, obj) 89 | sn_offset = 0 90 | batch_size = self.config.batch_size / 4 91 | while True: 92 | idx, sn_offset = get_next_k_idxs(idxs, batch_size, sn_offset) 93 | types = np.unique([subj[idx], obj[idx], nsubj[idx], nobj[idx]]) 94 | concepts = np.zeros([len(types), self.config.max_concepts_per_type], dtype=np.int32) 95 | concept_lens = np.zeros([len(types)], dtype=np.int32) 96 | for i, tid in enumerate(types): 97 | concepts_of_type_t = self.type2cuis[tid] if tid in self.type2cuis else [] 98 | random.shuffle(concepts_of_type_t) 99 | concepts_of_type_t = concepts_of_type_t[:self.config.max_concepts_per_type] 100 | concept_lens[i] = len(concepts_of_type_t) 101 | concepts[i, :len(concepts_of_type_t)] = concepts_of_type_t 102 | 103 | yield rel[idx], subj[idx], obj[idx], nsubj[idx], nobj[idx], \ 104 | concepts, concept_lens, types 105 | else: 106 | while True: 107 | yield [0], [0], [0], [0], [0], np.zeros([1, 1000], dtype=np.int32), [1], [0] 108 | 109 | def generate_sn_gen_mode(self, is_training): 110 | print('\n\ngenerating SN data in generation mode') 111 | num_samples = 10 112 | if is_training: 113 | idxs = np.random.permutation(np.arange(len(self.data['sn_rel']))) 114 | subj, rel, obj = self.data['sn_subj'], self.data['sn_rel'], self.data['sn_obj'] 115 | sn_offset = 0 116 | batch_size = self.config.batch_size / 4 117 | while True: 118 | idx, sn_offset = get_next_k_idxs(idxs, batch_size, sn_offset) 119 | subj_ = subj[idx] 120 | rel_ = rel[idx] 121 | obj_ = obj[idx] 122 | subj_samples, obj_samples = self.sn_sampler.sample_for_generator(subj_, rel_, obj_, num_samples) 123 | 124 | types = np.unique(np.concatenate([subj_, obj_, subj_samples.flatten(), obj_samples.flatten()])) 125 | concepts = np.zeros([len(types), self.config.max_concepts_per_type], dtype=np.int32) 126 | concept_lens = np.zeros([len(types)], dtype=np.int32) 127 | for i, tid in enumerate(types): 128 | concepts_of_type_t = self.type2cuis[tid] if tid in self.type2cuis else [] 129 | random.shuffle(concepts_of_type_t) 130 | concepts_of_type_t = concepts_of_type_t[:self.config.max_concepts_per_type] 131 | concept_lens[i] = len(concepts_of_type_t) 132 | concepts[i, :len(concepts_of_type_t)] = concepts_of_type_t 133 | 134 | yield rel_, subj_, obj_, subj_samples, obj_samples, \ 135 | concepts, concept_lens, types 136 | else: 137 | while True: 138 | yield [0], [0], [0], [0, 0], [0, 0], np.zeros([1, 1000], dtype=np.int32), [1], [0] 139 | 140 | def num_train_batches(self): 141 | return int(math.floor(float(len(self.train_idx)) / self.config.batch_size)) 142 | 143 | def num_val_batches(self): 144 | return int(math.floor(float(len(self.val_idx)) / self.config.batch_size)) 145 | 146 | 147 | class NegativeSampler: 148 | def __init__(self, subj=None, rel=None, obj=None, name=None, cachedir="/home/rmm120030/working/umls-mke/.cache", 149 | valid_triples=None): 150 | # cachedir = os.path.join(cachedir, name) 151 | # if os.path.exists(cachedir): 152 | # start = time.time() 153 | # print('loading negative sampler maps from %s' % cachedir) 154 | # self.sr2o = load_dict(os.path.join(cachedir, 'sr2o')) 155 | # self.or2s = load_dict(os.path.join(cachedir, 'or2s')) 156 | # self.concepts = ujson.load(open(os.path.join(cachedir, 'concepts.json'))) 157 | # print('done! Took %.2f seconds' % (time.time() - start)) 158 | # else: 159 | self.sr2o = defaultdict(set) 160 | self.or2s = defaultdict(set) 161 | concepts = set() 162 | triples = zip(subj, rel, obj) if valid_triples is None else valid_triples 163 | for s, r, o in tqdm(triples, desc='building triple maps', total=len(triples)): 164 | # s, r, o = int(s), int(r), int(o) 165 | self.sr2o[(s, r)].add(o) 166 | self.or2s[(o, r)].add(s) 167 | concepts.update([s, o]) 168 | self.concepts = list(concepts) 169 | 170 | # print('\n\ncaching negative sampler maps to %s' % cachedir) 171 | # os.makedirs(cachedir) 172 | # save_dict(self.sr2o, os.path.join(cachedir, 'sr2o')) 173 | # save_dict(self.or2s, os.path.join(cachedir, 'or2s')) 174 | # ujson.dump(self.concepts, open(os.path.join(cachedir, 'concepts.json'), 'w+')) 175 | # print('done!') 176 | 177 | def _neg_sample(self, s_, r_, o_, replace_s): 178 | while True: 179 | c = random.choice(self.concepts) 180 | if replace_s and c not in self.or2s[(o_, r_)]: 181 | return c, o_ 182 | elif not replace_s and c not in self.sr2o[(s_, r_)]: 183 | return s_, c 184 | 185 | def sample(self, subj, rel, obj): 186 | neg_subj = [] 187 | neg_obj = [] 188 | print("\n") 189 | for s, r, o in tqdm(zip(subj, rel, obj), desc='negative sampling', total=len(subj)): 190 | ns, no = self._neg_sample(s, r, o, random.random() > 0.5) 191 | neg_subj.append(ns) 192 | neg_obj.append(no) 193 | 194 | return np.asarray(neg_subj, dtype=np.int32), np.asarray(neg_obj, dtype=np.int32) 195 | 196 | def _sample_k(self, subj, rel, obj, k): 197 | neg_subj = [] 198 | neg_obj = [] 199 | for i in xrange(k): 200 | ns, no = self._neg_sample(subj, rel, obj, random.random() > 0.5) 201 | neg_subj.append(ns) 202 | neg_obj.append(no) 203 | 204 | return neg_subj, neg_obj 205 | 206 | def sample_for_generator(self, subj_array, rel_array, obj_array, k): 207 | subj_samples = [] 208 | obj_samples = [] 209 | for s, r, o in zip(subj_array, rel_array, obj_array): 210 | ns, no = self._sample_k(s, r, o, k) 211 | subj_samples.append(ns) 212 | obj_samples.append(no) 213 | 214 | return np.asarray(subj_samples, dtype=np.int32), np.asarray(obj_samples, dtype=np.int32) 215 | 216 | def invalid_concepts(self, subj, rel, obj, replace_subj): 217 | if replace_subj: 218 | return [c for c in self.concepts if c not in self.or2s[(obj, rel)]] 219 | else: 220 | return [c for c in self.concepts if c not in self.sr2o[(subj, rel)]] 221 | 222 | 223 | def get_next_k_idxs(all_idxs, k, offset): 224 | if offset + k < len(all_idxs): 225 | idx = all_idxs[offset: offset + k] 226 | offset += k 227 | else: 228 | random.shuffle(all_idxs) 229 | offset = k 230 | idx = all_idxs[:offset] 231 | return idx, offset 232 | 233 | 234 | def wrap_generators(mt_gen, sn_gen, is_training): 235 | if is_training: 236 | for mt_batch, sn_batch in izip(mt_gen(True), sn_gen(True)): 237 | yield mt_batch + sn_batch 238 | else: 239 | for b in mt_gen(True): 240 | yield b 241 | 242 | 243 | def save_dict(d, savepath): 244 | keys, values = [], [] 245 | for k, v in d.iteritems(): 246 | keys.append(list(k)) 247 | values.append(list(v)) 248 | 249 | ujson.dump(keys, open(savepath + '_keys.json', 'w+')) 250 | ujson.dump(values, open(savepath + '_values.json', 'w+')) 251 | 252 | 253 | def load_dict(savepath): 254 | keys = ujson.load(open(savepath + '_keys.json')) 255 | values = ujson.load(open(savepath + '_values.json')) 256 | return {tuple(k): set(v) for k, v in izip(keys, values)} 257 | -------------------------------------------------------------------------------- /python/eukg/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-mal/umls-embeddings/1ac29b15f5bedf8edbd2c90d0ccdae094be03518/python/eukg/data/__init__.py -------------------------------------------------------------------------------- /python/eukg/data/create_semnet_triples.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | import os 5 | from tqdm import tqdm 6 | from collections import defaultdict 7 | import numpy as np 8 | 9 | from create_test_set import split 10 | 11 | 12 | def process_mapping(umls_dir, data_dir): 13 | print('Creating mapping of semtypes to cuis and vice versa...') 14 | name2id = json.load(open(os.path.join(data_dir, 'name2id.json'))) 15 | cui2semtypes = defaultdict(list) 16 | semtype2cuis = defaultdict(list) 17 | with open(os.path.join(umls_dir, 'META', 'MRSTY.RRF'), 'r') as f: 18 | reader = csv.reader(f, delimiter='|') 19 | for row in tqdm(reader, desc="reading", total=3395307): 20 | cui = row[0].strip() 21 | if cui in name2id: 22 | cid = name2id[cui] 23 | tui = row[1] 24 | if tui in name2id: 25 | tid = name2id[tui] 26 | else: 27 | tid = len(name2id) 28 | name2id[tui] = tid 29 | cui2semtypes[cid].append(tid) 30 | semtype2cuis[tid].append(cid) 31 | print('Processed type mappings for %d semantic types' % len(semtype2cuis)) 32 | c2s_lens = sorted([len(l) for l in cui2semtypes.values()]) 33 | print('Maximum # semtypes for a cui: %d' % max(c2s_lens)) 34 | print('Average # semtypes for a cui: %.2f' % (float(sum(c2s_lens)) / len(c2s_lens))) 35 | print('Median # semtypes for a cui: %.2f' % c2s_lens[len(c2s_lens)/2]) 36 | 37 | s2c_lens = sorted([len(l) for l in semtype2cuis.values()]) 38 | print('Maximum # cuis for a semtype: %d' % max(s2c_lens)) 39 | print('Average # cuis for a semtype: %.2f' % (float(sum(s2c_lens)) / len(s2c_lens))) 40 | print('Median # cuis for a semtype: %.2f' % s2c_lens[len(s2c_lens)/2]) 41 | print('%% under 1k: %.4f' % (float(len([l for l in s2c_lens if l < 1000])) / len(s2c_lens))) 42 | print('%% under 2k: %.4f' % (float(len([l for l in s2c_lens if l < 2000])) / len(s2c_lens))) 43 | 44 | # json.dump(name2id, open('/home/rmm120030/working/umls-mke/data/name2id.json', 'w+')) 45 | json.dump(cui2semtypes, open(os.path.join(data_dir, 'semnet', 'cui2semtpyes.json'), 'w+')) 46 | json.dump(semtype2cuis, open(os.path.join(data_dir, 'semnet', 'semtype2cuis.json'), 'w+')) 47 | 48 | 49 | def semnet_triples(umls_dir, data_dir): 50 | print('Creating semantic network triples...') 51 | name2id = json.load(open(os.path.join(data_dir, 'name2id.json'))) 52 | 53 | total_relations = 0 54 | new_relations = 0 55 | tui2id = {} 56 | # relations which have a metathesaurus analog are mapped to the MT embedding 57 | with open(os.path.join(umls_dir, 'NET', 'SRDEF')) as f: 58 | reader = csv.reader(f, delimiter='|') 59 | for row in reader: 60 | tui = row[1] 61 | if row[0] == 'RL': 62 | total_relations += 1 63 | rel = row[2] 64 | if rel in name2id: 65 | print('reusing relation embedding for %s' % rel) 66 | name2id[tui] = name2id[rel] 67 | elif tui not in name2id: 68 | new_relations += 1 69 | name2id[tui] = len(name2id) 70 | else: 71 | if tui not in name2id: 72 | name2id[tui] = len(name2id) 73 | tui2id[tui] = name2id[tui] 74 | 75 | print('Created %d of %d new relations' % (new_relations, total_relations)) 76 | print('%d total embeddings' % len(name2id)) 77 | json.dump(name2id, open(os.path.join(data_dir, 'name2id.json'), 'w+')) 78 | json.dump(tui2id, open(os.path.join(data_dir, 'semnet', 'tui2id.json'), 'w+')) 79 | 80 | subj, rel, obj = [], [], [] 81 | with open(os.path.join(umls_dir, 'NET', 'SRSTRE1'), 'r') as f: 82 | reader = csv.reader(f, delimiter='|') 83 | for row in reader: 84 | subj.append(name2id[row[0]]) 85 | rel.append(name2id[row[1]]) 86 | obj.append(name2id[row[2]]) 87 | 88 | print('Saving the %d triples of the semantic network graph' % len(rel)) 89 | split(np.asarray(subj, dtype=np.int32), 90 | np.asarray(rel, dtype=np.int32), 91 | np.asarray(obj, dtype=np.int32), 92 | os.path.join(data_dir, 'semnet'), 93 | 'semnet', 94 | 600) 95 | 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser(description='Extract relation triples into a compressed numpy file from MRCONSO.RRF') 99 | parser.add_argument('umls_dir', help='UMLS directory') 100 | parser.add_argument('--output', default='data', help='the compressed numpy file to be created') 101 | 102 | args = parser.parse_args() 103 | semnet_triples(args.umls_dir, args.output) 104 | process_mapping(args.umls_dir, args.output) 105 | -------------------------------------------------------------------------------- /python/eukg/data/create_test_set.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import os 4 | 5 | 6 | def split(subj, rel, obj, data_dir, graph_name='metathesaurus', num_test=100000): 7 | valid_triples = set() 8 | for s, r, o in zip(subj, rel, obj): 9 | valid_triples.add((s, r, o)) 10 | subj = np.asarray([s for s, _, _ in valid_triples]) 11 | rel = np.asarray([r for _, r, _ in valid_triples]) 12 | obj = np.asarray([o for _, _, o in valid_triples]) 13 | 14 | perm = np.random.permutation(np.arange(len(rel))) 15 | test_idx = perm[:num_test] 16 | train_idx = perm[num_test:] 17 | print('created train/test splits (%d/%d)' % (len(train_idx), len(test_idx))) 18 | 19 | np.savez_compressed(os.path.join(data_dir, graph_name, 'train.npz'), 20 | subj=subj[train_idx], 21 | rel=rel[train_idx], 22 | obj=obj[train_idx]) 23 | print('saved train set') 24 | 25 | np.savez_compressed(os.path.join(data_dir, graph_name, 'test.npz'), 26 | subj=subj[test_idx], 27 | rel=rel[test_idx], 28 | obj=obj[test_idx]) 29 | print('saved test set') 30 | 31 | 32 | def from_train_file(data_dir, graph_name='metathesaurus', num_test=100000): 33 | npz = np.load(os.path.join(data_dir, graph_name, 'triples.npz')) 34 | data = dict(npz.iteritems()) 35 | npz.close() 36 | print('read all data') 37 | split(data['subj'], data['rel'], data['obj'], data_dir, graph_name, num_test) 38 | 39 | 40 | def from_train_test_files(): 41 | data_dir = sys.argv[1] 42 | npz = np.load(os.path.join(data_dir, 'metathesaurus', 'train.npz')) 43 | data = dict(npz.iteritems()) 44 | npz.close() 45 | npz = np.load(os.path.join(data_dir, 'metathesaurus', 'test.npz')) 46 | subj = np.concatenate((data['subj'], npz['subj'])) 47 | rel = np.concatenate((data['rel'], npz['rel'])) 48 | obj = np.concatenate((data['obj'], npz['obj'])) 49 | npz.close() 50 | print('read all data') 51 | valid_triples = set() 52 | for s, r, o in zip(subj, rel, obj): 53 | valid_triples.add((s, r, o)) 54 | subj = np.asarray([s for s, _, _ in valid_triples]) 55 | rel = np.asarray([r for _, r, _ in valid_triples]) 56 | obj = np.asarray([o for _, _, o in valid_triples]) 57 | 58 | perm = np.random.permutation(np.arange(len(rel))) 59 | num_test = 100000 60 | test_idx = perm[:num_test] 61 | train_idx = perm[num_test:] 62 | print('created train/test splits (%d/%d)' % (len(train_idx), len(test_idx))) 63 | 64 | np.savez_compressed(os.path.join(data_dir, 'metathesaurus', 'train.npz'), 65 | subj=subj[train_idx], 66 | rel=rel[train_idx], 67 | obj=obj[train_idx]) 68 | print('saved train set') 69 | 70 | np.savez_compressed(os.path.join(data_dir, 'metathesaurus', 'test.npz'), 71 | subj=subj[test_idx], 72 | rel=rel[test_idx], 73 | obj=obj[test_idx]) 74 | print('saved test set') 75 | 76 | 77 | if __name__ == "__main__": 78 | if len(sys.argv) > 2: 79 | from_train_file(sys.argv[1], sys.argv[2], int(sys.argv[3])) 80 | else: 81 | from_train_file(sys.argv[1]) 82 | -------------------------------------------------------------------------------- /python/eukg/data/create_triples.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import csv 4 | import json 5 | from tqdm import tqdm 6 | import argparse 7 | 8 | from create_test_set import split 9 | 10 | 11 | def metathesaurus_triples(rrf_file, output_dir, valid_relations): 12 | triples = set() 13 | conc2id = {} 14 | 15 | def add_concept(conc): 16 | if conc in conc2id: 17 | cid = conc2id[conc] 18 | else: 19 | cid = len(conc2id) 20 | conc2id[conc] = cid 21 | return cid 22 | 23 | with open(rrf_file, 'r') as f: 24 | reader = csv.reader(f, delimiter='|') 25 | for row in tqdm(reader, desc="reading", total=37207861): 26 | if row[7] in valid_relations: 27 | sid = add_concept(row[0]) 28 | rid = add_concept(row[7]) 29 | oid = add_concept(row[4]) 30 | triples.add((sid, rid, oid)) 31 | 32 | subjs, rels, objs = zip(*triples) 33 | snp = np.asarray(subjs, dtype=np.int32) 34 | rnp = np.asarray(rels, dtype=np.int32) 35 | onp = np.asarray(objs, dtype=np.int32) 36 | 37 | id2conc = {v: k for k, v in conc2id.iteritems()} 38 | concepts = [id2conc[i] for i in np.unique(np.concatenate((subjs, objs)))] 39 | relations = [id2conc[i] for i in set(rels)] 40 | 41 | print("Saving %d unique triples to %s. %d concepts spanning %d relations" % (rnp.shape[0], output_dir, len(concepts), 42 | len(relations))) 43 | 44 | split(snp, rnp, onp, output_dir) 45 | 46 | json.dump(conc2id, open(os.path.join(output_dir, 'name2id.json'), 'w+')) 47 | json.dump(concepts, open(os.path.join(output_dir, 'concept_vocab.json'), 'w+')) 48 | json.dump(relations, open(os.path.join(output_dir, 'relation_vocab.json'), 'w+')) 49 | 50 | 51 | def metathesaurus_triples_trimmed(rrf_file, output_dir, valid_concepts, valid_relations, important_concepts=None): 52 | triples = set() 53 | conc2id = {} 54 | 55 | def add_concept(conc): 56 | if conc in conc2id: 57 | cid = conc2id[conc] 58 | else: 59 | cid = len(conc2id) 60 | conc2id[conc] = cid 61 | return cid 62 | 63 | with open(rrf_file, 'r') as f: 64 | reader = csv.reader(f, delimiter='|') 65 | for row in tqdm(reader, desc="reading", total=37207861): 66 | if (row[0] in valid_concepts and row[4] in valid_concepts and row[7] in valid_relations) or \ 67 | (important_concepts is not None and row[7] != '' and (row[0] in important_concepts or row[4] in important_concepts)): 68 | sid = add_concept(row[0]) 69 | rid = add_concept(row[7]) 70 | oid = add_concept(row[4]) 71 | triples.add((sid, rid, oid)) 72 | 73 | subjs, rels, objs = zip(*triples) 74 | snp = np.asarray(subjs, dtype=np.int32) 75 | rnp = np.asarray(rels, dtype=np.int32) 76 | onp = np.asarray(objs, dtype=np.int32) 77 | 78 | id2conc = {v: k for k, v in conc2id.iteritems()} 79 | concepts = [id2conc[i] for i in np.unique(np.concatenate((subjs, objs)))] 80 | relations = [id2conc[i] for i in rels] 81 | 82 | print("Saving %d unique triples to %s" % (rnp.shape[0], output_dir)) 83 | 84 | np.savez_compressed(os.path.join(output_dir, 'triples'), 85 | subj=snp, 86 | rel=rnp, 87 | obj=onp) 88 | json.dump(conc2id, open(os.path.join(output_dir, 'name2id.json'), 'w+')) 89 | json.dump(concepts, open(os.path.join(concepts, 'concept_vocab.json'), 'w+')) 90 | json.dump(relations, open(os.path.join(relations, 'relation_vocab.json'), 'w+')) 91 | 92 | 93 | def main(): 94 | parser = argparse.ArgumentParser(description='Extract relation triples into a compressed numpy file from MRCONSO.RRF') 95 | parser.add_argument('umls_dir', help='UMLS MRCONSO.RRF file containing metathesaurus relations') 96 | parser.add_argument('--output', default='data', help='the compressed numpy file to be created') 97 | parser.add_argument('--valid_relations', default='data/valid_rels.txt', 98 | help='plaintext list of relations we want to extract triples for, one per line.') 99 | 100 | args = parser.parse_args() 101 | 102 | valid_relations = set([rel.strip() for rel in open(args.valid_relations)]) 103 | 104 | metathesaurus_triples(os.path.join(args.umls_dir, 'META', 'MRCONSO.RRF'), args.output, valid_relations) 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /python/eukg/data/data_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | from tqdm import tqdm 4 | import sys 5 | import math 6 | import os 7 | import json 8 | import random 9 | 10 | 11 | def negative_sampling_from_file(np_file): 12 | npz = np.load(np_file) 13 | subj, rel, obj = npz['subj'], npz['rel'], npz['obj'] 14 | 15 | sampler = NegativeSampler(subj, rel, obj) 16 | neg_subj, neg_obj = sampler.sample(subj, rel, obj) 17 | 18 | npz.close() 19 | np.savez_compressed(np_file, 20 | subj=subj, 21 | rel=rel, 22 | obj=obj, 23 | nsubj=neg_subj, 24 | nobj=neg_obj) 25 | 26 | 27 | def load_metathesaurus_data(data_dir, val_proportion): 28 | random.seed(1337) 29 | 30 | cui2id = json.load(open(os.path.join(data_dir, 'name2id.json'))) 31 | npz = np.load(os.path.join(data_dir, 'metathesaurus', 'train.npz')) 32 | data = dict(npz.iteritems()) 33 | npz.close() 34 | 35 | perm = np.random.permutation(np.arange(len(data['rel']))) 36 | num_val = int(math.ceil(len(perm) * val_proportion)) 37 | val_idx = perm[:num_val] 38 | train_idx = perm[num_val:] 39 | 40 | return cui2id, data, train_idx, val_idx 41 | 42 | 43 | def load_semantic_network_data(data_dir, data_map): 44 | type2cuis = json.load(open(os.path.join(data_dir, 'semnet', 'semtype2cuis.json'))) 45 | npz = np.load(os.path.join(data_dir, 'semnet', 'triples.npz')) 46 | for key, val in npz.iteritems(): 47 | data_map['sn_' + key] = val 48 | npz.close() 49 | 50 | return type2cuis 51 | 52 | 53 | def load_metathesaurus_test_data(data_dir): 54 | npz = np.load(os.path.join(data_dir, 'metathesaurus', 'test.npz')) 55 | data = dict(npz.iteritems()) 56 | npz.close() 57 | 58 | return data 59 | 60 | 61 | def save_config(outdir, config): 62 | print('saving config to %s' % outdir) 63 | with open('%s/config.json' % outdir, 'w+') as f: 64 | json.dump(config.flag_values_dict(), f) 65 | 66 | 67 | def main(): 68 | negative_sampling_from_file(sys.argv[1]) 69 | 70 | 71 | class NegativeSampler: 72 | def __init__(self, subj, rel, obj, name, cachedir="/tmp"): 73 | cachedir = os.path.join(cachedir, name) 74 | # if os.path.exists(cachedir): 75 | # print('loading negative sampler maps from %s' % cachedir) 76 | # self.sr2o = defaultdict(list, json.load(open(os.path.join(cachedir, 'sr2o.json')))) 77 | # self.or2s = defaultdict(list, json.load(open(os.path.join(cachedir, 'or2s.json')))) 78 | # self.concepts = json.load(open(os.path.join(cachedir, 'concepts.json'))) 79 | # else: 80 | self.sr2o = defaultdict(set) 81 | self.or2s = defaultdict(set) 82 | concepts = set() 83 | print("\n") 84 | for s, r, o in tqdm(zip(subj, rel, obj), desc='building triple maps', total=len(subj)): 85 | self.sr2o[(s, r)].add(o) 86 | self.or2s[(o, r)].add(s) 87 | concepts.update([s, o]) 88 | self.concepts = list(concepts) 89 | 90 | # os.makedirs(cachedir) 91 | # json.dump({k: list(v) for k, v in self.sr2o.iteritems()}, open(os.path.join(cachedir, 'sr2o.json'), 'w+')) 92 | # json.dump({k: list(v) for k, v in self.or2s.iteritems()}, open(os.path.join(cachedir, 'or2s.json'), 'w+')) 93 | # json.dump(self.concepts, open(os.path.join(cachedir, 'concepts.json'), 'w+')) 94 | 95 | def _neg_sample(self, s_, r_, o_, replace_s): 96 | while True: 97 | c = random.choice(self.concepts) 98 | if replace_s and c not in self.or2s[(o_, r_)]: 99 | return c, o_ 100 | elif not replace_s and c not in self.sr2o[(s_, r_)]: 101 | return s_, c 102 | 103 | def sample(self, subj, rel, obj): 104 | neg_subj = [] 105 | neg_obj = [] 106 | print("\n") 107 | for s, r, o in tqdm(zip(subj, rel, obj), desc='negative sampling', total=len(subj)): 108 | ns, no = self._neg_sample(s, r, o, random.random > 0.5) 109 | neg_subj.append(ns) 110 | neg_obj.append(no) 111 | 112 | return np.asarray(neg_subj, dtype=np.int32), np.asarray(neg_obj, dtype=np.int32) 113 | 114 | def _sample_k(self, subj, rel, obj, k): 115 | neg_subj = [] 116 | neg_obj = [] 117 | for i in xrange(k): 118 | ns, no = self._neg_sample(subj, rel, obj, random.random > 0.5) 119 | neg_subj.append(ns) 120 | neg_obj.append(no) 121 | 122 | return neg_subj, neg_obj 123 | 124 | def sample_for_generator(self, subj_array, rel_array, obj_array, k): 125 | subj_samples = [] 126 | obj_samples = [] 127 | for s, r, o in zip(subj_array, rel_array, obj_array): 128 | ns, no = self._sample_k(s, r, o, k) 129 | subj_samples.append(ns) 130 | obj_samples.append(no) 131 | 132 | return np.asarray(subj_samples), np.asarray(obj_samples) 133 | 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /python/eukg/emb/EmbeddingModel.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow.contrib import layers 4 | 5 | 6 | class BaseModel: 7 | def __init__(self, config): 8 | self.embedding_size = config.embedding_size 9 | self.embedding_device = config.embedding_device 10 | self.vocab_size = config.vocab_size 11 | 12 | self.embeddings = None 13 | self.ids_to_update = tf.placeholder(dtype=tf.int32, shape=[None], name='used_vectors') 14 | self.norm_op = None 15 | 16 | def energy(self, head, rel, tail, norm_ord='euclidean'): 17 | """ 18 | Calculates the energies of the batch of triples corresponding to head, rel and tail. 19 | Energy should be in (-inf, 0] 20 | :param head: [batch_size] vector of head entity ids 21 | :param rel: [batch_size] vector of relation ids 22 | :param tail: [batch_size] vector of tail entity ids 23 | :param norm_ord: order of the normalization function 24 | :return: [batch_size] vector of energies of the passed triples 25 | """ 26 | raise NotImplementedError("subclass should implement") 27 | 28 | def embedding_lookup(self, ids): 29 | """ 30 | returns embedding vectors or tuple of embedding vectors for the passed ids 31 | :param ids: ids of embedding vectors in an embedding matrix 32 | :return: embedding vectors or tuple of embedding vectors for the passed ids 33 | """ 34 | raise NotImplementedError("subclass should implement") 35 | 36 | def normalize_parameters(self): 37 | """ 38 | Returns the op that enforces the constraint that each embedding vector have a norm <= 1 39 | :return: the op that enforces the constraint that each embedding vector have a norm <= 1 40 | """ 41 | raise NotImplementedError("subclass should implement") 42 | 43 | # noinspection PyMethodMayBeStatic 44 | def normalize(self, rows, mat, ids): 45 | """ 46 | Normalizes the rows of the matrix mat corresponding to ids s.t. |row|_2 <= 1 for each row. 47 | :param rows: rows from the matrix mat, embedding vectors 48 | :param mat: matrix of embedding vectors 49 | :param ids: ids of rows in mat 50 | :return: the scatter update op that updates only the rows of mat corresponding to ids 51 | """ 52 | norm = tf.norm(rows) 53 | scaling = 1. / tf.maximum(norm, 1.) 54 | scaled = scaling * rows 55 | return tf.scatter_update(mat, ids, scaled) 56 | 57 | 58 | class TransE(BaseModel): 59 | def __init__(self, config, embeddings_dict=None): 60 | BaseModel.__init__(self, config) 61 | with tf.device("/%s:0" % self.embedding_device): 62 | if embeddings_dict is None: 63 | self.embeddings = tf.get_variable("embeddings", 64 | shape=[self.vocab_size, self.embedding_size], 65 | dtype=tf.float32, 66 | initializer=layers.xavier_initializer()) 67 | else: 68 | self.embeddings = tf.Variable(embeddings_dict['embs'], name="embeddings") 69 | 70 | def energy(self, head, rel, tail, norm_ord='euclidean'): 71 | h = self.embedding_lookup(head) 72 | r = self.embedding_lookup(rel) 73 | t = self.embedding_lookup(tail) 74 | 75 | return tf.norm(h + r - t, 76 | ord=norm_ord, 77 | axis=-1, 78 | keepdims=False, 79 | name='energy') 80 | 81 | def embedding_lookup(self, ids): 82 | with tf.device("/%s:0" % self.embedding_device): 83 | return tf.nn.embedding_lookup(self.embeddings, ids) 84 | 85 | def normalize_parameters(self): 86 | """ 87 | Enforces the contraint that the embedding vectors corresponding to ids_to_update <= 1.0 88 | """ 89 | params1 = self.embedding_lookup(self.ids_to_update) 90 | self.norm_op = self.normalize(params1, self.embeddings, self.ids_to_update) 91 | 92 | return self.norm_op 93 | 94 | 95 | class TransD(BaseModel): 96 | def __init__(self, config, embeddings_dict=None): 97 | BaseModel.__init__(self, config) 98 | with tf.device("/%s:0" % self.embedding_device): 99 | if embeddings_dict is None: 100 | print('Initializing embeddings.') 101 | self.embeddings = tf.get_variable("embeddings", 102 | shape=[self.vocab_size, self.embedding_size], # 364373 103 | dtype=tf.float32, 104 | initializer=layers.xavier_initializer()) 105 | else: 106 | print('Loading embeddings.') 107 | self.embeddings = tf.Variable(embeddings_dict['embs'], name="embeddings") 108 | with tf.device("/%s:%d" % (self.embedding_device, 0 if self.embedding_device == 'cpu' else 0)): 109 | if embeddings_dict is None or 'p_embs' not in embeddings_dict: 110 | print('Initializing projection embeddings.') 111 | if config.p_init == 'zeros': 112 | p_init = tf.initializers.zeros() 113 | elif config.p_init == 'xavier': 114 | p_init = layers.xavier_initializer() 115 | elif config.p_init == 'uniform': 116 | p_init = tf.initializers.random_uniform(minval=-0.1, maxval=0.1, dtype=tf.float32) 117 | else: 118 | raise Exception('unrecognized p initializer: %s' % config.p_init) 119 | 120 | # projection embeddings initialized to zeros 121 | self.p_embeddings = tf.get_variable("p_embeddings", 122 | shape=[self.vocab_size, self.embedding_size], 123 | dtype=tf.float32, 124 | initializer=p_init) 125 | else: 126 | print('Loading projection embeddings.') 127 | self.p_embeddings = tf.Variable(embeddings_dict['p_embs'], name="p_embeddings") 128 | 129 | def energy(self, head, rel, tail, norm_ord='euclidean'): 130 | """ 131 | Computes the TransD energy of a relation triple 132 | :param head: head concept embedding ids [batch_size] 133 | :param rel: relation embedding ids [batch_size] 134 | :param tail: tail concept embedding ids [batch_size] 135 | :param norm_ord: norm order ['euclidean', 'fro', 'inf', 1, 2, 3, etc.] 136 | :return: [batch_size] vector of energies 137 | """ 138 | # x & x_proj both [batch_size, embedding_size] 139 | h, h_proj = self.embedding_lookup(head) 140 | r, r_proj = self.embedding_lookup(rel) 141 | t, t_proj = self.embedding_lookup(tail) 142 | 143 | # [batch_size] 144 | return tf.norm(self.project(h, h_proj, r_proj) + r - self.project(t, t_proj, r_proj), 145 | ord=norm_ord, 146 | axis=1, 147 | keepdims=False, 148 | name="energy") 149 | 150 | # noinspection PyMethodMayBeStatic 151 | def project(self, c, c_proj, r_proj): 152 | """ 153 | Computes the projected concept embedding for relation r according to TransD: 154 | (c_proj^T*c)*r_proj + c 155 | :param c: concept embeddings [batch_size, embedding_size] 156 | :param c_proj: concept projection embeddings [batch_size, embedding_size] 157 | :param r_proj: relation projection embeddings [batch_size, embedding_size] 158 | :return: projected concept embedding [batch_size, embedding_size] 159 | """ 160 | return c + tf.reduce_sum(c * c_proj, axis=-1, keepdims=True) * r_proj 161 | 162 | def embedding_lookup(self, ids): 163 | with tf.device("/%s:0" % self.embedding_device): 164 | params1 = tf.nn.embedding_lookup(self.embeddings, ids) 165 | with tf.device("/%s:%d" % (self.embedding_device, 1 if self.embedding_device == 'cpu' else 0)): 166 | params2 = tf.nn.embedding_lookup(self.p_embeddings, ids) 167 | return params1, params2 168 | 169 | def normalize_parameters(self): 170 | """ 171 | Normalizes the vectors of embeddings corresponding to the passed ids 172 | :return: the normalization op 173 | """ 174 | with tf.device("/%s:0" % self.embedding_device): 175 | params1 = tf.nn.embedding_lookup(self.embeddings, self.ids_to_update) 176 | with tf.device("/%s:%d" % (self.embedding_device, 1 if self.embedding_device == 'cpu' else 0)): 177 | params2 = tf.nn.embedding_lookup(self.p_embeddings, self.ids_to_update) 178 | 179 | n1 = self.normalize(params1, self.embeddings, self.ids_to_update) 180 | n2 = self.normalize(params2, self.p_embeddings, self.ids_to_update) 181 | self.norm_op = n1, n2 182 | 183 | return self.norm_op 184 | 185 | 186 | class DistMult(TransE): 187 | def __init__(self, config, embeddings_dict=None): 188 | BaseModel.__init__(self, config) 189 | if config.energy_activation == 'relu': 190 | self.energy_activation = tf.nn.relu 191 | elif config.energy_activation == 'tanh': 192 | self.energy_activation = tf.nn.tanh 193 | elif config.energy_activation == 'sigmoid': 194 | self.energy_activation = tf.nn.sigmoid 195 | elif config.energy_activation is None: 196 | self.energy_activation = lambda x: x 197 | else: 198 | raise Exception('Unrecognized activation: %s' % config.energy_activation) 199 | 200 | with tf.device("/%s:0" % self.embedding_device): 201 | if embeddings_dict is None: 202 | self.embeddings = tf.get_variable("embeddings", 203 | shape=[self.vocab_size, self.embedding_size], 204 | dtype=tf.float32, 205 | initializer=tf.initializers.random_uniform(minval=-0.5, 206 | maxval=0.5, 207 | dtype=tf.float32)) 208 | else: 209 | self.embeddings = tf.Variable(embeddings_dict['embs'], name="embeddings") 210 | 211 | def energy(self, head, rel, tail, norm_ord='euclidean'): 212 | h = self.embedding_lookup(head) 213 | r = self.embedding_lookup(rel) 214 | t = self.embedding_lookup(tail) 215 | 216 | return self.energy_activation(tf.reduce_sum(h * r * t, 217 | axis=-1, 218 | keepdims=False)) 219 | 220 | def normalize_parameters(self): 221 | return tf.no_op() 222 | 223 | def regularization(self, parameters): 224 | reg_term = 0 225 | for p in parameters: 226 | reg_term += tf.reduce_sum(tf.norm(self.embedding_lookup(p))) 227 | return reg_term 228 | -------------------------------------------------------------------------------- /python/eukg/emb/Smoothing.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # from .BaseModel import BaseModel 4 | 5 | 6 | def add_semantic_network_loss(model): 7 | """ 8 | Adds the semantic network loss to the graph 9 | :param model: 10 | :type model: BaseModel.BaseModel 11 | :return: the semantic network loss - 1D real valued tensor 12 | """ 13 | print('Adding semantic network to graph') 14 | # dataset3: sr, ss, so, sns, sno, t, conc, counts 15 | model.smoothing_placeholders['sn_relations'] = rel = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_relations') 16 | model.smoothing_placeholders['sn_pos_subj'] = psubj = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_pos_subj') 17 | model.smoothing_placeholders['sn_pos_obj'] = pobj = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_pos_obj') 18 | model.smoothing_placeholders['sn_neg_subj'] = nsubj = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_neg_subj') 19 | model.smoothing_placeholders['sn_neg_obj'] = nobj = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_neg_obj') 20 | model.smoothing_placeholders['sn_types'] = types = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_types') 21 | model.smoothing_placeholders['sn_concepts'] = concepts = tf.placeholder(dtype=tf.int32, 22 | shape=[None, model.max_concepts_per_type], 23 | name='sn_concepts') 24 | model.smoothing_placeholders['sn_conc_counts'] = counts = tf.placeholder(dtype=tf.int32, 25 | shape=[None], 26 | name='sn_conc_counts') 27 | 28 | with tf.variable_scope("energy"): 29 | model.sn_pos_energy = pos_energy = model.embedding_model.energy(psubj, rel, pobj, model.energy_norm) 30 | with tf.variable_scope("energy", reuse=True): 31 | model.sn_neg_energy = neg_energy = model.embedding_model.energy(nsubj, rel, nobj, model.energy_norm) 32 | energy_loss = tf.reduce_mean(tf.nn.relu(model.gamma - neg_energy + pos_energy)) 33 | model.sn_reward = -tf.reduce_mean(neg_energy, name='sn_reward') 34 | 35 | # [batch_size, embedding_size] 36 | type_embeddings = model.embedding_model.embedding_lookup(types) 37 | # [batch_size, max_concepts_per_type, embedding_size] 38 | concepts_embeddings = model.embedding_model.embedding_lookup(concepts) 39 | 40 | def calc_alignment_loss(_type_embeddings, _concept_embeddings): 41 | mask = tf.expand_dims(tf.sequence_mask(counts, maxlen=model.max_concepts_per_type, dtype=tf.float32), axis=-1) 42 | sum_ = tf.reduce_sum(mask * _concept_embeddings, axis=1, keepdims=False) 43 | float_counts = tf.to_float(tf.expand_dims(counts, axis=-1)) 44 | avg_conc_embeddings = sum_ / tf.maximum(float_counts, tf.ones_like(float_counts)) 45 | return tf.reduce_mean(tf.abs(_type_embeddings - avg_conc_embeddings)) 46 | 47 | if isinstance(type_embeddings, tuple): 48 | alignment_loss = 0 49 | for type_embeddings_i, concepts_embeddings_i in zip(type_embeddings, concepts_embeddings): 50 | alignment_loss += calc_alignment_loss(type_embeddings_i, concepts_embeddings_i) 51 | else: 52 | alignment_loss = calc_alignment_loss(type_embeddings, concepts_embeddings) 53 | 54 | # summary 55 | tf.summary.scalar('sn_energy_loss', energy_loss) 56 | tf.summary.scalar('sn_alignment_loss', alignment_loss) 57 | # tf.summary.scalar('sn_accuracy', accuracy) 58 | avg_pos_energy = tf.reduce_mean(pos_energy) 59 | tf.summary.scalar('sn_pos_energy', avg_pos_energy) 60 | avg_neg_energy = tf.reduce_mean(neg_energy) 61 | tf.summary.scalar('sn_neg_energy', avg_neg_energy) 62 | 63 | # is this loss? 64 | loss = model.semnet_energy_param * energy_loss + model.semnet_alignment_param * alignment_loss 65 | return loss 66 | 67 | 68 | def add_gen_semantic_network(model): 69 | """ 70 | Adds the semantic network loss to the graph 71 | :param model: 72 | :type model: ..gan.Generator.Generator 73 | :return: the semantic network loss - 1D real valued tensor 74 | """ 75 | print('Adding semantic network to graph') 76 | # dataset4: sr, ss, so, sns, sno, t, conc, counts 77 | model.smoothing_placeholders['sn_relations'] = rel = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_relations') 78 | model.smoothing_placeholders['sn_pos_subj'] = psubj = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_pos_subj') 79 | model.smoothing_placeholders['sn_pos_obj'] = pobj = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_pos_obj') 80 | model.smoothing_placeholders['sn_neg_subj'] = nsubj = tf.placeholder(dtype=tf.int32, shape=[None, None], 81 | name='sn_neg_subj') 82 | model.smoothing_placeholders['sn_neg_obj'] = nobj = tf.placeholder(dtype=tf.int32, shape=[None, None], 83 | name='sn_neg_obj') 84 | model.smoothing_placeholders['sn_concepts'] = concepts = tf.placeholder(dtype=tf.int32, 85 | shape=[None, model.max_concepts_per_type], 86 | name='sn_concepts') 87 | model.smoothing_placeholders['sn_conc_counts'] = counts = tf.placeholder(dtype=tf.int32, 88 | shape=[None], 89 | name='sn_conc_counts') 90 | model.smoothing_placeholders['sn_types'] = types = tf.placeholder(dtype=tf.int32, shape=[None], name='sn_types') 91 | 92 | true_energy = model.embedding_model.energy(psubj, rel, pobj) 93 | sampl_energy = model.embedding_model.energy(nsubj, tf.expand_dims(rel, axis=-1), nobj, model.energy_norm) 94 | if model.gan_mode: 95 | model.sampl_distributions = tf.nn.softmax(-sampl_energy, axis=-1) 96 | model.type_probabilities = tf.gather_nd(model.sampl_distributions, model.gan_loss_sample, name='sampl_probs') 97 | else: 98 | sm_numerator = tf.exp(-true_energy) 99 | exp_sampl_energies = tf.exp(-sampl_energy) 100 | sm_denominator = tf.reduce_sum(exp_sampl_energies, axis=-1) + sm_numerator 101 | model.type_probabilities = sm_numerator / sm_denominator 102 | energy_loss = -tf.reduce_mean(tf.log(model.type_probabilities)) 103 | 104 | # [batch_size, embedding_size] 105 | type_embeddings = model.embedding_model.embedding_lookup(types) 106 | # [batch_size, max_concepts_per_type, embedding_size] 107 | concepts_embeddings = model.embedding_model.embedding_lookup(concepts) 108 | 109 | def calc_alignment_loss(_type_embeddings, _concept_embeddings): 110 | mask = tf.expand_dims(tf.sequence_mask(counts, maxlen=model.max_concepts_per_type, dtype=tf.float32), axis=-1) 111 | sum_ = tf.reduce_sum(mask * _concept_embeddings, axis=1, keepdims=False) 112 | float_counts = tf.to_float(tf.expand_dims(counts, axis=-1)) 113 | avg_conc_embeddings = sum_ / tf.maximum(float_counts, tf.ones_like(float_counts)) 114 | return tf.reduce_mean(tf.abs(_type_embeddings - avg_conc_embeddings)) 115 | 116 | if isinstance(type_embeddings, tuple): 117 | alignment_loss = 0 118 | for type_embeddings_i, concepts_embeddings_i in zip(type_embeddings, concepts_embeddings): 119 | alignment_loss += calc_alignment_loss(type_embeddings_i, concepts_embeddings_i) 120 | else: 121 | alignment_loss = calc_alignment_loss(type_embeddings, concepts_embeddings) 122 | 123 | return energy_loss, alignment_loss 124 | -------------------------------------------------------------------------------- /python/eukg/emb/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home/rmm120030/code/python/tf_util') 3 | -------------------------------------------------------------------------------- /python/eukg/gan/Discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from ..data import DataGenerator 5 | from ..emb import Smoothing 6 | 7 | from tf_util import Trainable 8 | 9 | 10 | class BaseModel(Trainable): 11 | def __init__(self, config, embedding_model, data_generator=None): 12 | """ 13 | :param config: config map 14 | :param embedding_model: KG Embedding Model 15 | :type embedding_model: EmbeddingModel.BaseModel 16 | :param data_generator: data generator 17 | :type data_generator: DataGenerator.DataGenerator 18 | """ 19 | Trainable.__init__(self) 20 | self.model = config.model 21 | self.embedding_model = embedding_model 22 | self.data_generator = data_generator 23 | 24 | # class variable declarations 25 | self.batch_size = config.batch_size 26 | # self.embedding_size = config.embedding_size 27 | self.vocab_size = config.vocab_size 28 | self.gamma = config.gamma 29 | self.embedding_device = config.embedding_device 30 | self.max_concepts_per_type = config.max_concepts_per_type 31 | # self.is_training = is_training 32 | self.energy_norm = config.energy_norm_ord 33 | self.use_semantic_network = not config.no_semantic_network 34 | self.semnet_alignment_param = config.semnet_alignment_param 35 | self.semnet_energy_param = config.semnet_energy_param 36 | self.regulatization_parameter = config.regularization_parameter 37 | 38 | # optimization 39 | self.learning_rate = tf.train.exponential_decay(config.learning_rate, 40 | tf.train.get_or_create_global_step(), 41 | config.batches_per_epoch, 42 | config.decay_rate, 43 | staircase=True) 44 | if config.optimizer == "adam": 45 | self.optimizer = lambda: tf.train.AdamOptimizer(self.learning_rate) 46 | else: 47 | self.optimizer = lambda: tf.train.MomentumOptimizer(self.learning_rate, 48 | config.momentum, 49 | use_nesterov=True) 50 | 51 | # input placeholders 52 | self.relations = tf.placeholder(dtype=tf.int32, shape=[None], name='relations') 53 | self.pos_subj = tf.placeholder(dtype=tf.int32, shape=[None], name='pos_subj') 54 | self.pos_obj = tf.placeholder(dtype=tf.int32, shape=[None], name='pos_obj') 55 | self.neg_subj = tf.placeholder(dtype=tf.int32, shape=[None], name='neg_subj') 56 | self.neg_obj = tf.placeholder(dtype=tf.int32, shape=[None], name='neg_obj') 57 | self.labels = tf.ones_like(self.relations, dtype=tf.int32) 58 | self.smoothing_placeholders = {} 59 | 60 | self.pos_energy = None 61 | self.neg_energy = None 62 | self.predictions = None 63 | self.reward = None 64 | self.sn_reward = None 65 | self.loss = None 66 | self.streaming_accuracy = None 67 | self.accuracy = None 68 | self.avg_pos_energy = None 69 | self.avg_neg_energy = None 70 | self.summary = None 71 | self.train_op = None 72 | 73 | # define reset op (for resetting counts for streaming metrics after each validation epoch) 74 | self.reset_streaming_metrics_op = tf.variables_initializer(tf.local_variables()) 75 | 76 | # define norm op 77 | self.norm_op = self.embedding_model.normalize_parameters() 78 | self.ids_to_update = self.embedding_model.ids_to_update 79 | 80 | def build(self): 81 | # energies 82 | with tf.variable_scope("energy"): 83 | self.pos_energy = self.embedding_model.energy(self.pos_subj, self.relations, self.pos_obj, self.energy_norm) 84 | with tf.variable_scope("energy", reuse=True): 85 | self.neg_energy = self.embedding_model.energy(self.neg_subj, self.relations, self.neg_obj, self.energy_norm) 86 | self.predictions = tf.argmax(tf.stack([self.pos_energy, self.neg_energy], axis=1), axis=1, output_type=tf.int32) 87 | self.reward = tf.reduce_mean(self.neg_energy, name='reward') 88 | 89 | # loss 90 | self.loss = tf.reduce_mean(tf.nn.relu(self.gamma - self.neg_energy + self.pos_energy), name='loss') 91 | 92 | if self.model == 'distmult': 93 | reg = self.regulatization_parameter * self.embedding_model.regularization([self.pos_subj, self.pos_obj, 94 | self.neg_subj, self.neg_obj, 95 | self.relations]) 96 | tf.summary.scalar('reg', reg) 97 | tf.summary.scalar('margin_loss', self.loss) 98 | self.loss += reg 99 | 100 | if self.use_semantic_network: 101 | semnet_loss = Smoothing.add_semantic_network_loss(self) 102 | self.loss += semnet_loss 103 | tf.summary.scalar('sn_loss', semnet_loss) 104 | # backprop 105 | self.train_op = self.optimizer().minimize(self.loss, tf.train.get_or_create_global_step()) 106 | 107 | # summary 108 | tf.summary.scalar('loss', self.loss) 109 | _, self.streaming_accuracy = tf.metrics.accuracy(labels=self.labels, predictions=self.predictions) 110 | tf.summary.scalar('streaming_accuracy', self.streaming_accuracy) 111 | self.accuracy = tf.reduce_mean(tf.to_float(tf.equal(self.predictions, self.labels))) 112 | tf.summary.scalar('accuracy', self.accuracy) 113 | self.avg_pos_energy = tf.reduce_mean(self.pos_energy) 114 | tf.summary.scalar('pos_energy', self.avg_pos_energy) 115 | self.avg_neg_energy = tf.reduce_mean(self.neg_energy) 116 | tf.summary.scalar('neg_energy', self.avg_neg_energy) 117 | tf.summary.scalar('margin', self.avg_neg_energy - self.avg_pos_energy) 118 | self.summary = tf.summary.merge_all() 119 | 120 | def fetches(self, is_training, verbose=False): 121 | fetches = [self.summary, self.loss] 122 | if verbose: 123 | if is_training: 124 | fetches += [self.accuracy] 125 | else: 126 | fetches += [self.streaming_accuracy] 127 | fetches += [self.avg_pos_energy, self.avg_neg_energy] 128 | if is_training: 129 | fetches += [self.train_op] 130 | return fetches 131 | 132 | def prepare_feed_dict(self, batch, is_training, **kwargs): 133 | # return {} 134 | if self.use_semantic_network: 135 | if is_training: 136 | rel, psub, pobj, nsub, nobj, sn_rel, sn_psub, sn_pobj, sn_nsub, sn_nobj, conc, c_lens, types = batch 137 | return {self.relations: rel, 138 | self.pos_subj: psub, 139 | self.pos_obj: pobj, 140 | self.neg_subj: nsub, 141 | self.neg_obj: nobj, 142 | self.smoothing_placeholders['sn_relations']: sn_rel, 143 | self.smoothing_placeholders['sn_pos_subj']: sn_psub, 144 | self.smoothing_placeholders['sn_pos_obj']: sn_pobj, 145 | self.smoothing_placeholders['sn_neg_subj']: sn_nsub, 146 | self.smoothing_placeholders['sn_neg_obj']: sn_nobj, 147 | self.smoothing_placeholders['sn_concepts']: conc, 148 | self.smoothing_placeholders['sn_conc_counts']: c_lens, 149 | self.smoothing_placeholders['sn_types']: types} 150 | else: 151 | rel, psub, pobj, nsub, nobj = batch 152 | return {self.relations: rel, 153 | self.pos_subj: psub, 154 | self.pos_obj: pobj, 155 | self.neg_subj: nsub, 156 | self.neg_obj: nobj, 157 | self.smoothing_placeholders['sn_relations']: [0], 158 | self.smoothing_placeholders['sn_pos_subj']: [0], 159 | self.smoothing_placeholders['sn_pos_obj']: [0], 160 | self.smoothing_placeholders['sn_neg_subj']: [0], 161 | self.smoothing_placeholders['sn_neg_obj']: [0], 162 | self.smoothing_placeholders['sn_concepts']: np.zeros([1, 1000], dtype=np.int32), 163 | self.smoothing_placeholders['sn_conc_counts']: [1], 164 | self.smoothing_placeholders['sn_types']: [0]} 165 | else: 166 | rel, psub, pobj, nsub, nobj = batch 167 | return {self.relations: rel, 168 | self.pos_subj: psub, 169 | self.pos_obj: pobj, 170 | self.neg_subj: nsub, 171 | self.neg_obj: nobj} 172 | 173 | def progress_update(self, batch, fetched, **kwargs): 174 | print('Avg loss of last batch: %.4f' % np.average(fetched[1])) 175 | print('Accuracy of last batch: %.4f' % np.average(fetched[2])) 176 | print('Avg pos energy of last batch: %.4f' % np.average(fetched[3])) 177 | print('Avg neg energy of last batch: %.4f' % np.average(fetched[4])) 178 | 179 | def data_provider(self, config, is_training, **kwargs): 180 | if self.use_semantic_network: 181 | return DataGenerator.wrap_generators(self.data_generator.generate_mt, 182 | self.data_generator.generate_sn, is_training) 183 | else: 184 | return self.data_generator.generate_mt(is_training) 185 | -------------------------------------------------------------------------------- /python/eukg/gan/Generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from Discriminator import BaseModel 5 | from ..emb import Smoothing 6 | from ..data import DataGenerator 7 | 8 | 9 | class Generator(BaseModel): 10 | def __init__(self, config, embedding_model, data_generator=None): 11 | BaseModel.__init__(self, config, embedding_model, data_generator) 12 | self.num_samples = config.num_generator_samples 13 | self.gan_mode = False 14 | 15 | # dataset2: s, r, o, ns, no 16 | self.neg_subj = tf.placeholder(dtype=tf.int32, shape=[None, self.num_samples], name="neg_subj") 17 | self.neg_obj = tf.placeholder(dtype=tf.int32, shape=[None, self.num_samples], name="neg_obj") 18 | self.discounted_reward = tf.placeholder(dtype=tf.float32, shape=[], name="discounted_reward") 19 | self.gan_loss_sample = tf.placeholder(dtype=tf.int32, shape=[None, 2], name="gan_loss_sample") 20 | 21 | self.probabilities = None 22 | self.sampl_energies = None 23 | self.true_energies = None 24 | self.probability_distributions = None 25 | 26 | # semantic network vars 27 | self.type_probabilities = None 28 | 29 | def build(self): 30 | summary = [] 31 | # [batch_size, num_samples] 32 | self.sampl_energies = self.embedding_model.energy(self.neg_subj, 33 | tf.expand_dims(self.relations, axis=1), 34 | self.neg_obj) 35 | # [batch_size] 36 | self.true_energies = self.embedding_model.energy(self.pos_subj, 37 | self.relations, 38 | self.pos_obj) 39 | 40 | # backprop 41 | optimizer = self.optimizer() 42 | # [batch_size] 43 | sm_numerator = tf.exp(-self.true_energies) 44 | # [batch_size] 45 | exp_sampl_nergies = tf.exp(-self.sampl_energies) 46 | sm_denominator = tf.reduce_sum(exp_sampl_nergies, axis=-1) + sm_numerator 47 | # [batch_size] 48 | self.probabilities = sm_numerator / sm_denominator 49 | self.loss = -tf.reduce_mean(tf.log(self.probabilities)) 50 | 51 | # regularization for distmult 52 | if self.model == "distmult": 53 | reg = self.regulatization_parameter * self.embedding_model.regularization([self.pos_subj, self.pos_obj, 54 | self.neg_subj, self.neg_obj, 55 | self.relations]) 56 | summary += [tf.summary.scalar('reg', reg), 57 | tf.summary.scalar('log_prob', self.loss)] 58 | self.loss += reg 59 | 60 | if self.use_semantic_network: 61 | sn_energy_loss, sn_alignment_loss = Smoothing.add_gen_semantic_network(self) 62 | self.loss += self.semnet_energy_param * sn_energy_loss + self.semnet_alignment_param * sn_alignment_loss 63 | summary += [tf.summary.scalar('sn_energy_loss', sn_energy_loss / self.batch_size), 64 | tf.summary.scalar('sn_alignment_loss', sn_alignment_loss / self.batch_size)] 65 | 66 | self.avg_pos_energy = tf.reduce_mean(self.true_energies) 67 | self.avg_neg_energy = tf.reduce_mean(self.sampl_energies) 68 | summary += [tf.summary.scalar('loss', self.loss), 69 | tf.summary.scalar('avg_prob', tf.reduce_mean(self.probabilities)), 70 | tf.summary.scalar('min_prob', tf.reduce_min(self.probabilities)), 71 | tf.summary.scalar('max_prob', tf.reduce_max(self.probabilities)), 72 | tf.summary.scalar('pos_energy', self.avg_pos_energy), 73 | tf.summary.scalar('neg_energy', self.avg_neg_energy), 74 | tf.summary.scalar('margin', self.avg_pos_energy - self.avg_neg_energy)] 75 | self.train_op = optimizer.minimize(self.loss, tf.train.get_or_create_global_step()) 76 | 77 | # summary 78 | self.summary = tf.summary.merge(summary) 79 | 80 | def fetches(self, is_training, verbose=False): 81 | fetches = [self.summary, self.loss] 82 | if verbose: 83 | fetches += [self.probabilities, self.avg_pos_energy, self.avg_neg_energy] 84 | if is_training: 85 | fetches += [self.train_op] 86 | return fetches 87 | 88 | def prepare_feed_dict(self, batch, is_training, **kwargs): 89 | if self.use_semantic_network: 90 | if is_training: 91 | rel, psub, pobj, nsub, nobj, sn_rel, sn_psub, sn_pobj, sn_nsub, sn_nobj, conc, c_lens, types = batch 92 | return {self.relations: rel, 93 | self.pos_subj: psub, 94 | self.pos_obj: pobj, 95 | self.neg_subj: nsub, 96 | self.neg_obj: nobj, 97 | self.smoothing_placeholders['sn_relations']: sn_rel, 98 | self.smoothing_placeholders['sn_pos_subj']: sn_psub, 99 | self.smoothing_placeholders['sn_pos_obj']: sn_pobj, 100 | self.smoothing_placeholders['sn_neg_subj']: sn_nsub, 101 | self.smoothing_placeholders['sn_neg_obj']: sn_nobj, 102 | self.smoothing_placeholders['sn_concepts']: conc, 103 | self.smoothing_placeholders['sn_conc_counts']: c_lens, 104 | self.smoothing_placeholders['sn_types']: types} 105 | else: 106 | rel, psub, pobj, nsub, nobj = batch 107 | return {self.relations: rel, 108 | self.pos_subj: psub, 109 | self.pos_obj: pobj, 110 | self.neg_subj: nsub, 111 | self.neg_obj: nobj, 112 | self.smoothing_placeholders['sn_relations']: [0], 113 | self.smoothing_placeholders['sn_pos_subj']: [0], 114 | self.smoothing_placeholders['sn_pos_obj']: [0], 115 | self.smoothing_placeholders['sn_neg_subj']: [[0]], 116 | self.smoothing_placeholders['sn_neg_obj']: [[0]], 117 | self.smoothing_placeholders['sn_concepts']: np.zeros([1, 1000], dtype=np.int32), 118 | self.smoothing_placeholders['sn_conc_counts']: [1], 119 | self.smoothing_placeholders['sn_types']: [0]} 120 | else: 121 | rel, psub, pobj, nsub, nobj = batch 122 | return {self.relations: rel, 123 | self.pos_subj: psub, 124 | self.pos_obj: pobj, 125 | self.neg_subj: nsub, 126 | self.neg_obj: nobj} 127 | 128 | def progress_update(self, batch, fetched, **kwargs): 129 | print('Avg loss of last batch: %.4f' % np.average(fetched[1])) 130 | print('Avg probability of last batch: %.4f' % np.average(fetched[2])) 131 | print('Avg pos energy of last batch: %.4f' % np.average(fetched[3])) 132 | print('Avg neg energy of last batch: %.4f' % np.average(fetched[4])) 133 | 134 | def data_provider(self, config, is_training, **kwargs): 135 | if self.use_semantic_network: 136 | return DataGenerator.wrap_generators(self.data_generator.generate_mt_gen_mode, 137 | self.data_generator.generate_sn_gen_mode, is_training) 138 | else: 139 | return self.data_generator.generate_mt_gen_mode(is_training) 140 | 141 | 142 | class GanGenerator(Generator): 143 | def __init__(self, config, embedding_model, data_generator=None): 144 | Generator.__init__(self, config, embedding_model, data_generator) 145 | self.gan_mode = True 146 | self.sampl_distributions = None 147 | 148 | def build(self): 149 | # [batch_size, num_samples] 150 | self.sampl_energies = self.embedding_model.energy(self.neg_subj, 151 | tf.expand_dims(self.relations, axis=1), 152 | self.neg_obj) 153 | # [batch_size] 154 | self.true_energies = self.embedding_model.energy(self.pos_subj, 155 | self.relations, 156 | self.pos_obj) 157 | 158 | optimizer = self.optimizer() 159 | if self.use_semantic_network: 160 | # this method also adds values for self.sampl_distributions and self.type_probabilities 161 | loss, _ = Smoothing.add_gen_semantic_network(self) 162 | grads_and_vars = optimizer.compute_gradients(loss) 163 | vars_with_grad = [v for g, v in grads_and_vars if g is not None] 164 | if not vars_with_grad: 165 | raise ValueError( 166 | "No gradients provided for any variable, check your graph for ops" 167 | " that do not support gradients, between variables %s and loss %s." % 168 | ([str(v) for _, v in grads_and_vars], loss)) 169 | discounted_grads_and_vars = [(self.discounted_reward * g, v) for g, v in grads_and_vars if g is not None] 170 | self.train_op = optimizer.apply_gradients(discounted_grads_and_vars, 171 | global_step=tf.train.get_or_create_global_step()) 172 | summary = [tf.summary.scalar('avg_st_prob', tf.reduce_mean(self.type_probabilities)), 173 | tf.summary.scalar('sn_loss', loss / self.batch_size), 174 | tf.summary.scalar('reward', self.discounted_reward)] 175 | else: 176 | # [batch_size, num_samples] - this is for sampling during GAN training 177 | self.probability_distributions = tf.nn.softmax(self.sampl_energies, axis=-1) 178 | self.probabilities = tf.gather_nd(self.probability_distributions, self.gan_loss_sample, name='sampl_probs') 179 | loss = -tf.reduce_sum(tf.log(self.probabilities)) 180 | summary = [tf.summary.scalar('avg_sampled_prob', tf.reduce_mean(self.probabilities))] 181 | 182 | # if training as part of a GAN, gradients should be scaled by discounted_reward 183 | grads_and_vars = optimizer.compute_gradients(loss) 184 | vars_with_grad = [v for g, v in grads_and_vars if g is not None] 185 | if not vars_with_grad: 186 | raise ValueError( 187 | "No gradients provided for any variable, check your graph for ops" 188 | " that do not support gradients, between variables %s and loss %s." % 189 | ([str(v) for _, v in grads_and_vars], loss)) 190 | discounted_grads_and_vars = [(self.discounted_reward * g, v) for g, v in grads_and_vars if g is not None] 191 | self.train_op = optimizer.apply_gradients(discounted_grads_and_vars, 192 | global_step=tf.train.get_or_create_global_step()) 193 | 194 | # reporting loss 195 | self.loss = loss / self.batch_size 196 | # summary 197 | self.summary = tf.summary.merge(summary) 198 | -------------------------------------------------------------------------------- /python/eukg/gan/Generator.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-mal/umls-embeddings/1ac29b15f5bedf8edbd2c90d0ccdae094be03518/python/eukg/gan/Generator.pyc -------------------------------------------------------------------------------- /python/eukg/gan/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home/rmm120030/code/python/tf_util') -------------------------------------------------------------------------------- /python/eukg/gan/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-mal/umls-embeddings/1ac29b15f5bedf8edbd2c90d0ccdae094be03518/python/eukg/gan/__init__.pyc -------------------------------------------------------------------------------- /python/eukg/gan/train_gan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import math 4 | import random 5 | from tqdm import tqdm 6 | import numpy as np 7 | from itertools import izip 8 | 9 | from .. import Config 10 | from ..data import data_util, DataGenerator 11 | from ..emb import EmbeddingModel 12 | from Generator import GanGenerator 13 | import Discriminator 14 | 15 | 16 | # noinspection PyUnboundLocalVariable 17 | def train(): 18 | config = Config.flags 19 | 20 | use_semnet = not config.no_semantic_network 21 | 22 | # init model dir 23 | gan_model_dir = os.path.join(config.model_dir, config.model, config.run_name) 24 | if not os.path.exists(gan_model_dir): 25 | os.makedirs(gan_model_dir) 26 | 27 | # init summaries dir 28 | config.summaries_dir = os.path.join(config.summaries_dir, config.run_name) 29 | if not os.path.exists(config.summaries_dir): 30 | os.makedirs(config.summaries_dir) 31 | 32 | # save the config 33 | data_util.save_config(gan_model_dir, config) 34 | 35 | # load data 36 | cui2id, data, train_idx, val_idx = data_util.load_metathesaurus_data(config.data_dir, config.val_proportion) 37 | config.val_progress_update_interval = int(math.floor(float(len(val_idx)) / config.batch_size)) 38 | config.batches_per_epoch = int(math.floor(float(len(train_idx)) / config.batch_size)) 39 | if not config.no_semantic_network: 40 | type2cuis = data_util.load_semantic_network_data(config.data_dir, data) 41 | else: 42 | type2cuis = None 43 | data_generator = DataGenerator.DataGenerator(data, train_idx, val_idx, config, type2cuis) 44 | 45 | with tf.Graph().as_default(), tf.Session() as session: 46 | # init models 47 | with tf.variable_scope(config.dis_run_name): 48 | discriminator = init_model(config, 'disc') 49 | with tf.variable_scope(config.gen_run_name): 50 | config.no_semantic_network = True 51 | config.learning_rate = 1e-1 52 | generator = init_model(config, 'gen') 53 | if use_semnet: 54 | with tf.variable_scope(config.sn_gen_run_name): 55 | config.no_semantic_network = False 56 | sn_generator = init_model(config, 'sn_gen') 57 | 58 | tf.global_variables_initializer().run() 59 | tf.local_variables_initializer().run() 60 | 61 | # init saver 62 | dis_saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=config.dis_run_name), 63 | max_to_keep=10) 64 | gen_saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=config.gen_run_name), 65 | max_to_keep=10) 66 | 67 | # load models 68 | dis_ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, config.dis_run_name)) 69 | dis_saver.restore(session, dis_ckpt) 70 | gen_ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, "distmult", config.gen_run_name)) 71 | gen_saver.restore(session, gen_ckpt) 72 | if use_semnet: 73 | sn_gen_saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 74 | scope=config.sn_gen_run_name), 75 | max_to_keep=10) 76 | sn_gen_ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, "distmult", config.sn_gen_run_name)) 77 | sn_gen_saver.restore(session, sn_gen_ckpt) 78 | 79 | # finalize graph 80 | tf.get_default_graph().finalize() 81 | 82 | # define streaming_accuracy reset per epoch 83 | print('local variables that will be reinitialized every epoch: %s' % tf.local_variables()) 84 | reset_local_vars = lambda: session.run(discriminator.reset_streaming_metrics_op) 85 | 86 | # init summary directories and summary writers 87 | if not os.path.exists(os.path.join(config.summaries_dir, 'train')): 88 | os.makedirs(os.path.join(config.summaries_dir, 'train')) 89 | train_summary_writer = tf.summary.FileWriter(os.path.join(config.summaries_dir, 'train')) 90 | if not os.path.exists(os.path.join(config.summaries_dir, 'val')): 91 | os.makedirs(os.path.join(config.summaries_dir, 'val')) 92 | val_summary_writer = tf.summary.FileWriter(os.path.join(config.summaries_dir, 'val')) 93 | 94 | # config_map = config.flag_values_dict() 95 | # config_map['data'] = data 96 | # config_map['train_idx'] = train_idx 97 | # config_map['val_idx'] = val_idx 98 | 99 | global_step = 0 100 | for ep in xrange(config.num_epochs): 101 | print('----------------------------') 102 | print('Begin Train Epoch %d' % ep) 103 | if use_semnet: 104 | global_step = train_epoch_sn(session, discriminator, generator, sn_generator, config, data_generator, 105 | train_summary_writer, global_step) 106 | sn_gen_saver.save(session, os.path.join(gan_model_dir, 'sn_generator', config.model), global_step=global_step) 107 | else: 108 | global_step = train_epoch(session, discriminator, generator, config, data_generator, train_summary_writer, 109 | global_step) 110 | print("Saving models to %s at step %d" % (gan_model_dir, global_step)) 111 | dis_saver.save(session, os.path.join(gan_model_dir, 'discriminator', config.model), global_step=global_step) 112 | gen_saver.save(session, os.path.join(gan_model_dir, 'generator', config.model), global_step=global_step) 113 | reset_local_vars() 114 | print('----------------------------') 115 | print('Begin Validation Epoch %d' % ep) 116 | validation_epoch(session, discriminator, config, data_generator, val_summary_writer, global_step) 117 | 118 | 119 | def init_model(config, mode): 120 | print('Initializing %s model...' % mode) 121 | 122 | if mode == 'disc': 123 | if config.model == 'transe': 124 | em = EmbeddingModel.TransE(config) 125 | elif config.model == 'transd': 126 | # config.embedding_size = config.embedding_size / 2 127 | em = EmbeddingModel.TransD(config) 128 | # config.embedding_size = config.embedding_size * 2 129 | else: 130 | raise ValueError('Unrecognized model type: %s' % config.model) 131 | model = Discriminator.BaseModel(config, em) 132 | elif mode == 'gen': 133 | em = EmbeddingModel.DistMult(config) 134 | model = GanGenerator(config, em) 135 | elif mode == 'sn_gen': 136 | em = EmbeddingModel.DistMult(config) 137 | model = GanGenerator(config, em) 138 | else: 139 | raise ValueError('Unrecognized mode: %s' % config.mode) 140 | 141 | model.build() 142 | return model 143 | 144 | 145 | def find_unique(tensor_list): 146 | if max([len(t.shape) for t in tensor_list[:10]]) == 1: 147 | return np.unique(np.concatenate(tensor_list[:10])) 148 | else: 149 | return np.unique(np.concatenate([t.flatten() for t in tensor_list[:10]])) 150 | 151 | 152 | def sample_corrupted_triples(sampl_sub, sampl_obj, probability_distributions, idx_np): 153 | nsub = [] 154 | nobj = [] 155 | sampl_idx = [] 156 | for i, dist in enumerate(probability_distributions): 157 | [j] = np.random.choice(idx_np, [1], p=dist) 158 | nsub.append(sampl_sub[i, j]) 159 | nobj.append(sampl_obj[i, j]) 160 | sampl_idx.append([i, j]) 161 | nsub = np.asarray(nsub) 162 | nobj = np.asarray(nobj) 163 | return nsub, nobj, sampl_idx 164 | 165 | 166 | def train_epoch(session, discriminator, generator, config, data_generator, summary_writer, global_step): 167 | baseline = 0. 168 | console_update_interval = config.progress_update_interval 169 | pbar = tqdm(total=console_update_interval) 170 | idx_np = np.arange(config.num_generator_samples) 171 | for b, batch in enumerate(data_generator.generate_mt_gen_mode(True)): 172 | verbose_batch = b > 0 and b % console_update_interval == 0 173 | 174 | # generation 175 | gen_feed_dict = generator.prepare_feed_dict(batch, True) 176 | probability_distributions = session.run(generator.probability_distributions, gen_feed_dict) 177 | rel, psub, pobj, sampl_sub, sampl_obj = batch 178 | nsub = [] 179 | nobj = [] 180 | sampl_idx = [] 181 | for i, dist in enumerate(probability_distributions): 182 | [j] = np.random.choice(idx_np, [1], p=dist) 183 | nsub.append(sampl_sub[i, j]) 184 | nobj.append(sampl_obj[i, j]) 185 | sampl_idx.append([i, j]) 186 | nsub = np.asarray(nsub) 187 | nobj = np.asarray(nobj) 188 | 189 | # discrimination 190 | dis_fetched = session.run(discriminator.fetches(True, verbose_batch) + [discriminator.reward], 191 | {discriminator.relations: rel, 192 | discriminator.pos_subj: psub, 193 | discriminator.pos_obj: pobj, 194 | discriminator.neg_subj: nsub, 195 | discriminator.neg_obj: nobj}) 196 | 197 | # generation reward 198 | discounted_reward = dis_fetched[-1] - baseline 199 | baseline = dis_fetched[-1] 200 | gen_feed_dict[generator.discounted_reward] = discounted_reward 201 | gen_feed_dict[generator.gan_loss_sample] = np.asarray(sampl_idx) 202 | gen_fetched = session.run([generator.summary, generator.loss, generator.probabilities, generator.train_op], 203 | gen_feed_dict) 204 | # assert gloss == gen_fetched[1], \ 205 | # "Forward pass for generation step does not match forward pass for generator learning! %f != %f" \ 206 | # % (gloss, gen_fetched[1]) 207 | 208 | # update tensorboard summary 209 | summary_writer.add_summary(dis_fetched[0], global_step) 210 | summary_writer.add_summary(gen_fetched[0], global_step) 211 | global_step += 1 212 | 213 | # perform normalization 214 | session.run([generator.norm_op, discriminator.norm_op], 215 | {generator.ids_to_update: find_unique(batch), 216 | discriminator.ids_to_update: find_unique([rel, psub, pobj, nsub, nobj])}) 217 | 218 | # udpate progress bar 219 | pbar.set_description("Training Batch: %d. GLoss: %.4f. DLoss: %.4f. Reward: %.4f" % 220 | (b, gen_fetched[1], dis_fetched[1], discounted_reward)) 221 | pbar.update() 222 | 223 | if verbose_batch: 224 | print('Discriminator:') 225 | discriminator.progress_update(batch, dis_fetched) 226 | print('Generator:') 227 | print('Avg probability of sampled negative examples from last batch: %.4f' % np.average(gen_fetched[2])) 228 | pbar.close() 229 | pbar = tqdm(total=console_update_interval) 230 | pbar.close() 231 | 232 | return global_step 233 | 234 | 235 | def train_epoch_sn(sess, discriminator, generator, sn_generator, config, data_generator, summary_writer, global_step): 236 | baseline = 0. 237 | sn_baseline = 0. 238 | pbar = None 239 | console_update_interval = config.progress_update_interval 240 | idx_np = np.arange(config.num_generator_samples) 241 | sn_idx_np = np.arange(10) 242 | for b, (mt_batch, sn_batch) in enumerate(izip(data_generator.generate_mt_gen_mode(True), 243 | data_generator.generate_sn_gen_mode(True))): 244 | verbose_batch = b > 0 and b % console_update_interval == 0 245 | 246 | # mt generation 247 | gen_feed_dict = generator.prepare_feed_dict(mt_batch, True) 248 | concept_distributions = sess.run(generator.probability_distributions, gen_feed_dict) 249 | rel, psub, pobj, sampl_sub, sampl_obj = mt_batch 250 | nsub, nobj, sampl_idx = sample_corrupted_triples(sampl_sub, sampl_obj, concept_distributions, idx_np) 251 | 252 | # sn generation 253 | sn_gen_feed_dict = {sn_generator.smoothing_placeholders['sn_relations']: sn_batch[0], 254 | sn_generator.smoothing_placeholders['sn_neg_subj']: sn_batch[3], 255 | sn_generator.smoothing_placeholders['sn_neg_obj']: sn_batch[4]} 256 | type_distributions = sess.run(sn_generator.sampl_distributions, sn_gen_feed_dict) 257 | sn_nsub, sn_nobj, sn_sampl_idx = sample_corrupted_triples(sn_batch[3], sn_batch[4], type_distributions, sn_idx_np) 258 | types = np.unique(np.concatenate([sn_batch[1], sn_batch[2], sn_nsub, sn_nobj])) 259 | concepts = np.zeros([len(types), config.max_concepts_per_type], dtype=np.int32) 260 | concept_lens = np.zeros([len(types)], dtype=np.int32) 261 | for i, tid in enumerate(types): 262 | concepts_of_type_t = data_generator.type2cuis[tid] if tid in data_generator.type2cuis else [] 263 | random.shuffle(concepts_of_type_t) 264 | concepts_of_type_t = concepts_of_type_t[:config.max_concepts_per_type] 265 | concept_lens[i] = len(concepts_of_type_t) 266 | concepts[i, :len(concepts_of_type_t)] = concepts_of_type_t 267 | 268 | # discrimination 269 | dis_fetched = sess.run(discriminator.fetches(True, verbose_batch) + [discriminator.sn_reward, discriminator.reward], 270 | {discriminator.relations: rel, 271 | discriminator.pos_subj: psub, 272 | discriminator.pos_obj: pobj, 273 | discriminator.neg_subj: nsub, 274 | discriminator.neg_obj: nobj, 275 | discriminator.smoothing_placeholders['sn_relations']: sn_batch[0], 276 | discriminator.smoothing_placeholders['sn_pos_subj']: sn_batch[1], 277 | discriminator.smoothing_placeholders['sn_pos_obj']: sn_batch[2], 278 | discriminator.smoothing_placeholders['sn_neg_subj']: sn_nsub, 279 | discriminator.smoothing_placeholders['sn_neg_obj']: sn_nobj, 280 | discriminator.smoothing_placeholders['sn_types']: types, 281 | discriminator.smoothing_placeholders['sn_concepts']: concepts, 282 | discriminator.smoothing_placeholders['sn_conc_counts']: concept_lens}) 283 | 284 | # generation reward 285 | discounted_reward = dis_fetched[-1] - baseline 286 | baseline = dis_fetched[-1] 287 | gen_feed_dict[generator.discounted_reward] = discounted_reward 288 | gen_feed_dict[generator.gan_loss_sample] = np.asarray(sampl_idx) 289 | gen_fetched = sess.run([generator.summary, generator.loss, generator.probabilities, generator.train_op], 290 | gen_feed_dict) 291 | 292 | # sn generation reward 293 | sn_discounted_reward = dis_fetched[-2] - sn_baseline 294 | sn_baseline = dis_fetched[-2] 295 | sn_gen_feed_dict[sn_generator.discounted_reward] = sn_discounted_reward 296 | sn_gen_feed_dict[sn_generator.gan_loss_sample] = np.asarray(sn_sampl_idx) 297 | sn_gen_fetched = sess.run([sn_generator.summary, sn_generator.loss, 298 | sn_generator.type_probabilities, sn_generator.train_op], 299 | sn_gen_feed_dict) 300 | 301 | # update tensorboard summary 302 | summary_writer.add_summary(dis_fetched[0], global_step) 303 | summary_writer.add_summary(gen_fetched[0], global_step) 304 | summary_writer.add_summary(sn_gen_fetched[0], global_step) 305 | global_step += 1 306 | 307 | # perform normalization 308 | sess.run([generator.norm_op, discriminator.norm_op, sn_generator.norm_op], 309 | {generator.ids_to_update: find_unique(mt_batch + sn_batch), 310 | discriminator.ids_to_update: find_unique([rel, psub, pobj, nsub, nobj]), 311 | sn_generator.ids_to_update: sn_batch[7]}) 312 | 313 | # udpate progress bar 314 | pbar = tqdm(total=console_update_interval) if pbar is None else pbar 315 | pbar.set_description("Training Batch: %d. GLoss: %.4f. SN_GLoss: %.4f. DLoss: %.4f." % 316 | (b, gen_fetched[1], sn_gen_fetched[1], dis_fetched[1])) 317 | pbar.update() 318 | 319 | if verbose_batch: 320 | print('Discriminator:') 321 | discriminator.progress_update(mt_batch, dis_fetched) 322 | print('Generator:') 323 | print('Avg probability of sampled negative examples from last batch: %.4f' % np.average(gen_fetched[2])) 324 | print('SN Generator:') 325 | print('Avg probability of sampled negative examples from last batch: %.4f' % np.average(sn_gen_fetched[2])) 326 | pbar.close() 327 | pbar = tqdm(total=console_update_interval) 328 | if pbar: 329 | pbar.close() 330 | 331 | return global_step 332 | 333 | 334 | def validation_epoch(session, model, config, data_generator, summary_writer, global_step): 335 | console_update_interval = config.val_progress_update_interval 336 | pbar = tqdm(total=console_update_interval) 337 | # validation epoch 338 | for b, batch in enumerate(data_generator.generate_mt(False)): 339 | verbose_batch = b > 0 and b % console_update_interval == 0 340 | 341 | fetched = session.run(model.fetches(False, verbose=verbose_batch), model.prepare_feed_dict(batch, False)) 342 | 343 | # update tensorboard summary 344 | summary_writer.add_summary(fetched[0], global_step) 345 | global_step += 1 346 | 347 | # udpate progress bar 348 | pbar.set_description("Validation Batch: %d. Loss: %.4f" % (b, fetched[1])) 349 | pbar.update() 350 | 351 | if verbose_batch: 352 | model.progress_update(batch, fetched) 353 | pbar.close() 354 | pbar = tqdm(total=console_update_interval) 355 | pbar.close() 356 | -------------------------------------------------------------------------------- /python/eukg/gan/train_gan.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-mal/umls-embeddings/1ac29b15f5bedf8edbd2c90d0ccdae094be03518/python/eukg/gan/train_gan.pyc -------------------------------------------------------------------------------- /python/eukg/save_embeddings.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import numpy as np 4 | 5 | import Config 6 | from train import init_model 7 | 8 | 9 | def main(): 10 | config = Config.flags 11 | 12 | if config.mode == 'gan': 13 | scope = config.dis_run_name 14 | run_name = config.run_name + "/discriminator" 15 | config.mode = 'disc' 16 | else: 17 | scope = config.run_name 18 | run_name = config.run_name 19 | 20 | with tf.Graph().as_default(), tf.Session() as session: 21 | with tf.variable_scope(scope): 22 | model = init_model(config, None) 23 | saver = tf.train.Saver([var for var in tf.global_variables() if 'embeddings' in var.name]) 24 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, run_name)) 25 | print('Loading checkpoint: %s' % ckpt) 26 | saver.restore(session, ckpt) 27 | 28 | embeddings = session.run(model.embedding_model.embeddings) 29 | if config.model == 'transd': 30 | embeddings = np.concatenate((embeddings, session.run(model.embedding_model.p_embeddings)), axis=1) 31 | 32 | np.savez_compressed(config.embedding_file, 33 | embs=embeddings) 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /python/eukg/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-mal/umls-embeddings/1ac29b15f5bedf8edbd2c90d0ccdae094be03518/python/eukg/test/__init__.py -------------------------------------------------------------------------------- /python/eukg/test/classification.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import random 4 | import os 5 | from tqdm import tqdm 6 | 7 | from ..data import data_util, DataGenerator 8 | from .. import Config, train 9 | from sklearn.svm import LinearSVC 10 | from sklearn import metrics 11 | 12 | config = Config.flags 13 | 14 | 15 | def evaluate(): 16 | random.seed(1337) 17 | np.random.seed(1337) 18 | config.no_semantic_network = True 19 | config.batch_size = 2000 20 | 21 | cui2id, train_data, _, _ = data_util.load_metathesaurus_data(config.data_dir, config.val_proportion) 22 | test_data = data_util.load_metathesaurus_test_data(config.data_dir) 23 | print('Loaded %d test triples from %s' % (len(test_data['rel']), config.data_dir)) 24 | concept_ids = np.unique(np.concatenate([train_data['subj'], train_data['obj'], test_data['subj'], test_data['obj']])) 25 | print('%d total unique concepts' % len(concept_ids)) 26 | val_idx = np.random.permutation(np.arange(len(train_data['rel'])))[:100000] 27 | val_data_generator = DataGenerator.DataGenerator(train_data, 28 | train_idx=val_idx, 29 | val_idx=[], 30 | config=config, 31 | test_mode=True) 32 | 33 | valid_triples = set() 34 | for s, r, o in zip(train_data['subj'], train_data['rel'], train_data['obj']): 35 | valid_triples.add((s, r, o)) 36 | for s, r, o in zip(test_data['subj'], test_data['rel'], test_data['obj']): 37 | valid_triples.add((s, r, o)) 38 | print('%d valid triples' % len(valid_triples)) 39 | 40 | model_name = config.run_name 41 | if config.mode == 'gan': 42 | scope = config.dis_run_name 43 | model_name += '/discriminator' 44 | config.mode = 'disc' 45 | else: 46 | scope = config.run_name 47 | 48 | with tf.Graph().as_default(), tf.Session() as session: 49 | # init model 50 | with tf.variable_scope(scope): 51 | model = train.init_model(config, None) 52 | 53 | tf.global_variables_initializer().run() 54 | tf.local_variables_initializer().run() 55 | 56 | # init saver 57 | tf_saver = tf.train.Saver(max_to_keep=10) 58 | 59 | # load model 60 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, model_name)) 61 | print('Loading checkpoint: %s' % ckpt) 62 | tf_saver.restore(session, ckpt) 63 | tf.get_default_graph().finalize() 64 | 65 | scores = [] 66 | labels = [] 67 | for r, s, o, ns, no in tqdm(val_data_generator.generate_mt(True), total=val_data_generator.num_train_batches()): 68 | pscores, nscores = session.run([model.pos_energy, model.neg_energy], {model.relations: r, 69 | model.pos_subj: s, 70 | model.pos_obj: o, 71 | model.neg_subj: ns, 72 | model.neg_obj: no}) 73 | scores += pscores.tolist() 74 | labels += np.ones_like(pscores, dtype=np.int).tolist() 75 | scores += nscores.tolist() 76 | labels += np.zeros_like(nscores, dtype=np.int).tolist() 77 | print('Calculated scores. Training SVM.') 78 | svm = LinearSVC(dual=False) 79 | svm.fit(np.asarray(scores).reshape(-1, 1), labels) 80 | print('Done.') 81 | 82 | data_generator = DataGenerator.DataGenerator(test_data, 83 | train_idx=np.arange(len(test_data['rel'])), 84 | val_idx=[], 85 | config=config, 86 | test_mode=True) 87 | data_generator._sampler = val_data_generator.sampler 88 | scores, labels = [], [] 89 | for r, s, o, ns, no in tqdm(data_generator.generate_mt(True), desc='classifying', 90 | total=data_generator.num_train_batches()): 91 | pscores, nscores = session.run([model.pos_energy, model.neg_energy], {model.relations: r, 92 | model.pos_subj: s, 93 | model.pos_obj: o, 94 | model.neg_subj: ns, 95 | model.neg_obj: no}) 96 | scores += pscores.tolist() 97 | labels += np.ones_like(pscores, dtype=np.int).tolist() 98 | scores += nscores.tolist() 99 | labels += np.zeros_like(nscores, dtype=np.int).tolist() 100 | predictions = svm.predict(np.asarray(scores).reshape(-1, 1)) 101 | print('pred: %s' % predictions.shape) 102 | print('lbl: %d' % len(labels)) 103 | print('Relation Triple Classification Accuracy: %.4f' % metrics.accuracy_score(labels, predictions)) 104 | print('Relation Triple Classification Precision: %.4f' % metrics.precision_score(labels, predictions)) 105 | print('Relation Triple Classification Recall: %.4f' % metrics.recall_score(labels, predictions)) 106 | print(metrics.classification_report(labels, predictions)) 107 | 108 | 109 | def evaluate_sn(): 110 | random.seed(1337) 111 | config.no_semantic_network = False 112 | 113 | data = {} 114 | cui2id, _, _, _ = data_util.load_metathesaurus_data(config.data_dir, 0.) 115 | _ = data_util.load_semantic_network_data(config.data_dir, data) 116 | subj, rel, obj = data['sn_subj'], data['sn_rel'], data['sn_obj'] 117 | print('Loaded %d sn triples from %s' % (len(rel), config.data_dir)) 118 | 119 | valid_triples = set() 120 | for trip in zip(subj, rel, obj): 121 | valid_triples.add(trip) 122 | print('%d valid triples' % len(valid_triples)) 123 | idxs = np.random.permutation(np.arange(len(rel))) 124 | idxs = idxs[:600] 125 | subj, rel, obj = subj[idxs], rel[idxs], obj[idxs] 126 | sampler = DataGenerator.NegativeSampler(valid_triples=valid_triples, name='???') 127 | nsubj, nobj = sampler.sample(subj, rel, obj) 128 | 129 | model_name = config.run_name 130 | if config.mode == 'gan': 131 | scope = config.dis_run_name 132 | model_name += '/discriminator' 133 | config.mode = 'disc' 134 | else: 135 | scope = config.run_name 136 | 137 | with tf.Graph().as_default(), tf.Session() as session: 138 | # init model 139 | with tf.variable_scope(scope): 140 | model = train.init_model(config, None) 141 | 142 | tf.global_variables_initializer().run() 143 | tf.local_variables_initializer().run() 144 | 145 | # init saver 146 | tf_saver = tf.train.Saver(max_to_keep=10) 147 | 148 | # load model 149 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, model_name)) 150 | print('Loading checkpoint: %s' % ckpt) 151 | tf_saver.restore(session, ckpt) 152 | tf.get_default_graph().finalize() 153 | 154 | feed_dict = {model.smoothing_placeholders['sn_relations']: rel, 155 | model.smoothing_placeholders['sn_pos_subj']: subj, 156 | model.smoothing_placeholders['sn_pos_obj']: obj, 157 | model.smoothing_placeholders['sn_neg_subj']: nsubj, 158 | model.smoothing_placeholders['sn_neg_obj']: nobj} 159 | pscores, nscores = session.run([model.sn_pos_energy, model.sn_neg_energy], feed_dict) 160 | scores = np.concatenate((pscores, nscores)) 161 | labels = np.concatenate((np.ones_like(pscores, dtype=np.int), np.zeros_like(nscores, dtype=np.int))) 162 | print('Calculated scores. Training SVM.') 163 | svm = LinearSVC(dual=False) 164 | svm.fit(scores.reshape(-1, 1), labels) 165 | print('Done.') 166 | 167 | with np.load(os.path.join(config.data_dir, 'semnet', 'triples.npz')) as npz: 168 | subj = npz['subj'] 169 | rel = npz['rel'] 170 | obj = npz['obj'] 171 | nsubj, nobj = sampler.sample(subj, rel, obj) 172 | feed_dict = {model.smoothing_placeholders['sn_relations']: rel, 173 | model.smoothing_placeholders['sn_pos_subj']: subj, 174 | model.smoothing_placeholders['sn_pos_obj']: obj, 175 | model.smoothing_placeholders['sn_neg_subj']: nsubj, 176 | model.smoothing_placeholders['sn_neg_obj']: nobj} 177 | pscores, nscores = session.run([model.sn_pos_energy, model.sn_neg_energy], feed_dict) 178 | predictions = svm.predict(pscores.reshape(-1, 1)).tolist() 179 | labels = np.ones_like(pscores, dtype=np.int).tolist() 180 | predictions += svm.predict(nscores.reshape(-1, 1)).tolist() 181 | labels += np.zeros_like(nscores, dtype=np.int).tolist() 182 | print('SN Relation Triple Classification Accuracy: %.4f' % metrics.accuracy_score(labels, predictions)) 183 | print('SN Relation Triple Classification Precision: %.4f' % metrics.precision_score(labels, predictions)) 184 | print('SN Relation Triple Classification Recall: %.4f' % metrics.recall_score(labels, predictions)) 185 | print(metrics.classification_report(labels, predictions)) 186 | 187 | 188 | 189 | if __name__ == "__main__": 190 | if config.eval_mode == 'sn': 191 | evaluate_sn() 192 | else: 193 | evaluate() 194 | -------------------------------------------------------------------------------- /python/eukg/test/nearest_neighbors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import json 4 | import os 5 | import csv 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | csv.field_size_limit(sys.maxsize) 10 | 11 | 12 | def knn(embeddings, idx, k): 13 | k = k+1 14 | e = np.expand_dims(embeddings[idx], axis=0) 15 | distances = np.linalg.norm(e - embeddings, axis=1) 16 | print(distances.shape) 17 | idxs = list(range(embeddings.shape[0])) 18 | idxs.sort(key=lambda i_: distances[i_]) 19 | 20 | return zip(idxs[:k], [distances[i] for i in idxs[:k]]) 21 | 22 | 23 | def main(): 24 | cui2names = defaultdict(list) 25 | with open('/home/rmm120030/working/umls-mke/umls/META/MRCONSO.RRF', 'r') as f: 26 | reader = csv.reader(f, delimiter='|') 27 | for row in tqdm(reader, desc="reading mrconso", total=8157818): 28 | cui2names[row[0]].append(row[14]) 29 | 30 | cui2id = json.load(open(os.path.join('/home/rmm120030/working/umls-mke/data', 'name2id.json'))) 31 | id2cui = {v: k for k, v in cui2id.iteritems()} 32 | with np.load(sys.argv[1]) as npz: 33 | embeddings = npz['embs'] 34 | print('Loaded embedding matrix: %s' % str(embeddings.shape)) 35 | 36 | while True: 37 | cui = raw_input('enter CUI (or \'exit\' to stop): ') 38 | if cui == "exit": 39 | exit() 40 | if cui in cui2id: 41 | neihbors = knn(embeddings, cui2id[cui], 10) 42 | for i, dist in neihbors: 43 | c = id2cui[i] 44 | print(' %.6f - %s - %s' % (dist, c, cui2names[c][:5])) 45 | else: 46 | print('No embedding for CUI %s' % cui) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /python/eukg/test/ppa.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import random 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | 8 | from ..data import data_util, DataGenerator 9 | from .. import Config, train 10 | 11 | config = Config.flags 12 | 13 | 14 | def evaluate(): 15 | random.seed(1337) 16 | config.no_semantic_network = True 17 | 18 | cui2id, train_data, _, _ = data_util.load_metathesaurus_data(config.data_dir, config.val_proportion) 19 | id2cui = {v: k for k, v in cui2id.iteritems()} 20 | test_data = data_util.load_metathesaurus_test_data(config.data_dir) 21 | print('Loaded %d test triples from %s' % (len(test_data['rel']), config.data_dir)) 22 | concept_ids = np.unique(np.concatenate([train_data['subj'], train_data['obj'], test_data['subj'], test_data['obj']])) 23 | print('%d total unique concepts' % len(concept_ids)) 24 | data_generator = DataGenerator.DataGenerator(test_data, 25 | train_idx=np.arange(len(test_data['rel'])), 26 | val_idx=[], 27 | config=config, 28 | test_mode=True) 29 | 30 | valid_triples = set() 31 | for s, r, o in zip(train_data['subj'], train_data['rel'], train_data['obj']): 32 | valid_triples.add((s, r, o)) 33 | for s, r, o in zip(test_data['subj'], test_data['rel'], test_data['obj']): 34 | valid_triples.add((s, r, o)) 35 | print('%d valid triples' % len(valid_triples)) 36 | 37 | model_name = config.run_name 38 | if config.mode == 'gan': 39 | scope = config.dis_run_name 40 | model_name += '/discriminator' 41 | config.mode = 'disc' 42 | else: 43 | scope = config.run_name 44 | 45 | with tf.Graph().as_default(), tf.Session() as session: 46 | # init model 47 | with tf.variable_scope(scope): 48 | model = train.init_model(config, data_generator) 49 | 50 | tf.global_variables_initializer().run() 51 | tf.local_variables_initializer().run() 52 | 53 | # init saver 54 | tf_saver = tf.train.Saver(max_to_keep=10) 55 | 56 | # load model 57 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, model_name)) 58 | print('Loading checkpoint: %s' % ckpt) 59 | tf_saver.restore(session, ckpt) 60 | tf.get_default_graph().finalize() 61 | 62 | def decode_triple(s_, r_, o_, score): 63 | return id2cui[s_], id2cui[r_], id2cui[o_], str(score) 64 | 65 | incorrect = [] # list of tuples of (triple, corrupted_triple) 66 | num_correct = 0 67 | total = 0 68 | for r, s, o, ns, no in tqdm(data_generator.generate_mt(True), total=data_generator.num_train_batches()): 69 | pscores, nscores = session.run([model.pos_energy, model.neg_energy], {model.relations: r, 70 | model.pos_subj: s, 71 | model.pos_obj: o, 72 | model.neg_subj: ns, 73 | model.neg_obj: no}) 74 | total += len(r) 75 | for pscore, nscore, rel, subj, obj, nsubj, nobj in zip(pscores, nscores, r, s, o, ns, no): 76 | if pscore < nscore: 77 | num_correct += 1 78 | else: 79 | incorrect.append((decode_triple(subj, rel, obj, pscore), decode_triple(nsubj, rel, nobj, nscore))) 80 | ppa = float(num_correct)/total 81 | print('PPA: %.4f' % ppa) 82 | outdir = os.path.join(config.eval_dir, config.run_name) 83 | if not os.path.exists(outdir): 84 | os.makedirs(outdir) 85 | json.dump(incorrect, open(os.path.join(outdir, 'ppa_incorrect.json'), 'w+')) 86 | with open(os.path.join(outdir, 'ppa.txt'), 'w+') as f: 87 | f.write(str(ppa)) 88 | 89 | 90 | def evaluate_sn(): 91 | random.seed(1337) 92 | config.no_semantic_network = False 93 | 94 | data = {} 95 | cui2id, _, _, _ = data_util.load_metathesaurus_data(config.data_dir, 0.) 96 | id2cui = {v: k for k, v in cui2id.iteritems()} 97 | _ = data_util.load_semantic_network_data(config.data_dir, data) 98 | subj, rel, obj = data['sn_subj'], data['sn_rel'], data['sn_obj'] 99 | print('Loaded %d sn triples from %s' % (len(rel), config.data_dir)) 100 | 101 | valid_triples = set() 102 | for trip in zip(subj, rel, obj): 103 | valid_triples.add(trip) 104 | print('%d valid triples' % len(valid_triples)) 105 | idxs = np.arange(len(rel)) 106 | np.random.shuffle(idxs) 107 | idxs = idxs[:600] 108 | subj, rel, obj = subj[idxs], rel[idxs], obj[idxs] 109 | sampler = DataGenerator.NegativeSampler(valid_triples=valid_triples, name='???') 110 | nsubj, nobj = sampler.sample(subj, rel, obj) 111 | 112 | model_name = config.run_name 113 | if config.mode == 'gan': 114 | scope = config.dis_run_name 115 | model_name += '/discriminator' 116 | config.mode = 'disc' 117 | else: 118 | scope = config.run_name 119 | 120 | with tf.Graph().as_default(), tf.Session() as session: 121 | # init model 122 | with tf.variable_scope(scope): 123 | model = train.init_model(config, None) 124 | 125 | tf.global_variables_initializer().run() 126 | tf.local_variables_initializer().run() 127 | 128 | # init saver 129 | tf_saver = tf.train.Saver(max_to_keep=10) 130 | 131 | # load model 132 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, model_name)) 133 | print('Loading checkpoint: %s' % ckpt) 134 | tf_saver.restore(session, ckpt) 135 | tf.get_default_graph().finalize() 136 | 137 | def decode_triple(s_, r_, o_, score): 138 | return id2cui[s_], id2cui[r_], id2cui[o_], str(score) 139 | 140 | incorrect = [] # list of tuples of (triple, corrupted_triple) 141 | num_correct = 0 142 | total = len(rel) 143 | feed_dict = {model.smoothing_placeholders['sn_relations']: rel, 144 | model.smoothing_placeholders['sn_pos_subj']: subj, 145 | model.smoothing_placeholders['sn_pos_obj']: obj, 146 | model.smoothing_placeholders['sn_neg_subj']: nsubj, 147 | model.smoothing_placeholders['sn_neg_obj']: nobj} 148 | pscores, nscores = session.run([model.sn_pos_energy, model.sn_neg_energy], feed_dict) 149 | for pscore, nscore, r, s, o, ns, no in zip(pscores, nscores, rel, subj, obj, nsubj, nobj): 150 | if pscore < nscore: 151 | num_correct += 1 152 | else: 153 | incorrect.append((decode_triple(s, r, o, pscore), decode_triple(ns, r, no, nscore))) 154 | ppa = float(num_correct)/total 155 | print('PPA: %.4f' % ppa) 156 | outdir = os.path.join(config.eval_dir, config.run_name) 157 | if not os.path.exists(outdir): 158 | os.makedirs(outdir) 159 | json.dump(incorrect, open(os.path.join(outdir, 'sn_ppa_incorrect.json'), 'w+')) 160 | with open(os.path.join(outdir, 'sn_ppa.txt'), 'w+') as f: 161 | f.write(str(ppa)) 162 | 163 | 164 | if __name__ == "__main__": 165 | if config.eval_mode == 'sn': 166 | evaluate_sn() 167 | else: 168 | evaluate() 169 | -------------------------------------------------------------------------------- /python/eukg/test/ranking_evals.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import random 4 | import math 5 | import os 6 | import json 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | 10 | from ..data import data_util, DataGenerator 11 | from .. import Config, train 12 | from ..threading_util import synchronized, parallel_stream 13 | 14 | 15 | config = Config.flags 16 | 17 | 18 | def save_ranks(): 19 | random.seed(1337) 20 | config.batch_size = 4096 21 | 22 | cui2id, train_data, _, _ = data_util.load_metathesaurus_data(config.data_dir, config.val_proportion) 23 | id2cui = {v: k for k, v in cui2id.iteritems()} 24 | test_data = data_util.load_metathesaurus_test_data(config.data_dir) 25 | print('Loaded %d test triples from %s' % (len(test_data['rel']), config.data_dir)) 26 | 27 | valid_triples = set() 28 | for s, r, o in zip(train_data['subj'], train_data['rel'], train_data['obj']): 29 | valid_triples.add((s, r, o)) 30 | for s, r, o in zip(test_data['subj'], test_data['rel'], test_data['obj']): 31 | valid_triples.add((s, r, o)) 32 | print('%d valid triples' % len(valid_triples)) 33 | 34 | model_name = config.run_name 35 | if config.mode == 'gan': 36 | scope = config.dis_run_name 37 | model_name += '/discriminator' 38 | config.mode = 'disc' 39 | else: 40 | scope = config.run_name 41 | 42 | with tf.Graph().as_default(), tf.Session() as session: 43 | # init model 44 | with tf.variable_scope(scope): 45 | model = train.init_model(config, None) 46 | 47 | tf.global_variables_initializer().run() 48 | tf.local_variables_initializer().run() 49 | 50 | # init saver 51 | tf_saver = tf.train.Saver(max_to_keep=10) 52 | 53 | # load model 54 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, model_name)) 55 | print('Loading checkpoint: %s' % ckpt) 56 | tf_saver.restore(session, ckpt) 57 | tf.get_default_graph().finalize() 58 | 59 | if not config.save_ranks: 60 | print('WARNING: ranks will not be saved! This run should only be for debugging purposes!') 61 | 62 | outdir = os.path.join(config.eval_dir, config.run_name) 63 | if not os.path.exists(outdir): 64 | os.makedirs(outdir) 65 | 66 | chunk_size = (1. / config.num_shards) * len(test_data['rel']) 67 | print('Chunk size: %f' % chunk_size) 68 | start = int((config.shard - 1) * chunk_size) 69 | end = (len(test_data['rel']) if config.shard == int(config.num_shards) else int(config.shard * chunk_size)) \ 70 | if config.save_ranks else start + 1000 71 | print('Processing data from idx %d to %d' % (start, end)) 72 | 73 | sampler = DataGenerator.NegativeSampler(valid_triples=valid_triples, name='gan') 74 | global ranks 75 | global first 76 | ranks = [] 77 | first = True 78 | open_file = lambda: open(os.path.join(outdir, 'ranks_%d.json' % config.shard), 'w+') \ 79 | if config.save_ranks else open("/dev/null") 80 | with open_file() as f, tqdm(total=(end-start)) as pbar: 81 | if config.save_ranks: 82 | f.write('[') 83 | 84 | # define thread function 85 | def calculate(triple): 86 | s, r, o = triple 87 | # subj 88 | invalid_concepts = [s] + synchronized(lambda: sampler.invalid_concepts(s, r, o, True)) 89 | subj_scores, subj_ranking = calculate_scores(s, r, o, True, invalid_concepts, session, model, config.batch_size) 90 | srank = subj_ranking[s] if s in subj_ranking else -1 91 | 92 | # obj 93 | invalid_concepts = [o] + synchronized(lambda: sampler.invalid_concepts(s, r, o, False)) 94 | obj_scores, obj_ranking = calculate_scores(s, r, o, False, invalid_concepts, session, model, config.batch_size) 95 | orank = obj_ranking[o] if o in obj_ranking else -1 96 | json_string = json.dumps([str(srank), 97 | str(orank), 98 | str(obj_scores[o]), 99 | id2cui[s], id2cui[r], id2cui[o]]) 100 | return srank, orank, json_string 101 | 102 | # define save function 103 | def save(future): 104 | global ranks 105 | global first 106 | 107 | srank, orank, json_string = future.result() 108 | ranks += [srank, orank] 109 | if config.save_ranks: 110 | if not first: 111 | f.write(',') 112 | first = False 113 | f.write(json_string) 114 | npr = np.asarray(ranks, dtype=np.float) 115 | if config.save_ranks: 116 | pbar.set_description('Srank: %5d. Orank: %5d.' % (srank, orank)) 117 | else: 118 | pbar.set_description('Srank: %5d. Orank: %5d. MRR: %.4f. H@10: %.4f' % 119 | (srank, orank, mrr(npr), hits_at_10(npr))) 120 | pbar.update() 121 | 122 | parallel_stream(zip(test_data['subj'][start:end], test_data['rel'][start:end], test_data['obj'][start:end]), 123 | parallelizable_fn=calculate, 124 | future_consumer=save) 125 | 126 | # finish json 127 | if config.save_ranks: 128 | f.write(']') 129 | 130 | ranks_np = np.asarray(ranks, dtype=np.float) 131 | print('Mean Reciprocal Rank of shard %d: %.4f' % (config.shard, mrr(ranks_np))) 132 | print('Mean Rank of shard %d: %.2f' % (config.shard, mr(ranks_np))) 133 | print('Hits @ 10 of shard %d: %.4f' % (config.shard, hits_at_10(ranks_np))) 134 | 135 | 136 | def calculate_scores(subj, rel, obj, replace_subject, concept_ids, session, model, batch_size): 137 | num_batches = int(math.ceil(float(len(concept_ids))/batch_size)) 138 | subjects = np.full(batch_size, subj, dtype=np.int32) 139 | relations = np.full(batch_size, rel, dtype=np.int32) 140 | objects = np.full(batch_size, obj, dtype=np.int32) 141 | scores = {} 142 | 143 | for b in xrange(num_batches): 144 | concepts = concept_ids[b*batch_size:(b+1)*batch_size] 145 | feed_dict = {model.pos_subj: subjects, 146 | model.relations: relations, 147 | model.pos_obj: objects} 148 | 149 | # pad concepts if necessary 150 | if len(concepts) < batch_size: 151 | concepts = np.pad(concepts, (0, batch_size - len(concepts)), mode='constant', constant_values=0) 152 | 153 | # replace subj/obj in feed dict 154 | if replace_subject: 155 | feed_dict[model.pos_subj] = concepts 156 | else: 157 | feed_dict[model.pos_obj] = concepts 158 | 159 | # calculate energies 160 | energies = session.run(model.pos_energy, feed_dict) 161 | 162 | # store scores 163 | for i, cid in enumerate(concept_ids[b*batch_size:(b+1)*batch_size]): 164 | scores[cid] = energies[i] 165 | 166 | ranking = sorted(scores.keys(), key=lambda k: scores[k]) 167 | 168 | rank_map = {} 169 | prev_rank = 0 170 | prev_score = -1 171 | total = 1 172 | for c in ranking: 173 | # if c has a lower score than prev 174 | if scores[c] > prev_score: 175 | # increment the rank 176 | prev_rank = total 177 | # update score 178 | prev_score = scores[c] 179 | total += 1 180 | rank_map[c] = prev_rank 181 | 182 | return scores, rank_map 183 | 184 | 185 | def save_ranks_sn(): 186 | random.seed(1337) 187 | config.batch_size = 4096 188 | assert not config.no_semantic_network 189 | 190 | data = {} 191 | cui2id, _, _, _ = data_util.load_metathesaurus_data(config.data_dir, 0.) 192 | id2cui = {v: k for k, v in cui2id.iteritems()} 193 | _ = data_util.load_semantic_network_data(config.data_dir, data) 194 | subj, rel, obj = data['sn_subj'], data['sn_rel'], data['sn_obj'] 195 | n = len(rel) 196 | print('Loaded %d sn triples from %s' % (n, config.data_dir)) 197 | 198 | valid_triples = set() 199 | for trip in zip(subj, rel, obj): 200 | valid_triples.add(trip) 201 | print('%d valid triples' % len(valid_triples)) 202 | 203 | model_name = config.run_name 204 | if config.mode == 'gan': 205 | scope = config.dis_run_name 206 | model_name += '/discriminator' 207 | config.mode = 'disc' 208 | else: 209 | scope = config.run_name 210 | 211 | with tf.Graph().as_default(), tf.Session() as session: 212 | # init model 213 | with tf.variable_scope(scope): 214 | model = train.init_model(config, None) 215 | 216 | tf.global_variables_initializer().run() 217 | tf.local_variables_initializer().run() 218 | 219 | # init saver 220 | tf_saver = tf.train.Saver(max_to_keep=10) 221 | 222 | # load model 223 | ckpt = tf.train.latest_checkpoint(os.path.join(config.model_dir, config.model, model_name)) 224 | print('Loading checkpoint: %s' % ckpt) 225 | tf_saver.restore(session, ckpt) 226 | tf.get_default_graph().finalize() 227 | 228 | if not config.save_ranks: 229 | print('WARNING: ranks will not be saved! This run should only be for debugging purposes!') 230 | 231 | outdir = os.path.join(config.eval_dir, config.run_name) 232 | if not os.path.exists(outdir): 233 | os.makedirs(outdir) 234 | 235 | sampler = DataGenerator.NegativeSampler(valid_triples=valid_triples, name='gan') 236 | global ranks 237 | global first 238 | ranks = [] 239 | first = True 240 | open_file = lambda: open(os.path.join(outdir, 'sn_ranks.json'), 'w+') \ 241 | if config.save_ranks else open("/dev/null") 242 | with open_file() as f, tqdm(total=600) as pbar: 243 | if config.save_ranks: 244 | f.write('[') 245 | 246 | # define thread function 247 | def calculate(triple): 248 | s, r, o = triple 249 | # subj 250 | invalid_concepts = [s] + synchronized(lambda: sampler.invalid_concepts(s, r, o, True)) 251 | subj_scores, subj_ranking = calculate_scores_sn(s, r, o, True, invalid_concepts, session, model, config.batch_size) 252 | srank = subj_ranking[s] if s in subj_ranking else -1 253 | 254 | # obj 255 | invalid_concepts = [o] + synchronized(lambda: sampler.invalid_concepts(s, r, o, False)) 256 | obj_scores, obj_ranking = calculate_scores_sn(s, r, o, False, invalid_concepts, session, model, config.batch_size) 257 | orank = obj_ranking[o] if o in obj_ranking else -1 258 | json_string = json.dumps([str(srank), 259 | str(orank), 260 | str(obj_scores[o]), 261 | id2cui[s], id2cui[r], id2cui[o]]) 262 | return srank, orank, json_string 263 | 264 | # define save function 265 | def save(future): 266 | global ranks 267 | global first 268 | 269 | srank, orank, json_string = future.result() 270 | ranks += [srank, orank] 271 | if config.save_ranks: 272 | if not first: 273 | f.write(',') 274 | first = False 275 | f.write(json_string) 276 | npr = np.asarray(ranks, dtype=np.float) 277 | pbar.set_description('Srank: %5d. Orank: %5d. MRR: %.4f. H@10: %.4f' % 278 | (srank, orank, mrr(npr), hits_at_10(npr))) 279 | pbar.update() 280 | 281 | with np.load(os.path.join(config.data_dir, 'semnet', 'test.npz')) as sn_npz: 282 | subj, rel, obj = sn_npz['subj'], sn_npz['rel'], sn_npz['obj'] 283 | parallel_stream(zip(subj, rel, obj), 284 | parallelizable_fn=calculate, 285 | future_consumer=save) 286 | 287 | # finish json 288 | if config.save_ranks: 289 | f.write(']') 290 | 291 | ranks_np = np.asarray(ranks, dtype=np.float) 292 | print('Mean Reciprocal Rank: %.4f' % (mrr(ranks_np))) 293 | print('Mean Rank: %.2f' % (mr(ranks_np))) 294 | print('Hits @ 10: %.4f' % (hits_at_10(ranks_np))) 295 | 296 | 297 | def calculate_scores_sn(subj, rel, obj, replace_subject, concept_ids, session, model, batch_size): 298 | num_batches = int(math.ceil(float(len(concept_ids))/batch_size)) 299 | subjects = np.full(batch_size, subj, dtype=np.int32) 300 | relations = np.full(batch_size, rel, dtype=np.int32) 301 | objects = np.full(batch_size, obj, dtype=np.int32) 302 | scores = {} 303 | 304 | for b in xrange(num_batches): 305 | concepts = concept_ids[b*batch_size:(b+1)*batch_size] 306 | feed_dict = {model.smoothing_placeholders['sn_pos_subj']: subjects, 307 | model.smoothing_placeholders['sn_relations']: relations, 308 | model.smoothing_placeholders['sn_pos_obj']: objects} 309 | 310 | # pad concepts if necessary 311 | if len(concepts) < batch_size: 312 | concepts = np.pad(concepts, (0, batch_size - len(concepts)), mode='constant', constant_values=0) 313 | 314 | # replace subj/obj in feed dict 315 | if replace_subject: 316 | feed_dict[model.smoothing_placeholders['sn_pos_subj']] = concepts 317 | else: 318 | feed_dict[model.smoothing_placeholders['sn_pos_obj']] = concepts 319 | 320 | # calculate energies 321 | energies = session.run(model.sn_pos_energy, feed_dict) 322 | 323 | # store scores 324 | for i, cid in enumerate(concept_ids[b*batch_size:(b+1)*batch_size]): 325 | scores[cid] = energies[i] 326 | 327 | ranking = sorted(scores.keys(), key=lambda k: scores[k]) 328 | 329 | rank_map = {} 330 | prev_rank = 0 331 | prev_score = -1 332 | total = 1 333 | for c in ranking: 334 | # if c has a lower score than prev 335 | if scores[c] > prev_score: 336 | # increment the rank 337 | prev_rank = total 338 | # update score 339 | prev_score = scores[c] 340 | total += 1 341 | rank_map[c] = prev_rank 342 | 343 | return scores, rank_map 344 | 345 | 346 | def mrr(ranks_np): 347 | return float(np.mean(1. / ranks_np)) 348 | 349 | 350 | def mr(ranks_np): 351 | return float(np.mean(ranks_np)) 352 | 353 | 354 | def hits_at_10(ranks_np): 355 | return float(len(ranks_np[ranks_np <= 10])) / len(ranks_np) 356 | 357 | 358 | def calculate_ranking_evals(): 359 | outdir = os.path.join(config.eval_dir, config.run_name) 360 | 361 | ppa = float(str(next(open(os.path.join(outdir, 'ppa.txt')))).strip()) 362 | print('PPA: %.4f' % ppa) 363 | 364 | ranks = [] 365 | for fname in os.listdir(outdir): 366 | full_path = os.path.join(outdir, fname) 367 | if os.path.isfile(full_path) and fname.startswith('ranks_'): 368 | for fields in json.load(open(full_path)): 369 | ranks += [fields[0], fields[1]] 370 | ranks_np = np.asarray(ranks, dtype=np.float) 371 | mrr_ = mrr(ranks_np) 372 | mr_ = mr(ranks_np) 373 | hat10 = hits_at_10(ranks_np) 374 | print('MRR: %.4f' % mrr_) 375 | print('MR: %.2f' % mr_) 376 | print('H@10: %.4f' % hat10) 377 | 378 | with open(os.path.join(outdir, 'ranking_evals.tsv'), 'w+') as f: 379 | f.write('ppa\t%f\n' % ppa) 380 | f.write('mrr\t%f\n' % mrr_) 381 | f.write('mr\t%f\n' % mr_) 382 | f.write('h@10\t%f' % hat10) 383 | 384 | 385 | def calculate_ranking_evals_per_rel(): 386 | outdir = os.path.join(config.eval_dir, config.run_name) 387 | 388 | ranks = defaultdict(list) 389 | for fname in os.listdir(outdir): 390 | full_path = os.path.join(outdir, fname) 391 | if os.path.isfile(full_path) and fname.startswith('ranks_'): 392 | for fields in json.load(open(full_path)): 393 | ranks[fields[4]] += [fields[0], fields[1]] 394 | print('Gathered %d rankings' % len(ranks)) 395 | 396 | for rel, rl in ranks.iteritems(): 397 | ranks_np = np.asarray(rl, dtype=np.float) 398 | ranks[rel] = [mrr(ranks_np), mr(ranks_np), hits_at_10(ranks_np), len(rl)] 399 | 400 | relations = ranks.keys() 401 | relations.sort(key=lambda x: ranks[x][1]) 402 | 403 | with open(os.path.join(outdir, 'ranking_evals_per_rel.tsv'), 'w+') as f: 404 | for rel in relations: 405 | [mrr_, mr_, hat10, count] = ranks[rel] 406 | print('----------%s(%d)----------' % (rel, count)) 407 | print('MRR: %.4f' % mrr_) 408 | print('MR: %.2f' % mr_) 409 | print('H@10: %.4f' % hat10) 410 | 411 | f.write('%s (%d)\n' % (rel, count)) 412 | f.write('mrr\t%f\n' % mrr_) 413 | f.write('mr\t%f\n' % mr_) 414 | f.write('h@10\t%f\n\n' % hat10) 415 | 416 | 417 | def split_ranking_evals(): 418 | outdir = os.path.join(config.eval_dir, config.run_name) 419 | 420 | ppa = float(str(next(open(os.path.join(outdir, 'ppa.txt')))).strip()) 421 | print('PPA: %.4f' % ppa) 422 | 423 | s_ranks = [] 424 | o_ranks = [] 425 | for fname in os.listdir(outdir): 426 | full_path = os.path.join(outdir, fname) 427 | if os.path.isfile(full_path) and fname.startswith('ranks_'): 428 | for fields in json.load(open(full_path)): 429 | s_ranks += [fields[0]] 430 | o_ranks += [fields[1]] 431 | 432 | def report(ranks): 433 | ranks_np = np.asarray(ranks, dtype=np.float) 434 | mrr_ = mrr(ranks_np) 435 | mr_ = mr(ranks_np) 436 | hat10 = hits_at_10(ranks_np) 437 | print('MRR: %.4f' % mrr_) 438 | print('MR: %.2f' % mr_) 439 | print('H@10: %.4f' % hat10) 440 | report(s_ranks) 441 | report(o_ranks) 442 | 443 | 444 | def fix_json(): 445 | config = Config.flags 446 | outdir = os.path.join(config.eval_dir, config.run_name) 447 | 448 | ppa = float(str(next(open(os.path.join(outdir, 'ppa.txt')))).strip()) 449 | print('PPA: %.4f' % ppa) 450 | 451 | ranks = [] 452 | for fname in os.listdir(outdir): 453 | full_path = os.path.join(outdir, fname) 454 | if os.path.isfile(full_path) and fname.startswith('ranks_'): 455 | json_string = '[' + str(next(open(full_path))).strip() 456 | with open(full_path, 'w+') as f: 457 | f.write(json_string) 458 | for fields in json.load(open(full_path)): 459 | ranks += [fields[0], fields[2]] 460 | ranks_np = np.asarray(ranks, dtype=np.float) 461 | mrr_ = mrr(ranks_np) 462 | mr_ = mr(ranks_np) 463 | hat10 = hits_at_10(ranks_np) 464 | print('MRR: %.4f' % mrr_) 465 | print('MR: %.2f' % mr_) 466 | print('H@10: %.4f' % hat10) 467 | 468 | with open(os.path.join(outdir, 'ranking_evals.tsv'), 'w+') as f: 469 | f.write('ppa\t%f\n' % ppa) 470 | f.write('mrr\t%f\n' % mrr_) 471 | f.write('mr\t%f\n' % mr_) 472 | f.write('h@10\t%f' % hat10) 473 | 474 | 475 | if __name__ == "__main__": 476 | if config.eval_mode == "save": 477 | save_ranks() 478 | elif config.eval_mode == "save-sn": 479 | save_ranks_sn() 480 | elif config.eval_mode == "calc": 481 | calculate_ranking_evals() 482 | elif config.eval_mode == "calc-rel": 483 | calculate_ranking_evals_per_rel() 484 | else: 485 | raise Exception('Unrecognized eval_mode: %s' % config.eval_mode) 486 | -------------------------------------------------------------------------------- /python/eukg/tf_util/ModelSaver.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Policy(Enum): 5 | EPOCH = 1 6 | TIMED = 2 7 | 8 | 9 | class ModelSaver: 10 | def __init__(self, tf_saver, session, model_file, policy): 11 | self.tf_saver = tf_saver 12 | self.session = session 13 | self.model_file = model_file 14 | self.policy = policy 15 | 16 | def save(self, global_step, policy, **kwargs): 17 | if self.save_condition(policy, **kwargs): 18 | print("Saving model to %s at step %d" % (self.model_file, global_step)) 19 | self.tf_saver.save(self.session, self.model_file, global_step=global_step) 20 | 21 | def save_condition(self, policy, **kwargs): 22 | return policy == self.policy 23 | 24 | 25 | class TimedSaver(ModelSaver): 26 | def __init__(self, tf_saver, session, model_file, seconds_per_save): 27 | ModelSaver.__init__(self, tf_saver, session, model_file, Policy.TIMED) 28 | self.seconds_per_save = seconds_per_save 29 | 30 | def save_condition(self, policy, **kwargs): 31 | return policy == self.policy and kwargs['seconds_since_last_save'] > self.seconds_per_save 32 | 33 | 34 | class EpochSaver(ModelSaver): 35 | def __init__(self, tf_saver, session, model_file, save_every_x_epochs=1): 36 | ModelSaver.__init__(self, tf_saver, session, model_file, Policy.EPOCH) 37 | self.save_every_x_epochs = save_every_x_epochs 38 | 39 | def save_condition(self, policy, **kwargs): 40 | return 'epoch' in kwargs and kwargs['epoch'] % self.save_every_x_epochs == 0 41 | -------------------------------------------------------------------------------- /python/eukg/tf_util/Trainable.py: -------------------------------------------------------------------------------- 1 | class Trainable: 2 | def __init__(self): 3 | pass 4 | 5 | def fetches(self, is_training, verbose=False): 6 | """ 7 | Returns a list of fetches to be passed to session.run() 8 | :param is_training: flag indicating if the model is training/testing 9 | :param verbose: flag indicating if a more verbose set of variables should be fetched (usually for debugging or 10 | progress updates) 11 | :return: a list of fetches to be passed to session.run() 12 | """ 13 | raise NotImplementedError("to be implemented by subclass") 14 | 15 | def prepare_feed_dict(self, batch, is_training, **kwargs): 16 | """ 17 | Turns a list of tensors into a dict of model_parameter: tensor 18 | :param batch: list of data tensors to be passed to the model 19 | :param is_training: flag indicating if the model is in training or testing mode 20 | :param kwargs: optional other params 21 | :return: the feed dict to be passed to session.run() 22 | """ 23 | raise NotImplementedError("to be implemented by subclass") 24 | 25 | def progress_update(self, batch, fetched, **kwargs): 26 | """ 27 | Prepares a progress update 28 | :param batch: batch data passed to prepare_feed_dict() 29 | :param fetched: tensors returned by session.run() 30 | :param kwargs: optional other params 31 | :return: String progress update 32 | """ 33 | raise NotImplementedError("to be implemented by subclass") 34 | 35 | def data_provider(self, config, is_training, **kwargs): 36 | """ 37 | Provides access to a data generator that generates batches of data 38 | :param config: dict of config flags to values (usually tf.flags.FLAGS) 39 | :param is_training: flag indicating if the model is in training or testing mode 40 | :param kwargs: optional other params 41 | :return: A generator that generates batches of data 42 | """ 43 | raise NotImplementedError("to be implemented by subclass") 44 | -------------------------------------------------------------------------------- /python/eukg/tf_util/Trainer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | from tqdm import tqdm 4 | import time 5 | 6 | from Trainable import Trainable 7 | from ModelSaver import ModelSaver, Policy 8 | 9 | 10 | def train(config, session, model, saver, 11 | train_post_step=None, 12 | train_post_epoch=None, 13 | val_post_step=None, 14 | val_post_epoch=None, 15 | global_step=0, 16 | max_batches_per_epoch=None): 17 | """ 18 | Trains and validates 19 | :param config: config 20 | :type config: dict 21 | :param session: tf session 22 | :param model: Trainable 23 | :type model: Trainable 24 | :param train_post_step: functions to execute after each train step 25 | :param saver: ModelSaver 26 | :type saver: ModelSaver 27 | :param train_post_step: functions to execute after each train step 28 | :param train_post_epoch: functions to execute after each train epoch 29 | :param val_post_step: functions to execute after each validation step 30 | :param val_post_epoch: functions to execute after each validation epoch 31 | :param global_step: optional argument to pass a non-zero global step 32 | :param max_batches_per_epoch: optional argument to limit a training epoch to some number of batches 33 | """ 34 | # init summary directories and summary writers 35 | if not os.path.exists(os.path.join(config['summaries_dir'], 'train')): 36 | os.makedirs(os.path.join(config['summaries_dir'], 'train')) 37 | train_summary_writer = tf.summary.FileWriter(os.path.join(config['summaries_dir'], 'train')) 38 | if not os.path.exists(os.path.join(config['summaries_dir'], 'val')): 39 | os.makedirs(os.path.join(config['summaries_dir'], 'val')) 40 | val_summary_writer = tf.summary.FileWriter(os.path.join(config['summaries_dir'], 'val')) 41 | 42 | # determine if data_provider should be bounded 43 | if max_batches_per_epoch is None: 44 | print('Each training epoch will pass through the full dataset.') 45 | train_data_provider = lambda: model.data_provider(config, True) 46 | else: 47 | print('Each training epoch will process at most %d batches.' % max_batches_per_epoch) 48 | global it 49 | it = iter(model.data_provider(config, True)) 50 | 51 | def bounded_train_data_provider(): 52 | global it 53 | for b in xrange(max_batches_per_epoch): 54 | try: 55 | yield it.next() 56 | except StopIteration: 57 | print('WARNING: reached the end of training data. Looping over it again.') 58 | it = iter(model.data_provider(config, True)) 59 | yield it.next() 60 | train_data_provider = bounded_train_data_provider 61 | 62 | # train 63 | for ep in xrange(config['num_epochs']): 64 | print('\nBegin training epoch %d' % ep) 65 | global_step = train_epoch(config, session, model, train_summary_writer, train_post_step, global_step, 66 | train_data_provider, saver) 67 | saver.save(global_step, Policy.EPOCH, epoch=ep) 68 | if train_post_epoch: 69 | for post_epoch in train_post_epoch: 70 | post_epoch() 71 | 72 | print('\nDone epoch %d. Begin validation' % ep) 73 | validate(config, session, model, val_summary_writer, val_post_step, global_step) 74 | if val_post_epoch: 75 | for post_epoch in val_post_epoch: 76 | post_epoch() 77 | 78 | 79 | def train_epoch(config, session, model, summary_writer, post_step, global_step, batch_generator, saver): 80 | console_update_interval = config['progress_update_interval'] 81 | pbar = tqdm(total=console_update_interval) 82 | start = time.time() 83 | for b, batch in enumerate(batch_generator()): 84 | verbose_batch = b > 0 and b % console_update_interval == 0 85 | 86 | # training batch 87 | fetched = session.run(model.fetches(True, verbose=verbose_batch), model.prepare_feed_dict(batch, True)) 88 | 89 | # update tensorboard summary 90 | summary_writer.add_summary(fetched[0], global_step) 91 | global_step += 1 92 | 93 | # perform post steps 94 | if post_step is not None: 95 | for step in post_step: 96 | step(global_step, batch) 97 | 98 | # udpate progress bar 99 | pbar.set_description("Training Batch: %d. Loss: %.4f" % (b, fetched[1])) 100 | pbar.update() 101 | 102 | if verbose_batch: 103 | pbar.close() 104 | model.progress_update(batch, fetched) 105 | pbar = tqdm(total=console_update_interval) 106 | 107 | saver.save(global_step, Policy.TIMED, seconds_since_last_save=(time.time() - start)) 108 | pbar.close() 109 | 110 | return global_step 111 | 112 | 113 | def validate(config, session, model, summary_writer, post_step, global_step): 114 | console_update_interval = config['val_progress_update_interval'] 115 | pbar = tqdm(total=console_update_interval) 116 | # validation epoch 117 | for b, batch in enumerate(model.data_provider(config, False)): 118 | verbose_batch = b > 0 and b % console_update_interval == 0 119 | 120 | fetched = session.run(model.fetches(False, verbose=verbose_batch), model.prepare_feed_dict(batch, False)) 121 | 122 | # update tensorboard summary 123 | summary_writer.add_summary(fetched[0], global_step) 124 | global_step += 1 125 | 126 | # perform post steps 127 | if post_step is not None: 128 | for step in post_step: 129 | step(b, batch) 130 | 131 | # udpate progress bar 132 | pbar.set_description("Validation Batch: %d. Loss: %.4f" % (b, fetched[1])) 133 | pbar.update() 134 | 135 | if verbose_batch: 136 | pbar.close() 137 | model.progress_update(batch, fetched) 138 | pbar = tqdm(total=console_update_interval) 139 | pbar.close() 140 | 141 | return global_step 142 | -------------------------------------------------------------------------------- /python/eukg/tf_util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-mal/umls-embeddings/1ac29b15f5bedf8edbd2c90d0ccdae094be03518/python/eukg/tf_util/__init__.py -------------------------------------------------------------------------------- /python/eukg/threading_util.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import threading 3 | from concurrent.futures import ThreadPoolExecutor 4 | 5 | 6 | condition = threading.Condition() 7 | def synchronized(func): 8 | condition.acquire() 9 | r = func() 10 | condition.notify() 11 | condition.release() 12 | return r 13 | 14 | 15 | def parallel_stream(iterable, parallelizable_fn, future_consumer=None): 16 | """ 17 | Executes parallelizable_fn on each element of iterable in parallel. When each thread finishes, its future is passed 18 | to future_consumer if provided. 19 | :param iterable: an iterable of objects to be processed by parallelizable_fn 20 | :param parallelizable_fn: a function that operates on elements of iterable 21 | :param future_consumer: optional consumer of objects returned by parallelizable_fn 22 | :return: void 23 | """ 24 | if future_consumer is None: 25 | future_consumer = lambda f: f.result() 26 | 27 | num_threads = multiprocessing.cpu_count() 28 | executor = ThreadPoolExecutor(max_workers=num_threads) 29 | futures = [] 30 | for tup in iterable: 31 | # all cpus are in use, wait until at least 1 thread finishes before spawning new threads 32 | if len(futures) >= num_threads: 33 | full_queue = True 34 | while full_queue: 35 | incomplete_futures = [] 36 | for fut in futures: 37 | if fut.done(): 38 | future_consumer(fut) 39 | full_queue = False 40 | else: 41 | incomplete_futures.append(fut) 42 | futures = incomplete_futures 43 | # spawn new thread 44 | futures += [executor.submit(parallelizable_fn, tup)] 45 | 46 | # ensure all threads have finished 47 | for fut in futures: 48 | future_consumer(fut) -------------------------------------------------------------------------------- /python/eukg/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import numpy as np 4 | import math 5 | 6 | from tf_util import Trainer, ModelSaver 7 | 8 | from emb import EmbeddingModel 9 | from gan import Generator, train_gan, Discriminator 10 | import Config 11 | from data import data_util, DataGenerator 12 | 13 | 14 | def train(): 15 | config = Config.flags 16 | 17 | if config.mode == 'gan': 18 | train_gan.train() 19 | exit() 20 | 21 | # init model dir 22 | config.model_dir = os.path.join(config.model_dir, config.model, config.run_name) 23 | if not os.path.exists(config.model_dir): 24 | os.makedirs(config.model_dir) 25 | 26 | # init summaries dir 27 | config.summaries_dir = os.path.join(config.summaries_dir, config.run_name) 28 | if not os.path.exists(config.summaries_dir): 29 | os.makedirs(config.summaries_dir) 30 | 31 | # save the config 32 | data_util.save_config(config.model_dir, config) 33 | 34 | # load data 35 | cui2id, data, train_idx, val_idx = data_util.load_metathesaurus_data(config.data_dir, config.val_proportion) 36 | config.val_progress_update_interval = int(math.floor(float(len(val_idx)) / config.batch_size)) 37 | config.batches_per_epoch = int(math.floor(float(len(train_idx)) / config.batch_size)) 38 | if not config.no_semantic_network: 39 | type2cuis = data_util.load_semantic_network_data(config.data_dir, data) 40 | else: 41 | type2cuis = None 42 | data_generator = DataGenerator.DataGenerator(data, train_idx, val_idx, config, type2cuis) 43 | 44 | # config map 45 | config_map = config.flag_values_dict() 46 | config_map['data'] = data 47 | config_map['train_idx'] = train_idx 48 | config_map['val_idx'] = val_idx 49 | if not config_map['no_semantic_network']: 50 | config_map['type2cuis'] = type2cuis 51 | 52 | with tf.Graph().as_default(), tf.Session() as session: 53 | # init model 54 | with tf.variable_scope(config.run_name): 55 | model = init_model(config, data_generator) 56 | # session.run(model.train_init_op) 57 | 58 | tf.global_variables_initializer().run() 59 | tf.local_variables_initializer().run() 60 | 61 | # init saver 62 | tf_saver = tf.train.Saver(max_to_keep=10) 63 | saver = init_saver(config, tf_saver, session) 64 | 65 | # load model 66 | global_step = 0 67 | if config.load: 68 | ckpt = tf.train.latest_checkpoint(config.model_dir) 69 | print('Loading checkpoint: %s' % ckpt) 70 | global_step = int(os.path.split(ckpt)[-1].split('-')[-1]) 71 | tf_saver.restore(session, ckpt) 72 | 73 | # finalize graph 74 | tf.get_default_graph().finalize() 75 | 76 | # define normalization step 77 | def find_unique(tensor_list): 78 | if max([len(t.shape) for t in tensor_list[:10]]) == 1: 79 | return np.unique(np.concatenate(tensor_list[:10])) 80 | else: 81 | return np.unique(np.concatenate([t.flatten() for t in tensor_list[:10]])) 82 | normalize = lambda _, batch: session.run(model.norm_op, 83 | {model.ids_to_update: find_unique(batch)}) 84 | 85 | # define streaming_accuracy reset per epoch 86 | print('local variables that will be reinitialized every epoch: %s' % tf.local_variables()) 87 | reset_local_vars = lambda: session.run(model.reset_streaming_metrics_op) 88 | 89 | # train 90 | Trainer.train(config_map, session, model, saver, 91 | train_post_step=[normalize], 92 | train_post_epoch=[reset_local_vars], 93 | val_post_epoch=[reset_local_vars], 94 | global_step=global_step, 95 | max_batches_per_epoch=config_map['max_batches_per_epoch']) 96 | 97 | 98 | def init_model(config, data_generator): 99 | print('Initializing %s embedding model in %s mode...' % (config.model, config.mode)) 100 | npz = np.load(config.embedding_file) if config.load_embeddings else None 101 | 102 | if config.model == 'transe': 103 | em = EmbeddingModel.TransE(config, embeddings_dict=npz) 104 | elif config.model == 'transd': 105 | config.embedding_size = config.embedding_size / 2 106 | em = EmbeddingModel.TransD(config, embeddings_dict=npz) 107 | elif config.model == 'distmult': 108 | em = EmbeddingModel.DistMult(config, embeddings_dict=npz) 109 | else: 110 | raise ValueError('Unrecognized model type: %s' % config.model) 111 | 112 | if config.mode == 'disc': 113 | model = Discriminator.BaseModel(config, em, data_generator) 114 | elif config.mode == 'gen': 115 | model = Generator.Generator(config, em, data_generator) 116 | else: 117 | raise ValueError('Unrecognized mode: %s' % config.mode) 118 | 119 | if npz: 120 | # noinspection PyUnresolvedReferences 121 | npz.close() 122 | 123 | model.build() 124 | print('Built model.') 125 | print('use semnet: %s' % model.use_semantic_network) 126 | return model 127 | 128 | 129 | def init_saver(config, tf_saver, session): 130 | model_file = os.path.join(config.model_dir, config.model) 131 | if config.save_strategy == 'timed': 132 | print('Models will be saved every %d seconds' % config.save_interval) 133 | return ModelSaver.TimedSaver(tf_saver, session, model_file, config.save_interval) 134 | elif config.save_strategy == 'epoch': 135 | print('Models will be saved every training epoch') 136 | return ModelSaver.EpochSaver(tf_saver, session, model_file) 137 | else: 138 | raise ValueError('Unrecognized save strategy: %s' % config.save_strategy) 139 | 140 | 141 | if __name__ == "__main__": 142 | train() 143 | --------------------------------------------------------------------------------