├── .gitignore ├── LICENSE ├── README.md ├── data ├── prepare-all.sh ├── prepare-cbt-data.sh └── prepare-embedding.sh ├── dataset ├── __init__.py ├── cbt.py ├── data_file_pairs.py ├── rc_dataset.py └── squad.py ├── main.py ├── models ├── __init__.py ├── attention_over_attention_reader.py ├── attention_sum_reader.py ├── model_data_pairs.py ├── nlp_base.py ├── r_net.py └── rc_base.py ├── requirements.txt ├── test ├── dataset_test.py └── notebook │ ├── test_aoa.ipynb │ └── test_as_reader.ipynb ├── utils ├── __init__.py └── log.py └── weights ├── AS-reader ├── best-CBT-CN │ ├── args.json │ └── result.json ├── best-CBT-NE │ ├── args.json │ └── result.json └── best-best-CBT-NE │ ├── args.json │ └── result.json └── AoA-reader ├── best-CBT-CN ├── args.json └── result.json ├── best-CBT-NE ├── args.json └── result.json └── best-best-CBT-NE ├── args.json └── result.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *,cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | ### JetBrains template 97 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 98 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 99 | 100 | # Gradle: 101 | .idea/**/gradle.xml 102 | .idea/**/libraries 103 | 104 | # Mongo Explorer plugin: 105 | .idea/**/mongoSettings.xml 106 | 107 | ## File-based project format: 108 | *.iws 109 | 110 | ## Plugin-specific files: 111 | 112 | # IntelliJ 113 | /out/ 114 | 115 | # mpeltonen/sbt-idea plugin 116 | .idea_modules/ 117 | 118 | # JIRA plugin 119 | atlassian-ide-plugin.xml 120 | 121 | # Crashlytics plugin (for Android Studio and IntelliJ) 122 | com_crashlytics_export_strings.xml 123 | crashlytics.properties 124 | crashlytics-build.properties 125 | fabric.properties 126 | logs/ 127 | .idea/ 128 | 129 | ### Tensorflow checkpoint 130 | checkpoint 131 | *.meta 132 | *.index 133 | *.data-00000-of-00001 134 | data/CBTest/ 135 | data/glove.6B/ 136 | data/SQuAD/ 137 | weights/args\.json 138 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | {one line to give the program's name and a brief idea of what it does.} 635 | Copyright (C) {year} {name of author} 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | {project} Copyright (C) {year} {fullname} 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reading Comprehension Experiments 2 | 3 | ## About 4 | 5 | This is the tensorflow version implementation/reproduce of some reading comprehension models in some reading comprehension datasets including the following: 6 | 7 | Models: 8 | 9 | - Attention Sum Reader model as presented in "Text Comprehension with the Attention Sum Reader Network" (ACL2016) available at [http://arxiv.org/abs/1603.01547](http://arxiv.org/abs/1603.01547). 10 | 11 | ![](http://7xpqrs.com1.z0.glb.clouddn.com/FjmgZjrmBJ5w8WdDU2v9BMRj21r8) 12 | 13 | - Attention over Attention Reader model as presented in "Attention-over-Attention Neural Networks for Reading Comprehension" (arXiv2016.7) available at https://arxiv.org/abs/1607.04423. 14 | 15 | ![](http://7xpqrs.com1.z0.glb.clouddn.com/FupB-rvxCvGvPTwa8UC4u3QUgqKI) 16 | 17 | Datasets: 18 | 19 | - CBT, Children’s Book Test.http://lanl.arxiv.org/pdf/1506.03340.pdf 20 | 21 | ## Start To Use 22 | 23 | #### 1.Clone the code 24 | 25 | ```shell 26 | git clone https://github.com/zhanghaoyu1993/RC-experiments.git 27 | ``` 28 | 29 | 30 | 31 | #### 2.Get needed data 32 | 33 | - Download and extract the dataset used in this repo. 34 | 35 | ```shell 36 | cd data 37 | ./prepare-all.sh 38 | ``` 39 | 40 | 41 | 42 | #### 3.Environment Preparation 43 | 44 | - Python-64bit >= v3.5. 45 | - Install require libraries using the following command. 46 | 47 | ```shell 48 | pip install -r requirements.txt 49 | ``` 50 | 51 | - Install tensorflow >= 1.1.0. 52 | 53 | ```shell 54 | pip install tensorflow-gpu --upgrade 55 | ``` 56 | 57 | - Install nltk punkt for tokenizer. 58 | 59 | ```shell 60 | python -m nltk.downloader punkt 61 | ``` 62 | 63 | 64 | 65 | #### 4.Set model, dataset and other command parameters 66 | 67 | - What is the entrance of the program? 68 | 69 | The main.py file in root directory. 70 | 71 | - How can I specify a model in command line? 72 | 73 | Type a command like above, the *model_class* is the class name of model, usually named in cambak-style: 74 | 75 | ```shell 76 | python main.py [model_class] 77 | ``` 78 | 79 | For example, if you want to use AttentionSumReader: 80 | 81 | ```shell 82 | python main.py AttentionSumReader 83 | ``` 84 | 85 | - How can I specify the dataset? 86 | 87 | Type a command like above, the *dataset_class* is the class name of dataset: 88 | 89 | ```shell 90 | python main.py [model_class] --dataset [dataset_class] 91 | ``` 92 | 93 | For example, if you want to use CBT: 94 | 95 | ```shell 96 | python main.py [model_class] --dataset CBT 97 | ``` 98 | 99 | You don't need to specify the data_root and train valid test file name in most cases, just specify the dataset. 100 | 101 | - How can I know all the parameters? 102 | 103 | The program use [argparse](https://docs.python.org/3/library/argparse.html) to deal with parameters, you can type the following command to get help: 104 | 105 | ```shell 106 | python main.py --help 107 | ``` 108 | 109 | or: 110 | 111 | ```shell 112 | python main.py -h 113 | ``` 114 | 115 | - The command parameters is so long! 116 | 117 | The parameters will be stored into a file named args.json when executed, so next time you can type the following simplified command: 118 | 119 | ```shell 120 | python main.py [model_class] --args_file [args.json] 121 | ``` 122 | 123 | 124 | 125 | #### 5.Train and test the model 126 | 127 | First, modify the parameters in the args.json. 128 | 129 | You can now train and test the model by entering the following commands. The params in [] should be determined by the real situation. 130 | 131 | - Train: 132 | 133 | ```shell 134 | python main.py [model_class] --args_file [args.json] --train 1 --test 0 135 | ``` 136 | 137 | After train, the parameters are stored in `weight_path/args.json` and the model checkpoints are stored in `weight_path`. 138 | 139 | - Test: 140 | 141 | ```shell 142 | python main.py [model_class] --args_file [args.json] --train 0 --test 1 143 | ``` 144 | 145 | After test, the performance of model are stored in `weight_path/result.json`. 146 | 147 | 148 | 149 | #### 6.model performance 150 | 151 | All the trained results and corresponding config params are saved in sub directories of weight_path(by default the `weight` folder) named `args.json` and `result.json`. 152 | 153 | You should know that the implementation of some models are **slightly different** from the original, but the basic ideas are same, so the results are for reference only. 154 | 155 | The best results of implemented models are listed below: 156 | 157 | - best result **we achieve**(with little hyper-parameter tune in single model) 158 | - best result listed in original paper(in the brackets) 159 | 160 | | | CBT-NE | CBT-CN | 161 | | ---------- | ----------- | ----------- | 162 | | AS-Reader | 69.88(68.6) | 65.0(63.4) | 163 | | AoA-Reader | 71.0(72.0) | 68.12(69.4) | 164 | 165 | 166 | 167 | #### 7.FAQ 168 | 169 | - How do I use args_file argument in the shell? 170 | 171 | Once you enter a command in the shell(maybe a long one), the config will be stored in weight_path/args.json where weight_path is defined by another argument, after the command execute you can use --args.json to simplify the following command: 172 | ```shell 173 | python main.py [model_class] --args_file [args.json] 174 | ``` 175 | And the priorities of arguments typed in the command line is higher than those stored in args.json, so you can change some arguments. -------------------------------------------------------------------------------- /data/prepare-all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ./prepare-embedding.sh 4 | ./prepare-cbt-data.sh -------------------------------------------------------------------------------- /data/prepare-cbt-data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # prepares the Children's Book Test datasets 3 | 4 | # get CBT data 5 | wget http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz 6 | 7 | # unpack all files 8 | tar -zxvf CBTest.tgz 9 | 10 | -------------------------------------------------------------------------------- /data/prepare-embedding.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # get glove embedding 4 | wget http://nlp.stanford.edu/data/glove.6B.zip 5 | 6 | # unpack all files 7 | unzip glove.6B.zip -d glove.6B 8 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .cbt import CBT 2 | from .squad import SQuAD 3 | 4 | CBT_NE = CBT 5 | CBT_CN = CBT 6 | 7 | __all__ = ["CBT_NE", "CBT_CN", "SQuAD"] 8 | -------------------------------------------------------------------------------- /dataset/cbt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import reduce 3 | 4 | import numpy as np 5 | from tensorflow.contrib.keras.python.keras.preprocessing.sequence import pad_sequences 6 | from tensorflow.python.platform import gfile 7 | from tensorflow.python.platform.gfile import FastGFile 8 | 9 | from dataset.rc_dataset import RCDataset 10 | from utils.log import logger 11 | 12 | 13 | class CBT(RCDataset): 14 | def __init__(self, args): 15 | self.A_len = 10 16 | super().__init__(args) 17 | 18 | def next_batch_feed_dict_by_dataset(self, dataset, _slice, samples): 19 | data = { 20 | "questions_bt:0": dataset[0][_slice], 21 | "documents_bt:0": dataset[1][_slice], 22 | "candidates_bi:0": dataset[2][_slice], 23 | "y_true_bi:0": dataset[3][_slice] 24 | } 25 | return data, samples 26 | 27 | def cbt_data_to_token_ids(self, data_file, target_file, vocab_file, max_count=None): 28 | """ 29 | 22 lines for one sample. 30 | first 20 lines:documents with line number in the front. 31 | 21st line:line-number question\tAnswer\t\tCandidate1|...|Candidate10. 32 | 22nd line:blank. 33 | """ 34 | if gfile.Exists(target_file): 35 | return 36 | logger("Tokenizing data in {}".format(data_file)) 37 | word_dict = self.load_vocab(vocab_file) 38 | counter = 0 39 | 40 | with gfile.FastGFile(data_file) as f: 41 | with gfile.FastGFile(target_file, mode="wb") as tokens_file: 42 | for line in f: 43 | counter += 1 44 | if counter % 100000 == 0: 45 | logger("Tokenizing line %d" % counter) 46 | if max_count and counter > max_count: 47 | break 48 | if counter % 22 == 21: 49 | q, a, _, A = line.split("\t") 50 | token_ids_q = self.sentence_to_token_ids(q, word_dict)[1:] 51 | token_ids_A = [word_dict.get(a.lower(), self.UNK_ID) for a in A.rstrip("\n").split("|")] 52 | tokens_file.write(" ".join([str(tok) for tok in token_ids_q]) + "\t" 53 | + str(word_dict.get(a.lower(), self.UNK_ID)) + "\t" 54 | + "|".join([str(tok) for tok in token_ids_A]) + "\n") 55 | else: 56 | token_ids = self.sentence_to_token_ids(line, word_dict) 57 | token_ids = token_ids[1:] if token_ids else token_ids 58 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n") 59 | 60 | def prepare_data(self, data_dir, train_file, valid_file, test_file, max_vocab_num, output_dir=""): 61 | """ 62 | build vocabulary and translate CBT data to id format. 63 | """ 64 | if not gfile.Exists(os.path.join(data_dir, output_dir)): 65 | os.mkdir(os.path.join(data_dir, output_dir)) 66 | os_train_file = os.path.join(data_dir, train_file) 67 | os_valid_file = os.path.join(data_dir, valid_file) 68 | os_test_file = os.path.join(data_dir, test_file) 69 | idx_train_file = os.path.join(data_dir, output_dir, train_file + ".%d.idx" % max_vocab_num) 70 | idx_valid_file = os.path.join(data_dir, output_dir, valid_file + ".%d.idx" % max_vocab_num) 71 | idx_test_file = os.path.join(data_dir, output_dir, test_file + ".%d.idx" % max_vocab_num) 72 | vocab_file = os.path.join(data_dir, output_dir, "vocab.%d" % max_vocab_num) 73 | 74 | if not gfile.Exists(vocab_file): 75 | word_counter = self.gen_vocab(os_train_file, max_count=self.args.max_count) 76 | word_counter = self.gen_vocab(os_valid_file, old_counter=word_counter, max_count=self.args.max_count) 77 | word_counter = self.gen_vocab(os_test_file, old_counter=word_counter, max_count=self.args.max_count) 78 | self.save_vocab(word_counter, vocab_file, max_vocab_num) 79 | 80 | # translate train/valid/test files to id format 81 | self.cbt_data_to_token_ids(os_train_file, idx_train_file, vocab_file, max_count=self.args.max_count) 82 | self.cbt_data_to_token_ids(os_valid_file, idx_valid_file, vocab_file, max_count=self.args.max_count) 83 | self.cbt_data_to_token_ids(os_test_file, idx_test_file, vocab_file, max_count=self.args.max_count) 84 | 85 | return vocab_file, idx_train_file, idx_valid_file, idx_test_file 86 | 87 | def read_cbt_data(self, file, max_count=None): 88 | """ 89 | read CBT data in id format. 90 | :return: (documents,questions,answers,candidates) each elements is a numpy array. 91 | """ 92 | documents, questions, answers, candidates = [], [], [], [] 93 | with FastGFile(file, mode="r") as f: 94 | counter = 0 95 | d, q, a, A = [], [], [], [] 96 | for line in f: 97 | counter += 1 98 | if max_count and counter > max_count: 99 | break 100 | if counter % 100000 == 0: 101 | logger("Reading line %d in %s" % (counter, file)) 102 | if counter % 22 == 21: 103 | tmp = line.strip().split("\t") 104 | q = tmp[0].split(" ") + [self.EOS_ID] 105 | a = [1 if tmp[1] == i else 0 for i in d] 106 | A = [a for a in tmp[2].split("|")] 107 | A.remove(tmp[1]) 108 | A.insert(0, tmp[1]) # put right answer in the first of candidate 109 | elif counter % 22 == 0: 110 | documents.append(d) 111 | questions.append(q) 112 | answers.append(a) 113 | candidates.append(A) 114 | d, q, a, A = [], [], [], [] 115 | else: 116 | d.extend(line.strip().split(" ") + [self.EOS_ID]) # add EOS ID in the end of each sentence 117 | 118 | d_lens = [len(i) for i in documents] 119 | q_lens = [len(i) for i in questions] 120 | avg_d_len = reduce(lambda x, y: x + y, d_lens) / len(documents) 121 | logger("Document average length: %d." % avg_d_len) 122 | logger("Document midden length: %d." % len(sorted(documents, key=len)[len(documents) // 2])) 123 | avg_q_len = reduce(lambda x, y: x + y, q_lens) / len(questions) 124 | logger("Question average length: %d." % avg_q_len) 125 | logger("Question midden length: %d." % len(sorted(questions, key=len)[len(questions) // 2])) 126 | 127 | return documents, questions, answers, candidates 128 | 129 | def preprocess_input_sequences(self, data): 130 | """ 131 | preprocess,pad to fixed length. 132 | """ 133 | documents, questions, answer, candidates = data 134 | 135 | questions_ok = pad_sequences(questions, maxlen=self.q_len, dtype="int32", padding="post", truncating="post") 136 | documents_ok = pad_sequences(documents, maxlen=self.d_len, dtype="int32", padding="post", truncating="post") 137 | candidates_ok = pad_sequences(candidates, maxlen=self.A_len, dtype="int32", padding="post", truncating="post") 138 | y_true = np.zeros_like(candidates_ok) 139 | y_true[:, 0] = 1 140 | return questions_ok, documents_ok, candidates_ok, y_true 141 | 142 | # noinspection PyAttributeOutsideInit 143 | def get_data_stream(self): 144 | # prepare data 145 | self.vocab_file, idx_train_file, idx_valid_file, idx_test_file = self.prepare_data( 146 | self.args.data_root, self.args.train_file, self.args.valid_file, 147 | self.args.test_file, self.args.max_vocab_num, 148 | output_dir=self.args.tmp_dir) 149 | 150 | # read data 151 | self.train_data = self.read_cbt_data(idx_train_file, max_count=self.args.max_count) 152 | self.valid_data = self.read_cbt_data(idx_valid_file, max_count=self.args.max_count) 153 | 154 | def get_max_length(d_bt): 155 | lens = [len(i) for i in d_bt] 156 | return max(lens) 157 | 158 | # data statistics 159 | self.d_len = get_max_length(self.train_data[0]) 160 | self.q_len = get_max_length(self.train_data[1]) 161 | self.train_sample_num = len(self.train_data[0]) 162 | self.valid_sample_num = len(self.valid_data[0]) 163 | self.train_idx = np.random.permutation(self.train_sample_num // self.args.batch_size) 164 | self.test_sample_num = 0 165 | 166 | if self.args.test: 167 | self.test_data = self.read_cbt_data(idx_test_file, max_count=self.args.max_count) 168 | self.test_sample_num = len(self.test_data[0]) 169 | 170 | return self.d_len, self.q_len, self.train_sample_num, self.valid_sample_num, self.test_sample_num 171 | -------------------------------------------------------------------------------- /dataset/data_file_pairs.py: -------------------------------------------------------------------------------- 1 | dataset_files_pairs = { 2 | "CBT_NE": [ 3 | "data/CBTest/CBTest/data/", 4 | "cbtest_NE_train.txt", 5 | "cbtest_NE_valid_2000ex.txt", 6 | "cbtest_NE_test_2500ex.txt"], 7 | "CBT_CN": [ 8 | "data/CBTest/CBTest/data/", 9 | "cbtest_CN_train.txt", 10 | "cbtest_CN_valid_2000ex.txt", 11 | "cbtest_CN_test_2500ex.txt"], 12 | "SQuAD": [ 13 | "data/SQuAD", 14 | "train-v1.1.json", 15 | "dev-v1.1.json", 16 | "dev-v1.1.json"] 17 | } 18 | -------------------------------------------------------------------------------- /dataset/rc_dataset.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import codecs 3 | import re 4 | from collections import Counter 5 | 6 | import nltk 7 | import numpy as np 8 | from tensorflow.python.platform import gfile 9 | 10 | from utils.log import logger 11 | 12 | 13 | def default_tokenizer(sentence): 14 | _DIGIT_RE = re.compile(r"\d+") 15 | sentence = _DIGIT_RE.sub("0", sentence) 16 | sentence = " ".join(sentence.split("|")) 17 | return nltk.word_tokenize(sentence.lower()) 18 | 19 | 20 | # noinspection PyAttributeOutsideInit 21 | class RCDataset(object, metaclass=abc.ABCMeta): 22 | def __init__(self, args): 23 | self.args = args 24 | # padding,start of sentence,end of sentence,unk,end of question 25 | self._PAD = "_PAD" 26 | self._BOS = "_BOS" 27 | self._EOS = "_EOS" 28 | self._UNK = "_UNK" 29 | self._EOQ = "_EOQ" 30 | self._START_VOCAB = [self._PAD, self._BOS, self._EOS, self._UNK, self._EOQ] 31 | self.PAD_ID = 0 32 | self.BOS_ID = 1 33 | self.EOS_ID = 2 34 | self.UNK_ID = 3 35 | self.EOQ_ID = 4 36 | 37 | self._BLANK = "XXXXX" 38 | 39 | # special character of char embedding: pad and unk 40 | self._CHAR_PAD = "γ" 41 | self._CHAR_UNK = "δ" 42 | self.CHAR_PAD_ID = 0 43 | self.CHAR_UNK_ID = 1 44 | self._CHAR_START_VOCAB = [self._CHAR_PAD, self._CHAR_UNK] 45 | 46 | @property 47 | def train_idx(self): 48 | return self._train_idx 49 | 50 | @train_idx.setter 51 | def train_idx(self, value): 52 | self._train_idx = value 53 | 54 | @property 55 | def train_sample_num(self): 56 | return self._train_sample_num 57 | 58 | @train_sample_num.setter 59 | def train_sample_num(self, value): 60 | self._train_sample_num = value 61 | 62 | @property 63 | def valid_sample_num(self): 64 | return self._valid_sample_num 65 | 66 | @valid_sample_num.setter 67 | def valid_sample_num(self, value): 68 | self._valid_sample_num = value 69 | 70 | @property 71 | def test_sample_num(self): 72 | return self._test_sample_num 73 | 74 | @test_sample_num.setter 75 | def test_sample_num(self, value): 76 | self._test_sample_num = value 77 | 78 | def shuffle(self): 79 | logger("Shuffle the dataset.") 80 | np.random.shuffle(self.train_idx) 81 | 82 | def get_next_batch(self, mode, idx): 83 | """ 84 | return next batch of data samples 85 | """ 86 | batch_size = self.args.batch_size 87 | if mode == "train": 88 | dataset = self.train_data 89 | sample_num = self.train_sample_num 90 | elif mode == "valid": 91 | dataset = self.valid_data 92 | sample_num = self.valid_sample_num 93 | else: 94 | dataset = self.test_data 95 | sample_num = self.test_sample_num 96 | if mode == "train": 97 | start = self.train_idx[idx] * batch_size 98 | stop = (self.train_idx[idx] + 1) * batch_size 99 | else: 100 | start = idx * batch_size 101 | stop = (idx + 1) * batch_size if start < sample_num and (idx + 1) * batch_size < sample_num else -1 102 | samples = batch_size if stop != -1 else len(dataset[0]) - start 103 | _slice = np.index_exp[start:stop] 104 | return self.next_batch_feed_dict_by_dataset(dataset, _slice, samples) 105 | 106 | @staticmethod 107 | def gen_embeddings(word_dict, embed_dim, in_file=None, init=np.zeros): 108 | """ 109 | Init embedding matrix with (or without) pre-trained word embeddings. 110 | """ 111 | num_words = max(word_dict.values()) + 1 112 | embedding_matrix = init(-0.05, 0.05, (num_words, embed_dim)) 113 | logger('Embeddings: %d x %d' % (num_words, embed_dim)) 114 | 115 | if not in_file: 116 | return embedding_matrix 117 | 118 | def get_dim(file): 119 | first = gfile.FastGFile(file, mode='r').readline() 120 | return len(first.split()) - 1 121 | 122 | assert get_dim(in_file) == embed_dim 123 | logger('Loading embedding file: %s' % in_file) 124 | pre_trained = 0 125 | for line in codecs.open(in_file, encoding="utf-8"): 126 | sp = line.split() 127 | if sp[0] in word_dict: 128 | pre_trained += 1 129 | embedding_matrix[word_dict[sp[0]]] = np.asarray([float(x) for x in sp[1:]], dtype=np.float32) 130 | logger("Pre-trained: {}, {:.3f}%".format(pre_trained, pre_trained * 100.0 / num_words)) 131 | return embedding_matrix 132 | 133 | def sentence_to_token_ids(self, sentence, word_dict, tokenizer=default_tokenizer): 134 | """ 135 | Turn sentence to token ids. 136 | sentence: ["I", "have", "a", "dog"] 137 | word_list: {"I": 1, "have": 2, "a": 4, "dog": 7"} 138 | return: [1, 2, 4, 7] 139 | """ 140 | return [word_dict.get(token, self.UNK_ID) for token in tokenizer(sentence)] 141 | 142 | def get_embedding_matrix(self, vocab_file, is_char_embedding=False): 143 | """ 144 | :param is_char_embedding: is the function called for generate char embedding 145 | :param vocab_file: file containing saved vocabulary. 146 | :return: a dict with each key as a word, each value as its corresponding embedding vector. 147 | """ 148 | word_dict = self.load_vocab(vocab_file) 149 | embedding_file = None if is_char_embedding else self.args.embedding_file 150 | embedding_dim = self.args.char_embedding_dim if is_char_embedding else self.args.embedding_dim 151 | embedding_matrix = self.gen_embeddings(word_dict, 152 | embedding_dim, 153 | embedding_file, 154 | init=np.random.uniform) 155 | return embedding_matrix 156 | 157 | def sort_by_length(self, data): 158 | # TODO: sort data array according to sequence length in order to speed up training 159 | pass 160 | 161 | @staticmethod 162 | def gen_char_vocab(data_file, tokenizer=default_tokenizer, old_counter=None): 163 | """ 164 | generate character level vocabulary according to train corpus. 165 | """ 166 | logger("Creating character dict from data {}.".format(data_file)) 167 | char_counter = old_counter if old_counter else Counter() 168 | with gfile.FastGFile(data_file) as f: 169 | for line in f: 170 | tokens = tokenizer(line.rstrip("\n")) 171 | char_counter.update([char for word in tokens for char in word]) 172 | 173 | # summary statistics 174 | total_chars = sum(char_counter.values()) 175 | distinct_chars = len(list(char_counter)) 176 | 177 | logger("STATISTICS" + "-" * 20) 178 | logger("Total characters: " + str(total_chars)) 179 | logger("Total distinct characters: " + str(distinct_chars)) 180 | return char_counter 181 | 182 | @staticmethod 183 | def gen_vocab(data_file, tokenizer=default_tokenizer, old_counter=None, max_count=None): 184 | """ 185 | generate vocabulary according to train corpus. 186 | """ 187 | logger("Creating word dict from data {}.".format(data_file)) 188 | word_counter = old_counter if old_counter else Counter() 189 | counter = 0 190 | with gfile.FastGFile(data_file) as f: 191 | for line in f: 192 | counter += 1 193 | if max_count and counter > max_count: 194 | break 195 | tokens = tokenizer(line.rstrip('\n')) 196 | word_counter.update(tokens) 197 | if counter % 100000 == 0: 198 | logger("Process line %d Done." % counter) 199 | 200 | # summary statistics 201 | total_words = sum(word_counter.values()) 202 | distinct_words = len(list(word_counter)) 203 | 204 | logger("STATISTICS" + "-" * 20) 205 | logger("Total words: " + str(total_words)) 206 | logger("Total distinct words: " + str(distinct_words)) 207 | 208 | return word_counter 209 | 210 | def save_char_vocab(self, char_counter, char_vocab_file, max_vocab_num=None): 211 | """ 212 | Save character vocabulary. 213 | We need two special vo 214 | """ 215 | with gfile.FastGFile(char_vocab_file, "w") as f: 216 | for char in self._CHAR_START_VOCAB: 217 | f.write(char + "\n") 218 | for char in list(map(lambda x: x[0], char_counter.most_common(max_vocab_num))): 219 | f.write(char + "\n") 220 | 221 | def save_vocab(self, word_counter, vocab_file, max_vocab_num=None): 222 | with gfile.FastGFile(vocab_file, "w") as f: 223 | for word in self._START_VOCAB: 224 | f.write(word + "\n") 225 | for word in list(map(lambda x: x[0], word_counter.most_common(max_vocab_num))): 226 | f.write(word + "\n") 227 | 228 | @staticmethod 229 | def load_vocab(vocab_file): 230 | """ 231 | load word(or char) vocabulary file to word/char dict 232 | """ 233 | if not gfile.Exists(vocab_file): 234 | raise ValueError("Vocabulary file %s not found.", vocab_file) 235 | word_dict = {} 236 | word_id = 0 237 | for line in codecs.open(vocab_file, encoding="utf-8"): 238 | word_dict.update({line.strip(): word_id}) 239 | word_id += 1 240 | return word_dict 241 | 242 | # noinspection PyAttributeOutsideInit 243 | def preprocess(self): 244 | self.train_data = self.preprocess_input_sequences(self.train_data) 245 | self.valid_data = self.preprocess_input_sequences(self.valid_data) 246 | if self.args.test: 247 | self.test_data = self.preprocess_input_sequences(self.test_data) 248 | 249 | @abc.abstractmethod 250 | def preprocess_input_sequences(self, data): 251 | """ 252 | Preprocess train/valid/test data. Should be specified by sub class. 253 | """ 254 | pass 255 | 256 | @abc.abstractmethod 257 | def get_data_stream(self): 258 | """ 259 | Get data statistics. 260 | """ 261 | pass 262 | 263 | @abc.abstractmethod 264 | def next_batch_feed_dict_by_dataset(self, dataset, _slice, samples): 265 | """ 266 | How to specify feed dict according to _slice. 267 | """ 268 | pass 269 | -------------------------------------------------------------------------------- /dataset/squad.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | from tensorflow.contrib.keras.python.keras.preprocessing.sequence import pad_sequences 6 | from tensorflow.python.platform import gfile 7 | 8 | from dataset.rc_dataset import RCDataset 9 | from utils.log import logger 10 | 11 | 12 | class SQuAD(RCDataset): 13 | def __init__(self, args): 14 | super(SQuAD, self).__init__(args) 15 | self.w_len = 10 16 | 17 | def next_batch_feed_dict_by_dataset(self, dataset, _slice, samples): 18 | data = { 19 | "documents_bt:0": dataset[0][_slice], 20 | "questions_bt:0": dataset[1][_slice], 21 | # TODO: substitute with real data 22 | "documents_btk:0": np.zeros([samples, self.d_len, self.w_len]), 23 | "questions_btk:0": np.zeros([samples, self.q_len, self.w_len]), 24 | "answer_start:0": dataset[2][_slice], 25 | "answer_end:0": dataset[3][_slice] 26 | } 27 | return data, samples 28 | 29 | def preprocess_input_sequences(self, data): 30 | documents, questions, answer_spans = data 31 | documents_ok = pad_sequences(documents, maxlen=self.d_len, dtype="int32", padding="post", truncating="post") 32 | questions_ok = pad_sequences(questions, maxlen=self.q_len, dtype="int32", padding="post", truncating="post") 33 | answer_start = [np.array([int(i == answer_span[0]) for i in range(self.d_len)]) for answer_span in answer_spans] 34 | answer_end = [np.array([int(i == answer_span[1]) for i in range(self.d_len)]) for answer_span in answer_spans] 35 | return documents_ok, questions_ok, np.asarray(answer_start), np.asarray(answer_end) 36 | 37 | def prepare_data(self, data_dir, train_file, valid_file, max_vocab_num, output_dir=""): 38 | """ 39 | build word vocabulary and character vocabulary. 40 | """ 41 | if not gfile.Exists(os.path.join(data_dir, output_dir)): 42 | os.mkdir(os.path.join(data_dir, output_dir)) 43 | os_train_file = os.path.join(data_dir, train_file) 44 | os_valid_file = os.path.join(data_dir, valid_file) 45 | vocab_file = os.path.join(data_dir, output_dir, "vocab.%d" % max_vocab_num) 46 | char_vocab_file = os.path.join(data_dir, output_dir, "char_vocab") 47 | 48 | vocab_data_file = os.path.join(data_dir, output_dir, "data.txt") 49 | 50 | def save_data(d_data, q_data): 51 | """ 52 | save all data to a file and use it build vocabulary. 53 | """ 54 | with open(vocab_data_file, mode="w", encoding="utf-8") as f: 55 | f.write("\t".join(d_data) + "\n") 56 | f.write("\t".join(q_data) + "\n") 57 | 58 | if not gfile.Exists(vocab_data_file): 59 | d, q, _ = self.read_squad_data(os_train_file) 60 | v_d, v_q, _ = self.read_squad_data(os_valid_file) 61 | save_data(d, q) 62 | save_data(v_d, v_q) 63 | if not gfile.Exists(vocab_file): 64 | logger("Start create vocabulary.") 65 | word_counter = self.gen_vocab(vocab_data_file, max_count=self.args.max_count) 66 | self.save_vocab(word_counter, vocab_file, max_vocab_num) 67 | if not gfile.Exists(char_vocab_file): 68 | logger("Start create character vocabulary.") 69 | char_counter = self.gen_char_vocab(vocab_data_file) 70 | self.save_char_vocab(char_counter, char_vocab_file, max_vocab_num=70) 71 | 72 | return os_train_file, os_valid_file, vocab_file, char_vocab_file 73 | 74 | def read_squad_data(self, file): 75 | """ 76 | read squad data file in string form 77 | :return tuple of (documents, questions, answer_spans) 78 | """ 79 | logger("Reading SQuAD data.") 80 | 81 | def extract(sample_data): 82 | document = sample_data["context"] 83 | for qas in sample_data["qas"]: 84 | question = qas["question"] 85 | for ans in qas["answers"]: 86 | answer_len = len(ans["text"]) 87 | answer_span = [ans["answer_start"], ans["answer_start"] + answer_len] 88 | assert (ans["text"] == document[ans["answer_start"]:(ans["answer_start"] + answer_len)]) 89 | documents.append(document) 90 | questions.append(question) 91 | answer_spans.append(answer_span) 92 | 93 | documents, questions, answer_spans = [], [], [] 94 | f = json.load(open(file, encoding="utf-8")) 95 | data_list, version = f["data"], f["version"] 96 | logger("SQuAD version: {}".format(version)) 97 | [extract(sample) for data in data_list for sample in data["paragraphs"]] 98 | if self.args.debug: 99 | documents, questions, answer_spans = documents[:500], questions[:500], answer_spans[:500] 100 | 101 | return documents, questions, answer_spans 102 | 103 | def squad_data_to_idx(self, vocab_file, *args): 104 | """ 105 | convert string list to index list form. 106 | """ 107 | logger("Convert string data to index.") 108 | word_dict = self.load_vocab(vocab_file) 109 | res_data = [0, ] * len(args) 110 | for idx, i in enumerate(args): 111 | tmp = [self.sentence_to_token_ids(document, word_dict) for document in i] 112 | res_data[idx] = tmp.copy() 113 | logger("Convert string2index done.") 114 | return res_data 115 | 116 | # noinspection PyAttributeOutsideInit 117 | def get_data_stream(self): 118 | # prepare data 119 | os_train_file, os_valid_file, self.vocab_file, self.char_vocab_file = self.prepare_data(self.args.data_root, 120 | self.args.train_file, 121 | self.args.valid_file, 122 | self.args.max_vocab_num, 123 | self.args.tmp_dir) 124 | 125 | # read data 126 | documents, questions, answer_spans = self.read_squad_data(os_train_file) 127 | v_documents, v_questions, v_answer_spans = self.read_squad_data(os_valid_file) 128 | documents, questions, v_documents, v_questions = self.squad_data_to_idx(self.vocab_file, documents, questions, 129 | v_documents, v_questions) 130 | # SQuAD cannot access the test data 131 | # first 9/10 train data -> train data 132 | # last 1/10 train data -> valid data 133 | # valid data -> test data 134 | train_num = len(documents) * 9 // 10 135 | self.train_data = (documents[:train_num], questions[:train_num], answer_spans[:train_num]) 136 | self.valid_data = (documents[train_num:], questions[train_num:], answer_spans[train_num:]) 137 | self.test_data = (v_documents, v_questions, v_answer_spans) 138 | 139 | def get_max_length(d_bt): 140 | lens = [len(i) for i in d_bt] 141 | return max(lens) 142 | 143 | # data statistics 144 | self.d_len = get_max_length(self.train_data[0]) 145 | self.q_len = get_max_length(self.train_data[1]) 146 | self.train_sample_num = len(self.train_data[0]) 147 | self.valid_sample_num = len(self.valid_data[0]) 148 | self.test_sample_num = len(self.test_data[0]) 149 | self.train_idx = np.random.permutation(self.train_sample_num // self.args.batch_size) 150 | 151 | return self.d_len, self.q_len, self.train_sample_num, self.valid_sample_num, self.test_sample_num 152 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from models.nlp_base import NLPBase 5 | 6 | 7 | def get_model_class(): 8 | if sys.argv[1] == "--help" or sys.argv[1] == "-h": 9 | return NLPBase() 10 | class_obj, class_name = None, sys.argv[1] 11 | try: 12 | import models 13 | class_obj = getattr(sys.modules["models"], class_name) 14 | sys.argv.pop(1) 15 | except AttributeError or IndexError: 16 | print("Model [{}] not found.\nSupported models:\n\n\t\t{}\n".format(class_name, sys.modules["models"].__all__)) 17 | exit(1) 18 | return class_obj() 19 | 20 | 21 | if __name__ == '__main__': 22 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3" 23 | model = get_model_class() 24 | model.execute() 25 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.model_data_pairs import models_in_datasets 2 | from .attention_over_attention_reader import AoAReader 3 | from .attention_sum_reader import AttentionSumReader 4 | from .r_net import RNet 5 | 6 | __all__ = list(set([model for models in models_in_datasets.values() for model in models])) 7 | -------------------------------------------------------------------------------- /models/attention_over_attention_reader.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, LSTMCell 3 | 4 | from models.rc_base import RcBase 5 | from utils.log import logger 6 | 7 | 8 | class AoAReader(RcBase): 9 | """ 10 | Attention-over-Attention reader in "Attention-over-Attention Neural Networks for Reading Comprehension" 11 | (arXiv2016.7) available at https://arxiv.org/abs/1607.04423. 12 | """ 13 | 14 | # noinspection PyAttributeOutsideInit 15 | def create_model(self): 16 | ######################### 17 | # b ... position of the example within the batch 18 | # t ... position of the word within the document/question 19 | # ... d for max length of document 20 | # ... q for max length of question 21 | # f ... features of the embedding vector or the encoded feature vector 22 | # i ... position of the word in candidates list 23 | # v ... position of the word in vocabulary 24 | ######################### 25 | _EPSILON = 10e-8 26 | num_layers = self.args.num_layers 27 | hidden_size = self.args.hidden_size 28 | cell = LSTMCell if self.args.use_lstm else GRUCell 29 | 30 | # model input 31 | questions_bt = tf.placeholder(dtype=tf.int32, shape=(None, self.q_len), name="questions_bt") 32 | documents_bt = tf.placeholder(dtype=tf.int32, shape=(None, self.d_len), name="documents_bt") 33 | candidates_bi = tf.placeholder(dtype=tf.int32, shape=(None, self.dataset.A_len), name="candidates_bi") 34 | y_true_bi = tf.placeholder(shape=(None, self.dataset.A_len), dtype=tf.float32, name="y_true_bi") 35 | keep_prob = tf.placeholder(dtype=tf.float32, name="keep_prob") 36 | 37 | init_embedding = tf.constant(self.embedding_matrix, dtype=tf.float32, name="embedding_init") 38 | embedding = tf.get_variable(initializer=init_embedding, 39 | name="embedding_matrix", 40 | dtype=tf.float32) 41 | embedding = tf.nn.dropout(embedding, keep_prob) 42 | 43 | # shape=(None) the length of inputs 44 | document_lengths = tf.reduce_sum(tf.sign(tf.abs(documents_bt)), 1) 45 | question_lengths = tf.reduce_sum(tf.sign(tf.abs(questions_bt)), 1) 46 | document_mask_bt = tf.sequence_mask(document_lengths, self.d_len, dtype=tf.float32) 47 | question_mask_bt = tf.sequence_mask(question_lengths, self.q_len, dtype=tf.float32) 48 | 49 | with tf.variable_scope('q_encoder', initializer=tf.orthogonal_initializer()): 50 | # encode question to fixed length of vector 51 | # output shape: (None, max_q_length, embedding_dim) 52 | question_embed_btf = tf.nn.embedding_lookup(embedding, questions_bt) 53 | logger("q_embed_btf shape {}".format(question_embed_btf.get_shape())) 54 | q_cell_fw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)]) 55 | q_cell_bw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)]) 56 | outputs, last_states = tf.nn.bidirectional_dynamic_rnn(cell_bw=q_cell_bw, 57 | cell_fw=q_cell_fw, 58 | dtype="float32", 59 | sequence_length=question_lengths, 60 | inputs=question_embed_btf, 61 | swap_memory=True) 62 | # q_encoder output shape: (None, max_t_length, hidden_size * 2) 63 | q_encoded_bqf = tf.concat(outputs, axis=-1) 64 | logger("q_encoded_bqf shape {}".format(q_encoded_bqf.get_shape())) 65 | 66 | with tf.variable_scope('d_encoder', initializer=tf.orthogonal_initializer()): 67 | # encode each document(context) word to fixed length vector 68 | # output shape: (None, max_d_length, embedding_dim) 69 | d_embed_btf = tf.nn.embedding_lookup(embedding, documents_bt) 70 | logger("d_embed_btf shape {}".format(d_embed_btf.get_shape())) 71 | d_cell_fw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)]) 72 | d_cell_bw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)]) 73 | outputs, last_states = tf.nn.bidirectional_dynamic_rnn(cell_bw=d_cell_bw, 74 | cell_fw=d_cell_fw, 75 | dtype="float32", 76 | sequence_length=document_lengths, 77 | inputs=d_embed_btf, 78 | swap_memory=True) 79 | # d_encoder output shape: (None, max_d_length, hidden_size * 2) 80 | d_encoded_bdf = tf.concat(outputs, axis=-1) 81 | logger("d_encoded_bdf shape {}".format(d_encoded_bdf.get_shape())) 82 | 83 | # mask of the pair-wise matrix 84 | M_mask = tf.einsum("bi,bj->bij", document_mask_bt, question_mask_bt) 85 | # batch pair-wise matching 86 | M_bdq = tf.matmul(d_encoded_bdf, q_encoded_bqf, adjoint_b=True) 87 | 88 | # individual attentions 89 | alpha_bdq = self.softmax_with_mask(M_bdq, 1, M_mask, name="alpha") 90 | beta_bdq = self.softmax_with_mask(M_bdq, 2, M_mask, name="beta") 91 | beta_bq1 = tf.expand_dims(tf.reduce_sum(beta_bdq, 1) / tf.to_float(tf.expand_dims(document_lengths, -1)), -1) 92 | logger("beta_bq1 shape:{}".format(beta_bq1.get_shape())) 93 | # document-level attention 94 | s_bd = tf.squeeze(tf.einsum("bdq,bqi->bdi", alpha_bdq, beta_bq1), -1) 95 | 96 | vocab_size = self.embedding_matrix.shape[0] 97 | # attention sum operation and gather within candidate_index 98 | y_hat_bi = tf.scan(fn=lambda prev, cur: tf.gather(tf.unsorted_segment_sum(cur[0], cur[1], vocab_size), cur[2]), 99 | elems=[s_bd, documents_bt, candidates_bi], 100 | initializer=tf.Variable([0] * self.dataset.A_len, dtype="float32")) 101 | 102 | # manual computation of crossentropy 103 | output_bi = y_hat_bi / tf.reduce_sum(y_hat_bi, axis=-1, keep_dims=True) 104 | epsilon = tf.convert_to_tensor(_EPSILON, output_bi.dtype.base_dtype, name="epsilon") 105 | output_bi = tf.clip_by_value(output_bi, epsilon, 1. - epsilon) 106 | 107 | # loss and correct number 108 | self.loss = tf.reduce_mean(- tf.reduce_sum(y_true_bi * tf.log(output_bi), axis=-1)) 109 | self.correct_prediction = tf.reduce_sum( 110 | tf.sign(tf.cast(tf.equal(tf.argmax(output_bi, 1), 111 | tf.argmax(y_true_bi, 1)), "float"))) 112 | 113 | @staticmethod 114 | def softmax_with_mask(logits, axis, mask, epsilon=10e-8, name=None): 115 | with tf.name_scope(name, 'softmax', [logits, mask]): 116 | max_axis = tf.reduce_max(logits, axis, keep_dims=True) 117 | target_exp = tf.exp(logits - max_axis) * mask 118 | normalize = tf.reduce_sum(target_exp, axis, keep_dims=True) 119 | softmax = target_exp / (normalize + epsilon) 120 | return softmax 121 | 122 | def get_batch_data(self, mode, idx): 123 | data, samples = self.dataset.get_next_batch(mode, idx) 124 | if mode == "train": 125 | data.update({"keep_prob:0": self.args.keep_prob}) 126 | else: 127 | data.update({"keep_prob:0": 1.0}) 128 | return data, samples 129 | -------------------------------------------------------------------------------- /models/attention_sum_reader.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.rnn import LSTMCell, MultiRNNCell, GRUCell 3 | 4 | from models.rc_base import RcBase 5 | from utils.log import logger 6 | 7 | _EPSILON = 10e-8 8 | 9 | 10 | class AttentionSumReader(RcBase): 11 | """ 12 | Attention Sum Reader model as presented in "Text Comprehension with the Attention Sum Reader Network" 13 | (ACL2016) available at http://arxiv.org/abs/1603.01547. 14 | """ 15 | 16 | # noinspection PyAttributeOutsideInit 17 | def create_model(self): 18 | ######################### 19 | # b ... position of the example within the batch 20 | # t ... position of the word within the document/question 21 | # f ... features of the embedding vector or the encoded feature vector 22 | # i ... position of the word in candidates list 23 | ######################### 24 | num_layers = self.args.num_layers 25 | hidden_size = self.args.hidden_size 26 | cell = LSTMCell if self.args.use_lstm else GRUCell 27 | 28 | # model input 29 | questions_bt = tf.placeholder(dtype=tf.int32, shape=(None, self.q_len), name="questions_bt") 30 | documents_bt = tf.placeholder(dtype=tf.int32, shape=(None, self.d_len), name="documents_bt") 31 | candidates_bi = tf.placeholder(dtype=tf.int32, shape=(None, self.dataset.A_len), name="candidates_bi") 32 | y_true_bi = tf.placeholder(shape=(None, self.dataset.A_len), dtype=tf.float32, name="y_true_bi") 33 | 34 | # shape=(None) the length of inputs 35 | context_lengths = tf.reduce_sum(tf.sign(tf.abs(documents_bt)), 1) 36 | question_lengths = tf.reduce_sum(tf.sign(tf.abs(questions_bt)), 1) 37 | context_mask_bt = tf.sequence_mask(context_lengths, self.d_len, dtype=tf.float32) 38 | 39 | init_embedding = tf.constant(self.embedding_matrix, dtype=tf.float32, name="embedding_init") 40 | embedding = tf.get_variable(initializer=init_embedding, 41 | name="embedding_matrix", 42 | dtype=tf.float32) 43 | 44 | with tf.variable_scope('q_encoder', initializer=tf.orthogonal_initializer()): 45 | # encode question to fixed length of vector 46 | # output shape: (None, max_q_length, embedding_dim) 47 | question_embed_btf = tf.nn.embedding_lookup(embedding, questions_bt) 48 | logger("q_embed_btf shape {}".format(question_embed_btf.get_shape())) 49 | q_cell_fw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)]) 50 | q_cell_bw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)]) 51 | outputs, last_states = tf.nn.bidirectional_dynamic_rnn(cell_bw=q_cell_bw, 52 | cell_fw=q_cell_fw, 53 | dtype="float32", 54 | sequence_length=question_lengths, 55 | inputs=question_embed_btf, 56 | swap_memory=True) 57 | # q_encoder output shape: (None, hidden_size * 2) 58 | q_encoded_bf = tf.concat([last_states[0][-1], last_states[1][-1]], axis=-1) 59 | logger("q_encoded_bf shape {}".format(q_encoded_bf.get_shape())) 60 | 61 | with tf.variable_scope('d_encoder', initializer=tf.orthogonal_initializer()): 62 | # encode each document(context) word to fixed length vector 63 | # output shape: (None, max_d_length, embedding_dim) 64 | d_embed_btf = tf.nn.embedding_lookup(embedding, documents_bt) 65 | logger("d_embed_btf shape {}".format(d_embed_btf.get_shape())) 66 | d_cell_fw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)]) 67 | d_cell_bw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)]) 68 | outputs, last_states = tf.nn.bidirectional_dynamic_rnn(cell_bw=d_cell_bw, 69 | cell_fw=d_cell_fw, 70 | dtype="float32", 71 | sequence_length=context_lengths, 72 | inputs=d_embed_btf, 73 | swap_memory=True) 74 | # d_encoder output shape: (None, max_d_length, hidden_size * 2) 75 | d_encoded_btf = tf.concat(outputs, axis=-1) 76 | logger("d_encoded_btf shape {}".format(d_encoded_btf.get_shape())) 77 | 78 | def att_dot(x): 79 | """attention dot product function""" 80 | d_btf, q_bf = x 81 | res = tf.matmul(tf.expand_dims(q_bf, -1), d_btf, adjoint_a=True, adjoint_b=True) 82 | return tf.reshape(res, [-1, self.d_len]) 83 | 84 | with tf.variable_scope('merge'): 85 | mem_attention_pre_soft_bt = att_dot([d_encoded_btf, q_encoded_bf]) 86 | mem_attention_pre_soft_masked_bt = tf.multiply(mem_attention_pre_soft_bt, 87 | context_mask_bt, 88 | name="attention_mask") 89 | mem_attention_bt = tf.nn.softmax(logits=mem_attention_pre_soft_masked_bt, name="softmax_attention") 90 | 91 | # attention-sum process 92 | def sum_prob_of_word(word_ix, sentence_ixs, sentence_attention_probs): 93 | word_ixs_in_sentence = tf.where(tf.equal(sentence_ixs, word_ix)) 94 | return tf.reduce_sum(tf.gather(sentence_attention_probs, word_ixs_in_sentence)) 95 | 96 | # noinspection PyUnusedLocal 97 | def sum_probs_single_sentence(prev, cur): 98 | candidate_indices_i, sentence_ixs_t, sentence_attention_probs_t = cur 99 | result = tf.scan( 100 | fn=lambda previous, x: sum_prob_of_word(x, sentence_ixs_t, sentence_attention_probs_t), 101 | elems=[candidate_indices_i], 102 | initializer=tf.constant(0., dtype="float32")) 103 | return result 104 | 105 | def sum_probs_batch(candidate_indices_bi, sentence_ixs_bt, sentence_attention_probs_bt): 106 | result = tf.scan( 107 | fn=sum_probs_single_sentence, 108 | elems=[candidate_indices_bi, sentence_ixs_bt, sentence_attention_probs_bt], 109 | initializer=tf.Variable([0] * self.dataset.A_len, dtype="float32")) 110 | return result 111 | 112 | # output shape: (None, i) i = max_candidate_length = 10 113 | y_hat = sum_probs_batch(candidates_bi, documents_bt, mem_attention_bt) 114 | 115 | # crossentropy 116 | output = y_hat / tf.reduce_sum(y_hat, axis=-1, keep_dims=True) 117 | # manual computation of crossentropy 118 | epsilon = tf.convert_to_tensor(_EPSILON, output.dtype.base_dtype, name="epsilon") 119 | output = tf.clip_by_value(output, epsilon, 1. - epsilon) 120 | self.loss = tf.reduce_mean(- tf.reduce_sum(y_true_bi * tf.log(output), axis=-1)) 121 | 122 | # correct prediction nums 123 | self.correct_prediction = tf.reduce_sum(tf.sign(tf.cast(tf.equal(tf.argmax(y_hat, 1), 124 | tf.argmax(y_true_bi, 1)), "float"))) 125 | -------------------------------------------------------------------------------- /models/model_data_pairs.py: -------------------------------------------------------------------------------- 1 | # make sure the model supports the dataset you use 2 | models_in_datasets = { 3 | "CBT_NE": ["AttentionSumReader", "AoAReader"], 4 | "CBT_CN": ["AttentionSumReader", "AoAReader"], 5 | "SQuAD": ["RNet"] 6 | } 7 | -------------------------------------------------------------------------------- /models/nlp_base.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import sys 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from dataset.data_file_pairs import dataset_files_pairs 9 | from utils.log import setup_from_args_file, save_args, err 10 | 11 | 12 | class NLPBase(object): 13 | """ 14 | Base class for NLP experiments based on tensorflow environment. 15 | Only do some arguments reading and serializing work. 16 | """ 17 | 18 | def __init__(self): 19 | self.model_name = self.__class__.__name__ 20 | self.sess = tf.Session() 21 | # get arguments 22 | self.args = self.get_args() 23 | 24 | # log set 25 | logging.basicConfig(filename=self.args.log_file, 26 | level=logging.DEBUG, 27 | format='%(asctime)s %(message)s', datefmt='%y-%m-%d %H:%M') 28 | 29 | # set random seed 30 | np.random.seed(self.args.random_seed) 31 | tf.set_random_seed(self.args.random_seed) 32 | 33 | # save arguments 34 | save_args(args=self.args) 35 | 36 | def add_args(self, parser): 37 | """ 38 | If some model need more arguments, override this method. 39 | """ 40 | pass 41 | 42 | def get_args(self): 43 | """ 44 | The priority of args: 45 | [low] ... args define in the code 46 | [middle] ... args define in args_file 47 | [high] ... args define in command line 48 | """ 49 | 50 | def str2bool(v): 51 | if v.lower() in ("yes", "true", "t", "y", "1"): 52 | return True 53 | if v.lower() in ("no", "false", "f", "n", "0", "none"): 54 | return False 55 | else: 56 | raise argparse.ArgumentTypeError('Boolean value expected.') 57 | 58 | def str_or_none(v): 59 | if not v or v.lower() in ("no", "false", "f", "n", "0", "none", "null"): 60 | return None 61 | return v 62 | 63 | def int_or_none(v): 64 | if not v or v.lower() in ("no", "false", "f", "n", "0", "none", "null"): 65 | return None 66 | return int(v) 67 | 68 | # TODO:Implement ensemble test 69 | parser = argparse.ArgumentParser(description="Reading Comprehension Experiment Code Base.") 70 | # ----------------------------------------------------------------------------------------------------------- 71 | group1 = parser.add_argument_group("1.Basic options") 72 | # basis argument 73 | group1.add_argument("--debug", default=False, type=str2bool, help="is debug mode on or off") 74 | 75 | group1.add_argument("--train", default=True, type=str2bool, help="train or not") 76 | 77 | group1.add_argument("--test", default=False, type=str2bool, help="test or not") 78 | 79 | group1.add_argument("--ensemble", default=False, type=str2bool, help="ensemble test or not") 80 | 81 | group1.add_argument("--random_seed", default=2088, type=int, help="random seed") 82 | 83 | group1.add_argument("--log_file", default=None, type=str_or_none, 84 | help="which file to save the log,if None,use screen") 85 | 86 | group1.add_argument("--weight_path", default="weights", help="path to save all trained models") 87 | 88 | group1.add_argument("--args_file", default=None, type=str_or_none, help="json file of current args") 89 | 90 | group1.add_argument("--print_every_n", default=10, type=int, help="print performance every n steps") 91 | 92 | # data specific argument 93 | group2 = parser.add_argument_group("2.Data specific options") 94 | # noinspection PyUnresolvedReferences 95 | import dataset 96 | group2.add_argument("--dataset", default="CBT", choices=sys.modules['dataset'].__all__, type=str, 97 | help='type of the dataset to load') 98 | 99 | group2.add_argument("--embedding_file", default="data/glove.6B/glove.6B.200d.txt", 100 | type=str_or_none, help="pre-trained embedding file") 101 | 102 | group2.add_argument("--max_vocab_num", default=100000, type=int, help="the max number of words in vocabulary") 103 | 104 | subgroup = group2.add_argument_group("Some default options related to dataset, don't change if it works") 105 | 106 | subgroup.add_argument("--data_root", default="data/CBTest/CBTest/data/", 107 | help="root path of the dataset") 108 | 109 | subgroup.add_argument("--tmp_dir", default="tmp", help="dataset specific tmp folder") 110 | 111 | subgroup.add_argument("--train_file", default="cbtest_NE_train.txt", help="train file") 112 | 113 | subgroup.add_argument("--valid_file", default="cbtest_NE_valid_2000ex.txt", help="validation file") 114 | 115 | subgroup.add_argument("--test_file", default="cbtest_NE_test_2500ex.txt", help="test file") 116 | 117 | subgroup.add_argument("--max_count", default=None, type=int_or_none, 118 | help="read n lines of data file, if None, read all data") 119 | 120 | # hyper-parameters 121 | group3 = parser.add_argument_group("3.Hyper parameters shared by all models") 122 | 123 | group3.add_argument("--use_char_embedding", default=False, type=str2bool, 124 | help="use character embedding or not") 125 | 126 | group3.add_argument("--char_embedding_dim", default=100, type=int, help="dimension of char embeddings") 127 | 128 | group3.add_argument("--embedding_dim", default=200, type=int, help="dimension of word embeddings") 129 | 130 | group3.add_argument("--hidden_size", default=128, type=int, help="RNN hidden size") 131 | 132 | group3.add_argument("--grad_clipping", default=10, type=int, help="the threshold value of gradient clip") 133 | 134 | group3.add_argument("--lr", default=0.001, type=float, help="learning rate") 135 | 136 | group3.add_argument("--keep_prob", default=0.9, type=float, help="dropout,percentage to keep during training") 137 | 138 | group3.add_argument("--l2", default=0.0001, type=float, help="l2 regularization weight") 139 | 140 | group3.add_argument("--num_layers", default=1, type=int, help="RNN layer number") 141 | 142 | group3.add_argument("--use_lstm", default=False, type=str2bool, 143 | help="RNN kind, if False, use GRU else LSTM") 144 | 145 | group3.add_argument("--batch_size", default=32, type=int, help="batch_size") 146 | 147 | group3.add_argument("--optimizer", default="ADAM", choices=["SGD", "ADAM"], 148 | help="optimize algorithms, SGD or Adam") 149 | 150 | group3.add_argument("--evaluate_every_n", default=400, type=int, 151 | help="evaluate performance on validation set and possibly saving the best model") 152 | 153 | group3.add_argument("--num_epoches", default=10, type=int, help="max epoch iterations") 154 | 155 | group3.add_argument("--patience", default=5, type=int, help="early stopping patience") 156 | # ----------------------------------------------------------------------------------------------------------- 157 | group4 = parser.add_argument_group("4.model [{}] specific parameters".format(self.model_name)) 158 | 159 | self.add_args(group4) 160 | 161 | args = parser.parse_args() 162 | 163 | setup_from_args_file(args.args_file) 164 | 165 | args = parser.parse_args() 166 | 167 | # set debug params 168 | args.max_count = 7392 if args.debug else args.max_count 169 | args.evaluate_every_n = 5 if args.debug else args.evaluate_every_n 170 | args.num_epoches = 2 if args.debug else args.num_epoches 171 | 172 | args = self.tune_args(args) 173 | 174 | return args 175 | 176 | @staticmethod 177 | def tune_args(args): 178 | """ 179 | tune the dataset specific args so train_file or test_file need not be changed 180 | """ 181 | try: 182 | files = dataset_files_pairs.get(args.dataset) 183 | args.data_root, args.train_file, args.valid_file, args.test_file = files 184 | return args 185 | except AssertionError: 186 | err("Error. Cannot find the specific key -> {} in dataset_files_pairs.".format(args.dataset)) 187 | -------------------------------------------------------------------------------- /models/r_net.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cairoHy/RC-experiments/0262f83481c364f29a43ac7cfc28da88d31f5adc/models/r_net.py -------------------------------------------------------------------------------- /models/rc_base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os 3 | import sys 4 | 5 | import tensorflow as tf 6 | 7 | # noinspection PyUnresolvedReferences 8 | import dataset 9 | from models import models_in_datasets 10 | from models.nlp_base import NLPBase 11 | from utils.log import logger, save_obj_to_json, err 12 | 13 | 14 | # noinspection PyAttributeOutsideInit 15 | class RcBase(NLPBase, metaclass=abc.ABCMeta): 16 | """ 17 | Base class of reading comprehension experiments. 18 | Reads different reading comprehension datasets according to specific class. 19 | creates a model and starts training it. 20 | Any deep learning model should inherit from this class and implement the create_model method. 21 | """ 22 | 23 | @property 24 | def loss(self): 25 | return self._loss 26 | 27 | @loss.setter 28 | def loss(self, value): 29 | self._loss = value 30 | 31 | @property 32 | def correct_prediction(self): 33 | return self._correct_prediction 34 | 35 | @correct_prediction.setter 36 | def correct_prediction(self, value): 37 | self._correct_prediction = value 38 | 39 | def get_train_op(self): 40 | """ 41 | define optimization operation 42 | """ 43 | if self.args.optimizer == "SGD": 44 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.args.lr) 45 | elif self.args.optimizer == "ADAM": 46 | optimizer = tf.train.AdamOptimizer(learning_rate=self.args.lr) 47 | else: 48 | raise NotImplementedError("Other Optimizer Not Implemented.-_-||") 49 | 50 | # gradient clip 51 | grad_vars = optimizer.compute_gradients(self.loss) 52 | grad_vars = [ 53 | (tf.clip_by_norm(grad, self.args.grad_clipping), var) 54 | if grad is not None else (grad, var) 55 | for grad, var in grad_vars] 56 | self.train_op = optimizer.apply_gradients(grad_vars, self.step) 57 | return 58 | 59 | @abc.abstractmethod 60 | def create_model(self): 61 | """ 62 | should be override by sub-class and create some operations include [loss, correct_prediction] 63 | as class attributes. 64 | """ 65 | return 66 | 67 | def execute(self): 68 | """ 69 | main method to train and test 70 | """ 71 | self.confirm_model_dataset_fitness() 72 | 73 | self.dataset = getattr(sys.modules["dataset"], self.args.dataset)(self.args) 74 | 75 | # Get the statistics of data 76 | # [document length] and [question length] to build the model 77 | # train/valid/test sample number to train and validate and test the model 78 | statistics = self.dataset.get_data_stream() 79 | self.d_len, self.q_len, self.train_nums, self.valid_nums, self.test_num = statistics 80 | self.dataset.preprocess() 81 | 82 | # Get the word embedding and character embedding(if necessary) 83 | self.embedding_matrix = self.dataset.get_embedding_matrix(self.dataset.vocab_file) 84 | if self.args.use_char_embedding and getattr(self.dataset, "char_vocab_file"): 85 | self.char_embedding_matrix = self.dataset.get_embedding_matrix(self.dataset.char_vocab_file, True) 86 | 87 | self.create_model() 88 | 89 | self.make_sure_model_is_valid() 90 | 91 | self.saver = tf.train.Saver(max_to_keep=20) 92 | 93 | if self.args.train: 94 | self.train() 95 | if self.args.test: 96 | self.test() 97 | 98 | self.sess.close() 99 | 100 | def get_batch_data(self, mode, idx): 101 | """ 102 | Get batch data and feed it to tensorflow graph 103 | Modify it in sub-class if needed. 104 | """ 105 | return self.dataset.get_next_batch(mode, idx) 106 | 107 | def train(self): 108 | """ 109 | train model 110 | """ 111 | self.step = tf.Variable(0, name="global_step", trainable=False) 112 | batch_size = self.args.batch_size 113 | epochs = self.args.num_epoches 114 | self.get_train_op() 115 | self.sess.run(tf.global_variables_initializer()) 116 | self.load_weight() 117 | 118 | # early stopping params, by default val_acc is the metric 119 | self.patience, self.best_val_acc = self.args.patience, 0. 120 | # Start training 121 | corrects_in_epoch, samples_in_epoch, loss_in_epoch = 0, 0, 0 122 | batch_num = self.train_nums // batch_size 123 | logger("Train on {} batches, {} samples per batch, {} total.".format(batch_num, batch_size, self.train_nums)) 124 | 125 | step = self.sess.run(self.step) 126 | while step < batch_num * epochs: 127 | step = self.sess.run(self.step) 128 | # on Epoch start 129 | if step % batch_num == 0: 130 | corrects_in_epoch, samples_in_epoch, loss_in_epoch = 0, 0, 0 131 | logger("{}Epoch : {}{}".format("-" * 40, step // batch_num + 1, "-" * 40)) 132 | self.dataset.shuffle() 133 | 134 | data, samples = self.get_batch_data("train", step % batch_num) 135 | loss, _, corrects_in_batch = self.sess.run([self.loss, self.train_op, self.correct_prediction], 136 | feed_dict=data) 137 | corrects_in_epoch += corrects_in_batch 138 | loss_in_epoch += loss * samples 139 | samples_in_epoch += samples 140 | 141 | # logger 142 | if step % self.args.print_every_n == 0: 143 | logger("Samples : {}/{}.\tStep : {}/{}.\tLoss : {:.4f}.\tAccuracy : {:.4f}".format( 144 | samples_in_epoch, self.train_nums, 145 | step % batch_num, batch_num, 146 | loss_in_epoch / samples_in_epoch, corrects_in_epoch / samples_in_epoch)) 147 | 148 | # evaluate on the valid set and early stopping 149 | if step and step % self.args.evaluate_every_n == 0: 150 | val_acc, val_loss = self.validate() 151 | self.early_stopping(val_acc, val_loss, step) 152 | 153 | def validate(self): 154 | batch_size = self.args.batch_size 155 | v_batch_num = self.valid_nums // batch_size 156 | # ensure the entire valid set is selected 157 | v_batch_num = v_batch_num + 1 if (self.valid_nums % batch_size) != 0 else v_batch_num 158 | logger("Validate on {} batches, {} samples per batch, {} total." 159 | .format(v_batch_num, batch_size, self.valid_nums)) 160 | val_num, val_corrects, v_loss = 0, 0, 0 161 | for i in range(v_batch_num): 162 | data, samples = self.get_batch_data("valid", i) 163 | if samples != 0: 164 | loss, v_correct = self.sess.run([self.loss, self.correct_prediction], feed_dict=data) 165 | val_num += samples 166 | val_corrects += v_correct 167 | v_loss += loss * samples 168 | assert (val_num == self.valid_nums) 169 | val_acc = val_corrects / val_num 170 | val_loss = v_loss / val_num 171 | logger("Evaluate on : {}/{}.\tVal acc : {:.4f}.\tVal Loss : {:.4f}".format(val_num, 172 | self.valid_nums, 173 | val_acc, 174 | val_loss)) 175 | return val_acc, val_loss 176 | 177 | # noinspection PyUnusedLocal 178 | def early_stopping(self, val_acc, val_loss, step): 179 | if val_acc > self.best_val_acc: 180 | self.patience = self.args.patience 181 | self.best_val_acc = val_acc 182 | self.save_weight(val_acc, step) 183 | elif self.patience == 1: 184 | logger("Oh u, stop training.") 185 | exit(0) 186 | else: 187 | self.patience -= 1 188 | logger("Remaining/Patience : {}/{} .".format(self.patience, self.args.patience)) 189 | 190 | def save_weight(self, val_acc, step): 191 | path = self.saver.save(self.sess, 192 | os.path.join(self.args.weight_path, 193 | "{}-val_acc-{:.4f}.models".format(self.model_name, val_acc)), 194 | global_step=step) 195 | logger("Save models to {}.".format(path)) 196 | 197 | def load_weight(self): 198 | ckpt = tf.train.get_checkpoint_state(self.args.weight_path) 199 | if ckpt is not None: 200 | logger("Load models from {}.".format(ckpt.model_checkpoint_path)) 201 | self.saver.restore(self.sess, ckpt.model_checkpoint_path) 202 | else: 203 | logger("No previous models.") 204 | 205 | def test(self): 206 | if not self.args.train: 207 | self.sess.run(tf.global_variables_initializer()) 208 | self.load_weight() 209 | batch_size = self.args.batch_size 210 | batch_num = self.test_num // batch_size 211 | batch_num = batch_num + 1 if (self.test_num % batch_size) != 0 else batch_num 212 | correct_num, total_num = 0, 0 213 | for i in range(batch_num): 214 | data, samples = self.get_batch_data("test", i) 215 | if samples != 0: 216 | correct, = self.sess.run([self.correct_prediction], feed_dict=data) 217 | correct_num, total_num = correct_num + correct, total_num + samples 218 | assert (total_num == self.test_num) 219 | logger("Test on : {}/{}".format(total_num, self.test_num)) 220 | test_acc = correct_num / total_num 221 | logger("Test accuracy is : {:.5f}".format(test_acc)) 222 | res = { 223 | "model": self.model_name, 224 | "test_acc": test_acc 225 | } 226 | save_obj_to_json(self.args.weight_path, res, "result.json") 227 | 228 | def confirm_model_dataset_fitness(self): 229 | # make sure the models_in_datasets var is correct 230 | try: 231 | assert (models_in_datasets.get(self.args.dataset, None) is not None) 232 | except AssertionError: 233 | err("Models_in_datasets doesn't have the specified dataset key: {}.".format(self.args.dataset)) 234 | self.sess.close() 235 | exit(1) 236 | # make sure the model fit the dataset 237 | try: 238 | assert (self.model_name in models_in_datasets.get(self.args.dataset, None)) 239 | except AssertionError: 240 | err("The model -> {} doesn't support the dataset -> {}".format(self.model_name, self.args.dataset)) 241 | self.sess.close() 242 | exit(1) 243 | 244 | def make_sure_model_is_valid(self): 245 | """ 246 | check if the model has necessary attributes 247 | """ 248 | try: 249 | _ = self.loss 250 | _ = self.correct_prediction 251 | except AttributeError as e: 252 | err("Your model {} doesn't have enough attributes.\nError Message:\n\t{}".format(self.model_name, e)) 253 | self.sess.close() 254 | exit(1) 255 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk>=3.2.1 2 | numpy>=1.12.1 3 | -------------------------------------------------------------------------------- /test/dataset_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import sys 4 | import unittest 5 | 6 | # noinspection PyUnresolvedReferences 7 | import dataset 8 | 9 | 10 | class TestDataset(unittest.TestCase): 11 | def setUp(self): 12 | logging.basicConfig(filename=None, 13 | level=logging.DEBUG, 14 | format='%(asctime)s %(message)s', datefmt='%y-%m-%d %H:%M') 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--debug", default=True, type=str, help="is debug mode on or off") 17 | 18 | parser.add_argument("--data_root", default="../data/SQuAD/", 19 | help="root path of the dataset") 20 | 21 | parser.add_argument("--tmp_dir", default="tmp", help="dataset specific tmp folder") 22 | 23 | parser.add_argument("--train_file", default="train-v1.1.json", help="train file") 24 | 25 | parser.add_argument("--valid_file", default="dev-v1.1.json", help="validation file") 26 | 27 | parser.add_argument("--max_count", default=None, type=int, 28 | help="read n lines of data file, if None, read all data") 29 | 30 | parser.add_argument("--max_vocab_num", default=100000, type=int, help="the max number of words in vocabulary") 31 | 32 | parser.add_argument("--batch_size", default=32, type=int, help="batch_size") 33 | 34 | parser.add_argument("--train", default=True, type=bool, help="train or not") 35 | 36 | parser.add_argument("--test", default=True, type=bool, help="test or not") 37 | 38 | self.args = parser.parse_known_args()[0] 39 | 40 | 41 | class TestCBT(TestDataset): 42 | def runTest(self): 43 | self.args.data_root = "../data/CBTest/CBTest/data/" 44 | self.args.train_file = "cbtest_NE_train.txt" 45 | self.args.valid_file = "cbtest_NE_valid_2000ex.txt" 46 | self.args.test_file = "cbtest_NE_test_2500ex.txt" 47 | self.dataset = getattr(sys.modules["dataset"], "CBT")(self.args) 48 | statistics = self.dataset.get_data_stream() 49 | for i in statistics[1:]: 50 | self.assertEqual(type(i), int, "Some data statistic not int.") 51 | self.assertGreater(i, 0, "Some data number not greater than zero.") 52 | 53 | 54 | class TestSQuAD(TestDataset): 55 | def runTest(self): 56 | self.dataset = getattr(sys.modules["dataset"], "SQuAD")(self.args) 57 | data_dir, train_file, valid_file = self.args.data_root, self.args.train_file, self.args.valid_file 58 | max_vocab_num, output_dir = self.args.max_vocab_num, self.args.tmp_dir 59 | 60 | os_train_file, os_valid_file, vocab_file, char_vocab_file = self.dataset.prepare_data(data_dir, train_file, 61 | valid_file, max_vocab_num, 62 | output_dir) 63 | 64 | documents, questions, _ = self.dataset.read_squad_data(os_train_file) 65 | v_documents, v_questions, _ = self.dataset.read_squad_data(os_valid_file) 66 | data = self.dataset.squad_data_to_idx(vocab_file, documents, questions, 67 | v_documents, v_questions) 68 | # make sure that each one of (d,q,v_d,v_q) is a list, and each element is a list too. 69 | for i in data: 70 | self.assertEqual(type(i), list, "some data in train set or valid set is not a list.") 71 | self.assertGreater(len(i), 0, "some data in train set or valid set is None.") 72 | self.assertEqual(type(i[0]), list, "some elements in train set or valid set is not a list.") 73 | for word in i[0]: 74 | self.assertEqual(type(word), int, "Not all the word is index form.") 75 | self.assertGreaterEqual(word, 0, "Invalid index for some word.") 76 | -------------------------------------------------------------------------------- /test/notebook/test_aoa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "### 1.test M_mask calculation" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import tensorflow as tf" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 25, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "q_mask shape:(5, 5)\nd_mask shape:(5, 10)\n" 31 | ] 32 | }, 33 | { 34 | "data": { 35 | "text/plain": [ 36 | "TensorShape([Dimension(5), Dimension(10), Dimension(5)])" 37 | ] 38 | }, 39 | "execution_count": 25, 40 | "metadata": {}, 41 | "output_type": "execute_result" 42 | } 43 | ], 44 | "source": [ 45 | "q_len, d_len = 5, 10\n", 46 | "q_lens = tf.constant([3, 2, 1, 3, 4], dtype=tf.int32)\n", 47 | "d_lens = tf.constant([7, 8, 9, 6, 6], dtype=tf.int32)\n", 48 | "q_mask = tf.sequence_mask(q_lens, q_len, dtype=tf.float32)\n", 49 | "d_mask = tf.sequence_mask(d_lens, d_len, dtype=tf.float32)\n", 50 | "\n", 51 | "print(\"q_mask shape:{}\".format(q_mask.get_shape()))\n", 52 | "print(\"d_mask shape:{}\".format(d_mask.get_shape()))\n", 53 | "M_mask = tf.einsum(\"bi,bj->bij\", d_mask, q_mask)\n", 54 | "M_mask.get_shape()" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 26, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "[[ 1. 1. 1. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 1. 0.]]\n--------------------------------------------------\n[[ 1. 1. 1. 1. 1. 1. 1. 0. 0. 0.]\n [ 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.]\n [ 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]\n [ 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]]\n--------------------------------------------------\n[[[ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]]\n\n [[ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]]\n\n [[ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]]\n\n [[ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]]\n\n [[ 1. 1. 1. 1. 0.]\n [ 1. 1. 1. 1. 0.]\n [ 1. 1. 1. 1. 0.]\n [ 1. 1. 1. 1. 0.]\n [ 1. 1. 1. 1. 0.]\n [ 1. 1. 1. 1. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]]]\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "with tf.Session() as sess:\n", 72 | " sess.run(tf.global_variables_initializer())\n", 73 | " print(sess.run(q_mask))\n", 74 | " print(\"-\" * 50)\n", 75 | " print(sess.run(d_mask))\n", 76 | " print(\"-\" * 50)\n", 77 | " print(sess.run(M_mask))" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "### 2.test attention sum" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 1, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "import tensorflow as tf\n", 94 | "import numpy as np" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 7, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "[[ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]\n [ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]\n [ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]]\n(3, 10)\n--------------------------------------------------\n[[ 4 3 18 19 7 9 13 1 10 11]\n [ 5 16 11 6 4 0 16 1 11 8]\n [ 3 5 2 12 2 8 14 1 15 11]]\n(3, 10)\n--------------------------------------------------\n[[ 3 11 19 11]\n [ 9 1 14 11]\n [13 10 10 13]]\n(3, 4)\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "batch_size = 3\n", 112 | "d_len = 10\n", 113 | "vocab_size = 20\n", 114 | "A_len = 4\n", 115 | "true_s_bd = np.array([0.1]*d_len*batch_size).reshape(batch_size,d_len)\n", 116 | "true_documents_bt = np.random.randint(0,vocab_size,size=(batch_size,d_len))\n", 117 | "true_candidates_bi = np.random.randint(0,vocab_size,size=(batch_size,A_len))\n", 118 | "\n", 119 | "print(true_s_bd)\n", 120 | "print(true_s_bd.shape)\n", 121 | "print(\"-\"*50)\n", 122 | "print(true_documents_bt)\n", 123 | "print(true_documents_bt.shape)\n", 124 | "print(\"-\"*50)\n", 125 | "print(true_candidates_bi)\n", 126 | "print(true_candidates_bi.shape)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 8, 132 | "metadata": { 133 | "collapsed": false 134 | }, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "[[ 0.1 0.1 0.1 0.1]\n [ 0. 0.1 0. 0.2]\n [ 0. 0. 0. 0. ]]\n" 141 | ] 142 | } 143 | ], 144 | "source": [ 145 | "s_bd = tf.placeholder(dtype=tf.float32, shape=(None, d_len), name=\"s_bd\")\n", 146 | "documents_bt = tf.placeholder(dtype=tf.int32, shape=(None, d_len), name=\"documents_bt\")\n", 147 | "candidates_bi = tf.placeholder(dtype=tf.int32, shape=(None, A_len), name=\"candidates_bi\")\n", 148 | "y_hat_bi = tf.scan(fn=lambda prev, cur:\n", 149 | "tf.gather(tf.unsorted_segment_sum(cur[0], cur[1], vocab_size), cur[2]),\n", 150 | " elems=[s_bd, documents_bt, candidates_bi],\n", 151 | " initializer=tf.Variable([0.] * A_len,dtype=tf.float32))\n", 152 | "with tf.Session() as sess:\n", 153 | " sess.run(tf.global_variables_initializer())\n", 154 | " data = {\n", 155 | " s_bd:true_s_bd,\n", 156 | " documents_bt:true_documents_bt,\n", 157 | " candidates_bi:true_candidates_bi\n", 158 | " }\n", 159 | " print(sess.run(y_hat_bi,feed_dict=data))" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "" 169 | ] 170 | } 171 | ], 172 | "metadata": { 173 | "kernelspec": { 174 | "display_name": "Python 2", 175 | "language": "python", 176 | "name": "python2" 177 | }, 178 | "language_info": { 179 | "codemirror_mode": { 180 | "name": "ipython", 181 | "version": 2.0 182 | }, 183 | "file_extension": ".py", 184 | "mimetype": "text/x-python", 185 | "name": "python", 186 | "nbconvert_exporter": "python", 187 | "pygments_lexer": "ipython2", 188 | "version": "2.7.6" 189 | } 190 | }, 191 | "nbformat": 4, 192 | "nbformat_minor": 0 193 | } -------------------------------------------------------------------------------- /test/notebook/test_as_reader.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 测试注意力向量计算" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 3, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "Using TensorFlow backend.\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "import tensorflow as tf\n", 25 | "import numpy as np\n", 26 | "import keras.backend as K" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 76, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "def my_dot(x):\n", 36 | " c = [tf.reduce_sum(tf.multiply(x[0][:, inx, :], x[1]), -1, keep_dims=True) for inx in range(3)]\n", 37 | " return tf.concat(c, -1)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 77, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "def my_dot_v2(x, y):\n", 47 | " \"\"\"注意力点乘函数,快速版本\"\"\"\n", 48 | " res = K.batch_dot(tf.expand_dims(y, -1),x, (1, 2))\n", 49 | " return K.reshape(res, [-1, 3])" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 4, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "def att_dot(x):\n", 59 | " \"\"\"注意力点乘函数\"\"\"\n", 60 | " d_btf, q_bf = x\n", 61 | " res = K.batch_dot(tf.expand_dims(y, -1),x, (1, 2))\n", 62 | " return tf.reshape(res, [-1, 3])" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 12, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "def new_att_dot(x):\n", 72 | " d_btf, q_bf = x\n", 73 | " res = tf.matmul(tf.expand_dims(q_bf, -1), d_btf, adjoint_a=True,adjoint_b=True)\n", 74 | " return tf.reshape(res, [-1, 3])" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "[[[ 0 1 2 3]\n [ 4 5 6 7]\n [ 8 9 10 11]]\n\n [[12 13 14 15]\n [16 17 18 19]\n [20 21 22 23]]]\n--------------------\n(2, 3, 4)\n--------------------\n(2, 4)\n--------------------\n[[0 1 2 3]\n [4 5 6 7]]\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "a = tf.placeholder(tf.float32,shape=(None,3,4))\n", 92 | "b = tf.placeholder(tf.float32,shape=(None,4))\n", 93 | "true_a = np.arange(24).reshape(2,3,4)\n", 94 | "true_b = np.arange(8).reshape(2,4)\n", 95 | "print(true_a)\n", 96 | "print('-'*20)\n", 97 | "print(true_a.shape)\n", 98 | "print('-'*20)\n", 99 | "print(true_b.shape)\n", 100 | "print('-'*20)\n", 101 | "print(true_b)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 14, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "data": { 111 | "text/plain": [ 112 | "TensorShape([Dimension(None), Dimension(3)])" 113 | ] 114 | }, 115 | "execution_count": 14, 116 | "metadata": {}, 117 | "output_type": "execute_result" 118 | } 119 | ], 120 | "source": [ 121 | "d = att_dot([a,b])\n", 122 | "d.get_shape()\n", 123 | "e = new_att_dot([a,b])\n", 124 | "e.get_shape()" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 15, 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "name": "stdout", 134 | "output_type": "stream", 135 | "text": [ 136 | "[[ 14. 38. 62.]\n [ 302. 390. 478.]]\n--------------------------------------------------\n[[ 14. 38. 62.]\n [ 302. 390. 478.]]\n" 137 | ] 138 | } 139 | ], 140 | "source": [ 141 | "with tf.Session() as sess:\n", 142 | " sess.run(tf.global_variables_initializer())\n", 143 | " print(sess.run(d, {a: true_a, b: true_b}))\n", 144 | " print(\"-\"*50)\n", 145 | " print(sess.run(e, {a: true_a, b: true_b}))" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 80, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "(?, 4, 1)\n(?, 3, 4)\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "x = tf.expand_dims(b,-1)\n", 163 | "print(x.get_shape())\n", 164 | "y = a\n", 165 | "print(y.get_shape())" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 81, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "text/plain": [ 176 | "TensorShape([Dimension(None), Dimension(3)])" 177 | ] 178 | }, 179 | "execution_count": 81, 180 | "metadata": {}, 181 | "output_type": "execute_result" 182 | } 183 | ], 184 | "source": [ 185 | "res = K.batch_dot(x,y,(1,2))\n", 186 | "res = tf.reshape(res,[-1,3])\n", 187 | "res.get_shape()" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 82, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "data": { 197 | "text/plain": [ 198 | "TensorShape([Dimension(None), Dimension(3)])" 199 | ] 200 | }, 201 | "execution_count": 82, 202 | "metadata": {}, 203 | "output_type": "execute_result" 204 | } 205 | ], 206 | "source": [ 207 | "d = my_dot_v2(a,b)\n", 208 | "d.get_shape()" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 63, 214 | "metadata": {}, 215 | "outputs": [ 216 | { 217 | "data": { 218 | "text/plain": [ 219 | "478" 220 | ] 221 | }, 222 | "execution_count": 63, 223 | "metadata": {}, 224 | "output_type": "execute_result" 225 | } 226 | ], 227 | "source": [ 228 | "np.arange(20,24).dot(np.array([4,5,6,7]))" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "## 测试tensorflow的scan" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 4, 241 | "metadata": { 242 | "collapsed": true 243 | }, 244 | "outputs": [ 245 | { 246 | "name": "stdout", 247 | "output_type": "stream", 248 | "text": [ 249 | "(3, 4)\nTensor(\"scan_1/while/TensorArrayReadV3:0\", shape=(3, 4), dtype=int32)\nTensor(\"scan_1/while/Identity_1:0\", shape=(3, 4), dtype=int32)\n" 250 | ] 251 | }, 252 | { 253 | "name": "stdout", 254 | "output_type": "stream", 255 | "text": [ 256 | "(?, 1)\n[ 27. 5. 7.]\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "sentence_ids = tf.Variable([5, 8, 1, 3, 0, 34, 8, 7, 3, 8])\n", 262 | "attentions = tf.Variable([5, 9, 1, 3, 0, 34, 9, 7, 3, 9],dtype=\"float32\")\n", 263 | "word_id = tf.Variable(8)\n", 264 | "word_ids = tf.Variable([8,5,7])\n", 265 | "aaa = tf.equal(sentence_ids, word_id)\n", 266 | "ccc = tf.where(aaa)\n", 267 | "qqq = tf.reduce_sum(tf.gather(attentions,ccc))\n", 268 | "\n", 269 | "\n", 270 | "def sum_prob_of_word(word_ix, sentence_ixs, sentence_attention_probs):\n", 271 | " word_ixs_in_sentence = tf.where(tf.equal(sentence_ixs, word_ix))\n", 272 | " return tf.reduce_sum(tf.gather(sentence_attention_probs, word_ixs_in_sentence))\n", 273 | "\n", 274 | "test_func = lambda x:sum_prob_of_word(x,sentence_ids,attentions)\n", 275 | "\n", 276 | "ppp = test_func(word_id)\n", 277 | "\n", 278 | "def sum_probs_single_sentence(prev,cur):\n", 279 | " candidate_indices_i, sentence_ixs_t, sentence_attention_probs_t = cur\n", 280 | " result = tf.scan(\n", 281 | " fn=lambda prev,x: sum_prob_of_word(x, sentence_ixs_t, sentence_attention_probs_t),\n", 282 | " elems=[candidate_indices_i],\n", 283 | " initializer=tf.Variable(0.,dtype=\"float32\"))\n", 284 | " return result\n", 285 | "\n", 286 | "zzz = sum_probs_single_sentence(None,[word_ids,sentence_ids,attentions])\n", 287 | "\n", 288 | "def func(prev, cur):\n", 289 | " print(cur.get_shape())\n", 290 | " print(cur)\n", 291 | " print(prev)\n", 292 | " return cur\n", 293 | "\n", 294 | "v = tf.Variable(np.arange(24).reshape(2, 3, 4))\n", 295 | "# print(v.get_shape())\n", 296 | "\n", 297 | "bbb = tf.scan(func, elems=v)\n", 298 | "with tf.Session() as sess:\n", 299 | " sess.run(tf.global_variables_initializer())\n", 300 | " print(ccc.get_shape())\n", 301 | " print(sess.run(zzz))\n", 302 | " # print(sess.run(bbb))" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 1, 308 | "metadata": {}, 309 | "outputs": [ 310 | { 311 | "name": "stdout", 312 | "output_type": "stream", 313 | "text": [ 314 | "[[ 0.60000002 0.2 0. ]\n [ 0. 0. 0. ]\n [ 0.1 0. 0.1 ]]\n" 315 | ] 316 | } 317 | ], 318 | "source": [ 319 | "import tensorflow as tf\n", 320 | "import numpy as np\n", 321 | "\n", 322 | "\n", 323 | "def sum_prob_of_word(word_ix, sentence_ixs, sentence_attention_probs):\n", 324 | " word_ixs_in_sentence = tf.where(tf.equal(sentence_ixs, word_ix))\n", 325 | " return tf.reduce_sum(tf.gather(sentence_attention_probs, word_ixs_in_sentence))\n", 326 | "\n", 327 | "def sum_probs_single_sentence(prev,cur):\n", 328 | " candidate_indices_i, sentence_ixs_t, sentence_attention_probs_t = cur\n", 329 | " result = tf.scan(\n", 330 | " fn=lambda prev,x: sum_prob_of_word(x, sentence_ixs_t, sentence_attention_probs_t),\n", 331 | " elems=[candidate_indices_i],\n", 332 | " initializer=tf.constant(0.,dtype=\"float32\"))\n", 333 | " return result\n", 334 | "\n", 335 | "def sum_probs_batch(candidate_indices_bt, sentence_ixs_bt, sentence_attention_probs_bt):\n", 336 | " result = tf.scan(\n", 337 | " fn=sum_probs_single_sentence,\n", 338 | " elems=[candidate_indices_bt, sentence_ixs_bt, sentence_attention_probs_bt],\n", 339 | " initializer=tf.Variable([1,2,3],dtype=\"float32\"))\n", 340 | " return result\n", 341 | "\n", 342 | "candidate_idx = tf.Variable([\n", 343 | " [16, 21, 8],\n", 344 | " [13, 19, 26],\n", 345 | " [23, 9, 23]\n", 346 | "])\n", 347 | "\n", 348 | "sentence_idx = tf.Variable([\n", 349 | " [16, 21, 23, 16, 8, 9, 21],\n", 350 | " [16, 21, 23, 16, 8, 9, 21],\n", 351 | " [16, 21, 23, 16, 8, 9, 21],\n", 352 | "])\n", 353 | "\n", 354 | "attention_idx = tf.Variable([\n", 355 | " [0.3, 0.2, 0.1, 0.3, 0, 0, 0],\n", 356 | " [0.3, 0.2, 0.1, 0.3, 0, 0, 0],\n", 357 | " [0.3, 0.2, 0.1, 0.3, 0, 0, 0]\n", 358 | "],dtype=\"float32\")\n", 359 | "\n", 360 | "o = sum_probs_batch(candidate_idx, sentence_idx, attention_idx)\n", 361 | "with tf.Session() as sess:\n", 362 | " sess.run(tf.global_variables_initializer())\n", 363 | " print(sess.run(o))" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 1, 369 | "metadata": {}, 370 | "outputs": [ 371 | { 372 | "name": "stderr", 373 | "output_type": "stream", 374 | "text": [ 375 | "Using TensorFlow backend.\n" 376 | ] 377 | }, 378 | { 379 | "data": { 380 | "text/plain": [ 381 | "array([[ 1., 0., 0., 0., 0., 0., 0.],\n [ 1., 1., 0., 0., 0., 0., 0.],\n [ 1., 1., 1., 0., 0., 0., 0.],\n [ 1., 1., 1., 1., 0., 0., 0.],\n [ 1., 1., 1., 1., 1., 0., 0.]], dtype=float32)" 382 | ] 383 | }, 384 | "execution_count": 1, 385 | "metadata": {}, 386 | "output_type": "execute_result" 387 | } 388 | ], 389 | "source": [ 390 | "import tensorflow as tf\n", 391 | "import keras.backend as K\n", 392 | "\n", 393 | "a = [1,2,3,4,5]\n", 394 | "K.eval(tf.sequence_mask(a,7,dtype=tf.float32))" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 2, 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [ 403 | "import numpy as np" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 1, 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "batch_size = 4\n", 413 | "A_size = 5" 414 | ] 415 | }, 416 | { 417 | "cell_type": "markdown", 418 | "metadata": {}, 419 | "source": [ 420 | "### 测试获取dynamic_rnn输出的有效位" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 1, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "import tensorflow as tf\n", 430 | "import numpy as np\n", 431 | "from tensorflow.contrib.rnn import LSTMCell, GRUCell, MultiRNNCell\n", 432 | "\n", 433 | "i = np.random.rand(1000)" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 2, 439 | "metadata": {}, 440 | "outputs": [], 441 | "source": [ 442 | "with tf.variable_scope('q_encoder'):\n", 443 | " cell = MultiRNNCell(cells=[GRUCell(2)] * 1)\n", 444 | " x = tf.placeholder(dtype=tf.float32, shape=(2, 5, 100), name=\"x\")\n", 445 | " q_lens = tf.placeholder(dtype=tf.int32,shape=(2))\n", 446 | " outputs, last_states = tf.nn.bidirectional_dynamic_rnn(cell_bw=cell,\n", 447 | " cell_fw=cell,\n", 448 | " dtype=\"float32\",\n", 449 | " sequence_length=q_lens,\n", 450 | " inputs=x,\n", 451 | " swap_memory=True)\n", 452 | " q_enc = tf.gather_nd(outputs, tf.stack([tf.range(q_lens.get_shape()[0]), q_lens - 1], axis=1))\n", 453 | " q_enc_c = tf.concat(outputs,axis=-1)" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 3, 459 | "metadata": { 460 | "collapsed": false 461 | }, 462 | "outputs": [ 463 | { 464 | "name": "stdout", 465 | "output_type": "stream", 466 | "text": [ 467 | "outputs:(array([[[ 0.0560709 , -0.10033438],\n [ 0.07781161, -0.220449 ],\n [ 0.06069176, -0.23297542],\n [ 0. , 0. ],\n [ 0. , 0. ]],\n\n [[ 0.02716292, -0.05908116],\n [-0.03233079, -0.02452423],\n [ 0. , 0. ],\n [ 0. , 0. ],\n [ 0. , 0. ]]], dtype=float32), array([[[ 0.08484475, 0.06869046],\n [ 0.07913341, 0.17418024],\n [ 0.11813057, -0.0282253 ],\n [ 0. , 0. ],\n [ 0. , 0. ]],\n\n [[-0.12557939, -0.38799691],\n [-0.09159085, -0.26809931],\n [ 0. , 0. ],\n [ 0. , 0. ],\n [ 0. , 0. ]]], dtype=float32))\nstates:((array([[ 0.06069176, -0.23297542],\n [-0.03233079, -0.02452423]], dtype=float32),), (array([[ 0.08484475, 0.06869046],\n [-0.12557939, -0.38799691]], dtype=float32),))\nq_enc:[[[ 0. 0. ]\n [ 0. 0. ]\n [ 0. 0. ]\n [ 0. 0. ]\n [ 0. 0. ]]\n\n [[-0.12557939 -0.38799691]\n [-0.09159085 -0.26809931]\n [ 0. 0. ]\n [ 0. 0. ]\n [ 0. 0. ]]]\nq_enc_c:[[[ 0.0560709 -0.10033438 0.08484475 0.06869046]\n [ 0.07781161 -0.220449 0.07913341 0.17418024]\n [ 0.06069176 -0.23297542 0.11813057 -0.0282253 ]\n [ 0. 0. 0. 0. ]\n [ 0. 0. 0. 0. ]]\n\n [[ 0.02716292 -0.05908116 -0.12557939 -0.38799691]\n [-0.03233079 -0.02452423 -0.09159085 -0.26809931]\n [ 0. 0. 0. 0. ]\n [ 0. 0. 0. 0. ]\n [ 0. 0. 0. 0. ]]]\n" 468 | ] 469 | } 470 | ], 471 | "source": [ 472 | "with tf.Session() as sess:\n", 473 | " sess.run(tf.global_variables_initializer())\n", 474 | " a, b, c, d = sess.run([outputs, last_states, q_enc, q_enc_c], feed_dict={\n", 475 | " q_lens: (3, 2),\n", 476 | " x: i.reshape(2, 5, 100)\n", 477 | " })\n", 478 | " print(\"outputs:{}\\nstates:{}\\nq_enc:{}\\nq_enc_c:{}\".format(a, b, c, d))" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "### 测试参数初始化" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 3, 491 | "metadata": {}, 492 | "outputs": [], 493 | "source": [ 494 | "import tensorflow as tf\n", 495 | "import numpy as np\n", 496 | "from tensorflow.contrib.rnn import LSTMCell, MultiRNNCell, GRUCell" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 9, 502 | "metadata": {}, 503 | "outputs": [ 504 | { 505 | "name": "stdout", 506 | "output_type": "stream", 507 | "text": [ 508 | "[[ 0.89980626 0.31822035 0.29847053]\n [-0.29329553 0.94766068 -0.12616226]\n [-0.3229962 0.02598153 0.94604361]]\n[[ 0. 0. 0.]\n [ 0. 0. 0.]\n [ 0. 0. 0.]]\n" 509 | ] 510 | } 511 | ], 512 | "source": [ 513 | "from tensorflow import variable_scope, get_variable\n", 514 | "\n", 515 | "with variable_scope(\"cc\",initializer=tf.orthogonal_initializer()):\n", 516 | " with variable_scope(\"ddd\"):\n", 517 | " a = get_variable(\"weight\",shape=(3,3),dtype=tf.float32)\n", 518 | "\n", 519 | "with variable_scope(\"cccc\",initializer=tf.zeros_initializer()):\n", 520 | " with variable_scope(\"ccc\"):\n", 521 | " b = get_variable(\"weight\",shape=(3,3),dtype=tf.float32)\n", 522 | "\n", 523 | "with tf.Session() as sess:\n", 524 | " sess.run(tf.global_variables_initializer())\n", 525 | " print(sess.run(a))\n", 526 | " print(sess.run(b))" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": null, 532 | "metadata": {}, 533 | "outputs": [], 534 | "source": [ 535 | "" 536 | ] 537 | } 538 | ], 539 | "metadata": { 540 | "kernelspec": { 541 | "display_name": "Python 2", 542 | "language": "python", 543 | "name": "python2" 544 | }, 545 | "language_info": { 546 | "codemirror_mode": { 547 | "name": "ipython", 548 | "version": 2.0 549 | }, 550 | "file_extension": ".py", 551 | "mimetype": "text/x-python", 552 | "name": "python", 553 | "nbconvert_exporter": "python", 554 | "pygments_lexer": "ipython2", 555 | "version": "2.7.6" 556 | } 557 | }, 558 | "nbformat": 4, 559 | "nbformat_minor": 0 560 | } -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cairoHy/RC-experiments/0262f83481c364f29a43ac7cfc28da88d31f5adc/utils/__init__.py -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import sys 6 | from pprint import pprint 7 | 8 | logger = logging.info 9 | err = logging.error 10 | 11 | 12 | def setup_from_args_file(file): 13 | if not file: 14 | return 15 | json_dict = json.load(open(file, encoding="utf-8")) 16 | args = [sys.argv[0]] 17 | for k, v in json_dict.items(): 18 | args.append("--{}".format(k)) 19 | args.append(str(v)) 20 | sys.argv = args.copy() + sys.argv[1:] 21 | 22 | 23 | def save_args(args): 24 | """ 25 | save all arguments. 26 | """ 27 | save_obj_to_json(args.weight_path, vars(args), filename="args.json") 28 | pprint(vars(args), indent=4) 29 | 30 | 31 | def save_obj_to_json(path, obj, filename): 32 | if not os.path.exists(path): 33 | os.mkdir(path) 34 | file = os.path.join(path, filename) 35 | with open(file, "w", encoding="utf-8") as fp: 36 | json.dump(obj, fp, sort_keys=True, indent=4) 37 | -------------------------------------------------------------------------------- /weights/AS-reader/best-CBT-CN/args.json: -------------------------------------------------------------------------------- 1 | { 2 | "args_file": "weights/args.json", 3 | "batch_size": 32, 4 | "data_root": "data/CBTest/CBTest/data/", 5 | "dataset": "cbt", 6 | "debug": false, 7 | "embedding_dim": 300, 8 | "embedding_file": "data/glove.6B/glove.6B.300d.txt", 9 | "ensemble": false, 10 | "evaluate_every_n": 400, 11 | "grad_clipping": 10, 12 | "hidden_size": 128, 13 | "keep_prob": 0.5, 14 | "l2": 0.0001, 15 | "log_file": null, 16 | "lr": 0.001, 17 | "max_count": null, 18 | "max_vocab_num": 100000, 19 | "num_epoches": 10, 20 | "num_layers": 1, 21 | "optimizer": "ADAM", 22 | "patience": 5, 23 | "print_every_n": 10, 24 | "random_seed": 2088, 25 | "test": true, 26 | "test_file": "cbtest_CN_test_2500ex.txt", 27 | "tmp_dir": "tmp", 28 | "train": false, 29 | "train_file": "cbtest_CN_train.txt", 30 | "use_lstm": false, 31 | "valid_file": "cbtest_CN_valid_2000ex.txt", 32 | "weight_path": "weights/" 33 | } -------------------------------------------------------------------------------- /weights/AS-reader/best-CBT-CN/result.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "attention_sum_reader.py", 3 | "test_acc": 0.65 4 | } -------------------------------------------------------------------------------- /weights/AS-reader/best-CBT-NE/args.json: -------------------------------------------------------------------------------- 1 | { 2 | "args_file": "weights/AS-reader/best-CBT-NE/args.json", 3 | "batch_size": 32, 4 | "data_root": "data/CBTest/CBTest/data/", 5 | "dataset": "cbt", 6 | "debug": false, 7 | "embedding_dim": 200, 8 | "embedding_file": "data/glove.6B/glove.6B.200d.txt", 9 | "ensemble": false, 10 | "evaluate_every_n": 400, 11 | "grad_clipping": 10, 12 | "hidden_size": 128, 13 | "keep_prob": 0.5, 14 | "l2": 0.0001, 15 | "log_file": null, 16 | "lr": 0.001, 17 | "max_count": null, 18 | "max_vocab_num": 100000, 19 | "num_epoches": 10, 20 | "num_layers": 1, 21 | "optimizer": "ADAM", 22 | "patience": 5, 23 | "print_every_n": 10, 24 | "random_seed": 2088, 25 | "test": true, 26 | "test_file": "cbtest_NE_test_2500ex.txt", 27 | "tmp_dir": "tmp", 28 | "train": false, 29 | "train_file": "cbtest_NE_train.txt", 30 | "use_lstm": false, 31 | "valid_file": "cbtest_NE_valid_2000ex.txt", 32 | "weight_path": "weights/AS-reader/best-CBT-NE" 33 | } -------------------------------------------------------------------------------- /weights/AS-reader/best-CBT-NE/result.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "attention_sum_reader.py", 3 | "test_acc": 0.6896 4 | } -------------------------------------------------------------------------------- /weights/AS-reader/best-best-CBT-NE/args.json: -------------------------------------------------------------------------------- 1 | { 2 | "args_file": "weights/AS-reader/best-best-CBT-NE/args.json", 3 | "batch_size": 32, 4 | "data_root": "data/CBTest/CBTest/data/", 5 | "dataset": "cbt", 6 | "debug": false, 7 | "embedding_dim": 300, 8 | "embedding_file": "data/glove.6B/glove.6B.300d.txt", 9 | "ensemble": false, 10 | "evaluate_every_n": 400, 11 | "grad_clipping": 10, 12 | "hidden_size": 128, 13 | "keep_prob": 0.5, 14 | "l2": 0.0001, 15 | "log_file": null, 16 | "lr": 0.001, 17 | "max_count": null, 18 | "max_vocab_num": 100000, 19 | "num_epoches": 10, 20 | "num_layers": 1, 21 | "optimizer": "ADAM", 22 | "patience": 5, 23 | "print_every_n": 10, 24 | "random_seed": 2088, 25 | "test": true, 26 | "test_file": "cbtest_NE_test_2500ex.txt", 27 | "tmp_dir": "tmp", 28 | "train": false, 29 | "train_file": "cbtest_NE_train.txt", 30 | "use_lstm": false, 31 | "valid_file": "cbtest_NE_valid_2000ex.txt", 32 | "weight_path": "weights/AS-reader/best-best-CBT-NE" 33 | } -------------------------------------------------------------------------------- /weights/AS-reader/best-best-CBT-NE/result.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "attention_sum_reader.py", 3 | "test_acc": 0.6988 4 | } -------------------------------------------------------------------------------- /weights/AoA-reader/best-CBT-CN/args.json: -------------------------------------------------------------------------------- 1 | { 2 | "args_file": "weights/args.json", 3 | "batch_size": 32, 4 | "data_root": "data/CBTest/data/", 5 | "dataset": "cbt", 6 | "debug": false, 7 | "embedding_dim": 300, 8 | "embedding_file": "data/glove.6B/glove.6B.300d.txt", 9 | "ensemble": false, 10 | "evaluate_every_n": 400, 11 | "grad_clipping": 10, 12 | "hidden_size": 128, 13 | "keep_prob": 0.5, 14 | "l2": 0.0001, 15 | "log_file": null, 16 | "lr": 0.001, 17 | "max_count": null, 18 | "max_vocab_num": 100000, 19 | "num_epoches": 10, 20 | "num_layers": 1, 21 | "optimizer": "ADAM", 22 | "patience": 5, 23 | "print_every_n": 10, 24 | "random_seed": 2088, 25 | "test": true, 26 | "test_file": "cbtest_CN_test_2500ex.txt", 27 | "tmp_dir": "tmp", 28 | "train": false, 29 | "train_file": "cbtest_CN_train.txt", 30 | "use_lstm": false, 31 | "valid_file": "cbtest_CN_valid_2000ex.txt", 32 | "weight_path": "weights/" 33 | } -------------------------------------------------------------------------------- /weights/AoA-reader/best-CBT-CN/result.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "attention_over_attention_reader.py", 3 | "test_acc": 0.6812 4 | } -------------------------------------------------------------------------------- /weights/AoA-reader/best-CBT-NE/args.json: -------------------------------------------------------------------------------- 1 | { 2 | "args_file": "weights/AoA-reader/best-CBT-NE/args.json", 3 | "batch_size": 32, 4 | "data_root": "data/CBTest/CBTest/data/", 5 | "dataset": "cbt", 6 | "debug": false, 7 | "embedding_dim": 200, 8 | "embedding_file": "data/glove.6B/glove.6B.200d.txt", 9 | "ensemble": false, 10 | "evaluate_every_n": 400, 11 | "grad_clipping": 10, 12 | "hidden_size": 128, 13 | "keep_prob": 1.0, 14 | "l2": 0.0001, 15 | "log_file": null, 16 | "lr": 0.001, 17 | "max_count": null, 18 | "max_vocab_num": 100000, 19 | "num_epoches": 10, 20 | "num_layers": 1, 21 | "optimizer": "ADAM", 22 | "patience": 5, 23 | "print_every_n": 10, 24 | "random_seed": 2088, 25 | "test": true, 26 | "test_file": "cbtest_NE_test_2500ex.txt", 27 | "tmp_dir": "tmp", 28 | "train": false, 29 | "train_file": "cbtest_NE_train.txt", 30 | "use_lstm": false, 31 | "valid_file": "cbtest_NE_valid_2000ex.txt", 32 | "weight_path": "weights/AoA-reader/best-CBT-NE" 33 | } -------------------------------------------------------------------------------- /weights/AoA-reader/best-CBT-NE/result.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "attention_over_attention_reader.py", 3 | "test_acc": 0.7088 4 | } -------------------------------------------------------------------------------- /weights/AoA-reader/best-best-CBT-NE/args.json: -------------------------------------------------------------------------------- 1 | { 2 | "args_file": "weights/args.json", 3 | "batch_size": 32, 4 | "data_root": "data/CBTest/CBTest/data/", 5 | "dataset": "cbt", 6 | "debug": false, 7 | "embedding_dim": 300, 8 | "embedding_file": "data/glove.6B/glove.6B.300d.txt", 9 | "ensemble": false, 10 | "evaluate_every_n": 400, 11 | "grad_clipping": 10, 12 | "hidden_size": 128, 13 | "keep_prob": 0.5, 14 | "l2": 0.0001, 15 | "log_file": null, 16 | "lr": 0.001, 17 | "max_count": null, 18 | "max_vocab_num": 100000, 19 | "num_epoches": 10, 20 | "num_layers": 1, 21 | "optimizer": "ADAM", 22 | "patience": 5, 23 | "print_every_n": 10, 24 | "random_seed": 2088, 25 | "test": true, 26 | "test_file": "cbtest_NE_test_2500ex.txt", 27 | "tmp_dir": "tmp", 28 | "train": false, 29 | "train_file": "cbtest_NE_train.txt", 30 | "use_lstm": false, 31 | "valid_file": "cbtest_NE_valid_2000ex.txt", 32 | "weight_path": "weights/" 33 | } -------------------------------------------------------------------------------- /weights/AoA-reader/best-best-CBT-NE/result.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "attention_over_attention_reader.py", 3 | "test_acc": 0.71 4 | } --------------------------------------------------------------------------------