├── .gitignore
├── .pylintrc
├── LICENSE
├── README.md
├── data
└── CSAbstruct
│ ├── dev.jsonl
│ ├── test.jsonl
│ └── train.jsonl
├── requirements.txt
├── scripts
├── rouge_eval.py
└── train.sh
└── sequential_sentence_classification
├── __init__.py
├── config.jsonnet
├── dataset_reader.py
├── model.py
└── predictor.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | tmp
3 | .vscode
4 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 |
3 | # Specify a configuration file.
4 | #rcfile=
5 |
6 | # Python code to execute, usually for sys.path manipulation such as
7 | # pygtk.require().
8 | init-hook='import sys; sys.path.append("./")'
9 |
10 | # Add files or directories to the blacklist. They should be base names, not
11 | # paths.
12 | ignore=CVS,custom_extensions
13 |
14 | # Add files or directories matching the regex patterns to the blacklist. The
15 | # regex matches against base names, not paths.
16 | ignore-patterns=
17 |
18 | # Pickle collected data for later comparisons.
19 | persistent=yes
20 |
21 | # List of plugins (as comma separated values of python modules names) to load,
22 | # usually to register additional checkers.
23 | load-plugins=
24 |
25 | # Use multiple processes to speed up Pylint.
26 | jobs=4
27 |
28 | # Allow loading of arbitrary C extensions. Extensions are imported into the
29 | # active Python interpreter and may run arbitrary code.
30 | unsafe-load-any-extension=no
31 |
32 | # A comma-separated list of package or module names from where C extensions may
33 | # be loaded. Extensions are loading into the active Python interpreter and may
34 | # run arbitrary code
35 | extension-pkg-whitelist=numpy,torch
36 |
37 | # Allow optimization of some AST trees. This will activate a peephole AST
38 | # optimizer, which will apply various small optimizations. For instance, it can
39 | # be used to obtain the result of joining multiple strings with the addition
40 | # operator. Joining a lot of strings can lead to a maximum recursion error in
41 | # Pylint and this flag can prevent that. It has one side effect, the resulting
42 | # AST will be different than the one from reality. This option is deprecated
43 | # and it will be removed in Pylint 2.0.
44 | optimize-ast=no
45 |
46 |
47 | [MESSAGES CONTROL]
48 |
49 | # Only show warnings with the listed confidence levels. Leave empty to show
50 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
51 | confidence=
52 |
53 | # Enable the message, report, category or checker with the given id(s). You can
54 | # either give multiple identifier separated by comma (,) or put this option
55 | # multiple time (only on the command line, not in the configuration file where
56 | # it should appear only once). See also the "--disable" option for examples.
57 | #enable=
58 |
59 | # Disable the message, report, category or checker with the given id(s). You
60 | # can either give multiple identifiers separated by comma (,) or put this
61 | # option multiple times (only on the command line, not in the configuration
62 | # file where it should appear only once).You can also use "--disable=all" to
63 | # disable everything first and then reenable specific checks. For example, if
64 | # you want to run only the similarities checker, you can use "--disable=all
65 | # --enable=similarities". If you want to run only the classes checker, but have
66 | # no Warning level messages displayed, use"--disable=all --enable=classes
67 | # --disable=W"
68 | disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating,missing-docstring,too-many-arguments,too-many-locals,too-many-statements,too-many-branches,too-many-nested-blocks,too-many-instance-attributes,fixme,too-few-public-methods,no-else-return,invalid-name,len-as-condition
69 |
70 |
71 | [REPORTS]
72 |
73 | # Set the output format. Available formats are text, parseable, colorized, msvs
74 | # (visual studio) and html. You can also give a reporter class, eg
75 | # mypackage.mymodule.MyReporterClass.
76 | output-format=text
77 |
78 | # Put messages in a separate file for each module / package specified on the
79 | # command line instead of printing them on stdout. Reports (if any) will be
80 | # written in a file name "pylint_global.[txt|html]". This option is deprecated
81 | # and it will be removed in Pylint 2.0.
82 | files-output=no
83 |
84 | # Tells whether to display a full report or only the messages
85 | reports=no
86 |
87 | # Python expression which should return a note less than 10 (10 is the highest
88 | # note). You have access to the variables errors warning, statement which
89 | # respectively contain the number of errors / warnings messages and the total
90 | # number of statements analyzed. This is used by the global evaluation report
91 | # (RP0004).
92 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
93 |
94 | # Template used to display messages. This is a python new-style format string
95 | # used to format the message information. See doc for all details
96 | #msg-template=
97 |
98 |
99 | [LOGGING]
100 |
101 | # Logging modules to check that the string format arguments are in logging
102 | # function parameter format
103 | logging-modules=logging
104 |
105 |
106 | [TYPECHECK]
107 |
108 | # Tells whether missing members accessed in mixin class should be ignored. A
109 | # mixin class is detected if its name ends with "mixin" (case insensitive).
110 | ignore-mixin-members=yes
111 |
112 | # List of module names for which member attributes should not be checked
113 | # (useful for modules/projects where namespaces are manipulated during runtime
114 | # and thus existing member attributes cannot be deduced by static analysis. It
115 | # supports qualified module names, as well as Unix pattern matching.
116 | ignored-modules=
117 |
118 | # List of class names for which member attributes should not be checked (useful
119 | # for classes with dynamically set attributes). This supports the use of
120 | # qualified names.
121 | ignored-classes=optparse.Values,thread._local,_thread._local,responses
122 |
123 | # List of members which are set dynamically and missed by pylint inference
124 | # system, and so shouldn't trigger E1101 when accessed. Python regular
125 | # expressions are accepted.
126 | generated-members=torch.*
127 |
128 | # List of decorators that produce context managers, such as
129 | # contextlib.contextmanager. Add to this list to register other decorators that
130 | # produce valid context managers.
131 | contextmanager-decorators=contextlib.contextmanager
132 |
133 |
134 | [SIMILARITIES]
135 |
136 | # Minimum lines number of a similarity.
137 | min-similarity-lines=4
138 |
139 | # Ignore comments when computing similarities.
140 | ignore-comments=yes
141 |
142 | # Ignore docstrings when computing similarities.
143 | ignore-docstrings=yes
144 |
145 | # Ignore imports when computing similarities.
146 | ignore-imports=no
147 |
148 |
149 | [FORMAT]
150 |
151 | # Maximum number of characters on a single line. Ideally, lines should be under 100 characters,
152 | # but we allow some leeway before calling it an error.
153 | max-line-length=115
154 |
155 | # Regexp for a line that is allowed to be longer than the limit.
156 | ignore-long-lines=^\s*(# )??$
157 |
158 | # Allow the body of an if to be on the same line as the test if there is no
159 | # else.
160 | single-line-if-stmt=no
161 |
162 | # List of optional constructs for which whitespace checking is disabled. `dict-
163 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
164 | # `trailing-comma` allows a space between comma and closing bracket: (a, ).
165 | # `empty-line` allows space-only lines.
166 | no-space-check=trailing-comma,dict-separator
167 |
168 | # Maximum number of lines in a module
169 | max-module-lines=1000
170 |
171 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
172 | # tab).
173 | indent-string=' '
174 |
175 | # Number of spaces of indent required inside a hanging or continued line.
176 | indent-after-paren=8
177 |
178 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
179 | expected-line-ending-format=
180 |
181 |
182 | [BASIC]
183 |
184 | # Good variable names which should always be accepted, separated by a comma
185 | good-names=i,j,k,ex,Run,_
186 |
187 | # Bad variable names which should always be refused, separated by a comma
188 | bad-names=foo,bar,baz,toto,tutu,tata
189 |
190 | # Colon-delimited sets of names that determine each other's naming style when
191 | # the name regexes allow several styles.
192 | name-group=
193 |
194 | # Include a hint for the correct naming format with invalid-name
195 | include-naming-hint=no
196 |
197 | # List of decorators that produce properties, such as abc.abstractproperty. Add
198 | # to this list to register other decorators that produce valid properties.
199 | property-classes=abc.abstractproperty
200 |
201 | # Regular expression matching correct function names
202 | function-rgx=[a-z_][a-z0-9_]{2,40}$
203 |
204 | # Naming hint for function names
205 | function-name-hint=[a-z_][a-z0-9_]{2,40}$
206 |
207 | # Regular expression matching correct variable names
208 | variable-rgx=[a-z_][a-z0-9_]{2,40}$
209 |
210 | # Naming hint for variable names
211 | variable-name-hint=[a-z_][a-z0-9_]{2,40}$
212 |
213 | # Regular expression matching correct constant names
214 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
215 |
216 | # Naming hint for constant names
217 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
218 |
219 | # Regular expression matching correct attribute names
220 | attr-rgx=[a-z_][a-z0-9_]{2,40}$
221 |
222 | # Naming hint for attribute names
223 | attr-name-hint=[a-z_][a-z0-9_]{2,40}$
224 |
225 | # Regular expression matching correct argument names
226 | argument-rgx=[a-z_][a-z0-9_]{2,40}$
227 |
228 | # Naming hint for argument names
229 | argument-name-hint=[a-z_][a-z0-9_]{2,40}$
230 |
231 | # Regular expression matching correct class attribute names
232 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,40}|(__.*__))$
233 |
234 | # Naming hint for class attribute names
235 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,40}|(__.*__))$
236 |
237 | # Regular expression matching correct inline iteration names
238 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
239 |
240 | # Naming hint for inline iteration names
241 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
242 |
243 | # Regular expression matching correct class names
244 | class-rgx=[A-Z_][a-zA-Z0-9]+$
245 |
246 | # Naming hint for class names
247 | class-name-hint=[A-Z_][a-zA-Z0-9]+$
248 |
249 | # Regular expression matching correct module names
250 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
251 |
252 | # Naming hint for module names
253 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
254 |
255 | # Regular expression matching correct method names
256 | method-rgx=[a-z_][a-z0-9_]{2,40}$
257 |
258 | # Naming hint for method names
259 | method-name-hint=[a-z_][a-z0-9_]{2,40}$
260 |
261 | # Regular expression which should only match function or class names that do
262 | # not require a docstring.
263 | no-docstring-rgx=^_
264 |
265 | # Minimum line length for functions/classes that require docstrings, shorter
266 | # ones are exempt.
267 | docstring-min-length=-1
268 |
269 |
270 | [ELIF]
271 |
272 | # Maximum number of nested blocks for function / method body
273 | max-nested-blocks=5
274 |
275 |
276 | [VARIABLES]
277 |
278 | # Tells whether we should check for unused import in __init__ files.
279 | init-import=no
280 |
281 | # A regular expression matching the name of dummy variables (i.e. expectedly
282 | # not used).
283 | dummy-variables-rgx=(_+[a-zA-Z0-9]*?$)|dummy
284 |
285 | # List of additional names supposed to be defined in builtins. Remember that
286 | # you should avoid to define new builtins when possible.
287 | additional-builtins=
288 |
289 | # List of strings which can identify a callback function by name. A callback
290 | # name must start or end with one of those strings.
291 | callbacks=cb_,_cb
292 |
293 | # List of qualified module names which can have objects that can redefine
294 | # builtins.
295 | redefining-builtins-modules=six.moves,future.builtins
296 |
297 |
298 | [SPELLING]
299 |
300 | # Spelling dictionary name. Available dictionaries: none. To make it working
301 | # install python-enchant package.
302 | spelling-dict=
303 |
304 | # List of comma separated words that should not be checked.
305 | spelling-ignore-words=
306 |
307 | # A path to a file that contains private dictionary; one word per line.
308 | spelling-private-dict-file=
309 |
310 | # Tells whether to store unknown words to indicated private dictionary in
311 | # --spelling-private-dict-file option instead of raising a message.
312 | spelling-store-unknown-words=no
313 |
314 |
315 | [MISCELLANEOUS]
316 |
317 | # List of note tags to take in consideration, separated by a comma.
318 | notes=FIXME,XXX,TODO
319 |
320 |
321 | [DESIGN]
322 |
323 | # Maximum number of arguments for function / method
324 | max-args=5
325 |
326 | # Argument names that match this expression will be ignored. Default to name
327 | # with leading underscore
328 | ignored-argument-names=_.*
329 |
330 | # Maximum number of locals for function / method body
331 | max-locals=15
332 |
333 | # Maximum number of return / yield for function / method body
334 | max-returns=6
335 |
336 | # Maximum number of branch for function / method body
337 | max-branches=12
338 |
339 | # Maximum number of statements in function / method body
340 | max-statements=50
341 |
342 | # Maximum number of parents for a class (see R0901).
343 | max-parents=7
344 |
345 | # Maximum number of attributes for a class (see R0902).
346 | max-attributes=7
347 |
348 | # Minimum number of public methods for a class (see R0903).
349 | min-public-methods=2
350 |
351 | # Maximum number of public methods for a class (see R0904).
352 | max-public-methods=20
353 |
354 | # Maximum number of boolean expressions in a if statement
355 | max-bool-expr=5
356 |
357 |
358 | [CLASSES]
359 |
360 | # List of method names used to declare (i.e. assign) instance attributes.
361 | defining-attr-methods=__init__,__new__,setUp
362 |
363 | # List of valid names for the first argument in a class method.
364 | valid-classmethod-first-arg=cls
365 |
366 | # List of valid names for the first argument in a metaclass class method.
367 | valid-metaclass-classmethod-first-arg=mcs
368 |
369 | # List of member names, which should be excluded from the protected access
370 | # warning.
371 | exclude-protected=_asdict,_fields,_replace,_source,_make
372 |
373 |
374 | [IMPORTS]
375 |
376 | # Deprecated modules which should not be used, separated by a comma
377 | deprecated-modules=regsub,TERMIOS,Bastion,rexec
378 |
379 | # Create a graph of every (i.e. internal and external) dependencies in the
380 | # given file (report RP0402 must not be disabled)
381 | import-graph=
382 |
383 | # Create a graph of external dependencies in the given file (report RP0402 must
384 | # not be disabled)
385 | ext-import-graph=
386 |
387 | # Create a graph of internal dependencies in the given file (report RP0402 must
388 | # not be disabled)
389 | int-import-graph=
390 |
391 | # Force import order to recognize a module as part of the standard
392 | # compatibility libraries.
393 | known-standard-library=
394 |
395 | # Force import order to recognize a module as part of a third party library.
396 | known-third-party=enchant
397 |
398 | # Analyse import fallback blocks. This can be used to support both Python 2 and
399 | # 3 compatible code, which means that the block might have code that exists
400 | # only in one or another interpreter, leading to false positives when analysed.
401 | analyse-fallback-blocks=no
402 |
403 |
404 | [EXCEPTIONS]
405 |
406 | # Exceptions that will emit a warning when being caught. Defaults to
407 | # "Exception"
408 | overgeneral-exceptions=Exception
409 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
Sequential Sentence Classification
2 | This repo has code and data for our paper ["Pretrained Language Models for Sequential Sentence Classification"](https://arxiv.org/pdf/1909.04054.pdf).
3 |
4 |
5 |
6 | ### How to run
7 |
8 | ```
9 | pip install -r requirements.txt
10 | scripts/train.sh tmp_output_dir
11 | ```
12 |
13 | Update the `scripts/train.sh` script with the appropriate hyperparameters and datapaths.
14 |
15 | ### CSAbstrcut dataset
16 |
17 | The train, dev, test splits of the dataset are in `data/CSAbstrcut`.
18 |
19 | CSAbstrcut is also available on the [Huggingface Hub](https://huggingface.co/datasets/allenai/csabstruct).
20 |
21 | ### Citing
22 |
23 | If you use the data or the model, please cite,
24 | ```
25 | @inproceedings{Cohan2019EMNLP,
26 | title={Pretrained Language Models for Sequential Sentence Classification},
27 | author={Arman Cohan, Iz Beltagy, Daniel King, Bhavana Dalvi, Dan Weld},
28 | year={2019},
29 | booktitle={EMNLP},
30 | }
31 | ```
32 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jsonlines
2 | -e git://github.com/ibeltagy/allennlp@fp16_and_others#egg=allennlp
--------------------------------------------------------------------------------
/scripts/rouge_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | from allennlp.models.archival import load_archive
8 | from allennlp.service.predictors import Predictor
9 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader
10 |
11 | # Rouge computation is taken from https://github.com/EdCo95/scientific-paper-summarisation/blob/master/Evaluation/rouge.py
12 | #
13 | # File Name : https://github.com/EdCo95/scientific-paper-summarisation/blob/master/Evaluation/rouge.py
14 | #
15 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
16 | #
17 | # Creation Date : 2015-01-07 06:03
18 | # Author : Ramakrishna Vedantam
19 |
20 | def my_lcs(string, sub):
21 | """
22 | Calculates longest common subsequence for a pair of tokenized strings
23 | :param string : list of str : tokens from a string split using whitespace
24 | :param sub : list of str : shorter string, also split using whitespace
25 | :returns: length (list of int): length of the longest common subsequence between the two strings
26 |
27 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
28 | """
29 | if(len(string)< len(sub)):
30 | sub, string = string, sub
31 |
32 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
33 |
34 | for j in range(1,len(sub)+1):
35 | for i in range(1,len(string)+1):
36 | if(string[i-1] == sub[j-1]):
37 | lengths[i][j] = lengths[i-1][j-1] + 1
38 | else:
39 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
40 |
41 | return lengths[len(string)][len(sub)]
42 |
43 | class Rouge():
44 | '''
45 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
46 |
47 | '''
48 | def __init__(self):
49 | # vrama91: updated the value below based on discussion with Hovey
50 | self.beta = 1.2
51 |
52 | def calc_score(self, candidate, refs):
53 | """
54 | Compute ROUGE-L score given one candidate and references for an image
55 | :param candidate: str : candidate sentence to be evaluated
56 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
57 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
58 | """
59 | assert(len(candidate)==1)
60 | assert(len(refs)>0)
61 | prec = []
62 | rec = []
63 |
64 | # split into tokens
65 | token_c = candidate[0].split(" ")
66 |
67 | for reference in refs:
68 | # split into tokens
69 | token_r = reference.split(" ")
70 | # compute the longest common subsequence
71 | lcs = my_lcs(token_r, token_c)
72 | prec.append(lcs/float(len(token_c)))
73 | rec.append(lcs/float(len(token_r)))
74 |
75 | prec_max = max(prec)
76 | rec_max = max(rec)
77 |
78 | if(prec_max!=0 and rec_max !=0):
79 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
80 | else:
81 | score = 0.0
82 | return score
83 |
84 | def compute_score(self, gts, res):
85 | """
86 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
87 | Invoked by evaluate_captions.py
88 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
89 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
90 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
91 | """
92 | assert(gts.keys() == res.keys())
93 | imgIds = gts.keys()
94 |
95 | score = []
96 | for id in imgIds:
97 | hypo = res[id]
98 | ref = gts[id]
99 |
100 | score.append(self.calc_score(hypo, ref))
101 |
102 | # Sanity check.
103 | assert(type(hypo) is list)
104 | assert(len(hypo) == 1)
105 | assert(type(ref) is list)
106 | assert(len(ref) > 0)
107 |
108 | average_score = np.mean(np.array(score))
109 | return average_score, np.array(score)
110 |
111 | def method(self):
112 | return "Rouge"
113 |
114 | def main(model_path: str, test_jsonl_file: str, test_highlights_path: str):
115 | rouge = Rouge()
116 | # Load paper highlights
117 | with open(test_highlights_path) as _highlights_json_file:
118 | higlights_by_id = json.load(_highlights_json_file)
119 |
120 | with open(test_abstracts_path) as _abstracts_json_file:
121 | abstracts_by_id = json.load(_abstracts_json_file)
122 |
123 | # Load allennlp model
124 | text_field_embedder = {"token_embedders": {"bert": {"pretrained_model": "/net/nfs.corp/s2-research/scibert/scibert_scivocab_uncased.tar.gz"}}}
125 | token_indexers = {"bert": {"pretrained_model": "/net/nfs.corp/s2-research/scibert/scivocab_uncased.vocab"}}
126 | overrides = {"model": {"text_field_embedder": text_field_embedder},
127 | "dataset_reader": {"token_indexers": token_indexers},}
128 | model_archive = load_archive(model_path, overrides=json.dumps(overrides), cuda_device=0)
129 | predictor = Predictor.from_archive(model_archive, 'SeqClassificationPredictor')
130 | dataset_reader = Predictor._dataset_reader
131 |
132 | # Load papers to predict on
133 | with open(test_jsonl_file) as _test_jsonl_file:
134 | test_lines = [json.loads(line) for line in _test_jsonl_file.read().split('\n')[:-1]]
135 |
136 | print("{} test lines loaded".format(len(test_lines)))
137 |
138 | abstract_total_score = 0
139 | abstract_total_instances = 0
140 | # Using abstracts as the predictions
141 | for line in test_lines:
142 | paper_id = line["abstract_id"]
143 | abstract_sentences = abstracts_by_id[paper_id]
144 | highlights = higlights_by_id[paper_id]
145 |
146 | summary_score = 0
147 | summary_sentences = 0
148 | for sentence in abstract_sentences:
149 | score = rouge.calc_score([sentence], highlights)
150 | summary_score += score
151 | summary_sentences += 1
152 |
153 | avg_summary_score = summary_score / summary_sentences
154 | abstract_total_score += avg_summary_score
155 | abstract_total_instances += 1
156 |
157 | print("final score:", abstract_total_score / abstract_total_instances)
158 |
159 | test_jsons = []
160 | with open(test_jsonl_file) as f:
161 | for line in f:
162 | test_jsons.append(json.loads(line))
163 |
164 | print("{} test jsons loaded".format(len(test_jsons)))
165 |
166 | # Predict on said papers
167 |
168 | total_score = 0
169 | total_instances = 0
170 | for json_dict in tqdm(test_jsons, desc="Predicting..."):
171 | instances = dataset_reader.read_one_example(json_dict)
172 | if not isinstance(instances, list): # if the datareader returns one instnace, put it in a list
173 | instances = [instances]
174 |
175 | sentences = json_dict['sentences']
176 | gold_scores_list = json_dict['highlight_scores']
177 | paper_id = instances[0].fields["abstract_id"].metadata
178 | highlights = higlights_by_id[paper_id]
179 |
180 | scores_list = []
181 | for instance in instances:
182 | prediction = predictor.predict_instance(instance)
183 | probs = prediction['action_probs']
184 | scores_list.extend(probs)
185 |
186 | assert len(sentences) == len(scores_list)
187 | assert len(sentences) == len(gold_scores_list)
188 |
189 | sentences_with_scores = list(zip(sentences, scores_list))
190 |
191 | # Note: the following line should get Oracle performance
192 | # sentences_with_scores = list(zip(sentences, gold_scores_list))
193 | sentences_with_scores = sorted(sentences_with_scores, key=lambda x: x[1], reverse=True)
194 |
195 | top_10_sentences = [s[0] for s in sentences_with_scores[:10]]
196 |
197 | summary_score = 0
198 | summary_sentences = 0
199 | for sentence in top_10_sentences:
200 | score = rouge.calc_score([sentence], highlights)
201 | summary_score += score
202 | summary_sentences += 1
203 |
204 | avg_summary_score = summary_score / summary_sentences
205 | total_score += avg_summary_score
206 | total_instances += 1
207 |
208 | print("final score:", total_score / total_instances)
209 |
210 | if __name__ == "__main__":
211 | parser = argparse.ArgumentParser()
212 | parser.add_argument(
213 | "--path_to_model",
214 | help="Path to the model to evaluate"
215 | )
216 | args = parser.parse_args()
217 |
218 | test_jsonl_file = os.path.join("data", "sci_sum", "rouge_test.jsonl")
219 | test_highlights_path = os.path.join("data", "sci_sum", "test_highlights.json")
220 | test_abstracts_path = os.path.join("data", "sci_sum", "test_abstracts.json")
221 |
222 | main(args.path_to_model, test_jsonl_file, test_highlights_path)
223 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export SEED=15270
4 | export PYTORCH_SEED=`expr $SEED / 10`
5 | export NUMPY_SEED=`expr $PYTORCH_SEED / 10`
6 |
7 | # path to bert vocab and weights
8 | export BERT_VOCAB=https://ai2-s2-research.s3-us-west-2.amazonaws.com/scibert/allennlp_files/scivocab_uncased.vocab
9 | export BERT_WEIGHTS=https://ai2-s2-research.s3-us-west-2.amazonaws.com/scibert/allennlp_files/scibert_scivocab_uncased.tar.gz
10 |
11 | # path to dataset files
12 | export TRAIN_PATH=data/CSAbstruct/train.jsonl
13 | export DEV_PATH=data/CSAbstruct/dev.jsonl
14 | export TEST_PATH=data/CSAbstruct/test.jsonl
15 |
16 | # model
17 | export USE_SEP=true # true for our model. false for baseline
18 | export WITH_CRF=false # CRF only works for the baseline
19 |
20 | # training params
21 | export cuda_device=0
22 | export BATCH_SIZE=4
23 | export LR=5e-5
24 | export TRAINING_DATA_INSTANCES=1668
25 | export NUM_EPOCHS=2
26 |
27 | # limit number of sentneces per examples, and number of words per sentence. This is dataset dependant
28 | export MAX_SENT_PER_EXAMPLE=10
29 | export SENT_MAX_LEN=80
30 |
31 | # this is for the evaluation of the summarization dataset
32 | export SCI_SUM=false
33 | export USE_ABSTRACT_SCORES=false
34 | export SCI_SUM_FAKE_SCORES=false # use fake scores for testing
35 |
36 | CONFIG_FILE=sequential_sentence_classification/config.jsonnet
37 |
38 | python -m allennlp.run train $CONFIG_FILE --include-package sequential_sentence_classification -s $SERIALIZATION_DIR "$@"
39 |
--------------------------------------------------------------------------------
/sequential_sentence_classification/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/sequential_sentence_classification/cf5ad6c663550dd8203f148cd703768d9ee86ff4/sequential_sentence_classification/__init__.py
--------------------------------------------------------------------------------
/sequential_sentence_classification/config.jsonnet:
--------------------------------------------------------------------------------
1 | local stringToBool(s) =
2 | if s == "true" then true
3 | else if s == "false" then false
4 | else error "invalid boolean: " + std.manifestJson(s);
5 |
6 | local boolToInt(s) =
7 | if s == true then 1
8 | else if s == false then 0
9 | else error "invalid boolean: " + std.manifestJson(s);
10 |
11 | {
12 | "random_seed": std.parseInt(std.extVar("SEED")),
13 | "pytorch_seed": std.parseInt(std.extVar("PYTORCH_SEED")),
14 | "numpy_seed": std.parseInt(std.extVar("NUMPY_SEED")),
15 | "dataset_reader":{
16 | "type":"SeqClassificationReader",
17 | "lazy": false,
18 | "sent_max_len": std.extVar("SENT_MAX_LEN"),
19 | "word_splitter": "bert-basic",
20 | "max_sent_per_example": std.extVar("MAX_SENT_PER_EXAMPLE"),
21 | "token_indexers": {
22 | "bert": {
23 | "type": "bert-pretrained",
24 | "pretrained_model": std.extVar("BERT_VOCAB"),
25 | "do_lowercase": true,
26 | "use_starting_offsets": false
27 | },
28 | },
29 | "use_sep": std.extVar("USE_SEP"),
30 | "sci_sum": stringToBool(std.extVar("SCI_SUM")),
31 | "use_abstract_scores": stringToBool(std.extVar("USE_ABSTRACT_SCORES")),
32 | "sci_sum_fake_scores": stringToBool(std.extVar("SCI_SUM_FAKE_SCORES")),
33 | },
34 |
35 | "train_data_path": std.extVar("TRAIN_PATH"),
36 | "validation_data_path": std.extVar("DEV_PATH"),
37 | "test_data_path": std.extVar("TEST_PATH"),
38 | "evaluate_on_test": true,
39 | "model": {
40 | "type": "SeqClassificationModel",
41 | "text_field_embedder": {
42 | "allow_unmatched_keys": true,
43 | "embedder_to_indexer_map": {
44 | "bert": if stringToBool(std.extVar("USE_SEP")) then ["bert"] else ["bert", "bert-offsets"],
45 | "tokens": ["tokens"],
46 | },
47 | "token_embedders": {
48 | "bert": {
49 | "type": "bert-pretrained",
50 | "pretrained_model": std.extVar("BERT_WEIGHTS"),
51 | "requires_grad": 'all',
52 | "top_layer_only": false,
53 | }
54 | }
55 | },
56 | "use_sep": std.extVar("USE_SEP"),
57 | "with_crf": std.extVar("WITH_CRF"),
58 | "bert_dropout": 0.1,
59 | "sci_sum": stringToBool(std.extVar("SCI_SUM")),
60 | "additional_feature_size": boolToInt(stringToBool(std.extVar("USE_ABSTRACT_SCORES"))),
61 | "self_attn": {
62 | "type": "stacked_self_attention",
63 | "input_dim": 768,
64 | "projection_dim": 100,
65 | "feedforward_hidden_dim": 50,
66 | "num_layers": 2,
67 | "num_attention_heads": 2,
68 | "hidden_dim": 100,
69 | },
70 | },
71 | "iterator": {
72 | "type": "bucket",
73 | "sorting_keys": [["sentences", "num_fields"]],
74 | "batch_size" : std.parseInt(std.extVar("BATCH_SIZE")),
75 | "cache_instances": true,
76 | "biggest_batch_first": true
77 | },
78 |
79 | "trainer": {
80 | "num_epochs": std.parseInt(std.extVar("NUM_EPOCHS")),
81 | "grad_clipping": 1.0,
82 | "patience": 5,
83 | "model_save_interval": 3600,
84 | "validation_metric": if stringToBool(std.extVar("SCI_SUM")) then "-loss" else '+acc',
85 | "min_delta": 0.001,
86 | "cuda_device": std.parseInt(std.extVar("cuda_device")),
87 | "gradient_accumulation_batch_size": 32,
88 | "optimizer": {
89 | "type": "bert_adam",
90 | "lr": std.extVar("LR"),
91 | "t_total": -1,
92 | "max_grad_norm": 1.0,
93 | "weight_decay": 0.01,
94 | "parameter_groups": [
95 | [["bias", "LayerNorm.bias", "LayerNorm.weight", "layer_norm.weight"], {"weight_decay": 0.0}],
96 | ],
97 | },
98 | "should_log_learning_rate": true,
99 | "learning_rate_scheduler": {
100 | "type": "slanted_triangular",
101 | "num_epochs": std.parseInt(std.extVar("NUM_EPOCHS")),
102 | "num_steps_per_epoch": std.parseInt(std.extVar("TRAINING_DATA_INSTANCES")) / 32,
103 | "cut_frac": 0.1,
104 | },
105 | }
106 | }
--------------------------------------------------------------------------------
/sequential_sentence_classification/dataset_reader.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import json
3 | from typing import Dict, List
4 | from overrides import overrides
5 |
6 | import numpy as np
7 |
8 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader
9 | from allennlp.common.file_utils import cached_path
10 | from allennlp.data import Tokenizer
11 | from allennlp.data.instance import Instance
12 | from allennlp.data.fields.field import Field
13 | from allennlp.data.fields import TextField, LabelField, ListField, ArrayField, MultiLabelField
14 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
15 | from allennlp.data.tokenizers import WordTokenizer
16 | from allennlp.data.tokenizers.token import Token
17 | from allennlp.data.tokenizers.word_splitter import SimpleWordSplitter, WordSplitter, SpacyWordSplitter
18 |
19 |
20 | @DatasetReader.register("SeqClassificationReader")
21 | class SeqClassificationReader(DatasetReader):
22 | """
23 | Reads a file from Pubmed-RCT dataset. Each instance contains an abstract_id,
24 | a list of sentences and a list of labels (one per sentence).
25 | Input File Format: Example abstract below:
26 | {
27 | "abstract_id": 5337700,
28 | "sentences": ["this is motivation", "this is method", "this is conclusion"],
29 | "labels": ["BACKGROUND", "RESULTS", "CONCLUSIONS"]
30 | }
31 | """
32 |
33 | def __init__(self,
34 | lazy: bool = False,
35 | token_indexers: Dict[str, TokenIndexer] = None,
36 | word_splitter: WordSplitter = None,
37 | tokenizer: Tokenizer = None,
38 | sent_max_len: int = 100,
39 | max_sent_per_example: int = 20,
40 | use_sep: bool = True,
41 | sci_sum: bool = False,
42 | use_abstract_scores: bool = True,
43 | sci_sum_fake_scores: bool = True,
44 | predict: bool = False,
45 | ) -> None:
46 | super().__init__(lazy)
47 | self._tokenizer = WordTokenizer(word_splitter=SpacyWordSplitter(pos_tags=False))
48 | self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
49 | self.sent_max_len = sent_max_len
50 | self.use_sep = use_sep
51 | self.predict = predict
52 | self.sci_sum = sci_sum
53 | self.max_sent_per_example = max_sent_per_example
54 | self.use_abstract_scores = use_abstract_scores
55 | self.sci_sum_fake_scores = sci_sum_fake_scores
56 |
57 | @overrides
58 | def _read(self, file_path: str):
59 | file_path = cached_path(file_path)
60 |
61 | with open(file_path) as f:
62 | for line in f:
63 | json_dict = json.loads(line)
64 | instances = self.read_one_example(json_dict)
65 | for instance in instances:
66 | yield instance
67 |
68 | def read_one_example(self, json_dict):
69 | instances = []
70 | sentences = json_dict["sentences"]
71 |
72 | if not self.predict:
73 | labels = json_dict["labels"]
74 | else:
75 | labels = None
76 |
77 | confidences = json_dict.get("confs", None)
78 |
79 | additional_features = None
80 | if self.sci_sum:
81 | if self.sci_sum_fake_scores:
82 | labels = [np.random.rand() for _ in sentences]
83 | else:
84 | labels = [s if s > 0 else 0.000001 for s in json_dict["highlight_scores"]]
85 |
86 | if self.use_abstract_scores:
87 | features = []
88 | if self.use_abstract_scores:
89 | if self.sci_sum_fake_scores:
90 | abstract_scores = [np.random.rand() for _ in sentences]
91 | else:
92 | abstract_scores = json_dict["abstract_scores"]
93 | features.append(abstract_scores)
94 |
95 | additional_features = list(map(list, zip(*features))) # some magic transpose function
96 |
97 | sentences, labels = self.filter_bad_sci_sum_sentences(sentences, labels)
98 |
99 | if len(sentences) == 0:
100 | return []
101 |
102 | for sentences_loop, labels_loop, confidences_loop, additional_features_loop in \
103 | self.enforce_max_sent_per_example(sentences, labels, confidences, additional_features):
104 |
105 | instance = self.text_to_instance(
106 | sentences=sentences_loop,
107 | labels=labels_loop,
108 | confidences=confidences_loop,
109 | additional_features=additional_features_loop,
110 | )
111 | instances.append(instance)
112 | return instances
113 |
114 | def enforce_max_sent_per_example(self, sentences, labels=None, confidences=None, additional_features=None):
115 | """
116 | Splits examples with len(sentences) > self.max_sent_per_example into multiple smaller examples
117 | with len(sentences) <= self.max_sent_per_example.
118 | Recursively split the list of sentences into two halves until each half
119 | has len(sentences) < <= self.max_sent_per_example. The goal is to produce splits that are of almost
120 | equal size to avoid the scenario where all splits are of size
121 | self.max_sent_per_example then the last split is 1 or 2 sentences
122 | This will result into losing context around the edges of each examples.
123 | """
124 | if labels is not None:
125 | assert len(sentences) == len(labels)
126 | if confidences is not None:
127 | assert len(sentences) == len(confidences)
128 | if additional_features is not None:
129 | assert len(sentences) == len(additional_features)
130 |
131 | if len(sentences) > self.max_sent_per_example and self.max_sent_per_example > 0:
132 | i = len(sentences) // 2
133 | l1 = self.enforce_max_sent_per_example(
134 | sentences[:i], None if labels is None else labels[:i],
135 | None if confidences is None else confidences[:i],
136 | None if additional_features is None else additional_features[:i])
137 | l2 = self.enforce_max_sent_per_example(
138 | sentences[i:], None if labels is None else labels[i:],
139 | None if confidences is None else confidences[i:],
140 | None if additional_features is None else additional_features[i:])
141 | return l1 + l2
142 | else:
143 | return [(sentences, labels, confidences, additional_features)]
144 |
145 | def is_bad_sentence(self, sentence: str):
146 | if len(sentence) > 10 and len(sentence) < 600:
147 | return False
148 | else:
149 | return True
150 |
151 | def filter_bad_sci_sum_sentences(self, sentences, labels):
152 | filtered_sentences = []
153 | filtered_labels = []
154 | if not self.predict:
155 | for sentence, label in zip(sentences, labels):
156 | # most sentences outside of this range are bad sentences
157 | if not self.is_bad_sentence(sentence):
158 | filtered_sentences.append(sentence)
159 | filtered_labels.append(label)
160 | else:
161 | filtered_sentences.append("BADSENTENCE")
162 | filtered_labels.append(0.000001)
163 | sentences = filtered_sentences
164 | labels = filtered_labels
165 | else:
166 | for sentence in sentences:
167 | # most sentences outside of this range are bad sentences
168 | if not self.is_bad_sentence(sentence):
169 | filtered_sentences.append(sentence)
170 | else:
171 | filtered_sentences.append("BADSENTENCE")
172 | sentences = filtered_sentences
173 |
174 | return sentences, labels
175 |
176 | @overrides
177 | def text_to_instance(self,
178 | sentences: List[str],
179 | labels: List[str] = None,
180 | confidences: List[float] = None,
181 | additional_features: List[float] = None,
182 | ) -> Instance:
183 | if not self.predict:
184 | assert len(sentences) == len(labels)
185 | if confidences is not None:
186 | assert len(sentences) == len(confidences)
187 | if additional_features is not None:
188 | assert len(sentences) == len(additional_features)
189 |
190 | if self.use_sep:
191 | tokenized_sentences = [self._tokenizer.tokenize(s)[:self.sent_max_len] + [Token("[SEP]")] for s in sentences]
192 | sentences = [list(itertools.chain.from_iterable(tokenized_sentences))[:-1]]
193 | else:
194 | # Tokenize the sentences
195 | sentences = [
196 | self._tokenizer.tokenize(sentence_text)[:self.sent_max_len]
197 | for sentence_text in sentences
198 | ]
199 |
200 | fields: Dict[str, Field] = {}
201 | fields["sentences"] = ListField([
202 | TextField(sentence, self._token_indexers)
203 | for sentence in sentences
204 | ])
205 |
206 | if labels is not None:
207 | if isinstance(labels[0], list):
208 | fields["labels"] = ListField([
209 | MultiLabelField(label) for label in labels
210 | ])
211 | else:
212 | # make the labels strings for easier identification of the neutral label
213 | # probably not strictly necessary
214 | if self.sci_sum:
215 | fields["labels"] = ArrayField(np.array(labels))
216 | else:
217 | fields["labels"] = ListField([
218 | LabelField(str(label)+"_label") for label in labels
219 | ])
220 |
221 | if confidences is not None:
222 | fields['confidences'] = ArrayField(np.array(confidences))
223 | if additional_features is not None:
224 | fields["additional_features"] = ArrayField(np.array(additional_features))
225 |
226 | return Instance(fields)
--------------------------------------------------------------------------------
/sequential_sentence_classification/model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Dict
3 |
4 | import torch
5 | from torch.nn import Linear
6 | from allennlp.data import Vocabulary
7 | from allennlp.models.model import Model
8 | from allennlp.modules import TextFieldEmbedder, TimeDistributed, Seq2SeqEncoder
9 | from allennlp.nn.util import get_text_field_mask
10 | from allennlp.training.metrics import F1Measure, CategoricalAccuracy
11 | from allennlp.modules.conditional_random_field import ConditionalRandomField
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | @Model.register("SeqClassificationModel")
16 | class SeqClassificationModel(Model):
17 | """
18 | Question answering model where answers are sentences
19 | """
20 |
21 | def __init__(self, vocab: Vocabulary,
22 | text_field_embedder: TextFieldEmbedder,
23 | use_sep: bool = True,
24 | with_crf: bool = False,
25 | self_attn: Seq2SeqEncoder = None,
26 | bert_dropout: float = 0.1,
27 | sci_sum: bool = False,
28 | additional_feature_size: int = 0,
29 | ) -> None:
30 | super(SeqClassificationModel, self).__init__(vocab)
31 |
32 | self.text_field_embedder = text_field_embedder
33 | self.vocab = vocab
34 | self.use_sep = use_sep
35 | self.with_crf = with_crf
36 | self.sci_sum = sci_sum
37 | self.self_attn = self_attn
38 | self.additional_feature_size = additional_feature_size
39 |
40 | self.dropout = torch.nn.Dropout(p=bert_dropout)
41 |
42 | # define loss
43 | if self.sci_sum:
44 | self.loss = torch.nn.MSELoss(reduction='none') # labels are rouge scores
45 | self.labels_are_scores = True
46 | self.num_labels = 1
47 | else:
48 | self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
49 | self.labels_are_scores = False
50 | self.num_labels = self.vocab.get_vocab_size(namespace='labels')
51 | # define accuracy metrics
52 | self.label_accuracy = CategoricalAccuracy()
53 | self.label_f1_metrics = {}
54 |
55 | # define F1 metrics per label
56 | for label_index in range(self.num_labels):
57 | label_name = self.vocab.get_token_from_index(namespace='labels', index=label_index)
58 | self.label_f1_metrics[label_name] = F1Measure(label_index)
59 |
60 | encoded_senetence_dim = text_field_embedder._token_embedders['bert'].output_dim
61 |
62 | ff_in_dim = encoded_senetence_dim if self.use_sep else self_attn.get_output_dim()
63 | ff_in_dim += self.additional_feature_size
64 |
65 | self.time_distributed_aggregate_feedforward = TimeDistributed(Linear(ff_in_dim, self.num_labels))
66 |
67 | if self.with_crf:
68 | self.crf = ConditionalRandomField(
69 | self.num_labels, constraints=None,
70 | include_start_end_transitions=True
71 | )
72 |
73 | def forward(self, # type: ignore
74 | sentences: torch.LongTensor,
75 | labels: torch.IntTensor = None,
76 | confidences: torch.Tensor = None,
77 | additional_features: torch.Tensor = None,
78 | ) -> Dict[str, torch.Tensor]:
79 | # pylint: disable=arguments-differ
80 | """
81 | Parameters
82 | ----------
83 | TODO: add description
84 |
85 | Returns
86 | -------
87 | An output dictionary consisting of:
88 | loss : torch.FloatTensor, optional
89 | A scalar loss to be optimised.
90 | """
91 | # ===========================================================================================================
92 | # Layer 1: For each sentence, participant pair: create a Glove embedding for each token
93 | # Input: sentences
94 | # Output: embedded_sentences
95 |
96 | # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size
97 | embedded_sentences = self.text_field_embedder(sentences)
98 | mask = get_text_field_mask(sentences, num_wrapping_dims=1).float()
99 | batch_size, num_sentences, _, _ = embedded_sentences.size()
100 |
101 | if self.use_sep:
102 | # The following code collects vectors of the SEP tokens from all the examples in the batch,
103 | # and arrange them in one list. It does the same for the labels and confidences.
104 | # TODO: replace 103 with '[SEP]'
105 | sentences_mask = sentences['bert'] == 103 # mask for all the SEP tokens in the batch
106 | embedded_sentences = embedded_sentences[sentences_mask] # given batch_size x num_sentences_per_example x sent_len x vector_len
107 | # returns num_sentences_per_batch x vector_len
108 | assert embedded_sentences.dim() == 2
109 | num_sentences = embedded_sentences.shape[0]
110 | # for the rest of the code in this model to work, think of the data we have as one example
111 | # with so many sentences and a batch of size 1
112 | batch_size = 1
113 | embedded_sentences = embedded_sentences.unsqueeze(dim=0)
114 | embedded_sentences = self.dropout(embedded_sentences)
115 |
116 | if labels is not None:
117 | if self.labels_are_scores:
118 | labels_mask = labels != 0.0 # mask for all the labels in the batch (no padding)
119 | else:
120 | labels_mask = labels != -1 # mask for all the labels in the batch (no padding)
121 |
122 | labels = labels[labels_mask] # given batch_size x num_sentences_per_example return num_sentences_per_batch
123 | assert labels.dim() == 1
124 | if confidences is not None:
125 | confidences = confidences[labels_mask]
126 | assert confidences.dim() == 1
127 | if additional_features is not None:
128 | additional_features = additional_features[labels_mask]
129 | assert additional_features.dim() == 2
130 |
131 | num_labels = labels.shape[0]
132 | if num_labels != num_sentences: # bert truncates long sentences, so some of the SEP tokens might be gone
133 | assert num_labels > num_sentences # but `num_labels` should be at least greater than `num_sentences`
134 | logger.warning(f'Found {num_labels} labels but {num_sentences} sentences')
135 | labels = labels[:num_sentences] # Ignore some labels. This is ok for training but bad for testing.
136 | # We are ignoring this problem for now.
137 | # TODO: fix, at least for testing
138 |
139 | # do the same for `confidences`
140 | if confidences is not None:
141 | num_confidences = confidences.shape[0]
142 | if num_confidences != num_sentences:
143 | assert num_confidences > num_sentences
144 | confidences = confidences[:num_sentences]
145 |
146 | # and for `additional_features`
147 | if additional_features is not None:
148 | num_additional_features = additional_features.shape[0]
149 | if num_additional_features != num_sentences:
150 | assert num_additional_features > num_sentences
151 | additional_features = additional_features[:num_sentences]
152 |
153 | # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1
154 | labels = labels.unsqueeze(dim=0)
155 | if confidences is not None:
156 | confidences = confidences.unsqueeze(dim=0)
157 | if additional_features is not None:
158 | additional_features = additional_features.unsqueeze(dim=0)
159 | else:
160 | # ['CLS'] token
161 | embedded_sentences = embedded_sentences[:, :, 0, :]
162 | embedded_sentences = self.dropout(embedded_sentences)
163 | batch_size, num_sentences, _ = embedded_sentences.size()
164 | sent_mask = (mask.sum(dim=2) != 0)
165 | embedded_sentences = self.self_attn(embedded_sentences, sent_mask)
166 |
167 | if additional_features is not None:
168 | embedded_sentences = torch.cat((embedded_sentences, additional_features), dim=-1)
169 |
170 | label_logits = self.time_distributed_aggregate_feedforward(embedded_sentences)
171 | # label_logits: batch_size, num_sentences, num_labels
172 |
173 | if self.labels_are_scores:
174 | label_probs = label_logits
175 | else:
176 | label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
177 |
178 | # Create output dictionary for the trainer
179 | # Compute loss and epoch metrics
180 | output_dict = {"action_probs": label_probs}
181 |
182 | # =====================================================================
183 |
184 | if self.with_crf:
185 | # Layer 4 = CRF layer across labels of sentences in an abstract
186 | mask_sentences = (labels != -1)
187 | best_paths = self.crf.viterbi_tags(label_logits, mask_sentences)
188 | #
189 | # # Just get the tags and ignore the score.
190 | predicted_labels = [x for x, y in best_paths]
191 | # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}")
192 |
193 | label_loss = 0.0
194 | if labels is not None:
195 | # Compute cross entropy loss
196 | flattened_logits = label_logits.view((batch_size * num_sentences), self.num_labels)
197 | flattened_gold = labels.contiguous().view(-1)
198 |
199 | if not self.with_crf:
200 | label_loss = self.loss(flattened_logits.squeeze(), flattened_gold)
201 | if confidences is not None:
202 | label_loss = label_loss * confidences.type_as(label_loss).view(-1)
203 | label_loss = label_loss.mean()
204 | flattened_probs = torch.softmax(flattened_logits, dim=-1)
205 | else:
206 | clamped_labels = torch.clamp(labels, min=0)
207 | log_likelihood = self.crf(label_logits, clamped_labels, mask_sentences)
208 | label_loss = -log_likelihood
209 | # compute categorical accuracy
210 | crf_label_probs = label_logits * 0.
211 | for i, instance_labels in enumerate(predicted_labels):
212 | for j, label_id in enumerate(instance_labels):
213 | crf_label_probs[i, j, label_id] = 1
214 | flattened_probs = crf_label_probs.view((batch_size * num_sentences), self.num_labels)
215 |
216 | if not self.labels_are_scores:
217 | evaluation_mask = (flattened_gold != -1)
218 | self.label_accuracy(flattened_probs.float().contiguous(), flattened_gold.squeeze(-1), mask=evaluation_mask)
219 |
220 | # compute F1 per label
221 | for label_index in range(self.num_labels):
222 | label_name = self.vocab.get_token_from_index(namespace='labels', index=label_index)
223 | metric = self.label_f1_metrics[label_name]
224 | metric(flattened_probs, flattened_gold, mask=evaluation_mask)
225 |
226 | if labels is not None:
227 | output_dict["loss"] = label_loss
228 | output_dict['action_logits'] = label_logits
229 | return output_dict
230 |
231 | def get_metrics(self, reset: bool = False):
232 | metric_dict = {}
233 |
234 | if not self.labels_are_scores:
235 | type_accuracy = self.label_accuracy.get_metric(reset)
236 | metric_dict['acc'] = type_accuracy
237 |
238 | average_F1 = 0.0
239 | for name, metric in self.label_f1_metrics.items():
240 | metric_val = metric.get_metric(reset)
241 | metric_dict[name + 'F'] = metric_val[2]
242 | average_F1 += metric_val[2]
243 |
244 | average_F1 /= len(self.label_f1_metrics.items())
245 | metric_dict['avgF'] = average_F1
246 |
247 | return metric_dict
248 |
--------------------------------------------------------------------------------
/sequential_sentence_classification/predictor.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from overrides import overrides
3 |
4 | from allennlp.common.util import JsonDict, sanitize
5 | from allennlp.data import Instance
6 | from allennlp.predictors.predictor import Predictor
7 |
8 |
9 | @Predictor.register('SeqClassificationPredictor')
10 | class SeqClassificationPredictor(Predictor):
11 | """
12 | Predictor for the abstruct model
13 | """
14 | def predict_json(self, json_dict: JsonDict) -> JsonDict:
15 | pred_labels = []
16 | sentences = json_dict['sentences']
17 | paper_id = json_dict['abstract_id']
18 | for sentences_loop, _, _, _ in \
19 | self._dataset_reader.enforce_max_sent_per_example(sentences):
20 | instance = self._dataset_reader.text_to_instance(abstract_id=0, sentences=sentences_loop)
21 | output = self._model.forward_on_instance(instance)
22 | idx = output['action_probs'].argmax(axis=1).tolist()
23 | labels = [self._model.vocab.get_token_from_index(i, namespace='labels') for i in idx]
24 | pred_labels.extend(labels)
25 | assert len(pred_labels) == len(sentences)
26 | preds = list(zip(sentences, pred_labels))
27 | return paper_id, preds
28 |
--------------------------------------------------------------------------------