├── .gitignore ├── .pylintrc ├── LICENSE ├── README.md ├── data └── CSAbstruct │ ├── dev.jsonl │ ├── test.jsonl │ └── train.jsonl ├── requirements.txt ├── scripts ├── rouge_eval.py └── train.sh └── sequential_sentence_classification ├── __init__.py ├── config.jsonnet ├── dataset_reader.py ├── model.py └── predictor.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | tmp 3 | .vscode 4 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # Specify a configuration file. 4 | #rcfile= 5 | 6 | # Python code to execute, usually for sys.path manipulation such as 7 | # pygtk.require(). 8 | init-hook='import sys; sys.path.append("./")' 9 | 10 | # Add files or directories to the blacklist. They should be base names, not 11 | # paths. 12 | ignore=CVS,custom_extensions 13 | 14 | # Add files or directories matching the regex patterns to the blacklist. The 15 | # regex matches against base names, not paths. 16 | ignore-patterns= 17 | 18 | # Pickle collected data for later comparisons. 19 | persistent=yes 20 | 21 | # List of plugins (as comma separated values of python modules names) to load, 22 | # usually to register additional checkers. 23 | load-plugins= 24 | 25 | # Use multiple processes to speed up Pylint. 26 | jobs=4 27 | 28 | # Allow loading of arbitrary C extensions. Extensions are imported into the 29 | # active Python interpreter and may run arbitrary code. 30 | unsafe-load-any-extension=no 31 | 32 | # A comma-separated list of package or module names from where C extensions may 33 | # be loaded. Extensions are loading into the active Python interpreter and may 34 | # run arbitrary code 35 | extension-pkg-whitelist=numpy,torch 36 | 37 | # Allow optimization of some AST trees. This will activate a peephole AST 38 | # optimizer, which will apply various small optimizations. For instance, it can 39 | # be used to obtain the result of joining multiple strings with the addition 40 | # operator. Joining a lot of strings can lead to a maximum recursion error in 41 | # Pylint and this flag can prevent that. It has one side effect, the resulting 42 | # AST will be different than the one from reality. This option is deprecated 43 | # and it will be removed in Pylint 2.0. 44 | optimize-ast=no 45 | 46 | 47 | [MESSAGES CONTROL] 48 | 49 | # Only show warnings with the listed confidence levels. Leave empty to show 50 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 51 | confidence= 52 | 53 | # Enable the message, report, category or checker with the given id(s). You can 54 | # either give multiple identifier separated by comma (,) or put this option 55 | # multiple time (only on the command line, not in the configuration file where 56 | # it should appear only once). See also the "--disable" option for examples. 57 | #enable= 58 | 59 | # Disable the message, report, category or checker with the given id(s). You 60 | # can either give multiple identifiers separated by comma (,) or put this 61 | # option multiple times (only on the command line, not in the configuration 62 | # file where it should appear only once).You can also use "--disable=all" to 63 | # disable everything first and then reenable specific checks. For example, if 64 | # you want to run only the similarities checker, you can use "--disable=all 65 | # --enable=similarities". If you want to run only the classes checker, but have 66 | # no Warning level messages displayed, use"--disable=all --enable=classes 67 | # --disable=W" 68 | disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating,missing-docstring,too-many-arguments,too-many-locals,too-many-statements,too-many-branches,too-many-nested-blocks,too-many-instance-attributes,fixme,too-few-public-methods,no-else-return,invalid-name,len-as-condition 69 | 70 | 71 | [REPORTS] 72 | 73 | # Set the output format. Available formats are text, parseable, colorized, msvs 74 | # (visual studio) and html. You can also give a reporter class, eg 75 | # mypackage.mymodule.MyReporterClass. 76 | output-format=text 77 | 78 | # Put messages in a separate file for each module / package specified on the 79 | # command line instead of printing them on stdout. Reports (if any) will be 80 | # written in a file name "pylint_global.[txt|html]". This option is deprecated 81 | # and it will be removed in Pylint 2.0. 82 | files-output=no 83 | 84 | # Tells whether to display a full report or only the messages 85 | reports=no 86 | 87 | # Python expression which should return a note less than 10 (10 is the highest 88 | # note). You have access to the variables errors warning, statement which 89 | # respectively contain the number of errors / warnings messages and the total 90 | # number of statements analyzed. This is used by the global evaluation report 91 | # (RP0004). 92 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 93 | 94 | # Template used to display messages. This is a python new-style format string 95 | # used to format the message information. See doc for all details 96 | #msg-template= 97 | 98 | 99 | [LOGGING] 100 | 101 | # Logging modules to check that the string format arguments are in logging 102 | # function parameter format 103 | logging-modules=logging 104 | 105 | 106 | [TYPECHECK] 107 | 108 | # Tells whether missing members accessed in mixin class should be ignored. A 109 | # mixin class is detected if its name ends with "mixin" (case insensitive). 110 | ignore-mixin-members=yes 111 | 112 | # List of module names for which member attributes should not be checked 113 | # (useful for modules/projects where namespaces are manipulated during runtime 114 | # and thus existing member attributes cannot be deduced by static analysis. It 115 | # supports qualified module names, as well as Unix pattern matching. 116 | ignored-modules= 117 | 118 | # List of class names for which member attributes should not be checked (useful 119 | # for classes with dynamically set attributes). This supports the use of 120 | # qualified names. 121 | ignored-classes=optparse.Values,thread._local,_thread._local,responses 122 | 123 | # List of members which are set dynamically and missed by pylint inference 124 | # system, and so shouldn't trigger E1101 when accessed. Python regular 125 | # expressions are accepted. 126 | generated-members=torch.* 127 | 128 | # List of decorators that produce context managers, such as 129 | # contextlib.contextmanager. Add to this list to register other decorators that 130 | # produce valid context managers. 131 | contextmanager-decorators=contextlib.contextmanager 132 | 133 | 134 | [SIMILARITIES] 135 | 136 | # Minimum lines number of a similarity. 137 | min-similarity-lines=4 138 | 139 | # Ignore comments when computing similarities. 140 | ignore-comments=yes 141 | 142 | # Ignore docstrings when computing similarities. 143 | ignore-docstrings=yes 144 | 145 | # Ignore imports when computing similarities. 146 | ignore-imports=no 147 | 148 | 149 | [FORMAT] 150 | 151 | # Maximum number of characters on a single line. Ideally, lines should be under 100 characters, 152 | # but we allow some leeway before calling it an error. 153 | max-line-length=115 154 | 155 | # Regexp for a line that is allowed to be longer than the limit. 156 | ignore-long-lines=^\s*(# )??$ 157 | 158 | # Allow the body of an if to be on the same line as the test if there is no 159 | # else. 160 | single-line-if-stmt=no 161 | 162 | # List of optional constructs for which whitespace checking is disabled. `dict- 163 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 164 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 165 | # `empty-line` allows space-only lines. 166 | no-space-check=trailing-comma,dict-separator 167 | 168 | # Maximum number of lines in a module 169 | max-module-lines=1000 170 | 171 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 172 | # tab). 173 | indent-string=' ' 174 | 175 | # Number of spaces of indent required inside a hanging or continued line. 176 | indent-after-paren=8 177 | 178 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 179 | expected-line-ending-format= 180 | 181 | 182 | [BASIC] 183 | 184 | # Good variable names which should always be accepted, separated by a comma 185 | good-names=i,j,k,ex,Run,_ 186 | 187 | # Bad variable names which should always be refused, separated by a comma 188 | bad-names=foo,bar,baz,toto,tutu,tata 189 | 190 | # Colon-delimited sets of names that determine each other's naming style when 191 | # the name regexes allow several styles. 192 | name-group= 193 | 194 | # Include a hint for the correct naming format with invalid-name 195 | include-naming-hint=no 196 | 197 | # List of decorators that produce properties, such as abc.abstractproperty. Add 198 | # to this list to register other decorators that produce valid properties. 199 | property-classes=abc.abstractproperty 200 | 201 | # Regular expression matching correct function names 202 | function-rgx=[a-z_][a-z0-9_]{2,40}$ 203 | 204 | # Naming hint for function names 205 | function-name-hint=[a-z_][a-z0-9_]{2,40}$ 206 | 207 | # Regular expression matching correct variable names 208 | variable-rgx=[a-z_][a-z0-9_]{2,40}$ 209 | 210 | # Naming hint for variable names 211 | variable-name-hint=[a-z_][a-z0-9_]{2,40}$ 212 | 213 | # Regular expression matching correct constant names 214 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 215 | 216 | # Naming hint for constant names 217 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 218 | 219 | # Regular expression matching correct attribute names 220 | attr-rgx=[a-z_][a-z0-9_]{2,40}$ 221 | 222 | # Naming hint for attribute names 223 | attr-name-hint=[a-z_][a-z0-9_]{2,40}$ 224 | 225 | # Regular expression matching correct argument names 226 | argument-rgx=[a-z_][a-z0-9_]{2,40}$ 227 | 228 | # Naming hint for argument names 229 | argument-name-hint=[a-z_][a-z0-9_]{2,40}$ 230 | 231 | # Regular expression matching correct class attribute names 232 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,40}|(__.*__))$ 233 | 234 | # Naming hint for class attribute names 235 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,40}|(__.*__))$ 236 | 237 | # Regular expression matching correct inline iteration names 238 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ 239 | 240 | # Naming hint for inline iteration names 241 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ 242 | 243 | # Regular expression matching correct class names 244 | class-rgx=[A-Z_][a-zA-Z0-9]+$ 245 | 246 | # Naming hint for class names 247 | class-name-hint=[A-Z_][a-zA-Z0-9]+$ 248 | 249 | # Regular expression matching correct module names 250 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 251 | 252 | # Naming hint for module names 253 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 254 | 255 | # Regular expression matching correct method names 256 | method-rgx=[a-z_][a-z0-9_]{2,40}$ 257 | 258 | # Naming hint for method names 259 | method-name-hint=[a-z_][a-z0-9_]{2,40}$ 260 | 261 | # Regular expression which should only match function or class names that do 262 | # not require a docstring. 263 | no-docstring-rgx=^_ 264 | 265 | # Minimum line length for functions/classes that require docstrings, shorter 266 | # ones are exempt. 267 | docstring-min-length=-1 268 | 269 | 270 | [ELIF] 271 | 272 | # Maximum number of nested blocks for function / method body 273 | max-nested-blocks=5 274 | 275 | 276 | [VARIABLES] 277 | 278 | # Tells whether we should check for unused import in __init__ files. 279 | init-import=no 280 | 281 | # A regular expression matching the name of dummy variables (i.e. expectedly 282 | # not used). 283 | dummy-variables-rgx=(_+[a-zA-Z0-9]*?$)|dummy 284 | 285 | # List of additional names supposed to be defined in builtins. Remember that 286 | # you should avoid to define new builtins when possible. 287 | additional-builtins= 288 | 289 | # List of strings which can identify a callback function by name. A callback 290 | # name must start or end with one of those strings. 291 | callbacks=cb_,_cb 292 | 293 | # List of qualified module names which can have objects that can redefine 294 | # builtins. 295 | redefining-builtins-modules=six.moves,future.builtins 296 | 297 | 298 | [SPELLING] 299 | 300 | # Spelling dictionary name. Available dictionaries: none. To make it working 301 | # install python-enchant package. 302 | spelling-dict= 303 | 304 | # List of comma separated words that should not be checked. 305 | spelling-ignore-words= 306 | 307 | # A path to a file that contains private dictionary; one word per line. 308 | spelling-private-dict-file= 309 | 310 | # Tells whether to store unknown words to indicated private dictionary in 311 | # --spelling-private-dict-file option instead of raising a message. 312 | spelling-store-unknown-words=no 313 | 314 | 315 | [MISCELLANEOUS] 316 | 317 | # List of note tags to take in consideration, separated by a comma. 318 | notes=FIXME,XXX,TODO 319 | 320 | 321 | [DESIGN] 322 | 323 | # Maximum number of arguments for function / method 324 | max-args=5 325 | 326 | # Argument names that match this expression will be ignored. Default to name 327 | # with leading underscore 328 | ignored-argument-names=_.* 329 | 330 | # Maximum number of locals for function / method body 331 | max-locals=15 332 | 333 | # Maximum number of return / yield for function / method body 334 | max-returns=6 335 | 336 | # Maximum number of branch for function / method body 337 | max-branches=12 338 | 339 | # Maximum number of statements in function / method body 340 | max-statements=50 341 | 342 | # Maximum number of parents for a class (see R0901). 343 | max-parents=7 344 | 345 | # Maximum number of attributes for a class (see R0902). 346 | max-attributes=7 347 | 348 | # Minimum number of public methods for a class (see R0903). 349 | min-public-methods=2 350 | 351 | # Maximum number of public methods for a class (see R0904). 352 | max-public-methods=20 353 | 354 | # Maximum number of boolean expressions in a if statement 355 | max-bool-expr=5 356 | 357 | 358 | [CLASSES] 359 | 360 | # List of method names used to declare (i.e. assign) instance attributes. 361 | defining-attr-methods=__init__,__new__,setUp 362 | 363 | # List of valid names for the first argument in a class method. 364 | valid-classmethod-first-arg=cls 365 | 366 | # List of valid names for the first argument in a metaclass class method. 367 | valid-metaclass-classmethod-first-arg=mcs 368 | 369 | # List of member names, which should be excluded from the protected access 370 | # warning. 371 | exclude-protected=_asdict,_fields,_replace,_source,_make 372 | 373 | 374 | [IMPORTS] 375 | 376 | # Deprecated modules which should not be used, separated by a comma 377 | deprecated-modules=regsub,TERMIOS,Bastion,rexec 378 | 379 | # Create a graph of every (i.e. internal and external) dependencies in the 380 | # given file (report RP0402 must not be disabled) 381 | import-graph= 382 | 383 | # Create a graph of external dependencies in the given file (report RP0402 must 384 | # not be disabled) 385 | ext-import-graph= 386 | 387 | # Create a graph of internal dependencies in the given file (report RP0402 must 388 | # not be disabled) 389 | int-import-graph= 390 | 391 | # Force import order to recognize a module as part of the standard 392 | # compatibility libraries. 393 | known-standard-library= 394 | 395 | # Force import order to recognize a module as part of a third party library. 396 | known-third-party=enchant 397 | 398 | # Analyse import fallback blocks. This can be used to support both Python 2 and 399 | # 3 compatible code, which means that the block might have code that exists 400 | # only in one or another interpreter, leading to false positives when analysed. 401 | analyse-fallback-blocks=no 402 | 403 | 404 | [EXCEPTIONS] 405 | 406 | # Exceptions that will emit a warning when being caught. Defaults to 407 | # "Exception" 408 | overgeneral-exceptions=Exception 409 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #

Sequential Sentence Classification

2 | This repo has code and data for our paper ["Pretrained Language Models for Sequential Sentence Classification"](https://arxiv.org/pdf/1909.04054.pdf). 3 | 4 | 5 | 6 | ### How to run 7 | 8 | ``` 9 | pip install -r requirements.txt 10 | scripts/train.sh tmp_output_dir 11 | ``` 12 | 13 | Update the `scripts/train.sh` script with the appropriate hyperparameters and datapaths. 14 | 15 | ### CSAbstrcut dataset 16 | 17 | The train, dev, test splits of the dataset are in `data/CSAbstrcut`. 18 | 19 | CSAbstrcut is also available on the [Huggingface Hub](https://huggingface.co/datasets/allenai/csabstruct). 20 | 21 | ### Citing 22 | 23 | If you use the data or the model, please cite, 24 | ``` 25 | @inproceedings{Cohan2019EMNLP, 26 | title={Pretrained Language Models for Sequential Sentence Classification}, 27 | author={Arman Cohan, Iz Beltagy, Daniel King, Bhavana Dalvi, Dan Weld}, 28 | year={2019}, 29 | booktitle={EMNLP}, 30 | } 31 | ``` 32 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsonlines 2 | -e git://github.com/ibeltagy/allennlp@fp16_and_others#egg=allennlp -------------------------------------------------------------------------------- /scripts/rouge_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | from allennlp.models.archival import load_archive 8 | from allennlp.service.predictors import Predictor 9 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 10 | 11 | # Rouge computation is taken from https://github.com/EdCo95/scientific-paper-summarisation/blob/master/Evaluation/rouge.py 12 | # 13 | # File Name : https://github.com/EdCo95/scientific-paper-summarisation/blob/master/Evaluation/rouge.py 14 | # 15 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 16 | # 17 | # Creation Date : 2015-01-07 06:03 18 | # Author : Ramakrishna Vedantam 19 | 20 | def my_lcs(string, sub): 21 | """ 22 | Calculates longest common subsequence for a pair of tokenized strings 23 | :param string : list of str : tokens from a string split using whitespace 24 | :param sub : list of str : shorter string, also split using whitespace 25 | :returns: length (list of int): length of the longest common subsequence between the two strings 26 | 27 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 28 | """ 29 | if(len(string)< len(sub)): 30 | sub, string = string, sub 31 | 32 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 33 | 34 | for j in range(1,len(sub)+1): 35 | for i in range(1,len(string)+1): 36 | if(string[i-1] == sub[j-1]): 37 | lengths[i][j] = lengths[i-1][j-1] + 1 38 | else: 39 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 40 | 41 | return lengths[len(string)][len(sub)] 42 | 43 | class Rouge(): 44 | ''' 45 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 46 | 47 | ''' 48 | def __init__(self): 49 | # vrama91: updated the value below based on discussion with Hovey 50 | self.beta = 1.2 51 | 52 | def calc_score(self, candidate, refs): 53 | """ 54 | Compute ROUGE-L score given one candidate and references for an image 55 | :param candidate: str : candidate sentence to be evaluated 56 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 57 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 58 | """ 59 | assert(len(candidate)==1) 60 | assert(len(refs)>0) 61 | prec = [] 62 | rec = [] 63 | 64 | # split into tokens 65 | token_c = candidate[0].split(" ") 66 | 67 | for reference in refs: 68 | # split into tokens 69 | token_r = reference.split(" ") 70 | # compute the longest common subsequence 71 | lcs = my_lcs(token_r, token_c) 72 | prec.append(lcs/float(len(token_c))) 73 | rec.append(lcs/float(len(token_r))) 74 | 75 | prec_max = max(prec) 76 | rec_max = max(rec) 77 | 78 | if(prec_max!=0 and rec_max !=0): 79 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 80 | else: 81 | score = 0.0 82 | return score 83 | 84 | def compute_score(self, gts, res): 85 | """ 86 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 87 | Invoked by evaluate_captions.py 88 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 89 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 90 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 91 | """ 92 | assert(gts.keys() == res.keys()) 93 | imgIds = gts.keys() 94 | 95 | score = [] 96 | for id in imgIds: 97 | hypo = res[id] 98 | ref = gts[id] 99 | 100 | score.append(self.calc_score(hypo, ref)) 101 | 102 | # Sanity check. 103 | assert(type(hypo) is list) 104 | assert(len(hypo) == 1) 105 | assert(type(ref) is list) 106 | assert(len(ref) > 0) 107 | 108 | average_score = np.mean(np.array(score)) 109 | return average_score, np.array(score) 110 | 111 | def method(self): 112 | return "Rouge" 113 | 114 | def main(model_path: str, test_jsonl_file: str, test_highlights_path: str): 115 | rouge = Rouge() 116 | # Load paper highlights 117 | with open(test_highlights_path) as _highlights_json_file: 118 | higlights_by_id = json.load(_highlights_json_file) 119 | 120 | with open(test_abstracts_path) as _abstracts_json_file: 121 | abstracts_by_id = json.load(_abstracts_json_file) 122 | 123 | # Load allennlp model 124 | text_field_embedder = {"token_embedders": {"bert": {"pretrained_model": "/net/nfs.corp/s2-research/scibert/scibert_scivocab_uncased.tar.gz"}}} 125 | token_indexers = {"bert": {"pretrained_model": "/net/nfs.corp/s2-research/scibert/scivocab_uncased.vocab"}} 126 | overrides = {"model": {"text_field_embedder": text_field_embedder}, 127 | "dataset_reader": {"token_indexers": token_indexers},} 128 | model_archive = load_archive(model_path, overrides=json.dumps(overrides), cuda_device=0) 129 | predictor = Predictor.from_archive(model_archive, 'SeqClassificationPredictor') 130 | dataset_reader = Predictor._dataset_reader 131 | 132 | # Load papers to predict on 133 | with open(test_jsonl_file) as _test_jsonl_file: 134 | test_lines = [json.loads(line) for line in _test_jsonl_file.read().split('\n')[:-1]] 135 | 136 | print("{} test lines loaded".format(len(test_lines))) 137 | 138 | abstract_total_score = 0 139 | abstract_total_instances = 0 140 | # Using abstracts as the predictions 141 | for line in test_lines: 142 | paper_id = line["abstract_id"] 143 | abstract_sentences = abstracts_by_id[paper_id] 144 | highlights = higlights_by_id[paper_id] 145 | 146 | summary_score = 0 147 | summary_sentences = 0 148 | for sentence in abstract_sentences: 149 | score = rouge.calc_score([sentence], highlights) 150 | summary_score += score 151 | summary_sentences += 1 152 | 153 | avg_summary_score = summary_score / summary_sentences 154 | abstract_total_score += avg_summary_score 155 | abstract_total_instances += 1 156 | 157 | print("final score:", abstract_total_score / abstract_total_instances) 158 | 159 | test_jsons = [] 160 | with open(test_jsonl_file) as f: 161 | for line in f: 162 | test_jsons.append(json.loads(line)) 163 | 164 | print("{} test jsons loaded".format(len(test_jsons))) 165 | 166 | # Predict on said papers 167 | 168 | total_score = 0 169 | total_instances = 0 170 | for json_dict in tqdm(test_jsons, desc="Predicting..."): 171 | instances = dataset_reader.read_one_example(json_dict) 172 | if not isinstance(instances, list): # if the datareader returns one instnace, put it in a list 173 | instances = [instances] 174 | 175 | sentences = json_dict['sentences'] 176 | gold_scores_list = json_dict['highlight_scores'] 177 | paper_id = instances[0].fields["abstract_id"].metadata 178 | highlights = higlights_by_id[paper_id] 179 | 180 | scores_list = [] 181 | for instance in instances: 182 | prediction = predictor.predict_instance(instance) 183 | probs = prediction['action_probs'] 184 | scores_list.extend(probs) 185 | 186 | assert len(sentences) == len(scores_list) 187 | assert len(sentences) == len(gold_scores_list) 188 | 189 | sentences_with_scores = list(zip(sentences, scores_list)) 190 | 191 | # Note: the following line should get Oracle performance 192 | # sentences_with_scores = list(zip(sentences, gold_scores_list)) 193 | sentences_with_scores = sorted(sentences_with_scores, key=lambda x: x[1], reverse=True) 194 | 195 | top_10_sentences = [s[0] for s in sentences_with_scores[:10]] 196 | 197 | summary_score = 0 198 | summary_sentences = 0 199 | for sentence in top_10_sentences: 200 | score = rouge.calc_score([sentence], highlights) 201 | summary_score += score 202 | summary_sentences += 1 203 | 204 | avg_summary_score = summary_score / summary_sentences 205 | total_score += avg_summary_score 206 | total_instances += 1 207 | 208 | print("final score:", total_score / total_instances) 209 | 210 | if __name__ == "__main__": 211 | parser = argparse.ArgumentParser() 212 | parser.add_argument( 213 | "--path_to_model", 214 | help="Path to the model to evaluate" 215 | ) 216 | args = parser.parse_args() 217 | 218 | test_jsonl_file = os.path.join("data", "sci_sum", "rouge_test.jsonl") 219 | test_highlights_path = os.path.join("data", "sci_sum", "test_highlights.json") 220 | test_abstracts_path = os.path.join("data", "sci_sum", "test_abstracts.json") 221 | 222 | main(args.path_to_model, test_jsonl_file, test_highlights_path) 223 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export SEED=15270 4 | export PYTORCH_SEED=`expr $SEED / 10` 5 | export NUMPY_SEED=`expr $PYTORCH_SEED / 10` 6 | 7 | # path to bert vocab and weights 8 | export BERT_VOCAB=https://ai2-s2-research.s3-us-west-2.amazonaws.com/scibert/allennlp_files/scivocab_uncased.vocab 9 | export BERT_WEIGHTS=https://ai2-s2-research.s3-us-west-2.amazonaws.com/scibert/allennlp_files/scibert_scivocab_uncased.tar.gz 10 | 11 | # path to dataset files 12 | export TRAIN_PATH=data/CSAbstruct/train.jsonl 13 | export DEV_PATH=data/CSAbstruct/dev.jsonl 14 | export TEST_PATH=data/CSAbstruct/test.jsonl 15 | 16 | # model 17 | export USE_SEP=true # true for our model. false for baseline 18 | export WITH_CRF=false # CRF only works for the baseline 19 | 20 | # training params 21 | export cuda_device=0 22 | export BATCH_SIZE=4 23 | export LR=5e-5 24 | export TRAINING_DATA_INSTANCES=1668 25 | export NUM_EPOCHS=2 26 | 27 | # limit number of sentneces per examples, and number of words per sentence. This is dataset dependant 28 | export MAX_SENT_PER_EXAMPLE=10 29 | export SENT_MAX_LEN=80 30 | 31 | # this is for the evaluation of the summarization dataset 32 | export SCI_SUM=false 33 | export USE_ABSTRACT_SCORES=false 34 | export SCI_SUM_FAKE_SCORES=false # use fake scores for testing 35 | 36 | CONFIG_FILE=sequential_sentence_classification/config.jsonnet 37 | 38 | python -m allennlp.run train $CONFIG_FILE --include-package sequential_sentence_classification -s $SERIALIZATION_DIR "$@" 39 | -------------------------------------------------------------------------------- /sequential_sentence_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/sequential_sentence_classification/cf5ad6c663550dd8203f148cd703768d9ee86ff4/sequential_sentence_classification/__init__.py -------------------------------------------------------------------------------- /sequential_sentence_classification/config.jsonnet: -------------------------------------------------------------------------------- 1 | local stringToBool(s) = 2 | if s == "true" then true 3 | else if s == "false" then false 4 | else error "invalid boolean: " + std.manifestJson(s); 5 | 6 | local boolToInt(s) = 7 | if s == true then 1 8 | else if s == false then 0 9 | else error "invalid boolean: " + std.manifestJson(s); 10 | 11 | { 12 | "random_seed": std.parseInt(std.extVar("SEED")), 13 | "pytorch_seed": std.parseInt(std.extVar("PYTORCH_SEED")), 14 | "numpy_seed": std.parseInt(std.extVar("NUMPY_SEED")), 15 | "dataset_reader":{ 16 | "type":"SeqClassificationReader", 17 | "lazy": false, 18 | "sent_max_len": std.extVar("SENT_MAX_LEN"), 19 | "word_splitter": "bert-basic", 20 | "max_sent_per_example": std.extVar("MAX_SENT_PER_EXAMPLE"), 21 | "token_indexers": { 22 | "bert": { 23 | "type": "bert-pretrained", 24 | "pretrained_model": std.extVar("BERT_VOCAB"), 25 | "do_lowercase": true, 26 | "use_starting_offsets": false 27 | }, 28 | }, 29 | "use_sep": std.extVar("USE_SEP"), 30 | "sci_sum": stringToBool(std.extVar("SCI_SUM")), 31 | "use_abstract_scores": stringToBool(std.extVar("USE_ABSTRACT_SCORES")), 32 | "sci_sum_fake_scores": stringToBool(std.extVar("SCI_SUM_FAKE_SCORES")), 33 | }, 34 | 35 | "train_data_path": std.extVar("TRAIN_PATH"), 36 | "validation_data_path": std.extVar("DEV_PATH"), 37 | "test_data_path": std.extVar("TEST_PATH"), 38 | "evaluate_on_test": true, 39 | "model": { 40 | "type": "SeqClassificationModel", 41 | "text_field_embedder": { 42 | "allow_unmatched_keys": true, 43 | "embedder_to_indexer_map": { 44 | "bert": if stringToBool(std.extVar("USE_SEP")) then ["bert"] else ["bert", "bert-offsets"], 45 | "tokens": ["tokens"], 46 | }, 47 | "token_embedders": { 48 | "bert": { 49 | "type": "bert-pretrained", 50 | "pretrained_model": std.extVar("BERT_WEIGHTS"), 51 | "requires_grad": 'all', 52 | "top_layer_only": false, 53 | } 54 | } 55 | }, 56 | "use_sep": std.extVar("USE_SEP"), 57 | "with_crf": std.extVar("WITH_CRF"), 58 | "bert_dropout": 0.1, 59 | "sci_sum": stringToBool(std.extVar("SCI_SUM")), 60 | "additional_feature_size": boolToInt(stringToBool(std.extVar("USE_ABSTRACT_SCORES"))), 61 | "self_attn": { 62 | "type": "stacked_self_attention", 63 | "input_dim": 768, 64 | "projection_dim": 100, 65 | "feedforward_hidden_dim": 50, 66 | "num_layers": 2, 67 | "num_attention_heads": 2, 68 | "hidden_dim": 100, 69 | }, 70 | }, 71 | "iterator": { 72 | "type": "bucket", 73 | "sorting_keys": [["sentences", "num_fields"]], 74 | "batch_size" : std.parseInt(std.extVar("BATCH_SIZE")), 75 | "cache_instances": true, 76 | "biggest_batch_first": true 77 | }, 78 | 79 | "trainer": { 80 | "num_epochs": std.parseInt(std.extVar("NUM_EPOCHS")), 81 | "grad_clipping": 1.0, 82 | "patience": 5, 83 | "model_save_interval": 3600, 84 | "validation_metric": if stringToBool(std.extVar("SCI_SUM")) then "-loss" else '+acc', 85 | "min_delta": 0.001, 86 | "cuda_device": std.parseInt(std.extVar("cuda_device")), 87 | "gradient_accumulation_batch_size": 32, 88 | "optimizer": { 89 | "type": "bert_adam", 90 | "lr": std.extVar("LR"), 91 | "t_total": -1, 92 | "max_grad_norm": 1.0, 93 | "weight_decay": 0.01, 94 | "parameter_groups": [ 95 | [["bias", "LayerNorm.bias", "LayerNorm.weight", "layer_norm.weight"], {"weight_decay": 0.0}], 96 | ], 97 | }, 98 | "should_log_learning_rate": true, 99 | "learning_rate_scheduler": { 100 | "type": "slanted_triangular", 101 | "num_epochs": std.parseInt(std.extVar("NUM_EPOCHS")), 102 | "num_steps_per_epoch": std.parseInt(std.extVar("TRAINING_DATA_INSTANCES")) / 32, 103 | "cut_frac": 0.1, 104 | }, 105 | } 106 | } -------------------------------------------------------------------------------- /sequential_sentence_classification/dataset_reader.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | from typing import Dict, List 4 | from overrides import overrides 5 | 6 | import numpy as np 7 | 8 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 9 | from allennlp.common.file_utils import cached_path 10 | from allennlp.data import Tokenizer 11 | from allennlp.data.instance import Instance 12 | from allennlp.data.fields.field import Field 13 | from allennlp.data.fields import TextField, LabelField, ListField, ArrayField, MultiLabelField 14 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 15 | from allennlp.data.tokenizers import WordTokenizer 16 | from allennlp.data.tokenizers.token import Token 17 | from allennlp.data.tokenizers.word_splitter import SimpleWordSplitter, WordSplitter, SpacyWordSplitter 18 | 19 | 20 | @DatasetReader.register("SeqClassificationReader") 21 | class SeqClassificationReader(DatasetReader): 22 | """ 23 | Reads a file from Pubmed-RCT dataset. Each instance contains an abstract_id, 24 | a list of sentences and a list of labels (one per sentence). 25 | Input File Format: Example abstract below: 26 | { 27 | "abstract_id": 5337700, 28 | "sentences": ["this is motivation", "this is method", "this is conclusion"], 29 | "labels": ["BACKGROUND", "RESULTS", "CONCLUSIONS"] 30 | } 31 | """ 32 | 33 | def __init__(self, 34 | lazy: bool = False, 35 | token_indexers: Dict[str, TokenIndexer] = None, 36 | word_splitter: WordSplitter = None, 37 | tokenizer: Tokenizer = None, 38 | sent_max_len: int = 100, 39 | max_sent_per_example: int = 20, 40 | use_sep: bool = True, 41 | sci_sum: bool = False, 42 | use_abstract_scores: bool = True, 43 | sci_sum_fake_scores: bool = True, 44 | predict: bool = False, 45 | ) -> None: 46 | super().__init__(lazy) 47 | self._tokenizer = WordTokenizer(word_splitter=SpacyWordSplitter(pos_tags=False)) 48 | self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} 49 | self.sent_max_len = sent_max_len 50 | self.use_sep = use_sep 51 | self.predict = predict 52 | self.sci_sum = sci_sum 53 | self.max_sent_per_example = max_sent_per_example 54 | self.use_abstract_scores = use_abstract_scores 55 | self.sci_sum_fake_scores = sci_sum_fake_scores 56 | 57 | @overrides 58 | def _read(self, file_path: str): 59 | file_path = cached_path(file_path) 60 | 61 | with open(file_path) as f: 62 | for line in f: 63 | json_dict = json.loads(line) 64 | instances = self.read_one_example(json_dict) 65 | for instance in instances: 66 | yield instance 67 | 68 | def read_one_example(self, json_dict): 69 | instances = [] 70 | sentences = json_dict["sentences"] 71 | 72 | if not self.predict: 73 | labels = json_dict["labels"] 74 | else: 75 | labels = None 76 | 77 | confidences = json_dict.get("confs", None) 78 | 79 | additional_features = None 80 | if self.sci_sum: 81 | if self.sci_sum_fake_scores: 82 | labels = [np.random.rand() for _ in sentences] 83 | else: 84 | labels = [s if s > 0 else 0.000001 for s in json_dict["highlight_scores"]] 85 | 86 | if self.use_abstract_scores: 87 | features = [] 88 | if self.use_abstract_scores: 89 | if self.sci_sum_fake_scores: 90 | abstract_scores = [np.random.rand() for _ in sentences] 91 | else: 92 | abstract_scores = json_dict["abstract_scores"] 93 | features.append(abstract_scores) 94 | 95 | additional_features = list(map(list, zip(*features))) # some magic transpose function 96 | 97 | sentences, labels = self.filter_bad_sci_sum_sentences(sentences, labels) 98 | 99 | if len(sentences) == 0: 100 | return [] 101 | 102 | for sentences_loop, labels_loop, confidences_loop, additional_features_loop in \ 103 | self.enforce_max_sent_per_example(sentences, labels, confidences, additional_features): 104 | 105 | instance = self.text_to_instance( 106 | sentences=sentences_loop, 107 | labels=labels_loop, 108 | confidences=confidences_loop, 109 | additional_features=additional_features_loop, 110 | ) 111 | instances.append(instance) 112 | return instances 113 | 114 | def enforce_max_sent_per_example(self, sentences, labels=None, confidences=None, additional_features=None): 115 | """ 116 | Splits examples with len(sentences) > self.max_sent_per_example into multiple smaller examples 117 | with len(sentences) <= self.max_sent_per_example. 118 | Recursively split the list of sentences into two halves until each half 119 | has len(sentences) < <= self.max_sent_per_example. The goal is to produce splits that are of almost 120 | equal size to avoid the scenario where all splits are of size 121 | self.max_sent_per_example then the last split is 1 or 2 sentences 122 | This will result into losing context around the edges of each examples. 123 | """ 124 | if labels is not None: 125 | assert len(sentences) == len(labels) 126 | if confidences is not None: 127 | assert len(sentences) == len(confidences) 128 | if additional_features is not None: 129 | assert len(sentences) == len(additional_features) 130 | 131 | if len(sentences) > self.max_sent_per_example and self.max_sent_per_example > 0: 132 | i = len(sentences) // 2 133 | l1 = self.enforce_max_sent_per_example( 134 | sentences[:i], None if labels is None else labels[:i], 135 | None if confidences is None else confidences[:i], 136 | None if additional_features is None else additional_features[:i]) 137 | l2 = self.enforce_max_sent_per_example( 138 | sentences[i:], None if labels is None else labels[i:], 139 | None if confidences is None else confidences[i:], 140 | None if additional_features is None else additional_features[i:]) 141 | return l1 + l2 142 | else: 143 | return [(sentences, labels, confidences, additional_features)] 144 | 145 | def is_bad_sentence(self, sentence: str): 146 | if len(sentence) > 10 and len(sentence) < 600: 147 | return False 148 | else: 149 | return True 150 | 151 | def filter_bad_sci_sum_sentences(self, sentences, labels): 152 | filtered_sentences = [] 153 | filtered_labels = [] 154 | if not self.predict: 155 | for sentence, label in zip(sentences, labels): 156 | # most sentences outside of this range are bad sentences 157 | if not self.is_bad_sentence(sentence): 158 | filtered_sentences.append(sentence) 159 | filtered_labels.append(label) 160 | else: 161 | filtered_sentences.append("BADSENTENCE") 162 | filtered_labels.append(0.000001) 163 | sentences = filtered_sentences 164 | labels = filtered_labels 165 | else: 166 | for sentence in sentences: 167 | # most sentences outside of this range are bad sentences 168 | if not self.is_bad_sentence(sentence): 169 | filtered_sentences.append(sentence) 170 | else: 171 | filtered_sentences.append("BADSENTENCE") 172 | sentences = filtered_sentences 173 | 174 | return sentences, labels 175 | 176 | @overrides 177 | def text_to_instance(self, 178 | sentences: List[str], 179 | labels: List[str] = None, 180 | confidences: List[float] = None, 181 | additional_features: List[float] = None, 182 | ) -> Instance: 183 | if not self.predict: 184 | assert len(sentences) == len(labels) 185 | if confidences is not None: 186 | assert len(sentences) == len(confidences) 187 | if additional_features is not None: 188 | assert len(sentences) == len(additional_features) 189 | 190 | if self.use_sep: 191 | tokenized_sentences = [self._tokenizer.tokenize(s)[:self.sent_max_len] + [Token("[SEP]")] for s in sentences] 192 | sentences = [list(itertools.chain.from_iterable(tokenized_sentences))[:-1]] 193 | else: 194 | # Tokenize the sentences 195 | sentences = [ 196 | self._tokenizer.tokenize(sentence_text)[:self.sent_max_len] 197 | for sentence_text in sentences 198 | ] 199 | 200 | fields: Dict[str, Field] = {} 201 | fields["sentences"] = ListField([ 202 | TextField(sentence, self._token_indexers) 203 | for sentence in sentences 204 | ]) 205 | 206 | if labels is not None: 207 | if isinstance(labels[0], list): 208 | fields["labels"] = ListField([ 209 | MultiLabelField(label) for label in labels 210 | ]) 211 | else: 212 | # make the labels strings for easier identification of the neutral label 213 | # probably not strictly necessary 214 | if self.sci_sum: 215 | fields["labels"] = ArrayField(np.array(labels)) 216 | else: 217 | fields["labels"] = ListField([ 218 | LabelField(str(label)+"_label") for label in labels 219 | ]) 220 | 221 | if confidences is not None: 222 | fields['confidences'] = ArrayField(np.array(confidences)) 223 | if additional_features is not None: 224 | fields["additional_features"] = ArrayField(np.array(additional_features)) 225 | 226 | return Instance(fields) -------------------------------------------------------------------------------- /sequential_sentence_classification/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict 3 | 4 | import torch 5 | from torch.nn import Linear 6 | from allennlp.data import Vocabulary 7 | from allennlp.models.model import Model 8 | from allennlp.modules import TextFieldEmbedder, TimeDistributed, Seq2SeqEncoder 9 | from allennlp.nn.util import get_text_field_mask 10 | from allennlp.training.metrics import F1Measure, CategoricalAccuracy 11 | from allennlp.modules.conditional_random_field import ConditionalRandomField 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | @Model.register("SeqClassificationModel") 16 | class SeqClassificationModel(Model): 17 | """ 18 | Question answering model where answers are sentences 19 | """ 20 | 21 | def __init__(self, vocab: Vocabulary, 22 | text_field_embedder: TextFieldEmbedder, 23 | use_sep: bool = True, 24 | with_crf: bool = False, 25 | self_attn: Seq2SeqEncoder = None, 26 | bert_dropout: float = 0.1, 27 | sci_sum: bool = False, 28 | additional_feature_size: int = 0, 29 | ) -> None: 30 | super(SeqClassificationModel, self).__init__(vocab) 31 | 32 | self.text_field_embedder = text_field_embedder 33 | self.vocab = vocab 34 | self.use_sep = use_sep 35 | self.with_crf = with_crf 36 | self.sci_sum = sci_sum 37 | self.self_attn = self_attn 38 | self.additional_feature_size = additional_feature_size 39 | 40 | self.dropout = torch.nn.Dropout(p=bert_dropout) 41 | 42 | # define loss 43 | if self.sci_sum: 44 | self.loss = torch.nn.MSELoss(reduction='none') # labels are rouge scores 45 | self.labels_are_scores = True 46 | self.num_labels = 1 47 | else: 48 | self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none') 49 | self.labels_are_scores = False 50 | self.num_labels = self.vocab.get_vocab_size(namespace='labels') 51 | # define accuracy metrics 52 | self.label_accuracy = CategoricalAccuracy() 53 | self.label_f1_metrics = {} 54 | 55 | # define F1 metrics per label 56 | for label_index in range(self.num_labels): 57 | label_name = self.vocab.get_token_from_index(namespace='labels', index=label_index) 58 | self.label_f1_metrics[label_name] = F1Measure(label_index) 59 | 60 | encoded_senetence_dim = text_field_embedder._token_embedders['bert'].output_dim 61 | 62 | ff_in_dim = encoded_senetence_dim if self.use_sep else self_attn.get_output_dim() 63 | ff_in_dim += self.additional_feature_size 64 | 65 | self.time_distributed_aggregate_feedforward = TimeDistributed(Linear(ff_in_dim, self.num_labels)) 66 | 67 | if self.with_crf: 68 | self.crf = ConditionalRandomField( 69 | self.num_labels, constraints=None, 70 | include_start_end_transitions=True 71 | ) 72 | 73 | def forward(self, # type: ignore 74 | sentences: torch.LongTensor, 75 | labels: torch.IntTensor = None, 76 | confidences: torch.Tensor = None, 77 | additional_features: torch.Tensor = None, 78 | ) -> Dict[str, torch.Tensor]: 79 | # pylint: disable=arguments-differ 80 | """ 81 | Parameters 82 | ---------- 83 | TODO: add description 84 | 85 | Returns 86 | ------- 87 | An output dictionary consisting of: 88 | loss : torch.FloatTensor, optional 89 | A scalar loss to be optimised. 90 | """ 91 | # =========================================================================================================== 92 | # Layer 1: For each sentence, participant pair: create a Glove embedding for each token 93 | # Input: sentences 94 | # Output: embedded_sentences 95 | 96 | # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size 97 | embedded_sentences = self.text_field_embedder(sentences) 98 | mask = get_text_field_mask(sentences, num_wrapping_dims=1).float() 99 | batch_size, num_sentences, _, _ = embedded_sentences.size() 100 | 101 | if self.use_sep: 102 | # The following code collects vectors of the SEP tokens from all the examples in the batch, 103 | # and arrange them in one list. It does the same for the labels and confidences. 104 | # TODO: replace 103 with '[SEP]' 105 | sentences_mask = sentences['bert'] == 103 # mask for all the SEP tokens in the batch 106 | embedded_sentences = embedded_sentences[sentences_mask] # given batch_size x num_sentences_per_example x sent_len x vector_len 107 | # returns num_sentences_per_batch x vector_len 108 | assert embedded_sentences.dim() == 2 109 | num_sentences = embedded_sentences.shape[0] 110 | # for the rest of the code in this model to work, think of the data we have as one example 111 | # with so many sentences and a batch of size 1 112 | batch_size = 1 113 | embedded_sentences = embedded_sentences.unsqueeze(dim=0) 114 | embedded_sentences = self.dropout(embedded_sentences) 115 | 116 | if labels is not None: 117 | if self.labels_are_scores: 118 | labels_mask = labels != 0.0 # mask for all the labels in the batch (no padding) 119 | else: 120 | labels_mask = labels != -1 # mask for all the labels in the batch (no padding) 121 | 122 | labels = labels[labels_mask] # given batch_size x num_sentences_per_example return num_sentences_per_batch 123 | assert labels.dim() == 1 124 | if confidences is not None: 125 | confidences = confidences[labels_mask] 126 | assert confidences.dim() == 1 127 | if additional_features is not None: 128 | additional_features = additional_features[labels_mask] 129 | assert additional_features.dim() == 2 130 | 131 | num_labels = labels.shape[0] 132 | if num_labels != num_sentences: # bert truncates long sentences, so some of the SEP tokens might be gone 133 | assert num_labels > num_sentences # but `num_labels` should be at least greater than `num_sentences` 134 | logger.warning(f'Found {num_labels} labels but {num_sentences} sentences') 135 | labels = labels[:num_sentences] # Ignore some labels. This is ok for training but bad for testing. 136 | # We are ignoring this problem for now. 137 | # TODO: fix, at least for testing 138 | 139 | # do the same for `confidences` 140 | if confidences is not None: 141 | num_confidences = confidences.shape[0] 142 | if num_confidences != num_sentences: 143 | assert num_confidences > num_sentences 144 | confidences = confidences[:num_sentences] 145 | 146 | # and for `additional_features` 147 | if additional_features is not None: 148 | num_additional_features = additional_features.shape[0] 149 | if num_additional_features != num_sentences: 150 | assert num_additional_features > num_sentences 151 | additional_features = additional_features[:num_sentences] 152 | 153 | # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1 154 | labels = labels.unsqueeze(dim=0) 155 | if confidences is not None: 156 | confidences = confidences.unsqueeze(dim=0) 157 | if additional_features is not None: 158 | additional_features = additional_features.unsqueeze(dim=0) 159 | else: 160 | # ['CLS'] token 161 | embedded_sentences = embedded_sentences[:, :, 0, :] 162 | embedded_sentences = self.dropout(embedded_sentences) 163 | batch_size, num_sentences, _ = embedded_sentences.size() 164 | sent_mask = (mask.sum(dim=2) != 0) 165 | embedded_sentences = self.self_attn(embedded_sentences, sent_mask) 166 | 167 | if additional_features is not None: 168 | embedded_sentences = torch.cat((embedded_sentences, additional_features), dim=-1) 169 | 170 | label_logits = self.time_distributed_aggregate_feedforward(embedded_sentences) 171 | # label_logits: batch_size, num_sentences, num_labels 172 | 173 | if self.labels_are_scores: 174 | label_probs = label_logits 175 | else: 176 | label_probs = torch.nn.functional.softmax(label_logits, dim=-1) 177 | 178 | # Create output dictionary for the trainer 179 | # Compute loss and epoch metrics 180 | output_dict = {"action_probs": label_probs} 181 | 182 | # ===================================================================== 183 | 184 | if self.with_crf: 185 | # Layer 4 = CRF layer across labels of sentences in an abstract 186 | mask_sentences = (labels != -1) 187 | best_paths = self.crf.viterbi_tags(label_logits, mask_sentences) 188 | # 189 | # # Just get the tags and ignore the score. 190 | predicted_labels = [x for x, y in best_paths] 191 | # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}") 192 | 193 | label_loss = 0.0 194 | if labels is not None: 195 | # Compute cross entropy loss 196 | flattened_logits = label_logits.view((batch_size * num_sentences), self.num_labels) 197 | flattened_gold = labels.contiguous().view(-1) 198 | 199 | if not self.with_crf: 200 | label_loss = self.loss(flattened_logits.squeeze(), flattened_gold) 201 | if confidences is not None: 202 | label_loss = label_loss * confidences.type_as(label_loss).view(-1) 203 | label_loss = label_loss.mean() 204 | flattened_probs = torch.softmax(flattened_logits, dim=-1) 205 | else: 206 | clamped_labels = torch.clamp(labels, min=0) 207 | log_likelihood = self.crf(label_logits, clamped_labels, mask_sentences) 208 | label_loss = -log_likelihood 209 | # compute categorical accuracy 210 | crf_label_probs = label_logits * 0. 211 | for i, instance_labels in enumerate(predicted_labels): 212 | for j, label_id in enumerate(instance_labels): 213 | crf_label_probs[i, j, label_id] = 1 214 | flattened_probs = crf_label_probs.view((batch_size * num_sentences), self.num_labels) 215 | 216 | if not self.labels_are_scores: 217 | evaluation_mask = (flattened_gold != -1) 218 | self.label_accuracy(flattened_probs.float().contiguous(), flattened_gold.squeeze(-1), mask=evaluation_mask) 219 | 220 | # compute F1 per label 221 | for label_index in range(self.num_labels): 222 | label_name = self.vocab.get_token_from_index(namespace='labels', index=label_index) 223 | metric = self.label_f1_metrics[label_name] 224 | metric(flattened_probs, flattened_gold, mask=evaluation_mask) 225 | 226 | if labels is not None: 227 | output_dict["loss"] = label_loss 228 | output_dict['action_logits'] = label_logits 229 | return output_dict 230 | 231 | def get_metrics(self, reset: bool = False): 232 | metric_dict = {} 233 | 234 | if not self.labels_are_scores: 235 | type_accuracy = self.label_accuracy.get_metric(reset) 236 | metric_dict['acc'] = type_accuracy 237 | 238 | average_F1 = 0.0 239 | for name, metric in self.label_f1_metrics.items(): 240 | metric_val = metric.get_metric(reset) 241 | metric_dict[name + 'F'] = metric_val[2] 242 | average_F1 += metric_val[2] 243 | 244 | average_F1 /= len(self.label_f1_metrics.items()) 245 | metric_dict['avgF'] = average_F1 246 | 247 | return metric_dict 248 | -------------------------------------------------------------------------------- /sequential_sentence_classification/predictor.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from overrides import overrides 3 | 4 | from allennlp.common.util import JsonDict, sanitize 5 | from allennlp.data import Instance 6 | from allennlp.predictors.predictor import Predictor 7 | 8 | 9 | @Predictor.register('SeqClassificationPredictor') 10 | class SeqClassificationPredictor(Predictor): 11 | """ 12 | Predictor for the abstruct model 13 | """ 14 | def predict_json(self, json_dict: JsonDict) -> JsonDict: 15 | pred_labels = [] 16 | sentences = json_dict['sentences'] 17 | paper_id = json_dict['abstract_id'] 18 | for sentences_loop, _, _, _ in \ 19 | self._dataset_reader.enforce_max_sent_per_example(sentences): 20 | instance = self._dataset_reader.text_to_instance(abstract_id=0, sentences=sentences_loop) 21 | output = self._model.forward_on_instance(instance) 22 | idx = output['action_probs'].argmax(axis=1).tolist() 23 | labels = [self._model.vocab.get_token_from_index(i, namespace='labels') for i in idx] 24 | pred_labels.extend(labels) 25 | assert len(pred_labels) == len(sentences) 26 | preds = list(zip(sentences, pred_labels)) 27 | return paper_id, preds 28 | --------------------------------------------------------------------------------