├── .gitignore
├── LICENSE
├── README.md
├── data
├── prepare-all.sh
├── prepare-cbt-data.sh
└── prepare-embedding.sh
├── dataset
├── __init__.py
├── cbt.py
├── data_file_pairs.py
├── rc_dataset.py
└── squad.py
├── main.py
├── models
├── __init__.py
├── attention_over_attention_reader.py
├── attention_sum_reader.py
├── model_data_pairs.py
├── nlp_base.py
├── r_net.py
└── rc_base.py
├── requirements.txt
├── test
├── dataset_test.py
└── notebook
│ ├── test_aoa.ipynb
│ └── test_as_reader.ipynb
├── utils
├── __init__.py
└── log.py
└── weights
├── AS-reader
├── best-CBT-CN
│ ├── args.json
│ └── result.json
├── best-CBT-NE
│ ├── args.json
│ └── result.json
└── best-best-CBT-NE
│ ├── args.json
│ └── result.json
└── AoA-reader
├── best-CBT-CN
├── args.json
└── result.json
├── best-CBT-NE
├── args.json
└── result.json
└── best-best-CBT-NE
├── args.json
└── result.json
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Python template
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | env/
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *,cover
48 | .hypothesis/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 |
58 | # Flask stuff:
59 | instance/
60 | .webassets-cache
61 |
62 | # Scrapy stuff:
63 | .scrapy
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 |
68 | # PyBuilder
69 | target/
70 |
71 | # Jupyter Notebook
72 | .ipynb_checkpoints
73 |
74 | # pyenv
75 | .python-version
76 |
77 | # celery beat schedule file
78 | celerybeat-schedule
79 |
80 | # SageMath parsed files
81 | *.sage.py
82 |
83 | # dotenv
84 | .env
85 |
86 | # virtualenv
87 | .venv
88 | venv/
89 | ENV/
90 |
91 | # Spyder project settings
92 | .spyderproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 | ### JetBrains template
97 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
98 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
99 |
100 | # Gradle:
101 | .idea/**/gradle.xml
102 | .idea/**/libraries
103 |
104 | # Mongo Explorer plugin:
105 | .idea/**/mongoSettings.xml
106 |
107 | ## File-based project format:
108 | *.iws
109 |
110 | ## Plugin-specific files:
111 |
112 | # IntelliJ
113 | /out/
114 |
115 | # mpeltonen/sbt-idea plugin
116 | .idea_modules/
117 |
118 | # JIRA plugin
119 | atlassian-ide-plugin.xml
120 |
121 | # Crashlytics plugin (for Android Studio and IntelliJ)
122 | com_crashlytics_export_strings.xml
123 | crashlytics.properties
124 | crashlytics-build.properties
125 | fabric.properties
126 | logs/
127 | .idea/
128 |
129 | ### Tensorflow checkpoint
130 | checkpoint
131 | *.meta
132 | *.index
133 | *.data-00000-of-00001
134 | data/CBTest/
135 | data/glove.6B/
136 | data/SQuAD/
137 | weights/args\.json
138 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 | {one line to give the program's name and a brief idea of what it does.}
635 | Copyright (C) {year} {name of author}
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | {project} Copyright (C) {year} {fullname}
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Reading Comprehension Experiments
2 |
3 | ## About
4 |
5 | This is the tensorflow version implementation/reproduce of some reading comprehension models in some reading comprehension datasets including the following:
6 |
7 | Models:
8 |
9 | - Attention Sum Reader model as presented in "Text Comprehension with the Attention Sum Reader Network" (ACL2016) available at [http://arxiv.org/abs/1603.01547](http://arxiv.org/abs/1603.01547).
10 |
11 | 
12 |
13 | - Attention over Attention Reader model as presented in "Attention-over-Attention Neural Networks for Reading Comprehension" (arXiv2016.7) available at https://arxiv.org/abs/1607.04423.
14 |
15 | 
16 |
17 | Datasets:
18 |
19 | - CBT, Children’s Book Test.http://lanl.arxiv.org/pdf/1506.03340.pdf
20 |
21 | ## Start To Use
22 |
23 | #### 1.Clone the code
24 |
25 | ```shell
26 | git clone https://github.com/zhanghaoyu1993/RC-experiments.git
27 | ```
28 |
29 |
30 |
31 | #### 2.Get needed data
32 |
33 | - Download and extract the dataset used in this repo.
34 |
35 | ```shell
36 | cd data
37 | ./prepare-all.sh
38 | ```
39 |
40 |
41 |
42 | #### 3.Environment Preparation
43 |
44 | - Python-64bit >= v3.5.
45 | - Install require libraries using the following command.
46 |
47 | ```shell
48 | pip install -r requirements.txt
49 | ```
50 |
51 | - Install tensorflow >= 1.1.0.
52 |
53 | ```shell
54 | pip install tensorflow-gpu --upgrade
55 | ```
56 |
57 | - Install nltk punkt for tokenizer.
58 |
59 | ```shell
60 | python -m nltk.downloader punkt
61 | ```
62 |
63 |
64 |
65 | #### 4.Set model, dataset and other command parameters
66 |
67 | - What is the entrance of the program?
68 |
69 | The main.py file in root directory.
70 |
71 | - How can I specify a model in command line?
72 |
73 | Type a command like above, the *model_class* is the class name of model, usually named in cambak-style:
74 |
75 | ```shell
76 | python main.py [model_class]
77 | ```
78 |
79 | For example, if you want to use AttentionSumReader:
80 |
81 | ```shell
82 | python main.py AttentionSumReader
83 | ```
84 |
85 | - How can I specify the dataset?
86 |
87 | Type a command like above, the *dataset_class* is the class name of dataset:
88 |
89 | ```shell
90 | python main.py [model_class] --dataset [dataset_class]
91 | ```
92 |
93 | For example, if you want to use CBT:
94 |
95 | ```shell
96 | python main.py [model_class] --dataset CBT
97 | ```
98 |
99 | You don't need to specify the data_root and train valid test file name in most cases, just specify the dataset.
100 |
101 | - How can I know all the parameters?
102 |
103 | The program use [argparse](https://docs.python.org/3/library/argparse.html) to deal with parameters, you can type the following command to get help:
104 |
105 | ```shell
106 | python main.py --help
107 | ```
108 |
109 | or:
110 |
111 | ```shell
112 | python main.py -h
113 | ```
114 |
115 | - The command parameters is so long!
116 |
117 | The parameters will be stored into a file named args.json when executed, so next time you can type the following simplified command:
118 |
119 | ```shell
120 | python main.py [model_class] --args_file [args.json]
121 | ```
122 |
123 |
124 |
125 | #### 5.Train and test the model
126 |
127 | First, modify the parameters in the args.json.
128 |
129 | You can now train and test the model by entering the following commands. The params in [] should be determined by the real situation.
130 |
131 | - Train:
132 |
133 | ```shell
134 | python main.py [model_class] --args_file [args.json] --train 1 --test 0
135 | ```
136 |
137 | After train, the parameters are stored in `weight_path/args.json` and the model checkpoints are stored in `weight_path`.
138 |
139 | - Test:
140 |
141 | ```shell
142 | python main.py [model_class] --args_file [args.json] --train 0 --test 1
143 | ```
144 |
145 | After test, the performance of model are stored in `weight_path/result.json`.
146 |
147 |
148 |
149 | #### 6.model performance
150 |
151 | All the trained results and corresponding config params are saved in sub directories of weight_path(by default the `weight` folder) named `args.json` and `result.json`.
152 |
153 | You should know that the implementation of some models are **slightly different** from the original, but the basic ideas are same, so the results are for reference only.
154 |
155 | The best results of implemented models are listed below:
156 |
157 | - best result **we achieve**(with little hyper-parameter tune in single model)
158 | - best result listed in original paper(in the brackets)
159 |
160 | | | CBT-NE | CBT-CN |
161 | | ---------- | ----------- | ----------- |
162 | | AS-Reader | 69.88(68.6) | 65.0(63.4) |
163 | | AoA-Reader | 71.0(72.0) | 68.12(69.4) |
164 |
165 |
166 |
167 | #### 7.FAQ
168 |
169 | - How do I use args_file argument in the shell?
170 |
171 | Once you enter a command in the shell(maybe a long one), the config will be stored in weight_path/args.json where weight_path is defined by another argument, after the command execute you can use --args.json to simplify the following command:
172 | ```shell
173 | python main.py [model_class] --args_file [args.json]
174 | ```
175 | And the priorities of arguments typed in the command line is higher than those stored in args.json, so you can change some arguments.
--------------------------------------------------------------------------------
/data/prepare-all.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ./prepare-embedding.sh
4 | ./prepare-cbt-data.sh
--------------------------------------------------------------------------------
/data/prepare-cbt-data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # prepares the Children's Book Test datasets
3 |
4 | # get CBT data
5 | wget http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz
6 |
7 | # unpack all files
8 | tar -zxvf CBTest.tgz
9 |
10 |
--------------------------------------------------------------------------------
/data/prepare-embedding.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # get glove embedding
4 | wget http://nlp.stanford.edu/data/glove.6B.zip
5 |
6 | # unpack all files
7 | unzip glove.6B.zip -d glove.6B
8 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .cbt import CBT
2 | from .squad import SQuAD
3 |
4 | CBT_NE = CBT
5 | CBT_CN = CBT
6 |
7 | __all__ = ["CBT_NE", "CBT_CN", "SQuAD"]
8 |
--------------------------------------------------------------------------------
/dataset/cbt.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import reduce
3 |
4 | import numpy as np
5 | from tensorflow.contrib.keras.python.keras.preprocessing.sequence import pad_sequences
6 | from tensorflow.python.platform import gfile
7 | from tensorflow.python.platform.gfile import FastGFile
8 |
9 | from dataset.rc_dataset import RCDataset
10 | from utils.log import logger
11 |
12 |
13 | class CBT(RCDataset):
14 | def __init__(self, args):
15 | self.A_len = 10
16 | super().__init__(args)
17 |
18 | def next_batch_feed_dict_by_dataset(self, dataset, _slice, samples):
19 | data = {
20 | "questions_bt:0": dataset[0][_slice],
21 | "documents_bt:0": dataset[1][_slice],
22 | "candidates_bi:0": dataset[2][_slice],
23 | "y_true_bi:0": dataset[3][_slice]
24 | }
25 | return data, samples
26 |
27 | def cbt_data_to_token_ids(self, data_file, target_file, vocab_file, max_count=None):
28 | """
29 | 22 lines for one sample.
30 | first 20 lines:documents with line number in the front.
31 | 21st line:line-number question\tAnswer\t\tCandidate1|...|Candidate10.
32 | 22nd line:blank.
33 | """
34 | if gfile.Exists(target_file):
35 | return
36 | logger("Tokenizing data in {}".format(data_file))
37 | word_dict = self.load_vocab(vocab_file)
38 | counter = 0
39 |
40 | with gfile.FastGFile(data_file) as f:
41 | with gfile.FastGFile(target_file, mode="wb") as tokens_file:
42 | for line in f:
43 | counter += 1
44 | if counter % 100000 == 0:
45 | logger("Tokenizing line %d" % counter)
46 | if max_count and counter > max_count:
47 | break
48 | if counter % 22 == 21:
49 | q, a, _, A = line.split("\t")
50 | token_ids_q = self.sentence_to_token_ids(q, word_dict)[1:]
51 | token_ids_A = [word_dict.get(a.lower(), self.UNK_ID) for a in A.rstrip("\n").split("|")]
52 | tokens_file.write(" ".join([str(tok) for tok in token_ids_q]) + "\t"
53 | + str(word_dict.get(a.lower(), self.UNK_ID)) + "\t"
54 | + "|".join([str(tok) for tok in token_ids_A]) + "\n")
55 | else:
56 | token_ids = self.sentence_to_token_ids(line, word_dict)
57 | token_ids = token_ids[1:] if token_ids else token_ids
58 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")
59 |
60 | def prepare_data(self, data_dir, train_file, valid_file, test_file, max_vocab_num, output_dir=""):
61 | """
62 | build vocabulary and translate CBT data to id format.
63 | """
64 | if not gfile.Exists(os.path.join(data_dir, output_dir)):
65 | os.mkdir(os.path.join(data_dir, output_dir))
66 | os_train_file = os.path.join(data_dir, train_file)
67 | os_valid_file = os.path.join(data_dir, valid_file)
68 | os_test_file = os.path.join(data_dir, test_file)
69 | idx_train_file = os.path.join(data_dir, output_dir, train_file + ".%d.idx" % max_vocab_num)
70 | idx_valid_file = os.path.join(data_dir, output_dir, valid_file + ".%d.idx" % max_vocab_num)
71 | idx_test_file = os.path.join(data_dir, output_dir, test_file + ".%d.idx" % max_vocab_num)
72 | vocab_file = os.path.join(data_dir, output_dir, "vocab.%d" % max_vocab_num)
73 |
74 | if not gfile.Exists(vocab_file):
75 | word_counter = self.gen_vocab(os_train_file, max_count=self.args.max_count)
76 | word_counter = self.gen_vocab(os_valid_file, old_counter=word_counter, max_count=self.args.max_count)
77 | word_counter = self.gen_vocab(os_test_file, old_counter=word_counter, max_count=self.args.max_count)
78 | self.save_vocab(word_counter, vocab_file, max_vocab_num)
79 |
80 | # translate train/valid/test files to id format
81 | self.cbt_data_to_token_ids(os_train_file, idx_train_file, vocab_file, max_count=self.args.max_count)
82 | self.cbt_data_to_token_ids(os_valid_file, idx_valid_file, vocab_file, max_count=self.args.max_count)
83 | self.cbt_data_to_token_ids(os_test_file, idx_test_file, vocab_file, max_count=self.args.max_count)
84 |
85 | return vocab_file, idx_train_file, idx_valid_file, idx_test_file
86 |
87 | def read_cbt_data(self, file, max_count=None):
88 | """
89 | read CBT data in id format.
90 | :return: (documents,questions,answers,candidates) each elements is a numpy array.
91 | """
92 | documents, questions, answers, candidates = [], [], [], []
93 | with FastGFile(file, mode="r") as f:
94 | counter = 0
95 | d, q, a, A = [], [], [], []
96 | for line in f:
97 | counter += 1
98 | if max_count and counter > max_count:
99 | break
100 | if counter % 100000 == 0:
101 | logger("Reading line %d in %s" % (counter, file))
102 | if counter % 22 == 21:
103 | tmp = line.strip().split("\t")
104 | q = tmp[0].split(" ") + [self.EOS_ID]
105 | a = [1 if tmp[1] == i else 0 for i in d]
106 | A = [a for a in tmp[2].split("|")]
107 | A.remove(tmp[1])
108 | A.insert(0, tmp[1]) # put right answer in the first of candidate
109 | elif counter % 22 == 0:
110 | documents.append(d)
111 | questions.append(q)
112 | answers.append(a)
113 | candidates.append(A)
114 | d, q, a, A = [], [], [], []
115 | else:
116 | d.extend(line.strip().split(" ") + [self.EOS_ID]) # add EOS ID in the end of each sentence
117 |
118 | d_lens = [len(i) for i in documents]
119 | q_lens = [len(i) for i in questions]
120 | avg_d_len = reduce(lambda x, y: x + y, d_lens) / len(documents)
121 | logger("Document average length: %d." % avg_d_len)
122 | logger("Document midden length: %d." % len(sorted(documents, key=len)[len(documents) // 2]))
123 | avg_q_len = reduce(lambda x, y: x + y, q_lens) / len(questions)
124 | logger("Question average length: %d." % avg_q_len)
125 | logger("Question midden length: %d." % len(sorted(questions, key=len)[len(questions) // 2]))
126 |
127 | return documents, questions, answers, candidates
128 |
129 | def preprocess_input_sequences(self, data):
130 | """
131 | preprocess,pad to fixed length.
132 | """
133 | documents, questions, answer, candidates = data
134 |
135 | questions_ok = pad_sequences(questions, maxlen=self.q_len, dtype="int32", padding="post", truncating="post")
136 | documents_ok = pad_sequences(documents, maxlen=self.d_len, dtype="int32", padding="post", truncating="post")
137 | candidates_ok = pad_sequences(candidates, maxlen=self.A_len, dtype="int32", padding="post", truncating="post")
138 | y_true = np.zeros_like(candidates_ok)
139 | y_true[:, 0] = 1
140 | return questions_ok, documents_ok, candidates_ok, y_true
141 |
142 | # noinspection PyAttributeOutsideInit
143 | def get_data_stream(self):
144 | # prepare data
145 | self.vocab_file, idx_train_file, idx_valid_file, idx_test_file = self.prepare_data(
146 | self.args.data_root, self.args.train_file, self.args.valid_file,
147 | self.args.test_file, self.args.max_vocab_num,
148 | output_dir=self.args.tmp_dir)
149 |
150 | # read data
151 | self.train_data = self.read_cbt_data(idx_train_file, max_count=self.args.max_count)
152 | self.valid_data = self.read_cbt_data(idx_valid_file, max_count=self.args.max_count)
153 |
154 | def get_max_length(d_bt):
155 | lens = [len(i) for i in d_bt]
156 | return max(lens)
157 |
158 | # data statistics
159 | self.d_len = get_max_length(self.train_data[0])
160 | self.q_len = get_max_length(self.train_data[1])
161 | self.train_sample_num = len(self.train_data[0])
162 | self.valid_sample_num = len(self.valid_data[0])
163 | self.train_idx = np.random.permutation(self.train_sample_num // self.args.batch_size)
164 | self.test_sample_num = 0
165 |
166 | if self.args.test:
167 | self.test_data = self.read_cbt_data(idx_test_file, max_count=self.args.max_count)
168 | self.test_sample_num = len(self.test_data[0])
169 |
170 | return self.d_len, self.q_len, self.train_sample_num, self.valid_sample_num, self.test_sample_num
171 |
--------------------------------------------------------------------------------
/dataset/data_file_pairs.py:
--------------------------------------------------------------------------------
1 | dataset_files_pairs = {
2 | "CBT_NE": [
3 | "data/CBTest/CBTest/data/",
4 | "cbtest_NE_train.txt",
5 | "cbtest_NE_valid_2000ex.txt",
6 | "cbtest_NE_test_2500ex.txt"],
7 | "CBT_CN": [
8 | "data/CBTest/CBTest/data/",
9 | "cbtest_CN_train.txt",
10 | "cbtest_CN_valid_2000ex.txt",
11 | "cbtest_CN_test_2500ex.txt"],
12 | "SQuAD": [
13 | "data/SQuAD",
14 | "train-v1.1.json",
15 | "dev-v1.1.json",
16 | "dev-v1.1.json"]
17 | }
18 |
--------------------------------------------------------------------------------
/dataset/rc_dataset.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import codecs
3 | import re
4 | from collections import Counter
5 |
6 | import nltk
7 | import numpy as np
8 | from tensorflow.python.platform import gfile
9 |
10 | from utils.log import logger
11 |
12 |
13 | def default_tokenizer(sentence):
14 | _DIGIT_RE = re.compile(r"\d+")
15 | sentence = _DIGIT_RE.sub("0", sentence)
16 | sentence = " ".join(sentence.split("|"))
17 | return nltk.word_tokenize(sentence.lower())
18 |
19 |
20 | # noinspection PyAttributeOutsideInit
21 | class RCDataset(object, metaclass=abc.ABCMeta):
22 | def __init__(self, args):
23 | self.args = args
24 | # padding,start of sentence,end of sentence,unk,end of question
25 | self._PAD = "_PAD"
26 | self._BOS = "_BOS"
27 | self._EOS = "_EOS"
28 | self._UNK = "_UNK"
29 | self._EOQ = "_EOQ"
30 | self._START_VOCAB = [self._PAD, self._BOS, self._EOS, self._UNK, self._EOQ]
31 | self.PAD_ID = 0
32 | self.BOS_ID = 1
33 | self.EOS_ID = 2
34 | self.UNK_ID = 3
35 | self.EOQ_ID = 4
36 |
37 | self._BLANK = "XXXXX"
38 |
39 | # special character of char embedding: pad and unk
40 | self._CHAR_PAD = "γ"
41 | self._CHAR_UNK = "δ"
42 | self.CHAR_PAD_ID = 0
43 | self.CHAR_UNK_ID = 1
44 | self._CHAR_START_VOCAB = [self._CHAR_PAD, self._CHAR_UNK]
45 |
46 | @property
47 | def train_idx(self):
48 | return self._train_idx
49 |
50 | @train_idx.setter
51 | def train_idx(self, value):
52 | self._train_idx = value
53 |
54 | @property
55 | def train_sample_num(self):
56 | return self._train_sample_num
57 |
58 | @train_sample_num.setter
59 | def train_sample_num(self, value):
60 | self._train_sample_num = value
61 |
62 | @property
63 | def valid_sample_num(self):
64 | return self._valid_sample_num
65 |
66 | @valid_sample_num.setter
67 | def valid_sample_num(self, value):
68 | self._valid_sample_num = value
69 |
70 | @property
71 | def test_sample_num(self):
72 | return self._test_sample_num
73 |
74 | @test_sample_num.setter
75 | def test_sample_num(self, value):
76 | self._test_sample_num = value
77 |
78 | def shuffle(self):
79 | logger("Shuffle the dataset.")
80 | np.random.shuffle(self.train_idx)
81 |
82 | def get_next_batch(self, mode, idx):
83 | """
84 | return next batch of data samples
85 | """
86 | batch_size = self.args.batch_size
87 | if mode == "train":
88 | dataset = self.train_data
89 | sample_num = self.train_sample_num
90 | elif mode == "valid":
91 | dataset = self.valid_data
92 | sample_num = self.valid_sample_num
93 | else:
94 | dataset = self.test_data
95 | sample_num = self.test_sample_num
96 | if mode == "train":
97 | start = self.train_idx[idx] * batch_size
98 | stop = (self.train_idx[idx] + 1) * batch_size
99 | else:
100 | start = idx * batch_size
101 | stop = (idx + 1) * batch_size if start < sample_num and (idx + 1) * batch_size < sample_num else -1
102 | samples = batch_size if stop != -1 else len(dataset[0]) - start
103 | _slice = np.index_exp[start:stop]
104 | return self.next_batch_feed_dict_by_dataset(dataset, _slice, samples)
105 |
106 | @staticmethod
107 | def gen_embeddings(word_dict, embed_dim, in_file=None, init=np.zeros):
108 | """
109 | Init embedding matrix with (or without) pre-trained word embeddings.
110 | """
111 | num_words = max(word_dict.values()) + 1
112 | embedding_matrix = init(-0.05, 0.05, (num_words, embed_dim))
113 | logger('Embeddings: %d x %d' % (num_words, embed_dim))
114 |
115 | if not in_file:
116 | return embedding_matrix
117 |
118 | def get_dim(file):
119 | first = gfile.FastGFile(file, mode='r').readline()
120 | return len(first.split()) - 1
121 |
122 | assert get_dim(in_file) == embed_dim
123 | logger('Loading embedding file: %s' % in_file)
124 | pre_trained = 0
125 | for line in codecs.open(in_file, encoding="utf-8"):
126 | sp = line.split()
127 | if sp[0] in word_dict:
128 | pre_trained += 1
129 | embedding_matrix[word_dict[sp[0]]] = np.asarray([float(x) for x in sp[1:]], dtype=np.float32)
130 | logger("Pre-trained: {}, {:.3f}%".format(pre_trained, pre_trained * 100.0 / num_words))
131 | return embedding_matrix
132 |
133 | def sentence_to_token_ids(self, sentence, word_dict, tokenizer=default_tokenizer):
134 | """
135 | Turn sentence to token ids.
136 | sentence: ["I", "have", "a", "dog"]
137 | word_list: {"I": 1, "have": 2, "a": 4, "dog": 7"}
138 | return: [1, 2, 4, 7]
139 | """
140 | return [word_dict.get(token, self.UNK_ID) for token in tokenizer(sentence)]
141 |
142 | def get_embedding_matrix(self, vocab_file, is_char_embedding=False):
143 | """
144 | :param is_char_embedding: is the function called for generate char embedding
145 | :param vocab_file: file containing saved vocabulary.
146 | :return: a dict with each key as a word, each value as its corresponding embedding vector.
147 | """
148 | word_dict = self.load_vocab(vocab_file)
149 | embedding_file = None if is_char_embedding else self.args.embedding_file
150 | embedding_dim = self.args.char_embedding_dim if is_char_embedding else self.args.embedding_dim
151 | embedding_matrix = self.gen_embeddings(word_dict,
152 | embedding_dim,
153 | embedding_file,
154 | init=np.random.uniform)
155 | return embedding_matrix
156 |
157 | def sort_by_length(self, data):
158 | # TODO: sort data array according to sequence length in order to speed up training
159 | pass
160 |
161 | @staticmethod
162 | def gen_char_vocab(data_file, tokenizer=default_tokenizer, old_counter=None):
163 | """
164 | generate character level vocabulary according to train corpus.
165 | """
166 | logger("Creating character dict from data {}.".format(data_file))
167 | char_counter = old_counter if old_counter else Counter()
168 | with gfile.FastGFile(data_file) as f:
169 | for line in f:
170 | tokens = tokenizer(line.rstrip("\n"))
171 | char_counter.update([char for word in tokens for char in word])
172 |
173 | # summary statistics
174 | total_chars = sum(char_counter.values())
175 | distinct_chars = len(list(char_counter))
176 |
177 | logger("STATISTICS" + "-" * 20)
178 | logger("Total characters: " + str(total_chars))
179 | logger("Total distinct characters: " + str(distinct_chars))
180 | return char_counter
181 |
182 | @staticmethod
183 | def gen_vocab(data_file, tokenizer=default_tokenizer, old_counter=None, max_count=None):
184 | """
185 | generate vocabulary according to train corpus.
186 | """
187 | logger("Creating word dict from data {}.".format(data_file))
188 | word_counter = old_counter if old_counter else Counter()
189 | counter = 0
190 | with gfile.FastGFile(data_file) as f:
191 | for line in f:
192 | counter += 1
193 | if max_count and counter > max_count:
194 | break
195 | tokens = tokenizer(line.rstrip('\n'))
196 | word_counter.update(tokens)
197 | if counter % 100000 == 0:
198 | logger("Process line %d Done." % counter)
199 |
200 | # summary statistics
201 | total_words = sum(word_counter.values())
202 | distinct_words = len(list(word_counter))
203 |
204 | logger("STATISTICS" + "-" * 20)
205 | logger("Total words: " + str(total_words))
206 | logger("Total distinct words: " + str(distinct_words))
207 |
208 | return word_counter
209 |
210 | def save_char_vocab(self, char_counter, char_vocab_file, max_vocab_num=None):
211 | """
212 | Save character vocabulary.
213 | We need two special vo
214 | """
215 | with gfile.FastGFile(char_vocab_file, "w") as f:
216 | for char in self._CHAR_START_VOCAB:
217 | f.write(char + "\n")
218 | for char in list(map(lambda x: x[0], char_counter.most_common(max_vocab_num))):
219 | f.write(char + "\n")
220 |
221 | def save_vocab(self, word_counter, vocab_file, max_vocab_num=None):
222 | with gfile.FastGFile(vocab_file, "w") as f:
223 | for word in self._START_VOCAB:
224 | f.write(word + "\n")
225 | for word in list(map(lambda x: x[0], word_counter.most_common(max_vocab_num))):
226 | f.write(word + "\n")
227 |
228 | @staticmethod
229 | def load_vocab(vocab_file):
230 | """
231 | load word(or char) vocabulary file to word/char dict
232 | """
233 | if not gfile.Exists(vocab_file):
234 | raise ValueError("Vocabulary file %s not found.", vocab_file)
235 | word_dict = {}
236 | word_id = 0
237 | for line in codecs.open(vocab_file, encoding="utf-8"):
238 | word_dict.update({line.strip(): word_id})
239 | word_id += 1
240 | return word_dict
241 |
242 | # noinspection PyAttributeOutsideInit
243 | def preprocess(self):
244 | self.train_data = self.preprocess_input_sequences(self.train_data)
245 | self.valid_data = self.preprocess_input_sequences(self.valid_data)
246 | if self.args.test:
247 | self.test_data = self.preprocess_input_sequences(self.test_data)
248 |
249 | @abc.abstractmethod
250 | def preprocess_input_sequences(self, data):
251 | """
252 | Preprocess train/valid/test data. Should be specified by sub class.
253 | """
254 | pass
255 |
256 | @abc.abstractmethod
257 | def get_data_stream(self):
258 | """
259 | Get data statistics.
260 | """
261 | pass
262 |
263 | @abc.abstractmethod
264 | def next_batch_feed_dict_by_dataset(self, dataset, _slice, samples):
265 | """
266 | How to specify feed dict according to _slice.
267 | """
268 | pass
269 |
--------------------------------------------------------------------------------
/dataset/squad.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import numpy as np
5 | from tensorflow.contrib.keras.python.keras.preprocessing.sequence import pad_sequences
6 | from tensorflow.python.platform import gfile
7 |
8 | from dataset.rc_dataset import RCDataset
9 | from utils.log import logger
10 |
11 |
12 | class SQuAD(RCDataset):
13 | def __init__(self, args):
14 | super(SQuAD, self).__init__(args)
15 | self.w_len = 10
16 |
17 | def next_batch_feed_dict_by_dataset(self, dataset, _slice, samples):
18 | data = {
19 | "documents_bt:0": dataset[0][_slice],
20 | "questions_bt:0": dataset[1][_slice],
21 | # TODO: substitute with real data
22 | "documents_btk:0": np.zeros([samples, self.d_len, self.w_len]),
23 | "questions_btk:0": np.zeros([samples, self.q_len, self.w_len]),
24 | "answer_start:0": dataset[2][_slice],
25 | "answer_end:0": dataset[3][_slice]
26 | }
27 | return data, samples
28 |
29 | def preprocess_input_sequences(self, data):
30 | documents, questions, answer_spans = data
31 | documents_ok = pad_sequences(documents, maxlen=self.d_len, dtype="int32", padding="post", truncating="post")
32 | questions_ok = pad_sequences(questions, maxlen=self.q_len, dtype="int32", padding="post", truncating="post")
33 | answer_start = [np.array([int(i == answer_span[0]) for i in range(self.d_len)]) for answer_span in answer_spans]
34 | answer_end = [np.array([int(i == answer_span[1]) for i in range(self.d_len)]) for answer_span in answer_spans]
35 | return documents_ok, questions_ok, np.asarray(answer_start), np.asarray(answer_end)
36 |
37 | def prepare_data(self, data_dir, train_file, valid_file, max_vocab_num, output_dir=""):
38 | """
39 | build word vocabulary and character vocabulary.
40 | """
41 | if not gfile.Exists(os.path.join(data_dir, output_dir)):
42 | os.mkdir(os.path.join(data_dir, output_dir))
43 | os_train_file = os.path.join(data_dir, train_file)
44 | os_valid_file = os.path.join(data_dir, valid_file)
45 | vocab_file = os.path.join(data_dir, output_dir, "vocab.%d" % max_vocab_num)
46 | char_vocab_file = os.path.join(data_dir, output_dir, "char_vocab")
47 |
48 | vocab_data_file = os.path.join(data_dir, output_dir, "data.txt")
49 |
50 | def save_data(d_data, q_data):
51 | """
52 | save all data to a file and use it build vocabulary.
53 | """
54 | with open(vocab_data_file, mode="w", encoding="utf-8") as f:
55 | f.write("\t".join(d_data) + "\n")
56 | f.write("\t".join(q_data) + "\n")
57 |
58 | if not gfile.Exists(vocab_data_file):
59 | d, q, _ = self.read_squad_data(os_train_file)
60 | v_d, v_q, _ = self.read_squad_data(os_valid_file)
61 | save_data(d, q)
62 | save_data(v_d, v_q)
63 | if not gfile.Exists(vocab_file):
64 | logger("Start create vocabulary.")
65 | word_counter = self.gen_vocab(vocab_data_file, max_count=self.args.max_count)
66 | self.save_vocab(word_counter, vocab_file, max_vocab_num)
67 | if not gfile.Exists(char_vocab_file):
68 | logger("Start create character vocabulary.")
69 | char_counter = self.gen_char_vocab(vocab_data_file)
70 | self.save_char_vocab(char_counter, char_vocab_file, max_vocab_num=70)
71 |
72 | return os_train_file, os_valid_file, vocab_file, char_vocab_file
73 |
74 | def read_squad_data(self, file):
75 | """
76 | read squad data file in string form
77 | :return tuple of (documents, questions, answer_spans)
78 | """
79 | logger("Reading SQuAD data.")
80 |
81 | def extract(sample_data):
82 | document = sample_data["context"]
83 | for qas in sample_data["qas"]:
84 | question = qas["question"]
85 | for ans in qas["answers"]:
86 | answer_len = len(ans["text"])
87 | answer_span = [ans["answer_start"], ans["answer_start"] + answer_len]
88 | assert (ans["text"] == document[ans["answer_start"]:(ans["answer_start"] + answer_len)])
89 | documents.append(document)
90 | questions.append(question)
91 | answer_spans.append(answer_span)
92 |
93 | documents, questions, answer_spans = [], [], []
94 | f = json.load(open(file, encoding="utf-8"))
95 | data_list, version = f["data"], f["version"]
96 | logger("SQuAD version: {}".format(version))
97 | [extract(sample) for data in data_list for sample in data["paragraphs"]]
98 | if self.args.debug:
99 | documents, questions, answer_spans = documents[:500], questions[:500], answer_spans[:500]
100 |
101 | return documents, questions, answer_spans
102 |
103 | def squad_data_to_idx(self, vocab_file, *args):
104 | """
105 | convert string list to index list form.
106 | """
107 | logger("Convert string data to index.")
108 | word_dict = self.load_vocab(vocab_file)
109 | res_data = [0, ] * len(args)
110 | for idx, i in enumerate(args):
111 | tmp = [self.sentence_to_token_ids(document, word_dict) for document in i]
112 | res_data[idx] = tmp.copy()
113 | logger("Convert string2index done.")
114 | return res_data
115 |
116 | # noinspection PyAttributeOutsideInit
117 | def get_data_stream(self):
118 | # prepare data
119 | os_train_file, os_valid_file, self.vocab_file, self.char_vocab_file = self.prepare_data(self.args.data_root,
120 | self.args.train_file,
121 | self.args.valid_file,
122 | self.args.max_vocab_num,
123 | self.args.tmp_dir)
124 |
125 | # read data
126 | documents, questions, answer_spans = self.read_squad_data(os_train_file)
127 | v_documents, v_questions, v_answer_spans = self.read_squad_data(os_valid_file)
128 | documents, questions, v_documents, v_questions = self.squad_data_to_idx(self.vocab_file, documents, questions,
129 | v_documents, v_questions)
130 | # SQuAD cannot access the test data
131 | # first 9/10 train data -> train data
132 | # last 1/10 train data -> valid data
133 | # valid data -> test data
134 | train_num = len(documents) * 9 // 10
135 | self.train_data = (documents[:train_num], questions[:train_num], answer_spans[:train_num])
136 | self.valid_data = (documents[train_num:], questions[train_num:], answer_spans[train_num:])
137 | self.test_data = (v_documents, v_questions, v_answer_spans)
138 |
139 | def get_max_length(d_bt):
140 | lens = [len(i) for i in d_bt]
141 | return max(lens)
142 |
143 | # data statistics
144 | self.d_len = get_max_length(self.train_data[0])
145 | self.q_len = get_max_length(self.train_data[1])
146 | self.train_sample_num = len(self.train_data[0])
147 | self.valid_sample_num = len(self.valid_data[0])
148 | self.test_sample_num = len(self.test_data[0])
149 | self.train_idx = np.random.permutation(self.train_sample_num // self.args.batch_size)
150 |
151 | return self.d_len, self.q_len, self.train_sample_num, self.valid_sample_num, self.test_sample_num
152 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | from models.nlp_base import NLPBase
5 |
6 |
7 | def get_model_class():
8 | if sys.argv[1] == "--help" or sys.argv[1] == "-h":
9 | return NLPBase()
10 | class_obj, class_name = None, sys.argv[1]
11 | try:
12 | import models
13 | class_obj = getattr(sys.modules["models"], class_name)
14 | sys.argv.pop(1)
15 | except AttributeError or IndexError:
16 | print("Model [{}] not found.\nSupported models:\n\n\t\t{}\n".format(class_name, sys.modules["models"].__all__))
17 | exit(1)
18 | return class_obj()
19 |
20 |
21 | if __name__ == '__main__':
22 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3"
23 | model = get_model_class()
24 | model.execute()
25 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.model_data_pairs import models_in_datasets
2 | from .attention_over_attention_reader import AoAReader
3 | from .attention_sum_reader import AttentionSumReader
4 | from .r_net import RNet
5 |
6 | __all__ = list(set([model for models in models_in_datasets.values() for model in models]))
7 |
--------------------------------------------------------------------------------
/models/attention_over_attention_reader.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, LSTMCell
3 |
4 | from models.rc_base import RcBase
5 | from utils.log import logger
6 |
7 |
8 | class AoAReader(RcBase):
9 | """
10 | Attention-over-Attention reader in "Attention-over-Attention Neural Networks for Reading Comprehension"
11 | (arXiv2016.7) available at https://arxiv.org/abs/1607.04423.
12 | """
13 |
14 | # noinspection PyAttributeOutsideInit
15 | def create_model(self):
16 | #########################
17 | # b ... position of the example within the batch
18 | # t ... position of the word within the document/question
19 | # ... d for max length of document
20 | # ... q for max length of question
21 | # f ... features of the embedding vector or the encoded feature vector
22 | # i ... position of the word in candidates list
23 | # v ... position of the word in vocabulary
24 | #########################
25 | _EPSILON = 10e-8
26 | num_layers = self.args.num_layers
27 | hidden_size = self.args.hidden_size
28 | cell = LSTMCell if self.args.use_lstm else GRUCell
29 |
30 | # model input
31 | questions_bt = tf.placeholder(dtype=tf.int32, shape=(None, self.q_len), name="questions_bt")
32 | documents_bt = tf.placeholder(dtype=tf.int32, shape=(None, self.d_len), name="documents_bt")
33 | candidates_bi = tf.placeholder(dtype=tf.int32, shape=(None, self.dataset.A_len), name="candidates_bi")
34 | y_true_bi = tf.placeholder(shape=(None, self.dataset.A_len), dtype=tf.float32, name="y_true_bi")
35 | keep_prob = tf.placeholder(dtype=tf.float32, name="keep_prob")
36 |
37 | init_embedding = tf.constant(self.embedding_matrix, dtype=tf.float32, name="embedding_init")
38 | embedding = tf.get_variable(initializer=init_embedding,
39 | name="embedding_matrix",
40 | dtype=tf.float32)
41 | embedding = tf.nn.dropout(embedding, keep_prob)
42 |
43 | # shape=(None) the length of inputs
44 | document_lengths = tf.reduce_sum(tf.sign(tf.abs(documents_bt)), 1)
45 | question_lengths = tf.reduce_sum(tf.sign(tf.abs(questions_bt)), 1)
46 | document_mask_bt = tf.sequence_mask(document_lengths, self.d_len, dtype=tf.float32)
47 | question_mask_bt = tf.sequence_mask(question_lengths, self.q_len, dtype=tf.float32)
48 |
49 | with tf.variable_scope('q_encoder', initializer=tf.orthogonal_initializer()):
50 | # encode question to fixed length of vector
51 | # output shape: (None, max_q_length, embedding_dim)
52 | question_embed_btf = tf.nn.embedding_lookup(embedding, questions_bt)
53 | logger("q_embed_btf shape {}".format(question_embed_btf.get_shape()))
54 | q_cell_fw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)])
55 | q_cell_bw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)])
56 | outputs, last_states = tf.nn.bidirectional_dynamic_rnn(cell_bw=q_cell_bw,
57 | cell_fw=q_cell_fw,
58 | dtype="float32",
59 | sequence_length=question_lengths,
60 | inputs=question_embed_btf,
61 | swap_memory=True)
62 | # q_encoder output shape: (None, max_t_length, hidden_size * 2)
63 | q_encoded_bqf = tf.concat(outputs, axis=-1)
64 | logger("q_encoded_bqf shape {}".format(q_encoded_bqf.get_shape()))
65 |
66 | with tf.variable_scope('d_encoder', initializer=tf.orthogonal_initializer()):
67 | # encode each document(context) word to fixed length vector
68 | # output shape: (None, max_d_length, embedding_dim)
69 | d_embed_btf = tf.nn.embedding_lookup(embedding, documents_bt)
70 | logger("d_embed_btf shape {}".format(d_embed_btf.get_shape()))
71 | d_cell_fw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)])
72 | d_cell_bw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)])
73 | outputs, last_states = tf.nn.bidirectional_dynamic_rnn(cell_bw=d_cell_bw,
74 | cell_fw=d_cell_fw,
75 | dtype="float32",
76 | sequence_length=document_lengths,
77 | inputs=d_embed_btf,
78 | swap_memory=True)
79 | # d_encoder output shape: (None, max_d_length, hidden_size * 2)
80 | d_encoded_bdf = tf.concat(outputs, axis=-1)
81 | logger("d_encoded_bdf shape {}".format(d_encoded_bdf.get_shape()))
82 |
83 | # mask of the pair-wise matrix
84 | M_mask = tf.einsum("bi,bj->bij", document_mask_bt, question_mask_bt)
85 | # batch pair-wise matching
86 | M_bdq = tf.matmul(d_encoded_bdf, q_encoded_bqf, adjoint_b=True)
87 |
88 | # individual attentions
89 | alpha_bdq = self.softmax_with_mask(M_bdq, 1, M_mask, name="alpha")
90 | beta_bdq = self.softmax_with_mask(M_bdq, 2, M_mask, name="beta")
91 | beta_bq1 = tf.expand_dims(tf.reduce_sum(beta_bdq, 1) / tf.to_float(tf.expand_dims(document_lengths, -1)), -1)
92 | logger("beta_bq1 shape:{}".format(beta_bq1.get_shape()))
93 | # document-level attention
94 | s_bd = tf.squeeze(tf.einsum("bdq,bqi->bdi", alpha_bdq, beta_bq1), -1)
95 |
96 | vocab_size = self.embedding_matrix.shape[0]
97 | # attention sum operation and gather within candidate_index
98 | y_hat_bi = tf.scan(fn=lambda prev, cur: tf.gather(tf.unsorted_segment_sum(cur[0], cur[1], vocab_size), cur[2]),
99 | elems=[s_bd, documents_bt, candidates_bi],
100 | initializer=tf.Variable([0] * self.dataset.A_len, dtype="float32"))
101 |
102 | # manual computation of crossentropy
103 | output_bi = y_hat_bi / tf.reduce_sum(y_hat_bi, axis=-1, keep_dims=True)
104 | epsilon = tf.convert_to_tensor(_EPSILON, output_bi.dtype.base_dtype, name="epsilon")
105 | output_bi = tf.clip_by_value(output_bi, epsilon, 1. - epsilon)
106 |
107 | # loss and correct number
108 | self.loss = tf.reduce_mean(- tf.reduce_sum(y_true_bi * tf.log(output_bi), axis=-1))
109 | self.correct_prediction = tf.reduce_sum(
110 | tf.sign(tf.cast(tf.equal(tf.argmax(output_bi, 1),
111 | tf.argmax(y_true_bi, 1)), "float")))
112 |
113 | @staticmethod
114 | def softmax_with_mask(logits, axis, mask, epsilon=10e-8, name=None):
115 | with tf.name_scope(name, 'softmax', [logits, mask]):
116 | max_axis = tf.reduce_max(logits, axis, keep_dims=True)
117 | target_exp = tf.exp(logits - max_axis) * mask
118 | normalize = tf.reduce_sum(target_exp, axis, keep_dims=True)
119 | softmax = target_exp / (normalize + epsilon)
120 | return softmax
121 |
122 | def get_batch_data(self, mode, idx):
123 | data, samples = self.dataset.get_next_batch(mode, idx)
124 | if mode == "train":
125 | data.update({"keep_prob:0": self.args.keep_prob})
126 | else:
127 | data.update({"keep_prob:0": 1.0})
128 | return data, samples
129 |
--------------------------------------------------------------------------------
/models/attention_sum_reader.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.rnn import LSTMCell, MultiRNNCell, GRUCell
3 |
4 | from models.rc_base import RcBase
5 | from utils.log import logger
6 |
7 | _EPSILON = 10e-8
8 |
9 |
10 | class AttentionSumReader(RcBase):
11 | """
12 | Attention Sum Reader model as presented in "Text Comprehension with the Attention Sum Reader Network"
13 | (ACL2016) available at http://arxiv.org/abs/1603.01547.
14 | """
15 |
16 | # noinspection PyAttributeOutsideInit
17 | def create_model(self):
18 | #########################
19 | # b ... position of the example within the batch
20 | # t ... position of the word within the document/question
21 | # f ... features of the embedding vector or the encoded feature vector
22 | # i ... position of the word in candidates list
23 | #########################
24 | num_layers = self.args.num_layers
25 | hidden_size = self.args.hidden_size
26 | cell = LSTMCell if self.args.use_lstm else GRUCell
27 |
28 | # model input
29 | questions_bt = tf.placeholder(dtype=tf.int32, shape=(None, self.q_len), name="questions_bt")
30 | documents_bt = tf.placeholder(dtype=tf.int32, shape=(None, self.d_len), name="documents_bt")
31 | candidates_bi = tf.placeholder(dtype=tf.int32, shape=(None, self.dataset.A_len), name="candidates_bi")
32 | y_true_bi = tf.placeholder(shape=(None, self.dataset.A_len), dtype=tf.float32, name="y_true_bi")
33 |
34 | # shape=(None) the length of inputs
35 | context_lengths = tf.reduce_sum(tf.sign(tf.abs(documents_bt)), 1)
36 | question_lengths = tf.reduce_sum(tf.sign(tf.abs(questions_bt)), 1)
37 | context_mask_bt = tf.sequence_mask(context_lengths, self.d_len, dtype=tf.float32)
38 |
39 | init_embedding = tf.constant(self.embedding_matrix, dtype=tf.float32, name="embedding_init")
40 | embedding = tf.get_variable(initializer=init_embedding,
41 | name="embedding_matrix",
42 | dtype=tf.float32)
43 |
44 | with tf.variable_scope('q_encoder', initializer=tf.orthogonal_initializer()):
45 | # encode question to fixed length of vector
46 | # output shape: (None, max_q_length, embedding_dim)
47 | question_embed_btf = tf.nn.embedding_lookup(embedding, questions_bt)
48 | logger("q_embed_btf shape {}".format(question_embed_btf.get_shape()))
49 | q_cell_fw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)])
50 | q_cell_bw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)])
51 | outputs, last_states = tf.nn.bidirectional_dynamic_rnn(cell_bw=q_cell_bw,
52 | cell_fw=q_cell_fw,
53 | dtype="float32",
54 | sequence_length=question_lengths,
55 | inputs=question_embed_btf,
56 | swap_memory=True)
57 | # q_encoder output shape: (None, hidden_size * 2)
58 | q_encoded_bf = tf.concat([last_states[0][-1], last_states[1][-1]], axis=-1)
59 | logger("q_encoded_bf shape {}".format(q_encoded_bf.get_shape()))
60 |
61 | with tf.variable_scope('d_encoder', initializer=tf.orthogonal_initializer()):
62 | # encode each document(context) word to fixed length vector
63 | # output shape: (None, max_d_length, embedding_dim)
64 | d_embed_btf = tf.nn.embedding_lookup(embedding, documents_bt)
65 | logger("d_embed_btf shape {}".format(d_embed_btf.get_shape()))
66 | d_cell_fw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)])
67 | d_cell_bw = MultiRNNCell(cells=[cell(hidden_size) for _ in range(num_layers)])
68 | outputs, last_states = tf.nn.bidirectional_dynamic_rnn(cell_bw=d_cell_bw,
69 | cell_fw=d_cell_fw,
70 | dtype="float32",
71 | sequence_length=context_lengths,
72 | inputs=d_embed_btf,
73 | swap_memory=True)
74 | # d_encoder output shape: (None, max_d_length, hidden_size * 2)
75 | d_encoded_btf = tf.concat(outputs, axis=-1)
76 | logger("d_encoded_btf shape {}".format(d_encoded_btf.get_shape()))
77 |
78 | def att_dot(x):
79 | """attention dot product function"""
80 | d_btf, q_bf = x
81 | res = tf.matmul(tf.expand_dims(q_bf, -1), d_btf, adjoint_a=True, adjoint_b=True)
82 | return tf.reshape(res, [-1, self.d_len])
83 |
84 | with tf.variable_scope('merge'):
85 | mem_attention_pre_soft_bt = att_dot([d_encoded_btf, q_encoded_bf])
86 | mem_attention_pre_soft_masked_bt = tf.multiply(mem_attention_pre_soft_bt,
87 | context_mask_bt,
88 | name="attention_mask")
89 | mem_attention_bt = tf.nn.softmax(logits=mem_attention_pre_soft_masked_bt, name="softmax_attention")
90 |
91 | # attention-sum process
92 | def sum_prob_of_word(word_ix, sentence_ixs, sentence_attention_probs):
93 | word_ixs_in_sentence = tf.where(tf.equal(sentence_ixs, word_ix))
94 | return tf.reduce_sum(tf.gather(sentence_attention_probs, word_ixs_in_sentence))
95 |
96 | # noinspection PyUnusedLocal
97 | def sum_probs_single_sentence(prev, cur):
98 | candidate_indices_i, sentence_ixs_t, sentence_attention_probs_t = cur
99 | result = tf.scan(
100 | fn=lambda previous, x: sum_prob_of_word(x, sentence_ixs_t, sentence_attention_probs_t),
101 | elems=[candidate_indices_i],
102 | initializer=tf.constant(0., dtype="float32"))
103 | return result
104 |
105 | def sum_probs_batch(candidate_indices_bi, sentence_ixs_bt, sentence_attention_probs_bt):
106 | result = tf.scan(
107 | fn=sum_probs_single_sentence,
108 | elems=[candidate_indices_bi, sentence_ixs_bt, sentence_attention_probs_bt],
109 | initializer=tf.Variable([0] * self.dataset.A_len, dtype="float32"))
110 | return result
111 |
112 | # output shape: (None, i) i = max_candidate_length = 10
113 | y_hat = sum_probs_batch(candidates_bi, documents_bt, mem_attention_bt)
114 |
115 | # crossentropy
116 | output = y_hat / tf.reduce_sum(y_hat, axis=-1, keep_dims=True)
117 | # manual computation of crossentropy
118 | epsilon = tf.convert_to_tensor(_EPSILON, output.dtype.base_dtype, name="epsilon")
119 | output = tf.clip_by_value(output, epsilon, 1. - epsilon)
120 | self.loss = tf.reduce_mean(- tf.reduce_sum(y_true_bi * tf.log(output), axis=-1))
121 |
122 | # correct prediction nums
123 | self.correct_prediction = tf.reduce_sum(tf.sign(tf.cast(tf.equal(tf.argmax(y_hat, 1),
124 | tf.argmax(y_true_bi, 1)), "float")))
125 |
--------------------------------------------------------------------------------
/models/model_data_pairs.py:
--------------------------------------------------------------------------------
1 | # make sure the model supports the dataset you use
2 | models_in_datasets = {
3 | "CBT_NE": ["AttentionSumReader", "AoAReader"],
4 | "CBT_CN": ["AttentionSumReader", "AoAReader"],
5 | "SQuAD": ["RNet"]
6 | }
7 |
--------------------------------------------------------------------------------
/models/nlp_base.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import sys
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | from dataset.data_file_pairs import dataset_files_pairs
9 | from utils.log import setup_from_args_file, save_args, err
10 |
11 |
12 | class NLPBase(object):
13 | """
14 | Base class for NLP experiments based on tensorflow environment.
15 | Only do some arguments reading and serializing work.
16 | """
17 |
18 | def __init__(self):
19 | self.model_name = self.__class__.__name__
20 | self.sess = tf.Session()
21 | # get arguments
22 | self.args = self.get_args()
23 |
24 | # log set
25 | logging.basicConfig(filename=self.args.log_file,
26 | level=logging.DEBUG,
27 | format='%(asctime)s %(message)s', datefmt='%y-%m-%d %H:%M')
28 |
29 | # set random seed
30 | np.random.seed(self.args.random_seed)
31 | tf.set_random_seed(self.args.random_seed)
32 |
33 | # save arguments
34 | save_args(args=self.args)
35 |
36 | def add_args(self, parser):
37 | """
38 | If some model need more arguments, override this method.
39 | """
40 | pass
41 |
42 | def get_args(self):
43 | """
44 | The priority of args:
45 | [low] ... args define in the code
46 | [middle] ... args define in args_file
47 | [high] ... args define in command line
48 | """
49 |
50 | def str2bool(v):
51 | if v.lower() in ("yes", "true", "t", "y", "1"):
52 | return True
53 | if v.lower() in ("no", "false", "f", "n", "0", "none"):
54 | return False
55 | else:
56 | raise argparse.ArgumentTypeError('Boolean value expected.')
57 |
58 | def str_or_none(v):
59 | if not v or v.lower() in ("no", "false", "f", "n", "0", "none", "null"):
60 | return None
61 | return v
62 |
63 | def int_or_none(v):
64 | if not v or v.lower() in ("no", "false", "f", "n", "0", "none", "null"):
65 | return None
66 | return int(v)
67 |
68 | # TODO:Implement ensemble test
69 | parser = argparse.ArgumentParser(description="Reading Comprehension Experiment Code Base.")
70 | # -----------------------------------------------------------------------------------------------------------
71 | group1 = parser.add_argument_group("1.Basic options")
72 | # basis argument
73 | group1.add_argument("--debug", default=False, type=str2bool, help="is debug mode on or off")
74 |
75 | group1.add_argument("--train", default=True, type=str2bool, help="train or not")
76 |
77 | group1.add_argument("--test", default=False, type=str2bool, help="test or not")
78 |
79 | group1.add_argument("--ensemble", default=False, type=str2bool, help="ensemble test or not")
80 |
81 | group1.add_argument("--random_seed", default=2088, type=int, help="random seed")
82 |
83 | group1.add_argument("--log_file", default=None, type=str_or_none,
84 | help="which file to save the log,if None,use screen")
85 |
86 | group1.add_argument("--weight_path", default="weights", help="path to save all trained models")
87 |
88 | group1.add_argument("--args_file", default=None, type=str_or_none, help="json file of current args")
89 |
90 | group1.add_argument("--print_every_n", default=10, type=int, help="print performance every n steps")
91 |
92 | # data specific argument
93 | group2 = parser.add_argument_group("2.Data specific options")
94 | # noinspection PyUnresolvedReferences
95 | import dataset
96 | group2.add_argument("--dataset", default="CBT", choices=sys.modules['dataset'].__all__, type=str,
97 | help='type of the dataset to load')
98 |
99 | group2.add_argument("--embedding_file", default="data/glove.6B/glove.6B.200d.txt",
100 | type=str_or_none, help="pre-trained embedding file")
101 |
102 | group2.add_argument("--max_vocab_num", default=100000, type=int, help="the max number of words in vocabulary")
103 |
104 | subgroup = group2.add_argument_group("Some default options related to dataset, don't change if it works")
105 |
106 | subgroup.add_argument("--data_root", default="data/CBTest/CBTest/data/",
107 | help="root path of the dataset")
108 |
109 | subgroup.add_argument("--tmp_dir", default="tmp", help="dataset specific tmp folder")
110 |
111 | subgroup.add_argument("--train_file", default="cbtest_NE_train.txt", help="train file")
112 |
113 | subgroup.add_argument("--valid_file", default="cbtest_NE_valid_2000ex.txt", help="validation file")
114 |
115 | subgroup.add_argument("--test_file", default="cbtest_NE_test_2500ex.txt", help="test file")
116 |
117 | subgroup.add_argument("--max_count", default=None, type=int_or_none,
118 | help="read n lines of data file, if None, read all data")
119 |
120 | # hyper-parameters
121 | group3 = parser.add_argument_group("3.Hyper parameters shared by all models")
122 |
123 | group3.add_argument("--use_char_embedding", default=False, type=str2bool,
124 | help="use character embedding or not")
125 |
126 | group3.add_argument("--char_embedding_dim", default=100, type=int, help="dimension of char embeddings")
127 |
128 | group3.add_argument("--embedding_dim", default=200, type=int, help="dimension of word embeddings")
129 |
130 | group3.add_argument("--hidden_size", default=128, type=int, help="RNN hidden size")
131 |
132 | group3.add_argument("--grad_clipping", default=10, type=int, help="the threshold value of gradient clip")
133 |
134 | group3.add_argument("--lr", default=0.001, type=float, help="learning rate")
135 |
136 | group3.add_argument("--keep_prob", default=0.9, type=float, help="dropout,percentage to keep during training")
137 |
138 | group3.add_argument("--l2", default=0.0001, type=float, help="l2 regularization weight")
139 |
140 | group3.add_argument("--num_layers", default=1, type=int, help="RNN layer number")
141 |
142 | group3.add_argument("--use_lstm", default=False, type=str2bool,
143 | help="RNN kind, if False, use GRU else LSTM")
144 |
145 | group3.add_argument("--batch_size", default=32, type=int, help="batch_size")
146 |
147 | group3.add_argument("--optimizer", default="ADAM", choices=["SGD", "ADAM"],
148 | help="optimize algorithms, SGD or Adam")
149 |
150 | group3.add_argument("--evaluate_every_n", default=400, type=int,
151 | help="evaluate performance on validation set and possibly saving the best model")
152 |
153 | group3.add_argument("--num_epoches", default=10, type=int, help="max epoch iterations")
154 |
155 | group3.add_argument("--patience", default=5, type=int, help="early stopping patience")
156 | # -----------------------------------------------------------------------------------------------------------
157 | group4 = parser.add_argument_group("4.model [{}] specific parameters".format(self.model_name))
158 |
159 | self.add_args(group4)
160 |
161 | args = parser.parse_args()
162 |
163 | setup_from_args_file(args.args_file)
164 |
165 | args = parser.parse_args()
166 |
167 | # set debug params
168 | args.max_count = 7392 if args.debug else args.max_count
169 | args.evaluate_every_n = 5 if args.debug else args.evaluate_every_n
170 | args.num_epoches = 2 if args.debug else args.num_epoches
171 |
172 | args = self.tune_args(args)
173 |
174 | return args
175 |
176 | @staticmethod
177 | def tune_args(args):
178 | """
179 | tune the dataset specific args so train_file or test_file need not be changed
180 | """
181 | try:
182 | files = dataset_files_pairs.get(args.dataset)
183 | args.data_root, args.train_file, args.valid_file, args.test_file = files
184 | return args
185 | except AssertionError:
186 | err("Error. Cannot find the specific key -> {} in dataset_files_pairs.".format(args.dataset))
187 |
--------------------------------------------------------------------------------
/models/r_net.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cairoHy/RC-experiments/0262f83481c364f29a43ac7cfc28da88d31f5adc/models/r_net.py
--------------------------------------------------------------------------------
/models/rc_base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import os
3 | import sys
4 |
5 | import tensorflow as tf
6 |
7 | # noinspection PyUnresolvedReferences
8 | import dataset
9 | from models import models_in_datasets
10 | from models.nlp_base import NLPBase
11 | from utils.log import logger, save_obj_to_json, err
12 |
13 |
14 | # noinspection PyAttributeOutsideInit
15 | class RcBase(NLPBase, metaclass=abc.ABCMeta):
16 | """
17 | Base class of reading comprehension experiments.
18 | Reads different reading comprehension datasets according to specific class.
19 | creates a model and starts training it.
20 | Any deep learning model should inherit from this class and implement the create_model method.
21 | """
22 |
23 | @property
24 | def loss(self):
25 | return self._loss
26 |
27 | @loss.setter
28 | def loss(self, value):
29 | self._loss = value
30 |
31 | @property
32 | def correct_prediction(self):
33 | return self._correct_prediction
34 |
35 | @correct_prediction.setter
36 | def correct_prediction(self, value):
37 | self._correct_prediction = value
38 |
39 | def get_train_op(self):
40 | """
41 | define optimization operation
42 | """
43 | if self.args.optimizer == "SGD":
44 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.args.lr)
45 | elif self.args.optimizer == "ADAM":
46 | optimizer = tf.train.AdamOptimizer(learning_rate=self.args.lr)
47 | else:
48 | raise NotImplementedError("Other Optimizer Not Implemented.-_-||")
49 |
50 | # gradient clip
51 | grad_vars = optimizer.compute_gradients(self.loss)
52 | grad_vars = [
53 | (tf.clip_by_norm(grad, self.args.grad_clipping), var)
54 | if grad is not None else (grad, var)
55 | for grad, var in grad_vars]
56 | self.train_op = optimizer.apply_gradients(grad_vars, self.step)
57 | return
58 |
59 | @abc.abstractmethod
60 | def create_model(self):
61 | """
62 | should be override by sub-class and create some operations include [loss, correct_prediction]
63 | as class attributes.
64 | """
65 | return
66 |
67 | def execute(self):
68 | """
69 | main method to train and test
70 | """
71 | self.confirm_model_dataset_fitness()
72 |
73 | self.dataset = getattr(sys.modules["dataset"], self.args.dataset)(self.args)
74 |
75 | # Get the statistics of data
76 | # [document length] and [question length] to build the model
77 | # train/valid/test sample number to train and validate and test the model
78 | statistics = self.dataset.get_data_stream()
79 | self.d_len, self.q_len, self.train_nums, self.valid_nums, self.test_num = statistics
80 | self.dataset.preprocess()
81 |
82 | # Get the word embedding and character embedding(if necessary)
83 | self.embedding_matrix = self.dataset.get_embedding_matrix(self.dataset.vocab_file)
84 | if self.args.use_char_embedding and getattr(self.dataset, "char_vocab_file"):
85 | self.char_embedding_matrix = self.dataset.get_embedding_matrix(self.dataset.char_vocab_file, True)
86 |
87 | self.create_model()
88 |
89 | self.make_sure_model_is_valid()
90 |
91 | self.saver = tf.train.Saver(max_to_keep=20)
92 |
93 | if self.args.train:
94 | self.train()
95 | if self.args.test:
96 | self.test()
97 |
98 | self.sess.close()
99 |
100 | def get_batch_data(self, mode, idx):
101 | """
102 | Get batch data and feed it to tensorflow graph
103 | Modify it in sub-class if needed.
104 | """
105 | return self.dataset.get_next_batch(mode, idx)
106 |
107 | def train(self):
108 | """
109 | train model
110 | """
111 | self.step = tf.Variable(0, name="global_step", trainable=False)
112 | batch_size = self.args.batch_size
113 | epochs = self.args.num_epoches
114 | self.get_train_op()
115 | self.sess.run(tf.global_variables_initializer())
116 | self.load_weight()
117 |
118 | # early stopping params, by default val_acc is the metric
119 | self.patience, self.best_val_acc = self.args.patience, 0.
120 | # Start training
121 | corrects_in_epoch, samples_in_epoch, loss_in_epoch = 0, 0, 0
122 | batch_num = self.train_nums // batch_size
123 | logger("Train on {} batches, {} samples per batch, {} total.".format(batch_num, batch_size, self.train_nums))
124 |
125 | step = self.sess.run(self.step)
126 | while step < batch_num * epochs:
127 | step = self.sess.run(self.step)
128 | # on Epoch start
129 | if step % batch_num == 0:
130 | corrects_in_epoch, samples_in_epoch, loss_in_epoch = 0, 0, 0
131 | logger("{}Epoch : {}{}".format("-" * 40, step // batch_num + 1, "-" * 40))
132 | self.dataset.shuffle()
133 |
134 | data, samples = self.get_batch_data("train", step % batch_num)
135 | loss, _, corrects_in_batch = self.sess.run([self.loss, self.train_op, self.correct_prediction],
136 | feed_dict=data)
137 | corrects_in_epoch += corrects_in_batch
138 | loss_in_epoch += loss * samples
139 | samples_in_epoch += samples
140 |
141 | # logger
142 | if step % self.args.print_every_n == 0:
143 | logger("Samples : {}/{}.\tStep : {}/{}.\tLoss : {:.4f}.\tAccuracy : {:.4f}".format(
144 | samples_in_epoch, self.train_nums,
145 | step % batch_num, batch_num,
146 | loss_in_epoch / samples_in_epoch, corrects_in_epoch / samples_in_epoch))
147 |
148 | # evaluate on the valid set and early stopping
149 | if step and step % self.args.evaluate_every_n == 0:
150 | val_acc, val_loss = self.validate()
151 | self.early_stopping(val_acc, val_loss, step)
152 |
153 | def validate(self):
154 | batch_size = self.args.batch_size
155 | v_batch_num = self.valid_nums // batch_size
156 | # ensure the entire valid set is selected
157 | v_batch_num = v_batch_num + 1 if (self.valid_nums % batch_size) != 0 else v_batch_num
158 | logger("Validate on {} batches, {} samples per batch, {} total."
159 | .format(v_batch_num, batch_size, self.valid_nums))
160 | val_num, val_corrects, v_loss = 0, 0, 0
161 | for i in range(v_batch_num):
162 | data, samples = self.get_batch_data("valid", i)
163 | if samples != 0:
164 | loss, v_correct = self.sess.run([self.loss, self.correct_prediction], feed_dict=data)
165 | val_num += samples
166 | val_corrects += v_correct
167 | v_loss += loss * samples
168 | assert (val_num == self.valid_nums)
169 | val_acc = val_corrects / val_num
170 | val_loss = v_loss / val_num
171 | logger("Evaluate on : {}/{}.\tVal acc : {:.4f}.\tVal Loss : {:.4f}".format(val_num,
172 | self.valid_nums,
173 | val_acc,
174 | val_loss))
175 | return val_acc, val_loss
176 |
177 | # noinspection PyUnusedLocal
178 | def early_stopping(self, val_acc, val_loss, step):
179 | if val_acc > self.best_val_acc:
180 | self.patience = self.args.patience
181 | self.best_val_acc = val_acc
182 | self.save_weight(val_acc, step)
183 | elif self.patience == 1:
184 | logger("Oh u, stop training.")
185 | exit(0)
186 | else:
187 | self.patience -= 1
188 | logger("Remaining/Patience : {}/{} .".format(self.patience, self.args.patience))
189 |
190 | def save_weight(self, val_acc, step):
191 | path = self.saver.save(self.sess,
192 | os.path.join(self.args.weight_path,
193 | "{}-val_acc-{:.4f}.models".format(self.model_name, val_acc)),
194 | global_step=step)
195 | logger("Save models to {}.".format(path))
196 |
197 | def load_weight(self):
198 | ckpt = tf.train.get_checkpoint_state(self.args.weight_path)
199 | if ckpt is not None:
200 | logger("Load models from {}.".format(ckpt.model_checkpoint_path))
201 | self.saver.restore(self.sess, ckpt.model_checkpoint_path)
202 | else:
203 | logger("No previous models.")
204 |
205 | def test(self):
206 | if not self.args.train:
207 | self.sess.run(tf.global_variables_initializer())
208 | self.load_weight()
209 | batch_size = self.args.batch_size
210 | batch_num = self.test_num // batch_size
211 | batch_num = batch_num + 1 if (self.test_num % batch_size) != 0 else batch_num
212 | correct_num, total_num = 0, 0
213 | for i in range(batch_num):
214 | data, samples = self.get_batch_data("test", i)
215 | if samples != 0:
216 | correct, = self.sess.run([self.correct_prediction], feed_dict=data)
217 | correct_num, total_num = correct_num + correct, total_num + samples
218 | assert (total_num == self.test_num)
219 | logger("Test on : {}/{}".format(total_num, self.test_num))
220 | test_acc = correct_num / total_num
221 | logger("Test accuracy is : {:.5f}".format(test_acc))
222 | res = {
223 | "model": self.model_name,
224 | "test_acc": test_acc
225 | }
226 | save_obj_to_json(self.args.weight_path, res, "result.json")
227 |
228 | def confirm_model_dataset_fitness(self):
229 | # make sure the models_in_datasets var is correct
230 | try:
231 | assert (models_in_datasets.get(self.args.dataset, None) is not None)
232 | except AssertionError:
233 | err("Models_in_datasets doesn't have the specified dataset key: {}.".format(self.args.dataset))
234 | self.sess.close()
235 | exit(1)
236 | # make sure the model fit the dataset
237 | try:
238 | assert (self.model_name in models_in_datasets.get(self.args.dataset, None))
239 | except AssertionError:
240 | err("The model -> {} doesn't support the dataset -> {}".format(self.model_name, self.args.dataset))
241 | self.sess.close()
242 | exit(1)
243 |
244 | def make_sure_model_is_valid(self):
245 | """
246 | check if the model has necessary attributes
247 | """
248 | try:
249 | _ = self.loss
250 | _ = self.correct_prediction
251 | except AttributeError as e:
252 | err("Your model {} doesn't have enough attributes.\nError Message:\n\t{}".format(self.model_name, e))
253 | self.sess.close()
254 | exit(1)
255 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | nltk>=3.2.1
2 | numpy>=1.12.1
3 |
--------------------------------------------------------------------------------
/test/dataset_test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import sys
4 | import unittest
5 |
6 | # noinspection PyUnresolvedReferences
7 | import dataset
8 |
9 |
10 | class TestDataset(unittest.TestCase):
11 | def setUp(self):
12 | logging.basicConfig(filename=None,
13 | level=logging.DEBUG,
14 | format='%(asctime)s %(message)s', datefmt='%y-%m-%d %H:%M')
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--debug", default=True, type=str, help="is debug mode on or off")
17 |
18 | parser.add_argument("--data_root", default="../data/SQuAD/",
19 | help="root path of the dataset")
20 |
21 | parser.add_argument("--tmp_dir", default="tmp", help="dataset specific tmp folder")
22 |
23 | parser.add_argument("--train_file", default="train-v1.1.json", help="train file")
24 |
25 | parser.add_argument("--valid_file", default="dev-v1.1.json", help="validation file")
26 |
27 | parser.add_argument("--max_count", default=None, type=int,
28 | help="read n lines of data file, if None, read all data")
29 |
30 | parser.add_argument("--max_vocab_num", default=100000, type=int, help="the max number of words in vocabulary")
31 |
32 | parser.add_argument("--batch_size", default=32, type=int, help="batch_size")
33 |
34 | parser.add_argument("--train", default=True, type=bool, help="train or not")
35 |
36 | parser.add_argument("--test", default=True, type=bool, help="test or not")
37 |
38 | self.args = parser.parse_known_args()[0]
39 |
40 |
41 | class TestCBT(TestDataset):
42 | def runTest(self):
43 | self.args.data_root = "../data/CBTest/CBTest/data/"
44 | self.args.train_file = "cbtest_NE_train.txt"
45 | self.args.valid_file = "cbtest_NE_valid_2000ex.txt"
46 | self.args.test_file = "cbtest_NE_test_2500ex.txt"
47 | self.dataset = getattr(sys.modules["dataset"], "CBT")(self.args)
48 | statistics = self.dataset.get_data_stream()
49 | for i in statistics[1:]:
50 | self.assertEqual(type(i), int, "Some data statistic not int.")
51 | self.assertGreater(i, 0, "Some data number not greater than zero.")
52 |
53 |
54 | class TestSQuAD(TestDataset):
55 | def runTest(self):
56 | self.dataset = getattr(sys.modules["dataset"], "SQuAD")(self.args)
57 | data_dir, train_file, valid_file = self.args.data_root, self.args.train_file, self.args.valid_file
58 | max_vocab_num, output_dir = self.args.max_vocab_num, self.args.tmp_dir
59 |
60 | os_train_file, os_valid_file, vocab_file, char_vocab_file = self.dataset.prepare_data(data_dir, train_file,
61 | valid_file, max_vocab_num,
62 | output_dir)
63 |
64 | documents, questions, _ = self.dataset.read_squad_data(os_train_file)
65 | v_documents, v_questions, _ = self.dataset.read_squad_data(os_valid_file)
66 | data = self.dataset.squad_data_to_idx(vocab_file, documents, questions,
67 | v_documents, v_questions)
68 | # make sure that each one of (d,q,v_d,v_q) is a list, and each element is a list too.
69 | for i in data:
70 | self.assertEqual(type(i), list, "some data in train set or valid set is not a list.")
71 | self.assertGreater(len(i), 0, "some data in train set or valid set is None.")
72 | self.assertEqual(type(i[0]), list, "some elements in train set or valid set is not a list.")
73 | for word in i[0]:
74 | self.assertEqual(type(word), int, "Not all the word is index form.")
75 | self.assertGreaterEqual(word, 0, "Invalid index for some word.")
76 |
--------------------------------------------------------------------------------
/test/notebook/test_aoa.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "collapsed": true
7 | },
8 | "source": [
9 | "### 1.test M_mask calculation"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import tensorflow as tf"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 25,
24 | "metadata": {},
25 | "outputs": [
26 | {
27 | "name": "stdout",
28 | "output_type": "stream",
29 | "text": [
30 | "q_mask shape:(5, 5)\nd_mask shape:(5, 10)\n"
31 | ]
32 | },
33 | {
34 | "data": {
35 | "text/plain": [
36 | "TensorShape([Dimension(5), Dimension(10), Dimension(5)])"
37 | ]
38 | },
39 | "execution_count": 25,
40 | "metadata": {},
41 | "output_type": "execute_result"
42 | }
43 | ],
44 | "source": [
45 | "q_len, d_len = 5, 10\n",
46 | "q_lens = tf.constant([3, 2, 1, 3, 4], dtype=tf.int32)\n",
47 | "d_lens = tf.constant([7, 8, 9, 6, 6], dtype=tf.int32)\n",
48 | "q_mask = tf.sequence_mask(q_lens, q_len, dtype=tf.float32)\n",
49 | "d_mask = tf.sequence_mask(d_lens, d_len, dtype=tf.float32)\n",
50 | "\n",
51 | "print(\"q_mask shape:{}\".format(q_mask.get_shape()))\n",
52 | "print(\"d_mask shape:{}\".format(d_mask.get_shape()))\n",
53 | "M_mask = tf.einsum(\"bi,bj->bij\", d_mask, q_mask)\n",
54 | "M_mask.get_shape()"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 26,
60 | "metadata": {},
61 | "outputs": [
62 | {
63 | "name": "stdout",
64 | "output_type": "stream",
65 | "text": [
66 | "[[ 1. 1. 1. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 1. 0.]]\n--------------------------------------------------\n[[ 1. 1. 1. 1. 1. 1. 1. 0. 0. 0.]\n [ 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.]\n [ 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]\n [ 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]]\n--------------------------------------------------\n[[[ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]]\n\n [[ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 1. 1. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]]\n\n [[ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 1. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]]\n\n [[ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 1. 1. 1. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]]\n\n [[ 1. 1. 1. 1. 0.]\n [ 1. 1. 1. 1. 0.]\n [ 1. 1. 1. 1. 0.]\n [ 1. 1. 1. 1. 0.]\n [ 1. 1. 1. 1. 0.]\n [ 1. 1. 1. 1. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.]]]\n"
67 | ]
68 | }
69 | ],
70 | "source": [
71 | "with tf.Session() as sess:\n",
72 | " sess.run(tf.global_variables_initializer())\n",
73 | " print(sess.run(q_mask))\n",
74 | " print(\"-\" * 50)\n",
75 | " print(sess.run(d_mask))\n",
76 | " print(\"-\" * 50)\n",
77 | " print(sess.run(M_mask))"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {},
83 | "source": [
84 | "### 2.test attention sum"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 1,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "import tensorflow as tf\n",
94 | "import numpy as np"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 7,
100 | "metadata": {},
101 | "outputs": [
102 | {
103 | "name": "stdout",
104 | "output_type": "stream",
105 | "text": [
106 | "[[ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]\n [ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]\n [ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]]\n(3, 10)\n--------------------------------------------------\n[[ 4 3 18 19 7 9 13 1 10 11]\n [ 5 16 11 6 4 0 16 1 11 8]\n [ 3 5 2 12 2 8 14 1 15 11]]\n(3, 10)\n--------------------------------------------------\n[[ 3 11 19 11]\n [ 9 1 14 11]\n [13 10 10 13]]\n(3, 4)\n"
107 | ]
108 | }
109 | ],
110 | "source": [
111 | "batch_size = 3\n",
112 | "d_len = 10\n",
113 | "vocab_size = 20\n",
114 | "A_len = 4\n",
115 | "true_s_bd = np.array([0.1]*d_len*batch_size).reshape(batch_size,d_len)\n",
116 | "true_documents_bt = np.random.randint(0,vocab_size,size=(batch_size,d_len))\n",
117 | "true_candidates_bi = np.random.randint(0,vocab_size,size=(batch_size,A_len))\n",
118 | "\n",
119 | "print(true_s_bd)\n",
120 | "print(true_s_bd.shape)\n",
121 | "print(\"-\"*50)\n",
122 | "print(true_documents_bt)\n",
123 | "print(true_documents_bt.shape)\n",
124 | "print(\"-\"*50)\n",
125 | "print(true_candidates_bi)\n",
126 | "print(true_candidates_bi.shape)"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 8,
132 | "metadata": {
133 | "collapsed": false
134 | },
135 | "outputs": [
136 | {
137 | "name": "stdout",
138 | "output_type": "stream",
139 | "text": [
140 | "[[ 0.1 0.1 0.1 0.1]\n [ 0. 0.1 0. 0.2]\n [ 0. 0. 0. 0. ]]\n"
141 | ]
142 | }
143 | ],
144 | "source": [
145 | "s_bd = tf.placeholder(dtype=tf.float32, shape=(None, d_len), name=\"s_bd\")\n",
146 | "documents_bt = tf.placeholder(dtype=tf.int32, shape=(None, d_len), name=\"documents_bt\")\n",
147 | "candidates_bi = tf.placeholder(dtype=tf.int32, shape=(None, A_len), name=\"candidates_bi\")\n",
148 | "y_hat_bi = tf.scan(fn=lambda prev, cur:\n",
149 | "tf.gather(tf.unsorted_segment_sum(cur[0], cur[1], vocab_size), cur[2]),\n",
150 | " elems=[s_bd, documents_bt, candidates_bi],\n",
151 | " initializer=tf.Variable([0.] * A_len,dtype=tf.float32))\n",
152 | "with tf.Session() as sess:\n",
153 | " sess.run(tf.global_variables_initializer())\n",
154 | " data = {\n",
155 | " s_bd:true_s_bd,\n",
156 | " documents_bt:true_documents_bt,\n",
157 | " candidates_bi:true_candidates_bi\n",
158 | " }\n",
159 | " print(sess.run(y_hat_bi,feed_dict=data))"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": null,
165 | "metadata": {},
166 | "outputs": [],
167 | "source": [
168 | ""
169 | ]
170 | }
171 | ],
172 | "metadata": {
173 | "kernelspec": {
174 | "display_name": "Python 2",
175 | "language": "python",
176 | "name": "python2"
177 | },
178 | "language_info": {
179 | "codemirror_mode": {
180 | "name": "ipython",
181 | "version": 2.0
182 | },
183 | "file_extension": ".py",
184 | "mimetype": "text/x-python",
185 | "name": "python",
186 | "nbconvert_exporter": "python",
187 | "pygments_lexer": "ipython2",
188 | "version": "2.7.6"
189 | }
190 | },
191 | "nbformat": 4,
192 | "nbformat_minor": 0
193 | }
--------------------------------------------------------------------------------
/test/notebook/test_as_reader.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## 测试注意力向量计算"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 3,
13 | "metadata": {},
14 | "outputs": [
15 | {
16 | "name": "stderr",
17 | "output_type": "stream",
18 | "text": [
19 | "Using TensorFlow backend.\n"
20 | ]
21 | }
22 | ],
23 | "source": [
24 | "import tensorflow as tf\n",
25 | "import numpy as np\n",
26 | "import keras.backend as K"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 76,
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "def my_dot(x):\n",
36 | " c = [tf.reduce_sum(tf.multiply(x[0][:, inx, :], x[1]), -1, keep_dims=True) for inx in range(3)]\n",
37 | " return tf.concat(c, -1)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 77,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "def my_dot_v2(x, y):\n",
47 | " \"\"\"注意力点乘函数,快速版本\"\"\"\n",
48 | " res = K.batch_dot(tf.expand_dims(y, -1),x, (1, 2))\n",
49 | " return K.reshape(res, [-1, 3])"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 4,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "def att_dot(x):\n",
59 | " \"\"\"注意力点乘函数\"\"\"\n",
60 | " d_btf, q_bf = x\n",
61 | " res = K.batch_dot(tf.expand_dims(y, -1),x, (1, 2))\n",
62 | " return tf.reshape(res, [-1, 3])"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 12,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "def new_att_dot(x):\n",
72 | " d_btf, q_bf = x\n",
73 | " res = tf.matmul(tf.expand_dims(q_bf, -1), d_btf, adjoint_a=True,adjoint_b=True)\n",
74 | " return tf.reshape(res, [-1, 3])"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 5,
80 | "metadata": {},
81 | "outputs": [
82 | {
83 | "name": "stdout",
84 | "output_type": "stream",
85 | "text": [
86 | "[[[ 0 1 2 3]\n [ 4 5 6 7]\n [ 8 9 10 11]]\n\n [[12 13 14 15]\n [16 17 18 19]\n [20 21 22 23]]]\n--------------------\n(2, 3, 4)\n--------------------\n(2, 4)\n--------------------\n[[0 1 2 3]\n [4 5 6 7]]\n"
87 | ]
88 | }
89 | ],
90 | "source": [
91 | "a = tf.placeholder(tf.float32,shape=(None,3,4))\n",
92 | "b = tf.placeholder(tf.float32,shape=(None,4))\n",
93 | "true_a = np.arange(24).reshape(2,3,4)\n",
94 | "true_b = np.arange(8).reshape(2,4)\n",
95 | "print(true_a)\n",
96 | "print('-'*20)\n",
97 | "print(true_a.shape)\n",
98 | "print('-'*20)\n",
99 | "print(true_b.shape)\n",
100 | "print('-'*20)\n",
101 | "print(true_b)"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 14,
107 | "metadata": {},
108 | "outputs": [
109 | {
110 | "data": {
111 | "text/plain": [
112 | "TensorShape([Dimension(None), Dimension(3)])"
113 | ]
114 | },
115 | "execution_count": 14,
116 | "metadata": {},
117 | "output_type": "execute_result"
118 | }
119 | ],
120 | "source": [
121 | "d = att_dot([a,b])\n",
122 | "d.get_shape()\n",
123 | "e = new_att_dot([a,b])\n",
124 | "e.get_shape()"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 15,
130 | "metadata": {},
131 | "outputs": [
132 | {
133 | "name": "stdout",
134 | "output_type": "stream",
135 | "text": [
136 | "[[ 14. 38. 62.]\n [ 302. 390. 478.]]\n--------------------------------------------------\n[[ 14. 38. 62.]\n [ 302. 390. 478.]]\n"
137 | ]
138 | }
139 | ],
140 | "source": [
141 | "with tf.Session() as sess:\n",
142 | " sess.run(tf.global_variables_initializer())\n",
143 | " print(sess.run(d, {a: true_a, b: true_b}))\n",
144 | " print(\"-\"*50)\n",
145 | " print(sess.run(e, {a: true_a, b: true_b}))"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": 80,
151 | "metadata": {},
152 | "outputs": [
153 | {
154 | "name": "stdout",
155 | "output_type": "stream",
156 | "text": [
157 | "(?, 4, 1)\n(?, 3, 4)\n"
158 | ]
159 | }
160 | ],
161 | "source": [
162 | "x = tf.expand_dims(b,-1)\n",
163 | "print(x.get_shape())\n",
164 | "y = a\n",
165 | "print(y.get_shape())"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": 81,
171 | "metadata": {},
172 | "outputs": [
173 | {
174 | "data": {
175 | "text/plain": [
176 | "TensorShape([Dimension(None), Dimension(3)])"
177 | ]
178 | },
179 | "execution_count": 81,
180 | "metadata": {},
181 | "output_type": "execute_result"
182 | }
183 | ],
184 | "source": [
185 | "res = K.batch_dot(x,y,(1,2))\n",
186 | "res = tf.reshape(res,[-1,3])\n",
187 | "res.get_shape()"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 82,
193 | "metadata": {},
194 | "outputs": [
195 | {
196 | "data": {
197 | "text/plain": [
198 | "TensorShape([Dimension(None), Dimension(3)])"
199 | ]
200 | },
201 | "execution_count": 82,
202 | "metadata": {},
203 | "output_type": "execute_result"
204 | }
205 | ],
206 | "source": [
207 | "d = my_dot_v2(a,b)\n",
208 | "d.get_shape()"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": 63,
214 | "metadata": {},
215 | "outputs": [
216 | {
217 | "data": {
218 | "text/plain": [
219 | "478"
220 | ]
221 | },
222 | "execution_count": 63,
223 | "metadata": {},
224 | "output_type": "execute_result"
225 | }
226 | ],
227 | "source": [
228 | "np.arange(20,24).dot(np.array([4,5,6,7]))"
229 | ]
230 | },
231 | {
232 | "cell_type": "markdown",
233 | "metadata": {},
234 | "source": [
235 | "## 测试tensorflow的scan"
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "execution_count": 4,
241 | "metadata": {
242 | "collapsed": true
243 | },
244 | "outputs": [
245 | {
246 | "name": "stdout",
247 | "output_type": "stream",
248 | "text": [
249 | "(3, 4)\nTensor(\"scan_1/while/TensorArrayReadV3:0\", shape=(3, 4), dtype=int32)\nTensor(\"scan_1/while/Identity_1:0\", shape=(3, 4), dtype=int32)\n"
250 | ]
251 | },
252 | {
253 | "name": "stdout",
254 | "output_type": "stream",
255 | "text": [
256 | "(?, 1)\n[ 27. 5. 7.]\n"
257 | ]
258 | }
259 | ],
260 | "source": [
261 | "sentence_ids = tf.Variable([5, 8, 1, 3, 0, 34, 8, 7, 3, 8])\n",
262 | "attentions = tf.Variable([5, 9, 1, 3, 0, 34, 9, 7, 3, 9],dtype=\"float32\")\n",
263 | "word_id = tf.Variable(8)\n",
264 | "word_ids = tf.Variable([8,5,7])\n",
265 | "aaa = tf.equal(sentence_ids, word_id)\n",
266 | "ccc = tf.where(aaa)\n",
267 | "qqq = tf.reduce_sum(tf.gather(attentions,ccc))\n",
268 | "\n",
269 | "\n",
270 | "def sum_prob_of_word(word_ix, sentence_ixs, sentence_attention_probs):\n",
271 | " word_ixs_in_sentence = tf.where(tf.equal(sentence_ixs, word_ix))\n",
272 | " return tf.reduce_sum(tf.gather(sentence_attention_probs, word_ixs_in_sentence))\n",
273 | "\n",
274 | "test_func = lambda x:sum_prob_of_word(x,sentence_ids,attentions)\n",
275 | "\n",
276 | "ppp = test_func(word_id)\n",
277 | "\n",
278 | "def sum_probs_single_sentence(prev,cur):\n",
279 | " candidate_indices_i, sentence_ixs_t, sentence_attention_probs_t = cur\n",
280 | " result = tf.scan(\n",
281 | " fn=lambda prev,x: sum_prob_of_word(x, sentence_ixs_t, sentence_attention_probs_t),\n",
282 | " elems=[candidate_indices_i],\n",
283 | " initializer=tf.Variable(0.,dtype=\"float32\"))\n",
284 | " return result\n",
285 | "\n",
286 | "zzz = sum_probs_single_sentence(None,[word_ids,sentence_ids,attentions])\n",
287 | "\n",
288 | "def func(prev, cur):\n",
289 | " print(cur.get_shape())\n",
290 | " print(cur)\n",
291 | " print(prev)\n",
292 | " return cur\n",
293 | "\n",
294 | "v = tf.Variable(np.arange(24).reshape(2, 3, 4))\n",
295 | "# print(v.get_shape())\n",
296 | "\n",
297 | "bbb = tf.scan(func, elems=v)\n",
298 | "with tf.Session() as sess:\n",
299 | " sess.run(tf.global_variables_initializer())\n",
300 | " print(ccc.get_shape())\n",
301 | " print(sess.run(zzz))\n",
302 | " # print(sess.run(bbb))"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": 1,
308 | "metadata": {},
309 | "outputs": [
310 | {
311 | "name": "stdout",
312 | "output_type": "stream",
313 | "text": [
314 | "[[ 0.60000002 0.2 0. ]\n [ 0. 0. 0. ]\n [ 0.1 0. 0.1 ]]\n"
315 | ]
316 | }
317 | ],
318 | "source": [
319 | "import tensorflow as tf\n",
320 | "import numpy as np\n",
321 | "\n",
322 | "\n",
323 | "def sum_prob_of_word(word_ix, sentence_ixs, sentence_attention_probs):\n",
324 | " word_ixs_in_sentence = tf.where(tf.equal(sentence_ixs, word_ix))\n",
325 | " return tf.reduce_sum(tf.gather(sentence_attention_probs, word_ixs_in_sentence))\n",
326 | "\n",
327 | "def sum_probs_single_sentence(prev,cur):\n",
328 | " candidate_indices_i, sentence_ixs_t, sentence_attention_probs_t = cur\n",
329 | " result = tf.scan(\n",
330 | " fn=lambda prev,x: sum_prob_of_word(x, sentence_ixs_t, sentence_attention_probs_t),\n",
331 | " elems=[candidate_indices_i],\n",
332 | " initializer=tf.constant(0.,dtype=\"float32\"))\n",
333 | " return result\n",
334 | "\n",
335 | "def sum_probs_batch(candidate_indices_bt, sentence_ixs_bt, sentence_attention_probs_bt):\n",
336 | " result = tf.scan(\n",
337 | " fn=sum_probs_single_sentence,\n",
338 | " elems=[candidate_indices_bt, sentence_ixs_bt, sentence_attention_probs_bt],\n",
339 | " initializer=tf.Variable([1,2,3],dtype=\"float32\"))\n",
340 | " return result\n",
341 | "\n",
342 | "candidate_idx = tf.Variable([\n",
343 | " [16, 21, 8],\n",
344 | " [13, 19, 26],\n",
345 | " [23, 9, 23]\n",
346 | "])\n",
347 | "\n",
348 | "sentence_idx = tf.Variable([\n",
349 | " [16, 21, 23, 16, 8, 9, 21],\n",
350 | " [16, 21, 23, 16, 8, 9, 21],\n",
351 | " [16, 21, 23, 16, 8, 9, 21],\n",
352 | "])\n",
353 | "\n",
354 | "attention_idx = tf.Variable([\n",
355 | " [0.3, 0.2, 0.1, 0.3, 0, 0, 0],\n",
356 | " [0.3, 0.2, 0.1, 0.3, 0, 0, 0],\n",
357 | " [0.3, 0.2, 0.1, 0.3, 0, 0, 0]\n",
358 | "],dtype=\"float32\")\n",
359 | "\n",
360 | "o = sum_probs_batch(candidate_idx, sentence_idx, attention_idx)\n",
361 | "with tf.Session() as sess:\n",
362 | " sess.run(tf.global_variables_initializer())\n",
363 | " print(sess.run(o))"
364 | ]
365 | },
366 | {
367 | "cell_type": "code",
368 | "execution_count": 1,
369 | "metadata": {},
370 | "outputs": [
371 | {
372 | "name": "stderr",
373 | "output_type": "stream",
374 | "text": [
375 | "Using TensorFlow backend.\n"
376 | ]
377 | },
378 | {
379 | "data": {
380 | "text/plain": [
381 | "array([[ 1., 0., 0., 0., 0., 0., 0.],\n [ 1., 1., 0., 0., 0., 0., 0.],\n [ 1., 1., 1., 0., 0., 0., 0.],\n [ 1., 1., 1., 1., 0., 0., 0.],\n [ 1., 1., 1., 1., 1., 0., 0.]], dtype=float32)"
382 | ]
383 | },
384 | "execution_count": 1,
385 | "metadata": {},
386 | "output_type": "execute_result"
387 | }
388 | ],
389 | "source": [
390 | "import tensorflow as tf\n",
391 | "import keras.backend as K\n",
392 | "\n",
393 | "a = [1,2,3,4,5]\n",
394 | "K.eval(tf.sequence_mask(a,7,dtype=tf.float32))"
395 | ]
396 | },
397 | {
398 | "cell_type": "code",
399 | "execution_count": 2,
400 | "metadata": {},
401 | "outputs": [],
402 | "source": [
403 | "import numpy as np"
404 | ]
405 | },
406 | {
407 | "cell_type": "code",
408 | "execution_count": 1,
409 | "metadata": {},
410 | "outputs": [],
411 | "source": [
412 | "batch_size = 4\n",
413 | "A_size = 5"
414 | ]
415 | },
416 | {
417 | "cell_type": "markdown",
418 | "metadata": {},
419 | "source": [
420 | "### 测试获取dynamic_rnn输出的有效位"
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "execution_count": 1,
426 | "metadata": {},
427 | "outputs": [],
428 | "source": [
429 | "import tensorflow as tf\n",
430 | "import numpy as np\n",
431 | "from tensorflow.contrib.rnn import LSTMCell, GRUCell, MultiRNNCell\n",
432 | "\n",
433 | "i = np.random.rand(1000)"
434 | ]
435 | },
436 | {
437 | "cell_type": "code",
438 | "execution_count": 2,
439 | "metadata": {},
440 | "outputs": [],
441 | "source": [
442 | "with tf.variable_scope('q_encoder'):\n",
443 | " cell = MultiRNNCell(cells=[GRUCell(2)] * 1)\n",
444 | " x = tf.placeholder(dtype=tf.float32, shape=(2, 5, 100), name=\"x\")\n",
445 | " q_lens = tf.placeholder(dtype=tf.int32,shape=(2))\n",
446 | " outputs, last_states = tf.nn.bidirectional_dynamic_rnn(cell_bw=cell,\n",
447 | " cell_fw=cell,\n",
448 | " dtype=\"float32\",\n",
449 | " sequence_length=q_lens,\n",
450 | " inputs=x,\n",
451 | " swap_memory=True)\n",
452 | " q_enc = tf.gather_nd(outputs, tf.stack([tf.range(q_lens.get_shape()[0]), q_lens - 1], axis=1))\n",
453 | " q_enc_c = tf.concat(outputs,axis=-1)"
454 | ]
455 | },
456 | {
457 | "cell_type": "code",
458 | "execution_count": 3,
459 | "metadata": {
460 | "collapsed": false
461 | },
462 | "outputs": [
463 | {
464 | "name": "stdout",
465 | "output_type": "stream",
466 | "text": [
467 | "outputs:(array([[[ 0.0560709 , -0.10033438],\n [ 0.07781161, -0.220449 ],\n [ 0.06069176, -0.23297542],\n [ 0. , 0. ],\n [ 0. , 0. ]],\n\n [[ 0.02716292, -0.05908116],\n [-0.03233079, -0.02452423],\n [ 0. , 0. ],\n [ 0. , 0. ],\n [ 0. , 0. ]]], dtype=float32), array([[[ 0.08484475, 0.06869046],\n [ 0.07913341, 0.17418024],\n [ 0.11813057, -0.0282253 ],\n [ 0. , 0. ],\n [ 0. , 0. ]],\n\n [[-0.12557939, -0.38799691],\n [-0.09159085, -0.26809931],\n [ 0. , 0. ],\n [ 0. , 0. ],\n [ 0. , 0. ]]], dtype=float32))\nstates:((array([[ 0.06069176, -0.23297542],\n [-0.03233079, -0.02452423]], dtype=float32),), (array([[ 0.08484475, 0.06869046],\n [-0.12557939, -0.38799691]], dtype=float32),))\nq_enc:[[[ 0. 0. ]\n [ 0. 0. ]\n [ 0. 0. ]\n [ 0. 0. ]\n [ 0. 0. ]]\n\n [[-0.12557939 -0.38799691]\n [-0.09159085 -0.26809931]\n [ 0. 0. ]\n [ 0. 0. ]\n [ 0. 0. ]]]\nq_enc_c:[[[ 0.0560709 -0.10033438 0.08484475 0.06869046]\n [ 0.07781161 -0.220449 0.07913341 0.17418024]\n [ 0.06069176 -0.23297542 0.11813057 -0.0282253 ]\n [ 0. 0. 0. 0. ]\n [ 0. 0. 0. 0. ]]\n\n [[ 0.02716292 -0.05908116 -0.12557939 -0.38799691]\n [-0.03233079 -0.02452423 -0.09159085 -0.26809931]\n [ 0. 0. 0. 0. ]\n [ 0. 0. 0. 0. ]\n [ 0. 0. 0. 0. ]]]\n"
468 | ]
469 | }
470 | ],
471 | "source": [
472 | "with tf.Session() as sess:\n",
473 | " sess.run(tf.global_variables_initializer())\n",
474 | " a, b, c, d = sess.run([outputs, last_states, q_enc, q_enc_c], feed_dict={\n",
475 | " q_lens: (3, 2),\n",
476 | " x: i.reshape(2, 5, 100)\n",
477 | " })\n",
478 | " print(\"outputs:{}\\nstates:{}\\nq_enc:{}\\nq_enc_c:{}\".format(a, b, c, d))"
479 | ]
480 | },
481 | {
482 | "cell_type": "markdown",
483 | "metadata": {},
484 | "source": [
485 | "### 测试参数初始化"
486 | ]
487 | },
488 | {
489 | "cell_type": "code",
490 | "execution_count": 3,
491 | "metadata": {},
492 | "outputs": [],
493 | "source": [
494 | "import tensorflow as tf\n",
495 | "import numpy as np\n",
496 | "from tensorflow.contrib.rnn import LSTMCell, MultiRNNCell, GRUCell"
497 | ]
498 | },
499 | {
500 | "cell_type": "code",
501 | "execution_count": 9,
502 | "metadata": {},
503 | "outputs": [
504 | {
505 | "name": "stdout",
506 | "output_type": "stream",
507 | "text": [
508 | "[[ 0.89980626 0.31822035 0.29847053]\n [-0.29329553 0.94766068 -0.12616226]\n [-0.3229962 0.02598153 0.94604361]]\n[[ 0. 0. 0.]\n [ 0. 0. 0.]\n [ 0. 0. 0.]]\n"
509 | ]
510 | }
511 | ],
512 | "source": [
513 | "from tensorflow import variable_scope, get_variable\n",
514 | "\n",
515 | "with variable_scope(\"cc\",initializer=tf.orthogonal_initializer()):\n",
516 | " with variable_scope(\"ddd\"):\n",
517 | " a = get_variable(\"weight\",shape=(3,3),dtype=tf.float32)\n",
518 | "\n",
519 | "with variable_scope(\"cccc\",initializer=tf.zeros_initializer()):\n",
520 | " with variable_scope(\"ccc\"):\n",
521 | " b = get_variable(\"weight\",shape=(3,3),dtype=tf.float32)\n",
522 | "\n",
523 | "with tf.Session() as sess:\n",
524 | " sess.run(tf.global_variables_initializer())\n",
525 | " print(sess.run(a))\n",
526 | " print(sess.run(b))"
527 | ]
528 | },
529 | {
530 | "cell_type": "code",
531 | "execution_count": null,
532 | "metadata": {},
533 | "outputs": [],
534 | "source": [
535 | ""
536 | ]
537 | }
538 | ],
539 | "metadata": {
540 | "kernelspec": {
541 | "display_name": "Python 2",
542 | "language": "python",
543 | "name": "python2"
544 | },
545 | "language_info": {
546 | "codemirror_mode": {
547 | "name": "ipython",
548 | "version": 2.0
549 | },
550 | "file_extension": ".py",
551 | "mimetype": "text/x-python",
552 | "name": "python",
553 | "nbconvert_exporter": "python",
554 | "pygments_lexer": "ipython2",
555 | "version": "2.7.6"
556 | }
557 | },
558 | "nbformat": 4,
559 | "nbformat_minor": 0
560 | }
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cairoHy/RC-experiments/0262f83481c364f29a43ac7cfc28da88d31f5adc/utils/__init__.py
--------------------------------------------------------------------------------
/utils/log.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 |
5 | import sys
6 | from pprint import pprint
7 |
8 | logger = logging.info
9 | err = logging.error
10 |
11 |
12 | def setup_from_args_file(file):
13 | if not file:
14 | return
15 | json_dict = json.load(open(file, encoding="utf-8"))
16 | args = [sys.argv[0]]
17 | for k, v in json_dict.items():
18 | args.append("--{}".format(k))
19 | args.append(str(v))
20 | sys.argv = args.copy() + sys.argv[1:]
21 |
22 |
23 | def save_args(args):
24 | """
25 | save all arguments.
26 | """
27 | save_obj_to_json(args.weight_path, vars(args), filename="args.json")
28 | pprint(vars(args), indent=4)
29 |
30 |
31 | def save_obj_to_json(path, obj, filename):
32 | if not os.path.exists(path):
33 | os.mkdir(path)
34 | file = os.path.join(path, filename)
35 | with open(file, "w", encoding="utf-8") as fp:
36 | json.dump(obj, fp, sort_keys=True, indent=4)
37 |
--------------------------------------------------------------------------------
/weights/AS-reader/best-CBT-CN/args.json:
--------------------------------------------------------------------------------
1 | {
2 | "args_file": "weights/args.json",
3 | "batch_size": 32,
4 | "data_root": "data/CBTest/CBTest/data/",
5 | "dataset": "cbt",
6 | "debug": false,
7 | "embedding_dim": 300,
8 | "embedding_file": "data/glove.6B/glove.6B.300d.txt",
9 | "ensemble": false,
10 | "evaluate_every_n": 400,
11 | "grad_clipping": 10,
12 | "hidden_size": 128,
13 | "keep_prob": 0.5,
14 | "l2": 0.0001,
15 | "log_file": null,
16 | "lr": 0.001,
17 | "max_count": null,
18 | "max_vocab_num": 100000,
19 | "num_epoches": 10,
20 | "num_layers": 1,
21 | "optimizer": "ADAM",
22 | "patience": 5,
23 | "print_every_n": 10,
24 | "random_seed": 2088,
25 | "test": true,
26 | "test_file": "cbtest_CN_test_2500ex.txt",
27 | "tmp_dir": "tmp",
28 | "train": false,
29 | "train_file": "cbtest_CN_train.txt",
30 | "use_lstm": false,
31 | "valid_file": "cbtest_CN_valid_2000ex.txt",
32 | "weight_path": "weights/"
33 | }
--------------------------------------------------------------------------------
/weights/AS-reader/best-CBT-CN/result.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "attention_sum_reader.py",
3 | "test_acc": 0.65
4 | }
--------------------------------------------------------------------------------
/weights/AS-reader/best-CBT-NE/args.json:
--------------------------------------------------------------------------------
1 | {
2 | "args_file": "weights/AS-reader/best-CBT-NE/args.json",
3 | "batch_size": 32,
4 | "data_root": "data/CBTest/CBTest/data/",
5 | "dataset": "cbt",
6 | "debug": false,
7 | "embedding_dim": 200,
8 | "embedding_file": "data/glove.6B/glove.6B.200d.txt",
9 | "ensemble": false,
10 | "evaluate_every_n": 400,
11 | "grad_clipping": 10,
12 | "hidden_size": 128,
13 | "keep_prob": 0.5,
14 | "l2": 0.0001,
15 | "log_file": null,
16 | "lr": 0.001,
17 | "max_count": null,
18 | "max_vocab_num": 100000,
19 | "num_epoches": 10,
20 | "num_layers": 1,
21 | "optimizer": "ADAM",
22 | "patience": 5,
23 | "print_every_n": 10,
24 | "random_seed": 2088,
25 | "test": true,
26 | "test_file": "cbtest_NE_test_2500ex.txt",
27 | "tmp_dir": "tmp",
28 | "train": false,
29 | "train_file": "cbtest_NE_train.txt",
30 | "use_lstm": false,
31 | "valid_file": "cbtest_NE_valid_2000ex.txt",
32 | "weight_path": "weights/AS-reader/best-CBT-NE"
33 | }
--------------------------------------------------------------------------------
/weights/AS-reader/best-CBT-NE/result.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "attention_sum_reader.py",
3 | "test_acc": 0.6896
4 | }
--------------------------------------------------------------------------------
/weights/AS-reader/best-best-CBT-NE/args.json:
--------------------------------------------------------------------------------
1 | {
2 | "args_file": "weights/AS-reader/best-best-CBT-NE/args.json",
3 | "batch_size": 32,
4 | "data_root": "data/CBTest/CBTest/data/",
5 | "dataset": "cbt",
6 | "debug": false,
7 | "embedding_dim": 300,
8 | "embedding_file": "data/glove.6B/glove.6B.300d.txt",
9 | "ensemble": false,
10 | "evaluate_every_n": 400,
11 | "grad_clipping": 10,
12 | "hidden_size": 128,
13 | "keep_prob": 0.5,
14 | "l2": 0.0001,
15 | "log_file": null,
16 | "lr": 0.001,
17 | "max_count": null,
18 | "max_vocab_num": 100000,
19 | "num_epoches": 10,
20 | "num_layers": 1,
21 | "optimizer": "ADAM",
22 | "patience": 5,
23 | "print_every_n": 10,
24 | "random_seed": 2088,
25 | "test": true,
26 | "test_file": "cbtest_NE_test_2500ex.txt",
27 | "tmp_dir": "tmp",
28 | "train": false,
29 | "train_file": "cbtest_NE_train.txt",
30 | "use_lstm": false,
31 | "valid_file": "cbtest_NE_valid_2000ex.txt",
32 | "weight_path": "weights/AS-reader/best-best-CBT-NE"
33 | }
--------------------------------------------------------------------------------
/weights/AS-reader/best-best-CBT-NE/result.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "attention_sum_reader.py",
3 | "test_acc": 0.6988
4 | }
--------------------------------------------------------------------------------
/weights/AoA-reader/best-CBT-CN/args.json:
--------------------------------------------------------------------------------
1 | {
2 | "args_file": "weights/args.json",
3 | "batch_size": 32,
4 | "data_root": "data/CBTest/data/",
5 | "dataset": "cbt",
6 | "debug": false,
7 | "embedding_dim": 300,
8 | "embedding_file": "data/glove.6B/glove.6B.300d.txt",
9 | "ensemble": false,
10 | "evaluate_every_n": 400,
11 | "grad_clipping": 10,
12 | "hidden_size": 128,
13 | "keep_prob": 0.5,
14 | "l2": 0.0001,
15 | "log_file": null,
16 | "lr": 0.001,
17 | "max_count": null,
18 | "max_vocab_num": 100000,
19 | "num_epoches": 10,
20 | "num_layers": 1,
21 | "optimizer": "ADAM",
22 | "patience": 5,
23 | "print_every_n": 10,
24 | "random_seed": 2088,
25 | "test": true,
26 | "test_file": "cbtest_CN_test_2500ex.txt",
27 | "tmp_dir": "tmp",
28 | "train": false,
29 | "train_file": "cbtest_CN_train.txt",
30 | "use_lstm": false,
31 | "valid_file": "cbtest_CN_valid_2000ex.txt",
32 | "weight_path": "weights/"
33 | }
--------------------------------------------------------------------------------
/weights/AoA-reader/best-CBT-CN/result.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "attention_over_attention_reader.py",
3 | "test_acc": 0.6812
4 | }
--------------------------------------------------------------------------------
/weights/AoA-reader/best-CBT-NE/args.json:
--------------------------------------------------------------------------------
1 | {
2 | "args_file": "weights/AoA-reader/best-CBT-NE/args.json",
3 | "batch_size": 32,
4 | "data_root": "data/CBTest/CBTest/data/",
5 | "dataset": "cbt",
6 | "debug": false,
7 | "embedding_dim": 200,
8 | "embedding_file": "data/glove.6B/glove.6B.200d.txt",
9 | "ensemble": false,
10 | "evaluate_every_n": 400,
11 | "grad_clipping": 10,
12 | "hidden_size": 128,
13 | "keep_prob": 1.0,
14 | "l2": 0.0001,
15 | "log_file": null,
16 | "lr": 0.001,
17 | "max_count": null,
18 | "max_vocab_num": 100000,
19 | "num_epoches": 10,
20 | "num_layers": 1,
21 | "optimizer": "ADAM",
22 | "patience": 5,
23 | "print_every_n": 10,
24 | "random_seed": 2088,
25 | "test": true,
26 | "test_file": "cbtest_NE_test_2500ex.txt",
27 | "tmp_dir": "tmp",
28 | "train": false,
29 | "train_file": "cbtest_NE_train.txt",
30 | "use_lstm": false,
31 | "valid_file": "cbtest_NE_valid_2000ex.txt",
32 | "weight_path": "weights/AoA-reader/best-CBT-NE"
33 | }
--------------------------------------------------------------------------------
/weights/AoA-reader/best-CBT-NE/result.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "attention_over_attention_reader.py",
3 | "test_acc": 0.7088
4 | }
--------------------------------------------------------------------------------
/weights/AoA-reader/best-best-CBT-NE/args.json:
--------------------------------------------------------------------------------
1 | {
2 | "args_file": "weights/args.json",
3 | "batch_size": 32,
4 | "data_root": "data/CBTest/CBTest/data/",
5 | "dataset": "cbt",
6 | "debug": false,
7 | "embedding_dim": 300,
8 | "embedding_file": "data/glove.6B/glove.6B.300d.txt",
9 | "ensemble": false,
10 | "evaluate_every_n": 400,
11 | "grad_clipping": 10,
12 | "hidden_size": 128,
13 | "keep_prob": 0.5,
14 | "l2": 0.0001,
15 | "log_file": null,
16 | "lr": 0.001,
17 | "max_count": null,
18 | "max_vocab_num": 100000,
19 | "num_epoches": 10,
20 | "num_layers": 1,
21 | "optimizer": "ADAM",
22 | "patience": 5,
23 | "print_every_n": 10,
24 | "random_seed": 2088,
25 | "test": true,
26 | "test_file": "cbtest_NE_test_2500ex.txt",
27 | "tmp_dir": "tmp",
28 | "train": false,
29 | "train_file": "cbtest_NE_train.txt",
30 | "use_lstm": false,
31 | "valid_file": "cbtest_NE_valid_2000ex.txt",
32 | "weight_path": "weights/"
33 | }
--------------------------------------------------------------------------------
/weights/AoA-reader/best-best-CBT-NE/result.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "attention_over_attention_reader.py",
3 | "test_acc": 0.71
4 | }
--------------------------------------------------------------------------------