├── .gitignore ├── .pylintrc ├── README.md ├── adversarialnlp ├── __init__.py ├── commands │ ├── __init__.py │ └── test_install.py ├── common │ ├── __init__.py │ └── file_utils.py ├── generators │ ├── __init__.py │ ├── addsent │ │ ├── __init__.py │ │ ├── addsent_generator.py │ │ ├── corenlp.py │ │ ├── rules │ │ │ ├── __init__.py │ │ │ ├── alteration_rules.py │ │ │ ├── answer_rules.py │ │ │ └── conversion_rules.py │ │ ├── squad_reader.py │ │ └── utils.py │ ├── generator.py │ └── swag │ │ ├── __init__.py │ │ ├── activitynet_captions_reader.py │ │ ├── openai_transformer_model.py │ │ ├── simple_bilm.py │ │ ├── swag_generator.py │ │ └── utils.py ├── pruners │ ├── __init__.py │ └── pruner.py ├── run.py ├── tests │ ├── __init__.py │ ├── dataset_readers │ │ ├── __init__.py │ │ └── activitynet_captions_test.py │ ├── fixtures │ │ ├── activitynet_captions.json │ │ └── squad.json │ └── generators │ │ ├── __init__.py │ │ ├── addsent_generator_test.py │ │ └── swag_generator_test.py └── version.py ├── bin └── adversarialnlp ├── docs ├── Makefile ├── common.rst ├── conf.py ├── generators.rst ├── index.rst ├── make.bat ├── readme.rst └── readthedoc_requirements.txt ├── pytest.ini ├── readthedocs.yml ├── requirements.txt ├── setup.cfg ├── setup.py └── tutorials └── usage.py /.gitignore: -------------------------------------------------------------------------------- 1 | # build artifacts 2 | 3 | .eggs/ 4 | .mypy_cache 5 | adversarialnlp.egg-info/ 6 | build/ 7 | dist/ 8 | data/* 9 | lib/* 10 | 11 | # dev tools 12 | 13 | .envrc 14 | .python-version 15 | 16 | 17 | # jupyter notebooks 18 | 19 | .ipynb_checkpoints 20 | 21 | 22 | # miscellaneous 23 | 24 | .cache/ 25 | .vscode/ 26 | docs/_build/ 27 | 28 | 29 | # python 30 | 31 | *.pyc 32 | *.pyo 33 | __pycache__ 34 | 35 | 36 | # testing and continuous integration 37 | 38 | .coverage 39 | .pytest_cache/ 40 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # Specify a configuration file. 4 | #rcfile= 5 | 6 | # Python code to execute, usually for sys.path manipulation such as 7 | # pygtk.require(). 8 | init-hook='import sys; sys.path.append("./")' 9 | 10 | # Add files or directories to the blacklist. They should be base names, not 11 | # paths. 12 | ignore=CVS,custom_extensions 13 | 14 | # Add files or directories matching the regex patterns to the blacklist. The 15 | # regex matches against base names, not paths. 16 | ignore-patterns= 17 | 18 | # Pickle collected data for later comparisons. 19 | persistent=yes 20 | 21 | # List of plugins (as comma separated values of python modules names) to load, 22 | # usually to register additional checkers. 23 | load-plugins= 24 | 25 | # Use multiple processes to speed up Pylint. 26 | jobs=4 27 | 28 | # Allow loading of arbitrary C extensions. Extensions are imported into the 29 | # active Python interpreter and may run arbitrary code. 30 | unsafe-load-any-extension=no 31 | 32 | # A comma-separated list of package or module names from where C extensions may 33 | # be loaded. Extensions are loading into the active Python interpreter and may 34 | # run arbitrary code 35 | extension-pkg-whitelist=numpy,torch,spacy,_jsonnet 36 | 37 | # Allow optimization of some AST trees. This will activate a peephole AST 38 | # optimizer, which will apply various small optimizations. For instance, it can 39 | # be used to obtain the result of joining multiple strings with the addition 40 | # operator. Joining a lot of strings can lead to a maximum recursion error in 41 | # Pylint and this flag can prevent that. It has one side effect, the resulting 42 | # AST will be different than the one from reality. This option is deprecated 43 | # and it will be removed in Pylint 2.0. 44 | optimize-ast=no 45 | 46 | 47 | [MESSAGES CONTROL] 48 | 49 | # Only show warnings with the listed confidence levels. Leave empty to show 50 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 51 | confidence= 52 | 53 | # Enable the message, report, category or checker with the given id(s). You can 54 | # either give multiple identifier separated by comma (,) or put this option 55 | # multiple time (only on the command line, not in the configuration file where 56 | # it should appear only once). See also the "--disable" option for examples. 57 | #enable= 58 | 59 | # Disable the message, report, category or checker with the given id(s). You 60 | # can either give multiple identifiers separated by comma (,) or put this 61 | # option multiple times (only on the command line, not in the configuration 62 | # file where it should appear only once).You can also use "--disable=all" to 63 | # disable everything first and then reenable specific checks. For example, if 64 | # you want to run only the similarities checker, you can use "--disable=all 65 | # --enable=similarities". If you want to run only the classes checker, but have 66 | # no Warning level messages displayed, use"--disable=all --enable=classes 67 | # --disable=W" 68 | disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating,missing-docstring,too-many-arguments,too-many-locals,too-many-statements,too-many-branches,too-many-nested-blocks,too-many-instance-attributes,fixme,too-few-public-methods,no-else-return 69 | 70 | 71 | [REPORTS] 72 | 73 | # Set the output format. Available formats are text, parseable, colorized, msvs 74 | # (visual studio) and html. You can also give a reporter class, eg 75 | # mypackage.mymodule.MyReporterClass. 76 | output-format=text 77 | 78 | # Put messages in a separate file for each module / package specified on the 79 | # command line instead of printing them on stdout. Reports (if any) will be 80 | # written in a file name "pylint_global.[txt|html]". This option is deprecated 81 | # and it will be removed in Pylint 2.0. 82 | files-output=no 83 | 84 | # Tells whether to display a full report or only the messages 85 | reports=yes 86 | 87 | # Python expression which should return a note less than 10 (10 is the highest 88 | # note). You have access to the variables errors warning, statement which 89 | # respectively contain the number of errors / warnings messages and the total 90 | # number of statements analyzed. This is used by the global evaluation report 91 | # (RP0004). 92 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 93 | 94 | # Template used to display messages. This is a python new-style format string 95 | # used to format the message information. See doc for all details 96 | #msg-template= 97 | 98 | 99 | [LOGGING] 100 | 101 | # Logging modules to check that the string format arguments are in logging 102 | # function parameter format 103 | logging-modules=logging 104 | 105 | 106 | [TYPECHECK] 107 | 108 | # Tells whether missing members accessed in mixin class should be ignored. A 109 | # mixin class is detected if its name ends with "mixin" (case insensitive). 110 | ignore-mixin-members=yes 111 | 112 | # List of module names for which member attributes should not be checked 113 | # (useful for modules/projects where namespaces are manipulated during runtime 114 | # and thus existing member attributes cannot be deduced by static analysis. It 115 | # supports qualified module names, as well as Unix pattern matching. 116 | ignored-modules= 117 | 118 | # List of class names for which member attributes should not be checked (useful 119 | # for classes with dynamically set attributes). This supports the use of 120 | # qualified names. 121 | ignored-classes=optparse.Values,thread._local,_thread._local,responses 122 | 123 | # List of members which are set dynamically and missed by pylint inference 124 | # system, and so shouldn't trigger E1101 when accessed. Python regular 125 | # expressions are accepted. 126 | generated-members=torch.* 127 | 128 | # List of decorators that produce context managers, such as 129 | # contextlib.contextmanager. Add to this list to register other decorators that 130 | # produce valid context managers. 131 | contextmanager-decorators=contextlib.contextmanager 132 | 133 | 134 | [SIMILARITIES] 135 | 136 | # Minimum lines number of a similarity. 137 | min-similarity-lines=4 138 | 139 | # Ignore comments when computing similarities. 140 | ignore-comments=yes 141 | 142 | # Ignore docstrings when computing similarities. 143 | ignore-docstrings=yes 144 | 145 | # Ignore imports when computing similarities. 146 | ignore-imports=no 147 | 148 | 149 | [FORMAT] 150 | 151 | # Maximum number of characters on a single line. Ideally, lines should be under 100 characters, 152 | # but we allow some leeway before calling it an error. 153 | max-line-length=115 154 | 155 | # Regexp for a line that is allowed to be longer than the limit. 156 | ignore-long-lines=^\s*(# )??$ 157 | 158 | # Allow the body of an if to be on the same line as the test if there is no 159 | # else. 160 | single-line-if-stmt=no 161 | 162 | # List of optional constructs for which whitespace checking is disabled. `dict- 163 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 164 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 165 | # `empty-line` allows space-only lines. 166 | no-space-check=trailing-comma,dict-separator 167 | 168 | # Maximum number of lines in a module 169 | max-module-lines=1000 170 | 171 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 172 | # tab). 173 | indent-string=' ' 174 | 175 | # Number of spaces of indent required inside a hanging or continued line. 176 | indent-after-paren=8 177 | 178 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 179 | expected-line-ending-format= 180 | 181 | 182 | [BASIC] 183 | 184 | # Good variable names which should always be accepted, separated by a comma 185 | good-names=i,j,k,ex,Run,_ 186 | 187 | # Bad variable names which should always be refused, separated by a comma 188 | bad-names=foo,bar,baz,toto,tutu,tata 189 | 190 | # Colon-delimited sets of names that determine each other's naming style when 191 | # the name regexes allow several styles. 192 | name-group= 193 | 194 | # Include a hint for the correct naming format with invalid-name 195 | include-naming-hint=no 196 | 197 | # List of decorators that produce properties, such as abc.abstractproperty. Add 198 | # to this list to register other decorators that produce valid properties. 199 | property-classes=abc.abstractproperty 200 | 201 | # Regular expression matching correct function names 202 | function-rgx=[a-z_][a-z0-9_]{2,40}$ 203 | 204 | # Naming hint for function names 205 | function-name-hint=[a-z_][a-z0-9_]{2,40}$ 206 | 207 | # Regular expression matching correct variable names 208 | variable-rgx=[a-z_][a-z0-9_]{2,40}$ 209 | 210 | # Naming hint for variable names 211 | variable-name-hint=[a-z_][a-z0-9_]{2,40}$ 212 | 213 | # Regular expression matching correct constant names 214 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 215 | 216 | # Naming hint for constant names 217 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 218 | 219 | # Regular expression matching correct attribute names 220 | attr-rgx=[a-z_][a-z0-9_]{2,40}$ 221 | 222 | # Naming hint for attribute names 223 | attr-name-hint=[a-z_][a-z0-9_]{2,40}$ 224 | 225 | # Regular expression matching correct argument names 226 | argument-rgx=[a-z_][a-z0-9_]{2,40}$ 227 | 228 | # Naming hint for argument names 229 | argument-name-hint=[a-z_][a-z0-9_]{2,40}$ 230 | 231 | # Regular expression matching correct class attribute names 232 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,40}|(__.*__))$ 233 | 234 | # Naming hint for class attribute names 235 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,40}|(__.*__))$ 236 | 237 | # Regular expression matching correct inline iteration names 238 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ 239 | 240 | # Naming hint for inline iteration names 241 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ 242 | 243 | # Regular expression matching correct class names 244 | class-rgx=[A-Z_][a-zA-Z0-9]+$ 245 | 246 | # Naming hint for class names 247 | class-name-hint=[A-Z_][a-zA-Z0-9]+$ 248 | 249 | # Regular expression matching correct module names 250 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 251 | 252 | # Naming hint for module names 253 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 254 | 255 | # Regular expression matching correct method names 256 | method-rgx=[a-z_][a-z0-9_]{2,40}$ 257 | 258 | # Naming hint for method names 259 | method-name-hint=[a-z_][a-z0-9_]{2,40}$ 260 | 261 | # Regular expression which should only match function or class names that do 262 | # not require a docstring. 263 | no-docstring-rgx=^_ 264 | 265 | # Minimum line length for functions/classes that require docstrings, shorter 266 | # ones are exempt. 267 | docstring-min-length=-1 268 | 269 | 270 | [ELIF] 271 | 272 | # Maximum number of nested blocks for function / method body 273 | max-nested-blocks=5 274 | 275 | 276 | [VARIABLES] 277 | 278 | # Tells whether we should check for unused import in __init__ files. 279 | init-import=no 280 | 281 | # A regular expression matching the name of dummy variables (i.e. expectedly 282 | # not used). 283 | dummy-variables-rgx=(_+[a-zA-Z0-9]*?$)|dummy 284 | 285 | # List of additional names supposed to be defined in builtins. Remember that 286 | # you should avoid to define new builtins when possible. 287 | additional-builtins= 288 | 289 | # List of strings which can identify a callback function by name. A callback 290 | # name must start or end with one of those strings. 291 | callbacks=cb_,_cb 292 | 293 | # List of qualified module names which can have objects that can redefine 294 | # builtins. 295 | redefining-builtins-modules=six.moves,future.builtins 296 | 297 | 298 | [SPELLING] 299 | 300 | # Spelling dictionary name. Available dictionaries: none. To make it working 301 | # install python-enchant package. 302 | spelling-dict= 303 | 304 | # List of comma separated words that should not be checked. 305 | spelling-ignore-words= 306 | 307 | # A path to a file that contains private dictionary; one word per line. 308 | spelling-private-dict-file= 309 | 310 | # Tells whether to store unknown words to indicated private dictionary in 311 | # --spelling-private-dict-file option instead of raising a message. 312 | spelling-store-unknown-words=no 313 | 314 | 315 | [MISCELLANEOUS] 316 | 317 | # List of note tags to take in consideration, separated by a comma. 318 | notes=FIXME,XXX,TODO 319 | 320 | 321 | [DESIGN] 322 | 323 | # Maximum number of arguments for function / method 324 | max-args=5 325 | 326 | # Argument names that match this expression will be ignored. Default to name 327 | # with leading underscore 328 | ignored-argument-names=_.* 329 | 330 | # Maximum number of locals for function / method body 331 | max-locals=15 332 | 333 | # Maximum number of return / yield for function / method body 334 | max-returns=6 335 | 336 | # Maximum number of branch for function / method body 337 | max-branches=12 338 | 339 | # Maximum number of statements in function / method body 340 | max-statements=50 341 | 342 | # Maximum number of parents for a class (see R0901). 343 | max-parents=7 344 | 345 | # Maximum number of attributes for a class (see R0902). 346 | max-attributes=7 347 | 348 | # Minimum number of public methods for a class (see R0903). 349 | min-public-methods=2 350 | 351 | # Maximum number of public methods for a class (see R0904). 352 | max-public-methods=20 353 | 354 | # Maximum number of boolean expressions in a if statement 355 | max-bool-expr=5 356 | 357 | 358 | [CLASSES] 359 | 360 | # List of method names used to declare (i.e. assign) instance attributes. 361 | defining-attr-methods=__init__,__new__,setUp 362 | 363 | # List of valid names for the first argument in a class method. 364 | valid-classmethod-first-arg=cls 365 | 366 | # List of valid names for the first argument in a metaclass class method. 367 | valid-metaclass-classmethod-first-arg=mcs 368 | 369 | # List of member names, which should be excluded from the protected access 370 | # warning. 371 | exclude-protected=_asdict,_fields,_replace,_source,_make 372 | 373 | 374 | [IMPORTS] 375 | 376 | # Deprecated modules which should not be used, separated by a comma 377 | deprecated-modules=regsub,TERMIOS,Bastion,rexec 378 | 379 | # Create a graph of every (i.e. internal and external) dependencies in the 380 | # given file (report RP0402 must not be disabled) 381 | import-graph= 382 | 383 | # Create a graph of external dependencies in the given file (report RP0402 must 384 | # not be disabled) 385 | ext-import-graph= 386 | 387 | # Create a graph of internal dependencies in the given file (report RP0402 must 388 | # not be disabled) 389 | int-import-graph= 390 | 391 | # Force import order to recognize a module as part of the standard 392 | # compatibility libraries. 393 | known-standard-library= 394 | 395 | # Force import order to recognize a module as part of a third party library. 396 | known-third-party=enchant 397 | 398 | # Analyse import fallback blocks. This can be used to support both Python 2 and 399 | # 3 compatible code, which means that the block might have code that exists 400 | # only in one or another interpreter, leading to false positives when analysed. 401 | analyse-fallback-blocks=no 402 | 403 | 404 | [EXCEPTIONS] 405 | 406 | # Exceptions that will emit a warning when being caught. Defaults to 407 | # "Exception" 408 | overgeneral-exceptions=Exception 409 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdversarialNLP - WIP 2 | 3 | AdversarialNLP is a generic library for crafting and using Adversarial NLP examples. 4 | 5 | Work in Progress 6 | 7 | ## Installation 8 | 9 | AdversarialNLP requires Python 3.6.1 or later. The preferred way to install AdversarialNLP is via `pip`. Just run `pip install adversarialnlp` in your Python environment and you're good to go! 10 | -------------------------------------------------------------------------------- /adversarialnlp/__init__.py: -------------------------------------------------------------------------------- 1 | from adversarialnlp.version import VERSION as __version__ 2 | -------------------------------------------------------------------------------- /adversarialnlp/commands/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import argparse 3 | import logging 4 | 5 | from allennlp.commands.subcommand import Subcommand 6 | from allennlp.common.util import import_submodules 7 | 8 | from adversarialnlp import __version__ 9 | from adversarialnlp.commands.test_install import TestInstall 10 | 11 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 12 | 13 | 14 | def main(prog: str = None, 15 | subcommand_overrides: Dict[str, Subcommand] = {}) -> None: 16 | """ 17 | :mod:`~adversarialnlp.run` command. 18 | """ 19 | # pylint: disable=dangerous-default-value 20 | parser = argparse.ArgumentParser(description="Run AdversarialNLP", usage='%(prog)s', prog=prog) 21 | parser.add_argument('--version', action='version', version='%(prog)s ' + __version__) 22 | 23 | subparsers = parser.add_subparsers(title='Commands', metavar='') 24 | 25 | subcommands = { 26 | # Default commands 27 | "test-install": TestInstall(), 28 | 29 | # Superseded by overrides 30 | **subcommand_overrides 31 | } 32 | 33 | for name, subcommand in subcommands.items(): 34 | subparser = subcommand.add_subparser(name, subparsers) 35 | # configure doesn't need include-package because it imports 36 | # whatever classes it needs. 37 | if name != "configure": 38 | subparser.add_argument('--include-package', 39 | type=str, 40 | action='append', 41 | default=[], 42 | help='additional packages to include') 43 | 44 | args = parser.parse_args() 45 | 46 | # If a subparser is triggered, it adds its work as `args.func`. 47 | # So if no such attribute has been added, no subparser was triggered, 48 | # so give the user some help. 49 | if 'func' in dir(args): 50 | # Import any additional modules needed (to register custom classes). 51 | for package_name in getattr(args, 'include_package', ()): 52 | import_submodules(package_name) 53 | args.func(args) 54 | else: 55 | parser.print_help() 56 | -------------------------------------------------------------------------------- /adversarialnlp/commands/test_install.py: -------------------------------------------------------------------------------- 1 | """ 2 | The ``test-install`` subcommand verifies 3 | an installation by running the unit tests. 4 | .. code-block:: bash 5 | $ adversarialnlp test-install --help 6 | usage: adversarialnlp test-install [-h] [--run-all] 7 | [--include-package INCLUDE_PACKAGE] 8 | Test that installation works by running the unit tests. 9 | optional arguments: 10 | -h, --help show this help message and exit 11 | --run-all By default, we skip tests that are slow or download 12 | large files. This flag will run all tests. 13 | --include-package INCLUDE_PACKAGE 14 | additional packages to include 15 | """ 16 | 17 | import argparse 18 | import logging 19 | import os 20 | import pathlib 21 | 22 | import pytest 23 | 24 | from allennlp.commands.subcommand import Subcommand 25 | 26 | import adversarialnlp 27 | 28 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 29 | 30 | class TestInstall(Subcommand): 31 | def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argparse.ArgumentParser: 32 | # pylint: disable=protected-access 33 | description = '''Test that installation works by running the unit tests.''' 34 | subparser = parser.add_parser( 35 | name, description=description, help='Run the unit tests.') 36 | 37 | subparser.add_argument('--run-all', action="store_true", 38 | help="By default, we skip tests that are slow " 39 | "or download large files. This flag will run all tests.") 40 | 41 | subparser.set_defaults(func=_run_test) 42 | 43 | return subparser 44 | 45 | 46 | def _get_module_root(): 47 | return pathlib.Path(adversarialnlp.__file__).parent 48 | 49 | 50 | def _run_test(args: argparse.Namespace): 51 | initial_working_dir = os.getcwd() 52 | module_parent = _get_module_root().parent 53 | logger.info("Changing directory to %s", module_parent) 54 | os.chdir(module_parent) 55 | test_dir = os.path.join(module_parent, "adversarialnlp") 56 | logger.info("Running tests at %s", test_dir) 57 | if args.run_all: 58 | # TODO(nfliu): remove this when notebooks have been rewritten as markdown. 59 | exit_code = pytest.main([test_dir, '--color=no', '-k', 'not notebooks_test']) 60 | else: 61 | exit_code = pytest.main([test_dir, '--color=no', '-k', 'not sniff_test and not notebooks_test', 62 | '-m', 'not java']) 63 | # Change back to original working directory after running tests 64 | os.chdir(initial_working_dir) 65 | exit(exit_code) 66 | -------------------------------------------------------------------------------- /adversarialnlp/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/adversarialnlp/543c02111c57bf245f2aa145c0e5a4879d151001/adversarialnlp/common/__init__.py -------------------------------------------------------------------------------- /adversarialnlp/common/file_utils.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name,protected-access 2 | #!/usr/bin/env python3 3 | # Copyright (c) 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | """ 9 | Utilities for downloading and building data. 10 | These can be replaced if your particular file system does not support them. 11 | """ 12 | from typing import Union, List 13 | from pathlib import Path 14 | import time 15 | import datetime 16 | import os 17 | import shutil 18 | import requests 19 | 20 | MODULE_ROOT = Path(__file__).parent.parent 21 | FIXTURES_ROOT = (MODULE_ROOT / "tests" / "fixtures").resolve() 22 | PACKAGE_ROOT = MODULE_ROOT.parent 23 | DATA_ROOT = (PACKAGE_ROOT / "data").resolve() 24 | 25 | class ProgressLogger(object): 26 | """Throttles and display progress in human readable form.""" 27 | 28 | def __init__(self, throttle=1, should_humanize=True): 29 | """Initialize Progress logger. 30 | :param throttle: default 1, number in seconds to use as throttle rate 31 | :param should_humanize: default True, whether to humanize data units 32 | """ 33 | self.latest = time.time() 34 | self.throttle_speed = throttle 35 | self.should_humanize = should_humanize 36 | 37 | def humanize(self, num, suffix='B'): 38 | """Convert units to more human-readable format.""" 39 | if num < 0: 40 | return num 41 | for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: 42 | if abs(num) < 1024.0: 43 | return "%3.1f%s%s" % (num, unit, suffix) 44 | num /= 1024.0 45 | return "%.1f%s%s" % (num, 'Yi', suffix) 46 | 47 | def log(self, curr, total, width=40, force=False): 48 | """Display a bar showing the current progress.""" 49 | if curr == 0 and total == -1: 50 | print('[ no data received for this file ]', end='\r') 51 | return 52 | curr_time = time.time() 53 | if not force and curr_time - self.latest < self.throttle_speed: 54 | return 55 | else: 56 | self.latest = curr_time 57 | 58 | self.latest = curr_time 59 | done = min(curr * width // total, width) 60 | remain = width - done 61 | 62 | if self.should_humanize: 63 | curr = self.humanize(curr) 64 | total = self.humanize(total) 65 | 66 | progress = '[{}{}] {} / {}'.format( 67 | ''.join(['|'] * done), 68 | ''.join(['.'] * remain), 69 | curr, 70 | total 71 | ) 72 | print(progress, end='\r') 73 | 74 | def built(path, version_string=None): 75 | """Checks if '.built' flag has been set for that task. 76 | If a version_string is provided, this has to match, or the version 77 | is regarded as not built. 78 | """ 79 | built_file_path = os.path.join(path, '.built') 80 | if not os.path.isfile(built_file_path): 81 | return False 82 | else: 83 | with open(built_file_path, 'r') as built_file: 84 | text = built_file.read().split('\n') 85 | if len(text) <= 2: 86 | return False 87 | for fname in text[1:-1]: 88 | if not os.path.isfile(os.path.join(path, fname)) and not os.path.isdir(os.path.join(path, fname)): 89 | return False 90 | return text[-1] == version_string if version_string else True 91 | 92 | 93 | def mark_done(path, fnames, version_string='vXX'): 94 | """Marks the path as done by adding a '.built' file with the current 95 | timestamp plus a version description string if specified. 96 | """ 97 | with open(os.path.join(path, '.built'), 'w') as built_file: 98 | built_file.write(str(datetime.datetime.today())) 99 | for fname in fnames: 100 | fname = fname.replace('.tar.gz', '').replace('.tgz', '').replace('.gz', '').replace('.zip', '') 101 | built_file.write('\n' + fname) 102 | built_file.write('\n' + version_string) 103 | 104 | 105 | def download(url, path, fname, redownload=False): 106 | """Downloads file using `requests`. If ``redownload`` is set to false, then 107 | will not download tar file again if it is present (default ``True``).""" 108 | outfile = os.path.join(path, fname) 109 | curr_download = not os.path.isfile(outfile) or redownload 110 | print("[ downloading: " + url + " to " + outfile + " ]") 111 | retry = 5 112 | exp_backoff = [2 ** r for r in reversed(range(retry))] 113 | 114 | logger = ProgressLogger() 115 | 116 | while curr_download and retry >= 0: 117 | resume_file = outfile + '.part' 118 | resume = os.path.isfile(resume_file) 119 | if resume: 120 | resume_pos = os.path.getsize(resume_file) 121 | mode = 'ab' 122 | else: 123 | resume_pos = 0 124 | mode = 'wb' 125 | response = None 126 | 127 | with requests.Session() as session: 128 | try: 129 | header = {'Range': 'bytes=%d-' % resume_pos, 130 | 'Accept-Encoding': 'identity'} if resume else {} 131 | response = session.get(url, stream=True, timeout=5, headers=header) 132 | 133 | # negative reply could be 'none' or just missing 134 | if resume and response.headers.get('Accept-Ranges', 'none') == 'none': 135 | resume_pos = 0 136 | mode = 'wb' 137 | 138 | CHUNK_SIZE = 32768 139 | total_size = int(response.headers.get('Content-Length', -1)) 140 | # server returns remaining size if resuming, so adjust total 141 | total_size += resume_pos 142 | done = resume_pos 143 | 144 | with open(resume_file, mode) as f: 145 | for chunk in response.iter_content(CHUNK_SIZE): 146 | if chunk: # filter out keep-alive new chunks 147 | f.write(chunk) 148 | if total_size > 0: 149 | done += len(chunk) 150 | if total_size < done: 151 | # don't freak out if content-length was too small 152 | total_size = done 153 | logger.log(done, total_size) 154 | break 155 | except requests.exceptions.ConnectionError: 156 | retry -= 1 157 | # TODO Better way to clean progress bar? 158 | print(''.join([' '] * 60), end='\r') 159 | if retry >= 0: 160 | print('Connection error, retrying. (%d retries left)' % retry) 161 | time.sleep(exp_backoff[retry]) 162 | else: 163 | print('Retried too many times, stopped retrying.') 164 | finally: 165 | if response: 166 | response.close() 167 | if retry < 0: 168 | raise RuntimeWarning('Connection broken too many times. Stopped retrying.') 169 | 170 | if curr_download and retry > 0: 171 | logger.log(done, total_size, force=True) 172 | print() 173 | if done < total_size: 174 | raise RuntimeWarning('Received less data than specified in ' + 175 | 'Content-Length header for ' + url + '.' + 176 | ' There may be a download problem.') 177 | move(resume_file, outfile) 178 | 179 | 180 | def make_dir(path): 181 | """Makes the directory and any nonexistent parent directories.""" 182 | # the current working directory is a fine path 183 | if path != '': 184 | os.makedirs(path, exist_ok=True) 185 | 186 | 187 | def move(path1, path2): 188 | """Renames the given file.""" 189 | shutil.move(path1, path2) 190 | 191 | 192 | def remove_dir(path): 193 | """Removes the given directory, if it exists.""" 194 | shutil.rmtree(path, ignore_errors=True) 195 | 196 | 197 | def untar(path, fname, deleteTar=True): 198 | """Unpacks the given archive file to the same directory, then (by default) 199 | deletes the archive file. 200 | """ 201 | print('unpacking ' + fname) 202 | fullpath = os.path.join(path, fname) 203 | if '.tar.gz' in fname: 204 | shutil.unpack_archive(fullpath, path, format='gztar') 205 | else: 206 | shutil.unpack_archive(fullpath, path) 207 | if deleteTar: 208 | os.remove(fullpath) 209 | 210 | 211 | def cat(file1, file2, outfile, deleteFiles=True): 212 | with open(outfile, 'wb') as wfd: 213 | for f in [file1, file2]: 214 | with open(f, 'rb') as fd: 215 | shutil.copyfileobj(fd, wfd, 1024 * 1024 * 10) 216 | # 10MB per writing chunk to avoid reading big file into memory. 217 | if deleteFiles: 218 | os.remove(file1) 219 | os.remove(file2) 220 | 221 | 222 | def _get_confirm_token(response): 223 | for key, value in response.cookies.items(): 224 | if key.startswith('download_warning'): 225 | return value 226 | return None 227 | 228 | def download_from_google_drive(gd_id, destination): 229 | """Uses the requests package to download a file from Google Drive.""" 230 | URL = 'https://docs.google.com/uc?export=download' 231 | 232 | with requests.Session() as session: 233 | response = session.get(URL, params={'id': gd_id}, stream=True) 234 | token = _get_confirm_token(response) 235 | 236 | if token: 237 | response.close() 238 | params = {'id': gd_id, 'confirm': token} 239 | response = session.get(URL, params=params, stream=True) 240 | 241 | CHUNK_SIZE = 32768 242 | with open(destination, 'wb') as f: 243 | for chunk in response.iter_content(CHUNK_SIZE): 244 | if chunk: # filter out keep-alive new chunks 245 | f.write(chunk) 246 | response.close() 247 | 248 | 249 | def download_files(fnames: List[Union[str, Path]], 250 | local_folder: str, 251 | version: str = 'v1.0', 252 | paths: Union[List[str], str] = 'aws') -> List[str]: 253 | r"""Download model/data files from a url. 254 | 255 | Args: 256 | fnames: List of filenames to download 257 | local_folder: Sub-folder of `./data` where models/data will 258 | be downloaded. 259 | version: Version of the model 260 | path: url or respective urls for downloading filenames. 261 | 262 | Return: 263 | List[str]: List of downloaded file path. 264 | If the downloaded file was a compressed file (`.tar.gz`, 265 | `.zip`, `.tgz`, `.gz`), return the path of the folder 266 | containing the extracted files. 267 | """ 268 | 269 | dpath = str(DATA_ROOT / local_folder) 270 | out_paths = list(dpath + '/' + fname.replace('.tar.gz', '').replace('.tgz', '').replace('.gz', '').replace('.zip', '') 271 | for fname in fnames) 272 | 273 | if not built(dpath, version): 274 | for fname in fnames: 275 | print('[building data: ' + dpath + '/' + fname + ']') 276 | if built(dpath): 277 | # An older version exists, so remove these outdated files. 278 | remove_dir(dpath) 279 | make_dir(dpath) 280 | 281 | if isinstance(paths, str): 282 | paths = [paths] * len(fnames) 283 | # Download the data. 284 | for fname, path in zip(fnames, paths): 285 | if path == 'aws': 286 | url = 'http://huggingface.co/downloads/models/' 287 | url += local_folder + '/' 288 | url += fname 289 | else: 290 | url = path + '/' + fname 291 | download(url, dpath, fname) 292 | if '.tar.gz' in fname or '.tgz' in fname or '.gz' in fname or '.zip' in fname: 293 | untar(dpath, fname) 294 | # Mark the data as built. 295 | mark_done(dpath, fnames, version) 296 | return out_paths 297 | -------------------------------------------------------------------------------- /adversarialnlp/generators/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import Generator 2 | from .swag import SwagGenerator 3 | from .addsent import AddSentGenerator 4 | -------------------------------------------------------------------------------- /adversarialnlp/generators/addsent/__init__.py: -------------------------------------------------------------------------------- 1 | from .addsent_generator import AddSentGenerator 2 | from .squad_reader import squad_reader 3 | -------------------------------------------------------------------------------- /adversarialnlp/generators/addsent/addsent_generator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import json 3 | import itertools 4 | from typing import Iterable, Dict, Tuple 5 | from collections import defaultdict 6 | 7 | from adversarialnlp.common.file_utils import download_files 8 | from adversarialnlp.generators import Generator 9 | from adversarialnlp.generators.addsent.rules import (ANSWER_RULES, HIGH_CONF_ALTER_RULES, ALL_ALTER_RULES, 10 | DO_NOT_ALTER, BAD_ALTERATIONS, CONVERSION_RULES) 11 | from adversarialnlp.generators.addsent.utils import (rejoin, ConstituencyParse, get_tokens_for_answers, 12 | get_determiner_for_answers, read_const_parse) 13 | from adversarialnlp.generators.addsent.squad_reader import squad_reader 14 | from adversarialnlp.generators.addsent.corenlp import StanfordCoreNLP 15 | 16 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 17 | 18 | SQUAD_FILE = 'data/squad/train-v1.1.json' 19 | NEARBY_GLOVE_FILE = 'data/addsent/nearby_n100_glove_6B_100d.json' 20 | POSTAG_FILE = 'data/addsent/postag_dict.json' 21 | 22 | class AddSentGenerator(Generator): 23 | r"""Adversarial examples generator based on AddSent. 24 | 25 | AddSent is described in the paper `Adversarial Examples for 26 | Evaluating Reading Comprehension Systems`_ 27 | by Robin Jia & Percy Liang 28 | 29 | Args, input and yield: 30 | See the ``Generator`` class. 31 | 32 | Additional arguments: 33 | alteration_strategy: Alteration strategy. Options: 34 | 35 | - `separate`: Do best alteration for each word separately. 36 | - `best`: Generate exactly one best alteration 37 | (may over-alter). 38 | - `high-conf`: Do all possible high-confidence alterations. 39 | - `high-conf-separate`: Do best high-confidence alteration 40 | for each word separately. 41 | - `all`: Do all possible alterations (very conservative) 42 | 43 | prepend: Insert adversarial example at the beginning 44 | or end of the context. 45 | use_answer_placeholder: Use and answer placeholder. 46 | 47 | Seeds: 48 | Tuple of SQuAD-like instances containing 49 | - question-answer-span, and 50 | - context paragraph. 51 | 52 | default_seeds: 53 | If no seeds are provided, the default_seeds are the training 54 | set of the 55 | `SQuAD V1.0 dataset `_. 56 | 57 | """ 58 | def __init__(self, 59 | alteration_strategy: str = 'high-conf', 60 | prepend: bool = False, 61 | use_answer_placeholder: bool = False, 62 | default_seeds: Iterable = None, 63 | quiet: bool = False): 64 | super(AddSentGenerator).__init__(default_seeds, quiet) 65 | model_files = download_files(fnames=['nearby_n100_glove_6B_100d.json', 66 | 'postag_dict.json'], 67 | local_folder='addsent') 68 | corenlp_path = download_files(fnames=['stanford-corenlp-full-2018-02-27.zip'], 69 | paths='http://nlp.stanford.edu/software/', 70 | local_folder='corenlp') 71 | 72 | self.nlp: StanfordCoreNLP = StanfordCoreNLP(corenlp_path[0]) 73 | with open(model_files[0], 'r') as data_file: 74 | self.nearby_word_dict: Dict = json.load(data_file) 75 | with open(model_files[1], 'r') as data_file: 76 | self.postag_dict: Dict = json.load(data_file) 77 | 78 | self.alteration_strategy: str = alteration_strategy 79 | self.prepend: bool = prepend 80 | self.use_answer_placeholder: bool = use_answer_placeholder 81 | if default_seeds is None: 82 | self.default_seeds = squad_reader(SQUAD_FILE) 83 | else: 84 | self.default_seeds = default_seeds 85 | 86 | def close(self): 87 | self.nlp.close() 88 | 89 | def _annotate(self, text: str, annotators: str): 90 | r"""Wrapper to call CoreNLP. """ 91 | props = {'annotators': annotators, 92 | 'ssplit.newlineIsSentenceBreak': 'always', 93 | 'outputFormat':'json'} 94 | return json.loads(self.nlp.annotate(text, properties=props)) 95 | 96 | def _alter_question(self, question, tokens, const_parse): 97 | r"""Alter the question to make it ask something else. """ 98 | used_words = [tok['word'].lower() for tok in tokens] 99 | new_qs = [] 100 | toks_all = [] 101 | if self.alteration_strategy.startswith('high-conf'): 102 | rules = HIGH_CONF_ALTER_RULES 103 | else: 104 | rules = ALL_ALTER_RULES 105 | for i, tok in enumerate(tokens): 106 | if tok['word'].lower() in DO_NOT_ALTER: 107 | if self.alteration_strategy in ('high-conf', 'all'): 108 | toks_all.append(tok) 109 | continue 110 | begin = tokens[:i] 111 | end = tokens[i+1:] 112 | found = False 113 | for rule_name in rules: 114 | rule = rules[rule_name] 115 | new_words = rule(tok, nearby_word_dict=self.nearby_word_dict, postag_dict=self.postag_dict) 116 | if new_words: 117 | for word in new_words: 118 | if word.lower() in used_words: 119 | continue 120 | if word.lower() in BAD_ALTERATIONS: 121 | continue 122 | # Match capitzliation 123 | if tok['word'] == tok['word'].upper(): 124 | word = word.upper() 125 | elif tok['word'] == tok['word'].title(): 126 | word = word.title() 127 | new_tok = dict(tok) 128 | new_tok['word'] = new_tok['lemma'] = new_tok['originalText'] = word 129 | new_tok['altered'] = True 130 | # NOTE: obviously this is approximate 131 | if self.alteration_strategy.endswith('separate'): 132 | new_tokens = begin + [new_tok] + end 133 | new_q = rejoin(new_tokens) 134 | tag = '%s-%d-%s' % (rule_name, i, word) 135 | new_const_parse = ConstituencyParse.replace_words( 136 | const_parse, [tok['word'] for tok in new_tokens]) 137 | new_qs.append((new_q, new_tokens, new_const_parse, tag)) 138 | break 139 | elif self.alteration_strategy in ('high-conf', 'all'): 140 | toks_all.append(new_tok) 141 | found = True 142 | break 143 | if self.alteration_strategy in ('high-conf', 'all') and found: 144 | break 145 | if self.alteration_strategy in ('high-conf', 'all') and not found: 146 | toks_all.append(tok) 147 | if self.alteration_strategy in ('high-conf', 'all'): 148 | new_q = rejoin(toks_all) 149 | new_const_parse = ConstituencyParse.replace_words( 150 | const_parse, [tok['word'] for tok in toks_all]) 151 | if new_q != question: 152 | new_qs.append((rejoin(toks_all), toks_all, new_const_parse, self.alteration_strategy)) 153 | return new_qs 154 | 155 | def generate_from_seed(self, seed: Tuple): 156 | r"""Edit a SQuAD example using rules. """ 157 | qas, paragraph = seed 158 | question = qas['question'].strip() 159 | if not self.quiet: 160 | print(f"Question: {question}") 161 | if self.use_answer_placeholder: 162 | answer = 'ANSWER' 163 | determiner = '' 164 | else: 165 | p_parse = self._annotate(paragraph, 'tokenize,ssplit,pos,ner,entitymentions') 166 | ind, a_toks = get_tokens_for_answers(qas['answers'], p_parse) 167 | determiner = get_determiner_for_answers(qas['answers']) 168 | answer_obj = qas['answers'][ind] 169 | for _, func in ANSWER_RULES: 170 | answer = func(answer_obj, a_toks, question, determiner=determiner) 171 | if answer: 172 | break 173 | else: 174 | raise ValueError('Missing answer') 175 | q_parse = self._annotate(question, 'tokenize,ssplit,pos,parse,ner') 176 | q_parse = q_parse['sentences'][0] 177 | q_tokens = q_parse['tokens'] 178 | q_const_parse = read_const_parse(q_parse['parse']) 179 | if self.alteration_strategy: 180 | # Easiest to alter the question before converting 181 | q_list = self._alter_question(question, q_tokens, q_const_parse) 182 | else: 183 | q_list = [(question, q_tokens, q_const_parse, 'unaltered')] 184 | for q_str, q_tokens, q_const_parse, tag in q_list: 185 | for rule in CONVERSION_RULES: 186 | sent = rule.convert(q_str, answer, q_tokens, q_const_parse) 187 | if sent: 188 | if not self.quiet: 189 | print(f" Sent ({tag}): {sent}'") 190 | cur_qa = { 191 | 'question': qas['question'], 192 | 'id': '%s-%s' % (qas['id'], tag), 193 | 'answers': qas['answers'] 194 | } 195 | if self.prepend: 196 | cur_text = '%s %s' % (sent, paragraph) 197 | new_answers = [] 198 | for ans in qas['answers']: 199 | new_answers.append({ 200 | 'text': ans['text'], 201 | 'answer_start': ans['answer_start'] + len(sent) + 1 202 | }) 203 | cur_qa['answers'] = new_answers 204 | else: 205 | cur_text = '%s %s' % (paragraph, sent) 206 | out_example = {'title': title, 207 | 'seed_context': paragraph, 208 | 'seed_qas': qas, 209 | 'context': cur_text, 210 | 'qas': [cur_qa]} 211 | yield out_example 212 | 213 | # from adversarialnlp.common.file_utils import FIXTURES_ROOT 214 | # generator = AddSentGenerator() 215 | # test_instances = squad_reader(FIXTURES_ROOT / 'squad.json') 216 | # batches = list(generator(test_instances, num_epochs=1)) 217 | # assert len(batches) != 0 218 | -------------------------------------------------------------------------------- /adversarialnlp/generators/addsent/corenlp.py: -------------------------------------------------------------------------------- 1 | # Python wrapper for Stanford CoreNLP 2 | # Copyright (c) 2017 Lynten Guo, 2018 Thomas Wolf 3 | # Extracted and adapted from https://github.com/Lynten/stanford-corenlp 4 | 5 | from __future__ import print_function 6 | 7 | import glob 8 | import json 9 | import logging 10 | import os 11 | import re 12 | import socket 13 | import subprocess 14 | import sys 15 | import time 16 | 17 | import psutil 18 | 19 | try: 20 | from urlparse import urlparse 21 | except ImportError: 22 | from urllib.parse import urlparse 23 | 24 | import requests 25 | 26 | 27 | class StanfordCoreNLP: 28 | def __init__(self, path_or_host, port=None, memory='4g', lang='en', timeout=1500, quiet=True, 29 | logging_level=logging.WARNING, max_retries=5): 30 | self.path_or_host = path_or_host 31 | self.port = port 32 | self.memory = memory 33 | self.lang = lang 34 | self.timeout = timeout 35 | self.quiet = quiet 36 | self.logging_level = logging_level 37 | 38 | logging.basicConfig(level=self.logging_level) 39 | 40 | # Check args 41 | self._check_args() 42 | 43 | if path_or_host.startswith('http'): 44 | self.url = path_or_host + ':' + str(port) 45 | logging.info('Using an existing server {}'.format(self.url)) 46 | else: 47 | 48 | # Check Java 49 | if not subprocess.call(['java', '-version'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) == 0: 50 | raise RuntimeError('Java not found.') 51 | 52 | # Check if the dir exists 53 | if not os.path.isdir(self.path_or_host): 54 | raise IOError(str(self.path_or_host) + ' is not a directory.') 55 | directory = os.path.normpath(self.path_or_host) + os.sep 56 | self.class_path_dir = directory 57 | 58 | # Check if the language specific model file exists 59 | switcher = { 60 | 'en': 'stanford-corenlp-[0-9].[0-9].[0-9]-models.jar', 61 | 'zh': 'stanford-chinese-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar', 62 | 'ar': 'stanford-arabic-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar', 63 | 'fr': 'stanford-french-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar', 64 | 'de': 'stanford-german-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar', 65 | 'es': 'stanford-spanish-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar' 66 | } 67 | jars = { 68 | 'en': 'stanford-corenlp-x.x.x-models.jar', 69 | 'zh': 'stanford-chinese-corenlp-yyyy-MM-dd-models.jar', 70 | 'ar': 'stanford-arabic-corenlp-yyyy-MM-dd-models.jar', 71 | 'fr': 'stanford-french-corenlp-yyyy-MM-dd-models.jar', 72 | 'de': 'stanford-german-corenlp-yyyy-MM-dd-models.jar', 73 | 'es': 'stanford-spanish-corenlp-yyyy-MM-dd-models.jar' 74 | } 75 | if len(glob.glob(directory + switcher.get(self.lang))) <= 0: 76 | raise IOError(jars.get( 77 | self.lang) + ' not exists. You should download and place it in the ' + directory + ' first.') 78 | 79 | # If port not set, auto select 80 | # Commenting: see https://github.com/Lynten/stanford-corenlp/issues/26 81 | # if self.port is None: 82 | # for port_candidate in range(9000, 65535): 83 | # if port_candidate not in [conn.laddr[1] for conn in psutil.net_connections()]: 84 | # self.port = port_candidate 85 | # break 86 | self.port = 9999 87 | 88 | # Check if the port is in use 89 | # Also commenting: see https://github.com/Lynten/stanford-corenlp/issues/26 90 | # if self.port in [conn.laddr[1] for conn in psutil.net_connections()]: 91 | # raise IOError('Port ' + str(self.port) + ' is already in use.') 92 | 93 | # Start native server 94 | logging.info('Initializing native server...') 95 | cmd = "java" 96 | java_args = "-Xmx{}".format(self.memory) 97 | java_class = "edu.stanford.nlp.pipeline.StanfordCoreNLPServer" 98 | class_path = '"{}*"'.format(directory) 99 | 100 | args = [cmd, java_args, '-cp', class_path, java_class, '-port', str(self.port)] 101 | 102 | args = ' '.join(args) 103 | 104 | logging.info(args) 105 | 106 | # Silence 107 | with open(os.devnull, 'w') as null_file: 108 | out_file = None 109 | if self.quiet: 110 | out_file = null_file 111 | 112 | self.p = subprocess.Popen(args, shell=True, stdout=out_file, stderr=subprocess.STDOUT) 113 | logging.info('Server shell PID: {}'.format(self.p.pid)) 114 | 115 | self.url = 'http://localhost:' + str(self.port) 116 | 117 | # Wait until server starts 118 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 119 | host_name = urlparse(self.url).hostname 120 | time.sleep(1) # OSX, not tested 121 | trial = 1 122 | while sock.connect_ex((host_name, self.port)): 123 | if trial > max_retries: 124 | raise ValueError('Corenlp server is not available') 125 | logging.info('Waiting until the server is available.') 126 | trial += 1 127 | time.sleep(1) 128 | logging.info('The server is available.') 129 | 130 | def __enter__(self): 131 | return self 132 | 133 | def __exit__(self, exc_type, exc_val, exc_tb): 134 | self.close() 135 | 136 | def close(self): 137 | logging.info('Cleanup...') 138 | if hasattr(self, 'p'): 139 | try: 140 | parent = psutil.Process(self.p.pid) 141 | except psutil.NoSuchProcess: 142 | logging.info('No process: {}'.format(self.p.pid)) 143 | return 144 | 145 | if self.class_path_dir not in ' '.join(parent.cmdline()): 146 | logging.info('Process not in: {}'.format(parent.cmdline())) 147 | return 148 | 149 | children = parent.children(recursive=True) 150 | for process in children: 151 | logging.info('Killing pid: {}, cmdline: {}'.format(process.pid, process.cmdline())) 152 | # process.send_signal(signal.SIGTERM) 153 | process.kill() 154 | 155 | logging.info('Killing shell pid: {}, cmdline: {}'.format(parent.pid, parent.cmdline())) 156 | # parent.send_signal(signal.SIGTERM) 157 | parent.kill() 158 | 159 | def annotate(self, text, properties=None): 160 | if sys.version_info.major >= 3: 161 | text = text.encode('utf-8') 162 | 163 | r = requests.post(self.url, params={'properties': str(properties)}, data=text, 164 | headers={'Connection': 'close'}) 165 | return r.text 166 | 167 | def tregex(self, sentence, pattern): 168 | tregex_url = self.url + '/tregex' 169 | r_dict = self._request(tregex_url, "tokenize,ssplit,depparse,parse", sentence, pattern=pattern) 170 | return r_dict 171 | 172 | def tokensregex(self, sentence, pattern): 173 | tokensregex_url = self.url + '/tokensregex' 174 | r_dict = self._request(tokensregex_url, "tokenize,ssplit,depparse", sentence, pattern=pattern) 175 | return r_dict 176 | 177 | def semgrex(self, sentence, pattern): 178 | semgrex_url = self.url + '/semgrex' 179 | r_dict = self._request(semgrex_url, "tokenize,ssplit,depparse", sentence, pattern=pattern) 180 | return r_dict 181 | 182 | def word_tokenize(self, sentence, span=False): 183 | r_dict = self._request('ssplit,tokenize', sentence) 184 | tokens = [token['originalText'] for s in r_dict['sentences'] for token in s['tokens']] 185 | 186 | # Whether return token span 187 | if span: 188 | spans = [(token['characterOffsetBegin'], token['characterOffsetEnd']) for s in r_dict['sentences'] for token 189 | in s['tokens']] 190 | return tokens, spans 191 | else: 192 | return tokens 193 | 194 | def pos_tag(self, sentence): 195 | r_dict = self._request(self.url, 'pos', sentence) 196 | words = [] 197 | tags = [] 198 | for s in r_dict['sentences']: 199 | for token in s['tokens']: 200 | words.append(token['originalText']) 201 | tags.append(token['pos']) 202 | return list(zip(words, tags)) 203 | 204 | def ner(self, sentence): 205 | r_dict = self._request(self.url, 'ner', sentence) 206 | words = [] 207 | ner_tags = [] 208 | for s in r_dict['sentences']: 209 | for token in s['tokens']: 210 | words.append(token['originalText']) 211 | ner_tags.append(token['ner']) 212 | return list(zip(words, ner_tags)) 213 | 214 | def parse(self, sentence): 215 | r_dict = self._request(self.url, 'pos,parse', sentence) 216 | return [s['parse'] for s in r_dict['sentences']][0] 217 | 218 | def dependency_parse(self, sentence): 219 | r_dict = self._request(self.url, 'depparse', sentence) 220 | return [(dep['dep'], dep['governor'], dep['dependent']) for s in r_dict['sentences'] for dep in 221 | s['basicDependencies']] 222 | 223 | def coref(self, text): 224 | r_dict = self._request('coref', text) 225 | 226 | corefs = [] 227 | for k, mentions in r_dict['corefs'].items(): 228 | simplified_mentions = [] 229 | for m in mentions: 230 | simplified_mentions.append((m['sentNum'], m['startIndex'], m['endIndex'], m['text'])) 231 | corefs.append(simplified_mentions) 232 | return corefs 233 | 234 | def switch_language(self, language="en"): 235 | self._check_language(language) 236 | self.lang = language 237 | 238 | def _request(self, url, annotators=None, data=None, *args, **kwargs): 239 | if sys.version_info.major >= 3: 240 | data = data.encode('utf-8') 241 | 242 | properties = {'annotators': annotators, 'outputFormat': 'json'} 243 | params = {'properties': str(properties), 'pipelineLanguage': self.lang} 244 | if 'pattern' in kwargs: 245 | params = {"pattern": kwargs['pattern'], 'properties': str(properties), 'pipelineLanguage': self.lang} 246 | 247 | logging.info(params) 248 | r = requests.post(url, params=params, data=data, headers={'Connection': 'close'}) 249 | r_dict = json.loads(r.text) 250 | 251 | return r_dict 252 | 253 | def _check_args(self): 254 | self._check_language(self.lang) 255 | if not re.match('\dg', self.memory): 256 | raise ValueError('memory=' + self.memory + ' not supported. Use 4g, 6g, 8g and etc. ') 257 | 258 | def _check_language(self, lang): 259 | if lang not in ['en', 'zh', 'ar', 'fr', 'de', 'es']: 260 | raise ValueError('lang=' + self.lang + ' not supported. Use English(en), Chinese(zh), Arabic(ar), ' 261 | 'French(fr), German(de), Spanish(es).') 262 | -------------------------------------------------------------------------------- /adversarialnlp/generators/addsent/rules/__init__.py: -------------------------------------------------------------------------------- 1 | from .answer_rules import ANSWER_RULES 2 | from .alteration_rules import (HIGH_CONF_ALTER_RULES, ALL_ALTER_RULES, 3 | DO_NOT_ALTER, BAD_ALTERATIONS) 4 | from .conversion_rules import CONVERSION_RULES 5 | -------------------------------------------------------------------------------- /adversarialnlp/generators/addsent/rules/alteration_rules.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import nltk 4 | nltk.download('wordnet') 5 | 6 | from nltk.corpus import wordnet as wn 7 | from nltk.stem.lancaster import LancasterStemmer 8 | 9 | STEMMER = LancasterStemmer() 10 | 11 | POS_TO_WORDNET = { 12 | 'NN': wn.NOUN, 13 | 'JJ': wn.ADJ, 14 | 'JJR': wn.ADJ, 15 | 'JJS': wn.ADJ, 16 | } 17 | 18 | def alter_special(token, **kwargs): 19 | w = token['originalText'] 20 | if w in SPECIAL_ALTERATIONS: 21 | return [SPECIAL_ALTERATIONS[w]] 22 | return None 23 | 24 | def alter_nearby(pos_list, ignore_pos=False, is_ner=False): 25 | def func(token, nearby_word_dict=None, postag_dict=None, **kwargs): 26 | if token['pos'] not in pos_list: return None 27 | if is_ner and token['ner'] not in ('PERSON', 'LOCATION', 'ORGANIZATION', 'MISC'): 28 | return None 29 | w = token['word'].lower() 30 | if w in ('war'): return None 31 | if w not in nearby_word_dict: return None 32 | new_words = [] 33 | w_stem = STEMMER.stem(w.replace('.', '')) 34 | for x in nearby_word_dict[w][1:]: 35 | new_word = x['word'] 36 | # Make sure words aren't too similar (e.g. same stem) 37 | new_stem = STEMMER.stem(new_word.replace('.', '')) 38 | if w_stem.startswith(new_stem) or new_stem.startswith(w_stem): continue 39 | if not ignore_pos: 40 | # Check for POS tag match 41 | if new_word not in postag_dict: continue 42 | new_postag = postag_dict[new_word] 43 | if new_postag != token['pos']: continue 44 | new_words.append(new_word) 45 | return new_words 46 | return func 47 | 48 | def alter_entity_glove(token, nearby_word_dict=None, **kwargs): 49 | # NOTE: Deprecated 50 | if token['ner'] not in ('PERSON', 'LOCATION', 'ORGANIZATION', 'MISC'): return None 51 | w = token['word'].lower() 52 | if w == token['word']: return None # Only do capitalized words 53 | if w not in nearby_word_dict: return None 54 | new_words = [] 55 | for x in nearby_word_dict[w][1:3]: 56 | if token['word'] == w.upper(): 57 | new_words.append(x['word'].upper()) 58 | else: 59 | new_words.append(x['word'].title()) 60 | return new_words 61 | 62 | def alter_entity_type(token, **kwargs): 63 | pos = token['pos'] 64 | ner = token['ner'] 65 | word = token['word'] 66 | is_abbrev = word == word.upper() and not word == word.lower() 67 | if token['pos'] not in ( 68 | 'JJ', 'JJR', 'JJS', 'NN', 'NNS', 'NNP', 'NNPS', 'RB', 'RBR', 'RBS', 69 | 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ'): 70 | # Don't alter non-content words 71 | return None 72 | if ner == 'PERSON': 73 | return ['Jackson'] 74 | elif ner == 'LOCATION': 75 | return ['Berlin'] 76 | elif ner == 'ORGANIZATION': 77 | if is_abbrev: return ['UNICEF'] 78 | return ['Acme'] 79 | elif ner == 'MISC': 80 | return ['Neptune'] 81 | elif ner == 'NNP': 82 | if is_abbrev: return ['XKCD'] 83 | return ['Dalek'] 84 | elif pos == 'NNPS': 85 | return ['Daleks'] 86 | return None 87 | 88 | def alter_wordnet_antonyms(token, **kwargs): 89 | if token['pos'] not in POS_TO_WORDNET: return None 90 | w = token['word'].lower() 91 | wn_pos = POS_TO_WORDNET[token['pos']] 92 | synsets = wn.synsets(w, wn_pos) 93 | if not synsets: return None 94 | synset = synsets[0] 95 | antonyms = [] 96 | for lem in synset.lemmas(): 97 | if lem.antonyms(): 98 | for a in lem.antonyms(): 99 | new_word = a.name() 100 | if '_' in a.name(): continue 101 | antonyms.append(new_word) 102 | return antonyms 103 | 104 | SPECIAL_ALTERATIONS = { 105 | 'States': 'Kingdom', 106 | 'US': 'UK', 107 | 'U.S': 'U.K.', 108 | 'U.S.': 'U.K.', 109 | 'UK': 'US', 110 | 'U.K.': 'U.S.', 111 | 'U.K': 'U.S.', 112 | 'largest': 'smallest', 113 | 'smallest': 'largest', 114 | 'highest': 'lowest', 115 | 'lowest': 'highest', 116 | 'May': 'April', 117 | 'Peyton': 'Trevor', 118 | } 119 | 120 | DO_NOT_ALTER = ['many', 'such', 'few', 'much', 'other', 'same', 'general', 121 | 'type', 'record', 'kind', 'sort', 'part', 'form', 'terms', 'use', 122 | 'place', 'way', 'old', 'young', 'bowl', 'united', 'one', 123 | 'likely', 'different', 'square', 'war', 'republic', 'doctor', 'color'] 124 | 125 | BAD_ALTERATIONS = ['mx2004', 'planet', 'u.s.', 'Http://Www.Co.Mo.Md.Us'] 126 | 127 | HIGH_CONF_ALTER_RULES = collections.OrderedDict([ 128 | ('special', alter_special), 129 | ('wn_antonyms', alter_wordnet_antonyms), 130 | ('nearbyNum', alter_nearby(['CD'], ignore_pos=True)), 131 | ('nearbyProperNoun', alter_nearby(['NNP', 'NNPS'])), 132 | ('nearbyProperNoun', alter_nearby(['NNP', 'NNPS'], ignore_pos=True)), 133 | ('nearbyEntityNouns', alter_nearby(['NN', 'NNS'], is_ner=True)), 134 | ('nearbyEntityJJ', alter_nearby(['JJ', 'JJR', 'JJS'], is_ner=True)), 135 | ('entityType', alter_entity_type), 136 | #('entity_glove', alter_entity_glove), 137 | ]) 138 | ALL_ALTER_RULES = collections.OrderedDict(list(HIGH_CONF_ALTER_RULES.items()) + [ 139 | ('nearbyAdj', alter_nearby(['JJ', 'JJR', 'JJS'])), 140 | ('nearbyNoun', alter_nearby(['NN', 'NNS'])), 141 | #('nearbyNoun', alter_nearby(['NN', 'NNS'], ignore_pos=True)), 142 | ]) 143 | -------------------------------------------------------------------------------- /adversarialnlp/generators/addsent/rules/answer_rules.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from adversarialnlp.generators.addsent.utils import rejoin 4 | 5 | MONTHS = ['january', 'february', 'march', 'april', 'may', 'june', 'july', 6 | 'august', 'september', 'october', 'november', 'december'] 7 | 8 | def ans_number(a, tokens, q, **kwargs): 9 | out_toks = [] 10 | seen_num = False 11 | for t in tokens: 12 | ner = t['ner'] 13 | pos = t['pos'] 14 | w = t['word'] 15 | out_tok = {'before': t['before']} 16 | 17 | # Split on dashes 18 | leftover = '' 19 | dash_toks = w.split('-') 20 | if len(dash_toks) > 1: 21 | w = dash_toks[0] 22 | leftover = '-'.join(dash_toks[1:]) 23 | 24 | # Try to get a number out 25 | value = None 26 | if w != '%': 27 | # Percent sign should just pass through 28 | try: 29 | value = float(w.replace(',', '')) 30 | except: 31 | try: 32 | norm_ner = t['normalizedNER'] 33 | if norm_ner[0] in ('%', '>', '<'): 34 | norm_ner = norm_ner[1:] 35 | value = float(norm_ner) 36 | except: 37 | pass 38 | if not value and ( 39 | ner == 'NUMBER' or 40 | (ner == 'PERCENT' and pos == 'CD')): 41 | # Force this to be a number anyways 42 | value = 10 43 | if value: 44 | if math.isinf(value) or math.isnan(value): value = 9001 45 | seen_num = True 46 | if w in ('thousand', 'million', 'billion', 'trillion'): 47 | if w == 'thousand': 48 | new_val = 'million' 49 | else: 50 | new_val = 'thousand' 51 | else: 52 | if value < 2500 and value > 1000: 53 | new_val = str(value - 75) 54 | else: 55 | # Change leading digit 56 | if value == int(value): 57 | val_chars = list('%d' % value) 58 | else: 59 | val_chars = list('%g' % value) 60 | c = val_chars[0] 61 | for i in range(len(val_chars)): 62 | c = val_chars[i] 63 | if c >= '0' and c <= '9': 64 | val_chars[i] = str(max((int(c) + 5) % 10, 1)) 65 | break 66 | new_val = ''.join(val_chars) 67 | if leftover: 68 | new_val = '%s-%s' % (new_val, leftover) 69 | out_tok['originalText'] = new_val 70 | else: 71 | out_tok['originalText'] = t['originalText'] 72 | out_toks.append(out_tok) 73 | if seen_num: 74 | return rejoin(out_toks).strip() 75 | else: 76 | return None 77 | 78 | def ans_date(a, tokens, q, **kwargs): 79 | out_toks = [] 80 | if not all(t['ner'] == 'DATE' for t in tokens): 81 | return None 82 | for t in tokens: 83 | if t['pos'] == 'CD' or t['word'].isdigit(): 84 | try: 85 | value = int(t['word']) 86 | except: 87 | value = 10 # fallback 88 | if value > 50: new_val = str(value - 25) # Year 89 | else: # Day of month 90 | if value > 15: new_val = str(value - 11) 91 | else: new_val = str(value + 11) 92 | else: 93 | if t['word'].lower() in MONTHS: 94 | m_ind = MONTHS.index(t['word'].lower()) 95 | new_val = MONTHS[(m_ind + 6) % 12].title() 96 | else: 97 | # Give up 98 | new_val = t['originalText'] 99 | out_toks.append({'before': t['before'], 'originalText': new_val}) 100 | new_ans = rejoin(out_toks).strip() 101 | if new_ans == a['text']: return None 102 | return new_ans 103 | 104 | def ans_entity_full(ner_tag, new_ans): 105 | """Returns a function that yields new_ans iff every token has |ner_tag|.""" 106 | def func(a, tokens, q, **kwargs): 107 | for t in tokens: 108 | if t['ner'] != ner_tag: return None 109 | return new_ans 110 | return func 111 | 112 | def ans_abbrev(new_ans): 113 | def func(a, tokens, q, **kwargs): 114 | s = a['text'] 115 | if s == s.upper() and s != s.lower(): 116 | return new_ans 117 | return None 118 | return func 119 | 120 | def ans_match_wh(wh_word, new_ans): 121 | """Returns a function that yields new_ans if the question starts with |wh_word|.""" 122 | def func(a, tokens, q, **kwargs): 123 | if q.lower().startswith(wh_word + ' '): 124 | return new_ans 125 | return None 126 | return func 127 | 128 | def ans_pos(pos, new_ans, end=False, add_dt=False): 129 | """Returns a function that yields new_ans if the first/last token has |pos|.""" 130 | def func(a, tokens, q, determiner, **kwargs): 131 | if end: 132 | t = tokens[-1] 133 | else: 134 | t = tokens[0] 135 | if t['pos'] != pos: return None 136 | if add_dt and determiner: 137 | return '%s %s' % (determiner, new_ans) 138 | return new_ans 139 | return func 140 | 141 | 142 | def ans_catch_all(new_ans): 143 | def func(a, tokens, q, **kwargs): 144 | return new_ans 145 | return func 146 | 147 | ANSWER_RULES = [ 148 | ('date', ans_date), 149 | ('number', ans_number), 150 | ('ner_person', ans_entity_full('PERSON', 'Jeff Dean')), 151 | ('ner_location', ans_entity_full('LOCATION', 'Chicago')), 152 | ('ner_organization', ans_entity_full('ORGANIZATION', 'Stark Industries')), 153 | ('ner_misc', ans_entity_full('MISC', 'Jupiter')), 154 | ('abbrev', ans_abbrev('LSTM')), 155 | ('wh_who', ans_match_wh('who', 'Jeff Dean')), 156 | ('wh_when', ans_match_wh('when', '1956')), 157 | ('wh_where', ans_match_wh('where', 'Chicago')), 158 | ('wh_where', ans_match_wh('how many', '42')), 159 | # Starts with verb 160 | ('pos_begin_vb', ans_pos('VB', 'learn')), 161 | ('pos_end_vbd', ans_pos('VBD', 'learned')), 162 | ('pos_end_vbg', ans_pos('VBG', 'learning')), 163 | ('pos_end_vbp', ans_pos('VBP', 'learns')), 164 | ('pos_end_vbz', ans_pos('VBZ', 'learns')), 165 | # Ends with some POS tag 166 | ('pos_end_nn', ans_pos('NN', 'hamster', end=True, add_dt=True)), 167 | ('pos_end_nnp', ans_pos('NNP', 'Central Park', end=True, add_dt=True)), 168 | ('pos_end_nns', ans_pos('NNS', 'hamsters', end=True, add_dt=True)), 169 | ('pos_end_nnps', ans_pos('NNPS', 'Kew Gardens', end=True, add_dt=True)), 170 | ('pos_end_jj', ans_pos('JJ', 'deep', end=True)), 171 | ('pos_end_jjr', ans_pos('JJR', 'deeper', end=True)), 172 | ('pos_end_jjs', ans_pos('JJS', 'deepest', end=True)), 173 | ('pos_end_rb', ans_pos('RB', 'silently', end=True)), 174 | ('pos_end_vbg', ans_pos('VBG', 'learning', end=True)), 175 | ('catch_all', ans_catch_all('aliens')), 176 | ] 177 | -------------------------------------------------------------------------------- /adversarialnlp/generators/addsent/rules/conversion_rules.py: -------------------------------------------------------------------------------- 1 | from pattern import en as patten 2 | 3 | CONST_PARSE_MACROS = { 4 | '$Noun': '$NP/$NN/$NNS/$NNP/$NNPS', 5 | '$Verb': '$VB/$VBD/$VBP/$VBZ', 6 | '$Part': '$VBN/$VG', 7 | '$Be': 'is/are/was/were', 8 | '$Do': "do/did/does/don't/didn't/doesn't", 9 | '$WHP': '$WHADJP/$WHADVP/$WHNP/$WHPP', 10 | } 11 | 12 | # Map to pattern.en aliases 13 | # http://www.clips.ua.ac.be/pages/pattern-en#conjugation 14 | POS_TO_PATTERN = { 15 | 'vb': 'inf', # Infinitive 16 | 'vbp': '1sg', # non-3rd-person singular present 17 | 'vbz': '3sg', # 3rd-person singular present 18 | 'vbg': 'part', # gerund or present participle 19 | 'vbd': 'p', # past 20 | 'vbn': 'ppart', # past participle 21 | } 22 | # Tenses prioritized by likelihood of arising 23 | PATTERN_TENSES = ['inf', '3sg', 'p', 'part', 'ppart', '1sg'] 24 | 25 | def _check_match(node, pattern_tok): 26 | if pattern_tok in CONST_PARSE_MACROS: 27 | pattern_tok = CONST_PARSE_MACROS[pattern_tok] 28 | if ':' in pattern_tok: 29 | # ':' means you match the LHS category and start with something on the right 30 | lhs, rhs = pattern_tok.split(':') 31 | match_lhs = _check_match(node, lhs) 32 | if not match_lhs: return False 33 | phrase = node.get_phrase().lower() 34 | retval = any(phrase.startswith(w) for w in rhs.split('/')) 35 | return retval 36 | elif '/' in pattern_tok: 37 | return any(_check_match(node, t) for t in pattern_tok.split('/')) 38 | return ((pattern_tok.startswith('$') and pattern_tok[1:] == node.tag) or 39 | (node.word and pattern_tok.lower() == node.word.lower())) 40 | 41 | def _recursive_match_pattern(pattern_toks, stack, matches): 42 | """Recursively try to match a pattern, greedily.""" 43 | if len(matches) == len(pattern_toks): 44 | # We matched everything in the pattern; also need stack to be empty 45 | return len(stack) == 0 46 | if len(stack) == 0: return False 47 | cur_tok = pattern_toks[len(matches)] 48 | node = stack.pop() 49 | # See if we match the current token at this level 50 | is_match = _check_match(node, cur_tok) 51 | if is_match: 52 | cur_num_matches = len(matches) 53 | matches.append(node) 54 | new_stack = list(stack) 55 | success = _recursive_match_pattern(pattern_toks, new_stack, matches) 56 | if success: return True 57 | # Backtrack 58 | while len(matches) > cur_num_matches: 59 | matches.pop() 60 | # Recurse to children 61 | if not node.children: return False # No children to recurse on, we failed 62 | stack.extend(node.children[::-1]) # Leftmost children should be popped first 63 | return _recursive_match_pattern(pattern_toks, stack, matches) 64 | 65 | def match_pattern(pattern, const_parse): 66 | pattern_toks = pattern.split(' ') 67 | whole_phrase = const_parse.get_phrase() 68 | if whole_phrase.endswith('?') or whole_phrase.endswith('.'): 69 | # Match trailing punctuation as needed 70 | pattern_toks.append(whole_phrase[-1]) 71 | matches = [] 72 | success = _recursive_match_pattern(pattern_toks, [const_parse], matches) 73 | if success: 74 | return matches 75 | else: 76 | return None 77 | 78 | def run_postprocessing(s, rules, all_args): 79 | rule_list = rules.split(',') 80 | for rule in rule_list: 81 | if rule == 'lower': 82 | s = s.lower() 83 | elif rule.startswith('tense-'): 84 | ind = int(rule[6:]) 85 | orig_vb = all_args[ind] 86 | tenses = patten.tenses(orig_vb) 87 | for tense in PATTERN_TENSES: # Prioritize by PATTERN_TENSES 88 | if tense in tenses: 89 | break 90 | else: # Default to first tense 91 | tense = PATTERN_TENSES[0] 92 | s = patten.conjugate(s, tense) 93 | elif rule in POS_TO_PATTERN: 94 | s = patten.conjugate(s, POS_TO_PATTERN[rule]) 95 | return s 96 | 97 | def convert_whp(node, q, a, tokens, quiet=False): 98 | if node.tag in ('WHNP', 'WHADJP', 'WHADVP', 'WHPP'): 99 | # Apply WHP rules 100 | cur_phrase = node.get_phrase() 101 | cur_tokens = tokens[node.get_start_index():node.get_end_index()] 102 | for r in WHP_RULES: 103 | phrase = r.convert(cur_phrase, a, cur_tokens, node, run_fix_style=False) 104 | if phrase: 105 | if not quiet: 106 | print(f" WHP Rule '{r.name}': {phrase}") 107 | return phrase 108 | return None 109 | 110 | ### Rules for converting questions into declarative sentences 111 | def fix_style(s): 112 | """Minor, general style fixes for questions.""" 113 | s = s.replace('?', '') # Delete question marks anywhere in sentence. 114 | s = s.strip(' .') 115 | if s[0] == s[0].lower(): 116 | s = s[0].upper() + s[1:] 117 | return s + '.' 118 | 119 | class ConversionRule(object): 120 | def convert(self, q, a, tokens, const_parse, run_fix_style=True): 121 | raise NotImplementedError 122 | 123 | class ConstituencyRule(ConversionRule): 124 | """A rule for converting question to sentence based on constituency parse.""" 125 | def __init__(self, in_pattern, out_pattern, postproc=None): 126 | self.in_pattern = in_pattern # e.g. "where did $NP $VP" 127 | self.out_pattern = out_pattern #unicode(out_pattern) 128 | # e.g. "{1} did {2} at {0}." Answer is always 0 129 | self.name = in_pattern 130 | if postproc: 131 | self.postproc = postproc 132 | else: 133 | self.postproc = {} 134 | 135 | def convert(self, q, a, tokens, const_parse, run_fix_style=True) -> str: 136 | pattern_toks = self.in_pattern.split(' ') # Don't care about trailing punctuation 137 | match = match_pattern(self.in_pattern, const_parse) 138 | appended_clause = False 139 | if not match: 140 | # Try adding a PP at the beginning 141 | appended_clause = True 142 | new_pattern = '$PP , ' + self.in_pattern 143 | pattern_toks = new_pattern.split(' ') 144 | match = match_pattern(new_pattern, const_parse) 145 | if not match: 146 | # Try adding an SBAR at the beginning 147 | new_pattern = '$SBAR , ' + self.in_pattern 148 | pattern_toks = new_pattern.split(' ') 149 | match = match_pattern(new_pattern, const_parse) 150 | if not match: return None 151 | appended_clause_match = None 152 | fmt_args = [a] 153 | for t, m in zip(pattern_toks, match): 154 | if t.startswith('$') or '/' in t: 155 | # First check if it's a WHP 156 | phrase = convert_whp(m, q, a, tokens) 157 | if not phrase: 158 | phrase = m.get_phrase() 159 | fmt_args.append(phrase) 160 | if appended_clause: 161 | appended_clause_match = fmt_args[1] 162 | fmt_args = [a] + fmt_args[2:] 163 | for i in range(len(fmt_args)): 164 | if i in self.postproc: 165 | # Run postprocessing filters 166 | fmt_args[i] = run_postprocessing(fmt_args[i], self.postproc[i], fmt_args) 167 | output = self.gen_output(fmt_args) 168 | if appended_clause: 169 | output = appended_clause_match + ', ' + output 170 | if run_fix_style: 171 | output = fix_style(output) 172 | return output 173 | 174 | 175 | def gen_output(self, fmt_args): 176 | """By default, use self.out_pattern. Can be overridden.""" 177 | return self.out_pattern.format(*fmt_args) 178 | 179 | class ReplaceRule(ConversionRule): 180 | """A simple rule that replaces some tokens with the answer.""" 181 | def __init__(self, target, replacement='{}', start=False): 182 | self.target = target 183 | self.replacement = replacement #unicode(replacement) 184 | self.name = 'replace(%s)' % target 185 | self.start = start 186 | 187 | def convert(self, q, a, tokens, const_parse, run_fix_style=True): 188 | t_toks = self.target.split(' ') 189 | q_toks = q.rstrip('?.').split(' ') 190 | replacement_text = self.replacement.format(a) 191 | for i in range(len(q_toks)): 192 | if self.start and i != 0: continue 193 | if ' '.join(q_toks[i:i + len(t_toks)]).rstrip(',').lower() == self.target: 194 | begin = q_toks[:i] 195 | end = q_toks[i + len(t_toks):] 196 | output = ' '.join(begin + [replacement_text] + end) 197 | if run_fix_style: 198 | output = fix_style(output) 199 | return output 200 | return None 201 | 202 | class FindWHPRule(ConversionRule): 203 | """A rule that looks for $WHP's from right to left and does replacements.""" 204 | name = 'FindWHP' 205 | def _recursive_convert(self, node, q, a, tokens, found_whp): 206 | if node.word: 207 | return node.word, found_whp 208 | if not found_whp: 209 | whp_phrase = convert_whp(node, q, a, tokens) 210 | if whp_phrase: 211 | return whp_phrase, True 212 | child_phrases = [] 213 | for c in node.children[::-1]: 214 | c_phrase, found_whp = self._recursive_convert(c, q, a, tokens, found_whp) 215 | child_phrases.append(c_phrase) 216 | out_toks = [] 217 | for i, p in enumerate(child_phrases[::-1]): 218 | if i == 0 or p.startswith("'"): 219 | out_toks.append(p) 220 | else: 221 | out_toks.append(' ' + p) 222 | return ''.join(out_toks), found_whp 223 | 224 | def convert(self, q, a, tokens, const_parse, run_fix_style=True): 225 | out_phrase, found_whp = self._recursive_convert(const_parse, q, a, tokens, False) 226 | if found_whp: 227 | if run_fix_style: 228 | out_phrase = fix_style(out_phrase) 229 | return out_phrase 230 | return None 231 | 232 | class AnswerRule(ConversionRule): 233 | """Just return the answer.""" 234 | name = 'AnswerRule' 235 | def convert(self, q, a, tokens, const_parse, run_fix_style=True): 236 | return a 237 | 238 | CONVERSION_RULES = [ 239 | # Special rules 240 | ConstituencyRule('$WHP:what $Be $NP called that $VP', '{2} that {3} {1} called {1}'), 241 | 242 | # What type of X 243 | #ConstituencyRule("$WHP:what/which type/sort/kind/group of $NP/$Noun $Be $NP", '{5} {4} a {1} {3}'), 244 | #ConstituencyRule("$WHP:what/which type/sort/kind/group of $NP/$Noun $Be $VP", '{1} {3} {4} {5}'), 245 | #ConstituencyRule("$WHP:what/which type/sort/kind/group of $NP $VP", '{1} {3} {4}'), 246 | 247 | # How $JJ 248 | ConstituencyRule('how $JJ $Be $NP $IN $NP', '{3} {2} {0} {1} {4} {5}'), 249 | ConstituencyRule('how $JJ $Be $NP $SBAR', '{3} {2} {0} {1} {4}'), 250 | ConstituencyRule('how $JJ $Be $NP', '{3} {2} {0} {1}'), 251 | 252 | # When/where $Verb 253 | ConstituencyRule('$WHP:when/where $Do $NP', '{3} occurred in {1}'), 254 | ConstituencyRule('$WHP:when/where $Do $NP $Verb', '{3} {4} in {1}', {4: 'tense-2'}), 255 | ConstituencyRule('$WHP:when/where $Do $NP $Verb $NP/$PP', '{3} {4} {5} in {1}', {4: 'tense-2'}), 256 | ConstituencyRule('$WHP:when/where $Do $NP $Verb $NP $PP', '{3} {4} {5} {6} in {1}', {4: 'tense-2'}), 257 | ConstituencyRule('$WHP:when/where $Be $NP', '{3} {2} in {1}'), 258 | ConstituencyRule('$WHP:when/where $Verb $NP $VP/$ADJP', '{3} {2} {4} in {1}'), 259 | 260 | # What/who/how $Do 261 | ConstituencyRule("$WHP:what/which/who $Do $NP do", '{3} {1}', {0: 'tense-2'}), 262 | ConstituencyRule("$WHP:what/which/who/how $Do $NP $Verb", '{3} {4} {1}', {4: 'tense-2'}), 263 | ConstituencyRule("$WHP:what/which/who $Do $NP $Verb $IN/$NP", '{3} {4} {5} {1}', {4: 'tense-2', 0: 'vbg'}), 264 | ConstituencyRule("$WHP:what/which/who $Do $NP $Verb $PP", '{3} {4} {1} {5}', {4: 'tense-2', 0: 'vbg'}), 265 | ConstituencyRule("$WHP:what/which/who $Do $NP $Verb $NP $VP", '{3} {4} {5} {6} {1}', {4: 'tense-2'}), 266 | ConstituencyRule("$WHP:what/which/who $Do $NP $Verb to $VB", '{3} {4} to {5} {1}', {4: 'tense-2'}), 267 | ConstituencyRule("$WHP:what/which/who $Do $NP $Verb to $VB $VP", '{3} {4} to {5} {1} {6}', {4: 'tense-2'}), 268 | ConstituencyRule("$WHP:what/which/who/how $Do $NP $Verb $NP $IN $VP", '{3} {4} {5} {6} {1} {7}', {4: 'tense-2'}), 269 | ConstituencyRule("$WHP:what/which/who/how $Do $NP $Verb $PP/$S/$VP/$SBAR/$SQ", '{3} {4} {1} {5}', {4: 'tense-2'}), 270 | ConstituencyRule("$WHP:what/which/who/how $Do $NP $Verb $PP $PP/$S/$VP/$SBAR", '{3} {4} {1} {5} {6}', {4: 'tense-2'}), 271 | 272 | # What/who/how $Be 273 | # Watch out for things that end in a preposition 274 | ConstituencyRule("$WHP:what/which/who $Be/$MD $NP of $NP $Verb/$Part $IN", '{3} of {4} {2} {5} {6} {1}'), 275 | ConstituencyRule("$WHP:what/which/who $Be/$MD $NP $NP $IN", '{3} {2} {4} {5} {1}'), 276 | ConstituencyRule("$WHP:what/which/who $Be/$MD $NP $VP/$IN", '{3} {2} {4} {1}'), 277 | ConstituencyRule("$WHP:what/which/who $Be/$MD $NP $IN $NP/$VP", '{1} {2} {3} {4} {5}'), 278 | ConstituencyRule('$WHP:what/which/who $Be/$MD $NP $Verb $PP', '{3} {2} {4} {1} {5}'), 279 | ConstituencyRule('$WHP:what/which/who $Be/$MD $NP/$VP/$PP', '{1} {2} {3}'), 280 | ConstituencyRule("$WHP:how $Be/$MD $NP $VP", '{3} {2} {4} by {1}'), 281 | 282 | # What/who $Verb 283 | ConstituencyRule("$WHP:what/which/who $VP", '{1} {2}'), 284 | 285 | # $IN what/which $NP 286 | ConstituencyRule('$IN what/which $NP $Do $NP $Verb $NP', '{5} {6} {7} {1} the {3} of {0}', 287 | {1: 'lower', 6: 'tense-4'}), 288 | ConstituencyRule('$IN what/which $NP $Be $NP $VP/$ADJP', '{5} {4} {6} {1} the {3} of {0}', 289 | {1: 'lower'}), 290 | ConstituencyRule('$IN what/which $NP $Verb $NP/$ADJP $VP', '{5} {4} {6} {1} the {3} of {0}', 291 | {1: 'lower'}), 292 | FindWHPRule(), 293 | ] 294 | 295 | # Rules for going from WHP to an answer constituent 296 | WHP_RULES = [ 297 | # WHPP rules 298 | ConstituencyRule('$IN what/which type/sort/kind/group of $NP/$Noun', '{1} {0} {4}'), 299 | ConstituencyRule('$IN what/which type/sort/kind/group of $NP/$Noun $PP', '{1} {0} {4} {5}'), 300 | ConstituencyRule('$IN what/which $NP', '{1} the {3} of {0}'), 301 | ConstituencyRule('$IN $WP/$WDT', '{1} {0}'), 302 | 303 | # what/which 304 | ConstituencyRule('what/which type/sort/kind/group of $NP/$Noun', '{0} {3}'), 305 | ConstituencyRule('what/which type/sort/kind/group of $NP/$Noun $PP', '{0} {3} {4}'), 306 | ConstituencyRule('what/which $NP', 'the {2} of {0}'), 307 | 308 | # How many 309 | ConstituencyRule('how many/much $NP', '{0} {2}'), 310 | 311 | # Replace 312 | ReplaceRule('what'), 313 | ReplaceRule('who'), 314 | ReplaceRule('how many'), 315 | ReplaceRule('how much'), 316 | ReplaceRule('which'), 317 | ReplaceRule('where'), 318 | ReplaceRule('when'), 319 | ReplaceRule('why'), 320 | ReplaceRule('how'), 321 | 322 | # Just give the answer 323 | AnswerRule(), 324 | ] 325 | -------------------------------------------------------------------------------- /adversarialnlp/generators/addsent/squad_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import Iterator, List, Tuple 4 | 5 | from adversarialnlp.common.file_utils import download_files 6 | 7 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 8 | 9 | 10 | def squad_reader(file_path: str = None) -> Iterator[List[Tuple[str, str]]]: 11 | r""" Reads a JSON-formatted SQuAD file and returns an Iterator. 12 | 13 | Args: 14 | file_path: Path to a JSON-formatted SQuAD file. 15 | If no path is provided, download and use SQuAD v1.0 training dataset. 16 | 17 | Return: 18 | list of tuple (question_answer, paragraph). 19 | """ 20 | if file_path is None: 21 | file_path = download_files(fnames=['train-v1.1.json'], 22 | paths='https://rajpurkar.github.io/SQuAD-explorer/dataset/', 23 | local_folder='squad') 24 | file_path = file_path[0] 25 | 26 | logger.info("Reading file at %s", file_path) 27 | with open(file_path) as dataset_file: 28 | dataset_json = json.load(dataset_file) 29 | dataset = dataset_json['data'] 30 | logger.info("Reading the dataset") 31 | out_data = [] 32 | for article in dataset: 33 | for paragraph_json in article['paragraphs']: 34 | paragraph = paragraph_json["context"] 35 | for question_answer in paragraph_json['qas']: 36 | question_answer["question"] = question_answer["question"].strip().replace("\n", "") 37 | out_data.append((question_answer, paragraph)) 38 | return out_data 39 | -------------------------------------------------------------------------------- /adversarialnlp/generators/addsent/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for AddSent generator.""" 2 | from typing import List, Dict, Tuple, Optional 3 | 4 | class ConstituencyParse(object): 5 | """A CoreNLP constituency parse (or a node in a parse tree). 6 | 7 | Word-level constituents have |word| and |index| set and no children. 8 | Phrase-level constituents have no |word| or |index| and have at least one child. 9 | """ 10 | def __init__(self, tag, children=None, word=None, index=None): 11 | self.tag = tag 12 | if children: 13 | self.children = children 14 | else: 15 | self.children = None 16 | self.word = word 17 | self.index = index 18 | 19 | @classmethod 20 | def _recursive_parse_corenlp(cls, tokens, i, j): 21 | orig_i = i 22 | if tokens[i] == '(': 23 | tag = tokens[i + 1] 24 | children = [] 25 | i = i + 2 26 | while True: 27 | child, i, j = cls._recursive_parse_corenlp(tokens, i, j) 28 | if isinstance(child, cls): 29 | children.append(child) 30 | if tokens[i] == ')': 31 | return cls(tag, children), i + 1, j 32 | else: 33 | if tokens[i] != ')': 34 | raise ValueError('Expected ")" following leaf') 35 | return cls(tag, word=child, index=j), i + 1, j + 1 36 | else: 37 | # Only other possibility is it's a word 38 | return tokens[i], i + 1, j 39 | 40 | @classmethod 41 | def from_corenlp(cls, s): 42 | """Parses the "parse" attribute returned by CoreNLP parse annotator.""" 43 | # "parse": "(ROOT\n (SBARQ\n (WHNP (WDT What)\n (NP (NN portion)\n (PP (IN of)\n (NP\n (NP (NNS households))\n (PP (IN in)\n (NP (NNP Jacksonville)))))))\n (SQ\n (VP (VBP have)\n (NP (RB only) (CD one) (NN person))))\n (. ? )))", 44 | s_spaced = s.replace('\n', ' ').replace('(', ' ( ').replace(')', ' ) ') 45 | tokens = [t for t in s_spaced.split(' ') if t] 46 | tree, index, num_words = cls._recursive_parse_corenlp(tokens, 0, 0) 47 | if index != len(tokens): 48 | raise ValueError('Only parsed %d of %d tokens' % (index, len(tokens))) 49 | return tree 50 | 51 | def is_singleton(self): 52 | if self.word: 53 | return True 54 | if len(self.children) > 1: 55 | return False 56 | return self.children[0].is_singleton() 57 | 58 | def print_tree(self, indent=0): 59 | spaces = ' ' * indent 60 | if self.word: 61 | print(f"{spaces}{self.tag}: {self.word} ({self.index})") 62 | else: 63 | print(f"{spaces}{self.tag}") 64 | for c in self.children: 65 | c.print_tree(indent=indent + 1) 66 | 67 | def get_phrase(self): 68 | if self.word: 69 | return self.word 70 | toks = [] 71 | for i, c in enumerate(self.children): 72 | p = c.get_phrase() 73 | if i == 0 or p.startswith("'"): 74 | toks.append(p) 75 | else: 76 | toks.append(' ' + p) 77 | return ''.join(toks) 78 | 79 | def get_start_index(self): 80 | if self.index is not None: 81 | return self.index 82 | return self.children[0].get_start_index() 83 | 84 | def get_end_index(self): 85 | if self.index is not None: 86 | return self.index + 1 87 | return self.children[-1].get_end_index() 88 | 89 | @classmethod 90 | def _recursive_replace_words(cls, tree, new_words, i): 91 | if tree.word: 92 | new_word = new_words[i] 93 | return (cls(tree.tag, word=new_word, index=tree.index), i + 1) 94 | new_children = [] 95 | for c in tree.children: 96 | new_child, i = cls._recursive_replace_words(c, new_words, i) 97 | new_children.append(new_child) 98 | return cls(tree.tag, children=new_children), i 99 | 100 | @classmethod 101 | def replace_words(cls, tree, new_words): 102 | """Return a new tree, with new words replacing old ones.""" 103 | new_tree, i = cls._recursive_replace_words(tree, new_words, 0) 104 | if i != len(new_words): 105 | raise ValueError('len(new_words) == %d != i == %d' % (len(new_words), i)) 106 | return new_tree 107 | 108 | def rejoin(tokens: List[Dict[str, str]], sep: str = None) -> str: 109 | """Rejoin tokens into the original sentence. 110 | 111 | Args: 112 | tokens: a list of dicts containing 'originalText' and 'before' fields. 113 | All other fields will be ignored. 114 | sep: if provided, use the given character as a separator instead of 115 | the 'before' field (e.g. if you want to preserve where tokens are). 116 | Returns: the original sentence that generated this CoreNLP token list. 117 | """ 118 | if sep is None: 119 | return ''.join('%s%s' % (t['before'], t['originalText']) for t in tokens) 120 | else: 121 | # Use the given separator instead 122 | return sep.join(t['originalText'] for t in tokens) 123 | 124 | 125 | def get_tokens_for_answers(answer_objs: List[Tuple[int, Dict]], corenlp_obj: Dict) -> Tuple[int, List]: 126 | """Get CoreNLP tokens corresponding to a SQuAD answer object.""" 127 | first_a_toks = None 128 | for i, a_obj in enumerate(answer_objs): 129 | a_toks = [] 130 | answer_start = a_obj['answer_start'] 131 | answer_end = answer_start + len(a_obj['text']) 132 | for sent in corenlp_obj['sentences']: 133 | for tok in sent['tokens']: 134 | if tok['characterOffsetBegin'] >= answer_end: 135 | continue 136 | if tok['characterOffsetEnd'] <= answer_start: 137 | continue 138 | a_toks.append(tok) 139 | if rejoin(a_toks).strip() == a_obj['text']: 140 | # Make sure that the tokens reconstruct the answer 141 | return i, a_toks 142 | if i == 0: 143 | first_a_toks = a_toks 144 | # None of the extracted token lists reconstruct the answer 145 | # Default to the first 146 | return 0, first_a_toks 147 | 148 | def get_determiner_for_answers(answer_objs: List[Dict]) -> Optional[str]: 149 | for ans in answer_objs: 150 | words = ans['text'].split(' ') 151 | if words[0].lower() == 'the': 152 | return 'the' 153 | if words[0].lower() in ('a', 'an'): 154 | return 'a' 155 | return None 156 | 157 | def compress_whnp(tree, inside_whnp=False): 158 | if not tree.children: return tree # Reached leaf 159 | # Compress all children 160 | for i, c in enumerate(tree.children): 161 | tree.children[i] = compress_whnp(c, inside_whnp=inside_whnp or tree.tag == 'WHNP') 162 | if tree.tag != 'WHNP': 163 | if inside_whnp: 164 | # Wrap everything in an NP 165 | return ConstituencyParse('NP', children=[tree]) 166 | return tree 167 | wh_word = None 168 | new_np_children = [] 169 | new_siblings = [] 170 | for i, c in enumerate(tree.children): 171 | if i == 0: 172 | if c.tag in ('WHNP', 'WHADJP', 'WHAVP', 'WHPP'): 173 | wh_word = c.children[0] 174 | new_np_children.extend(c.children[1:]) 175 | elif c.tag in ('WDT', 'WP', 'WP$', 'WRB'): 176 | wh_word = c 177 | else: 178 | # No WH-word at start of WHNP 179 | return tree 180 | else: 181 | if c.tag == 'SQ': # Due to bad parse, SQ may show up here 182 | new_siblings = tree.children[i:] 183 | break 184 | # Wrap everything in an NP 185 | new_np_children.append(ConstituencyParse('NP', children=[c])) 186 | if new_np_children: 187 | new_np = ConstituencyParse('NP', children=new_np_children) 188 | new_tree = ConstituencyParse('WHNP', children=[wh_word, new_np]) 189 | else: 190 | new_tree = tree 191 | if new_siblings: 192 | new_tree = ConstituencyParse('SBARQ', children=[new_tree] + new_siblings) 193 | return new_tree 194 | 195 | def read_const_parse(parse_str): 196 | tree = ConstituencyParse.from_corenlp(parse_str) 197 | new_tree = compress_whnp(tree) 198 | return new_tree 199 | -------------------------------------------------------------------------------- /adversarialnlp/generators/generator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Union, Iterable, List 3 | from collections import defaultdict 4 | import itertools 5 | 6 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 7 | 8 | class Generator(): 9 | r"""An abstract ``Generator`` class. 10 | 11 | A ``Generator`` takes as inputs an iterable of seeds (for examples 12 | samples from a training dataset) and edit them to generate 13 | potential adversarial examples. 14 | 15 | This class is an abstract class. To implement a ``Generator``, you 16 | should override the `generate_from_seed(self, seed: any)` method 17 | with a specific method to use for yielding adversarial samples 18 | from a seed sample. 19 | 20 | Optionally, you should also: 21 | - define a typing class for the ``seed`` objects 22 | - define a default seed source in the ``__init__`` class, for 23 | examples by downloading an appropriate dataset. See examples 24 | in the ``AddSentGenerator`` class. 25 | 26 | Args: 27 | default_seeds: Default Iterable to use as source of seeds. 28 | quiet: Output debuging information. 29 | 30 | Inputs: 31 | **seed_instances** (optional): Instances to use as seed 32 | for adversarial example generation. If None uses the 33 | default_seeds providing at class instantiation. 34 | Default to None 35 | **num_epochs** (optional): How many times should we iterate 36 | over the seeds? If None, we will iterate over it forever. 37 | Default to None. 38 | **shuffle** (optional): Shuffle the instances before iteration. 39 | If True, we will shuffle the instances before iterating. 40 | Default to False. 41 | 42 | Yields: 43 | **adversarial_examples** (Iterable): Adversarial examples 44 | generated from the seeds. 45 | 46 | Examples:: 47 | 48 | >> generator = Generator() 49 | >> examples = generator(num_epochs=1) 50 | """ 51 | 52 | def __init__(self, 53 | default_seeds: Iterable = None, 54 | quiet: bool = False): 55 | self.default_seeds = default_seeds 56 | self.quiet: bool = quiet 57 | 58 | self._epochs: Dict[int, int] = defaultdict(int) 59 | 60 | def generate_from_seed(self, seed: any): 61 | r"""Generate an adversarial example from a seed. 62 | """ 63 | raise NotImplementedError 64 | 65 | def __call__(self, 66 | seeds: Iterable = None, 67 | num_epochs: int = None, 68 | shuffle: bool = True) -> Iterable: 69 | r"""Generate adversarial examples using _generate_from_seed. 70 | 71 | Args: 72 | seeds: Instances to use as seed for adversarial 73 | example generation. 74 | num_epochs: How many times should we iterate over the seeds? 75 | If None, we will iterate over it forever. 76 | shuffle: Shuffle the instances before iteration. 77 | If True, we will shuffle the instances before iterating. 78 | 79 | Yields: adversarial_examples 80 | adversarial_examples: Adversarial examples generated 81 | from the seeds. 82 | """ 83 | if seeds is None: 84 | if self.default_seeds is not None: 85 | seeds = self.default_seeds 86 | else: 87 | return 88 | # Instances is likely to be a list, which cannot be used as a key, 89 | # so we take the object id instead. 90 | key = id(seeds) 91 | starting_epoch = self._epochs[key] 92 | 93 | if num_epochs is None: 94 | epochs: Iterable[int] = itertools.count(starting_epoch) 95 | else: 96 | epochs = range(starting_epoch, starting_epoch + num_epochs) 97 | 98 | for epoch in epochs: 99 | self._epochs[key] = epoch 100 | for seed in seeds: 101 | yield from self.generate_from_seed(seed) 102 | -------------------------------------------------------------------------------- /adversarialnlp/generators/swag/__init__.py: -------------------------------------------------------------------------------- 1 | from .swag_generator import SwagGenerator 2 | from .activitynet_captions_reader import ActivityNetCaptionsDatasetReader 3 | -------------------------------------------------------------------------------- /adversarialnlp/generators/swag/activitynet_captions_reader.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import json 3 | import logging 4 | from overrides import overrides 5 | from unidecode import unidecode 6 | 7 | from allennlp.common.file_utils import cached_path 8 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 9 | from allennlp.data.fields import TextField, MetadataField 10 | from allennlp.data.instance import Instance 11 | from allennlp.data.tokenizers import Tokenizer, WordTokenizer 12 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 13 | 14 | from adversarialnlp.generators.swag.utils import pairwise, postprocess 15 | 16 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 17 | 18 | 19 | @DatasetReader.register("activitynet_captions") 20 | class ActivityNetCaptionsDatasetReader(DatasetReader): 21 | r""" Reads ActivityNet Captions JSON files and creates a dataset suitable for crafting 22 | adversarial examples with swag using these captions. 23 | 24 | Expected format: 25 | JSON dict[video_id, video_obj] where 26 | video_id: str, 27 | video_obj: { 28 | "duration": float, 29 | "timestamps": list of pairs of float, 30 | "sentences": list of strings 31 | } 32 | 33 | The output of ``read`` is a list of ``Instance`` s with the fields: 34 | video_id: ``MetadataField`` 35 | first_sentence: ``TextField`` 36 | second_sentence: ``TextField`` 37 | 38 | The instances are created from all consecutive pair of sentences 39 | associated to each video. 40 | Ex: if a video has three associated sentences: s1, s2, s3 read will 41 | generate two instances: 42 | 43 | 1. Instance("first_sentence" = s1, "second_sentence" = s2) 44 | 2. Instance("first_sentence" = s2, "second_sentence" = s3) 45 | 46 | Args: 47 | lazy : If True, training will start sooner, but will take 48 | longer per batch. This allows training with datasets that 49 | are too large to fit in memory. Passed to DatasetReader. 50 | tokenizer : Tokenizer to use to split the title and abstract 51 | into words or other kinds of tokens. 52 | token_indexers : Indexers used to define input token 53 | representations. 54 | """ 55 | def __init__(self, 56 | lazy: bool = False, 57 | tokenizer: Tokenizer = None, 58 | token_indexers: Dict[str, TokenIndexer] = None) -> None: 59 | super().__init__(lazy) 60 | self._tokenizer = tokenizer or WordTokenizer() 61 | self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} 62 | 63 | @overrides 64 | def _read(self, file_path): 65 | with open(cached_path(file_path), "r") as data_file: 66 | logger.info("Reading instances from: %s", file_path) 67 | json_data = json.load(data_file) 68 | for video_id, value in json_data.items(): 69 | sentences = [postprocess(unidecode(x.strip())) 70 | for x in value['sentences']] 71 | for first_sentence, second_sentence in pairwise(sentences): 72 | yield self.text_to_instance(video_id, first_sentence, second_sentence) 73 | 74 | @overrides 75 | def text_to_instance(self, 76 | video_id: str, 77 | first_sentence: str, 78 | second_sentence: str) -> Instance: # type: ignore 79 | # pylint: disable=arguments-differ 80 | tokenized_first_sentence = self._tokenizer.tokenize(first_sentence) 81 | tokenized_second_sentence = self._tokenizer.tokenize(second_sentence) 82 | first_sentence_field = TextField(tokenized_first_sentence, self._token_indexers) 83 | second_sentence_field = TextField(tokenized_second_sentence, self._token_indexers) 84 | fields = {'video_id': MetadataField(video_id), 85 | 'first_sentence': first_sentence_field, 86 | 'second_sentence': second_sentence_field} 87 | return Instance(fields) 88 | -------------------------------------------------------------------------------- /adversarialnlp/generators/swag/openai_transformer_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Union, Optional 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from allennlp.common.checks import ConfigurationError 7 | from allennlp.data.vocabulary import Vocabulary 8 | from allennlp.models.model import Model 9 | from allennlp.modules.openai_transformer import OpenaiTransformer 10 | from allennlp.modules.token_embedders import OpenaiTransformerEmbedder 11 | from allennlp.nn.util import get_text_field_mask, remove_sentence_boundaries 12 | 13 | 14 | @Model.register('openai-transformer-language-model') 15 | class OpenAITransformerLanguageModel(Model): 16 | """ 17 | The ``OpenAITransformerLanguageModel`` is a wrapper around ``OpenATransformerModule``. 18 | 19 | Parameters 20 | ---------- 21 | vocab: ``Vocabulary`` 22 | remove_bos_eos: ``bool``, optional (default: True) 23 | Typically the provided token indexes will be augmented with 24 | begin-sentence and end-sentence tokens. If this flag is True 25 | the corresponding embeddings will be removed from the return values. 26 | """ 27 | def __init__(self, 28 | vocab: Vocabulary, 29 | openai_token_embedder: OpenaiTransformerEmbedder, 30 | remove_bos_eos: bool = True) -> None: 31 | super().__init__(vocab) 32 | model_path = "https://s3-us-west-2.amazonaws.com/allennlp/models/openai-transformer-lm-2018.07.23.tar.gz" 33 | indexer = OpenaiTransformerBytePairIndexer(model_path=model_path) 34 | transformer = OpenaiTransformer(model_path=model_path) 35 | self._token_embedders = OpenaiTransformerEmbedder(transformer=transformer, top_layer_only=True) 36 | self._remove_bos_eos = remove_bos_eos 37 | 38 | def _get_target_token_embedding(self, 39 | token_embeddings: torch.Tensor, 40 | mask: torch.Tensor, 41 | direction: int) -> torch.Tensor: 42 | # Need to shift the mask in the correct direction 43 | zero_col = token_embeddings.new_zeros(mask.size(0), 1).byte() 44 | if direction == 0: 45 | # forward direction, get token to right 46 | shifted_mask = torch.cat([zero_col, mask[:, 0:-1]], dim=1) 47 | else: 48 | shifted_mask = torch.cat([mask[:, 1:], zero_col], dim=1) 49 | return token_embeddings.masked_select(shifted_mask.unsqueeze(-1)).view(-1, self._forward_dim) 50 | 51 | def _compute_loss(self, 52 | lm_embeddings: torch.Tensor, 53 | token_embeddings: torch.Tensor, 54 | forward_targets: torch.Tensor, 55 | backward_targets: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 56 | # lm_embeddings is shape (batch_size, timesteps, dim * 2) 57 | # forward_targets, backward_targets are shape (batch_size, timesteps) 58 | # masked with 0 59 | forward_embeddings, backward_embeddings = lm_embeddings.chunk(2, -1) 60 | losses: List[torch.Tensor] = [] 61 | for idx, embedding, targets in ((0, forward_embeddings, forward_targets), 62 | (1, backward_embeddings, backward_targets)): 63 | mask = targets > 0 64 | # we need to subtract 1 to undo the padding id since the softmax 65 | # does not include a padding dimension 66 | non_masked_targets = targets.masked_select(mask) - 1 67 | non_masked_embedding = embedding.masked_select( 68 | mask.unsqueeze(-1) 69 | ).view(-1, self._forward_dim) 70 | # note: need to return average loss across forward and backward 71 | # directions, but total sum loss across all batches. 72 | # Assuming batches include full sentences, forward and backward 73 | # directions have the same number of samples, so sum up loss 74 | # here then divide by 2 just below 75 | if not self._softmax_loss.tie_embeddings or not self._use_character_inputs: 76 | losses.append(self._softmax_loss(non_masked_embedding, non_masked_targets)) 77 | else: 78 | # we also need the token embeddings corresponding to the 79 | # the targets 80 | raise NotImplementedError("This requires SampledSoftmaxLoss, which isn't implemented yet.") 81 | # pylint: disable=unreachable 82 | non_masked_token_embedding = self._get_target_token_embedding(token_embeddings, mask, idx) 83 | losses.append(self._softmax(non_masked_embedding, 84 | non_masked_targets, 85 | non_masked_token_embedding)) 86 | 87 | return losses[0], losses[1] 88 | 89 | def forward(self, # type: ignore 90 | source: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]: 91 | """ 92 | Computes the averaged forward and backward LM loss from the batch. 93 | 94 | By convention, the input dict is required to have at least a ``"tokens"`` 95 | entry that's the output of a ``SingleIdTokenIndexer``, which is used 96 | to compute the language model targets. 97 | 98 | If the model was instantatiated with ``remove_bos_eos=True``, 99 | then it is expected that each of the input sentences was augmented with 100 | begin-sentence and end-sentence tokens. 101 | 102 | Parameters 103 | ---------- 104 | tokens: ``torch.Tensor``, required. 105 | The output of ``Batch.as_tensor_dict()`` for a batch of sentences. 106 | 107 | Returns 108 | ------- 109 | Dict with keys: 110 | 111 | ``'loss'``: ``torch.Tensor`` 112 | averaged forward/backward negative log likelihood 113 | ``'forward_loss'``: ``torch.Tensor`` 114 | forward direction negative log likelihood 115 | ``'backward_loss'``: ``torch.Tensor`` 116 | backward direction negative log likelihood 117 | ``'lm_embeddings'``: ``torch.Tensor`` 118 | (batch_size, timesteps, embed_dim) tensor of top layer contextual representations 119 | ``'mask'``: ``torch.Tensor`` 120 | (batch_size, timesteps) mask for the embeddings 121 | """ 122 | # pylint: disable=arguments-differ 123 | mask = get_text_field_mask(source) 124 | 125 | # We must have token_ids so that we can compute targets 126 | token_ids = source.get("tokens") 127 | if token_ids is None: 128 | raise ConfigurationError("Your data must have a 'tokens': SingleIdTokenIndexer() " 129 | "in order to use the BidirectionalLM") 130 | 131 | # Use token_ids to compute targets 132 | forward_targets = torch.zeros_like(token_ids) 133 | backward_targets = torch.zeros_like(token_ids) 134 | forward_targets[:, 0:-1] = token_ids[:, 1:] 135 | backward_targets[:, 1:] = token_ids[:, 0:-1] 136 | 137 | embeddings = self._text_field_embedder(source) 138 | 139 | # Apply LayerNorm if appropriate. 140 | embeddings = self._layer_norm(embeddings) 141 | 142 | contextual_embeddings = self._contextualizer(embeddings, mask) 143 | 144 | # add dropout 145 | contextual_embeddings = self._dropout(contextual_embeddings) 146 | 147 | # compute softmax loss 148 | forward_loss, backward_loss = self._compute_loss(contextual_embeddings, 149 | embeddings, 150 | forward_targets, 151 | backward_targets) 152 | 153 | num_targets = torch.sum((forward_targets > 0).long()) 154 | if num_targets > 0: 155 | average_loss = 0.5 * (forward_loss + backward_loss) / num_targets.float() 156 | else: 157 | average_loss = torch.tensor(0.0).to(forward_targets.device) # pylint: disable=not-callable 158 | # this is stored to compute perplexity if needed 159 | self._last_average_loss[0] = average_loss.detach().item() 160 | 161 | if num_targets > 0: 162 | # loss is directly minimized 163 | if self._loss_scale == 'n_samples': 164 | scale_factor = num_targets.float() 165 | else: 166 | scale_factor = self._loss_scale 167 | 168 | return_dict = { 169 | 'loss': average_loss * scale_factor, 170 | 'forward_loss': forward_loss * scale_factor / num_targets.float(), 171 | 'backward_loss': backward_loss * scale_factor / num_targets.float() 172 | } 173 | else: 174 | # average_loss zero tensor, return it for all 175 | return_dict = { 176 | 'loss': average_loss, 177 | 'forward_loss': average_loss, 178 | 'backward_loss': average_loss 179 | } 180 | 181 | if self._remove_bos_eos: 182 | contextual_embeddings, mask = remove_sentence_boundaries(contextual_embeddings, mask) 183 | 184 | return_dict.update({ 185 | 'lm_embeddings': contextual_embeddings, 186 | 'mask': mask 187 | }) 188 | 189 | return return_dict 190 | -------------------------------------------------------------------------------- /adversarialnlp/generators/swag/simple_bilm.py: -------------------------------------------------------------------------------- 1 | """ 2 | A wrapper around ai2s elmo LM to allow for an lm objective... 3 | """ 4 | 5 | from typing import Optional, Tuple 6 | from typing import Union, List, Dict 7 | 8 | import numpy as np 9 | import torch 10 | from allennlp.common.checks import ConfigurationError 11 | from allennlp.data import Token, Vocabulary, Instance 12 | from allennlp.data.dataset import Batch 13 | from allennlp.data.fields import TextField 14 | from allennlp.data.token_indexers import SingleIdTokenIndexer 15 | from allennlp.modules.augmented_lstm import AugmentedLstm 16 | from allennlp.modules.seq2seq_encoders.pytorch_seq2seq_wrapper import PytorchSeq2SeqWrapper 17 | from allennlp.nn.util import sequence_cross_entropy_with_logits 18 | from torch.autograd import Variable 19 | from torch.nn import functional as F 20 | from torch.nn.utils.rnn import PackedSequence 21 | 22 | 23 | def _de_duplicate_generations(generations): 24 | """ 25 | Given a list of list of strings, filter out the ones that are duplicates. and return an idx corresponding 26 | to the good ones 27 | :param generations: 28 | :return: 29 | """ 30 | dup_set = set() 31 | unique_idx = [] 32 | for i, gen_i in enumerate(generations): 33 | gen_i_str = ' '.join(gen_i) 34 | if gen_i_str not in dup_set: 35 | unique_idx.append(i) 36 | dup_set.add(gen_i_str) 37 | return [generations[i] for i in unique_idx], np.array(unique_idx) 38 | 39 | 40 | class StackedLstm(torch.nn.Module): 41 | """ 42 | A stacked LSTM. 43 | 44 | Parameters 45 | ---------- 46 | input_size : int, required 47 | The dimension of the inputs to the LSTM. 48 | hidden_size : int, required 49 | The dimension of the outputs of the LSTM. 50 | num_layers : int, required 51 | The number of stacked LSTMs to use. 52 | recurrent_dropout_probability: float, optional (default = 0.0) 53 | The dropout probability to be used in a dropout scheme as stated in 54 | `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks 55 | `_ . 56 | use_input_projection_bias : bool, optional (default = True) 57 | Whether or not to use a bias on the input projection layer. This is mainly here 58 | for backwards compatibility reasons and will be removed (and set to False) 59 | in future releases. 60 | 61 | Returns 62 | ------- 63 | output_accumulator : PackedSequence 64 | The outputs of the interleaved LSTMs per timestep. A tensor of shape 65 | (batch_size, max_timesteps, hidden_size) where for a given batch 66 | element, all outputs past the sequence length for that batch are 67 | zero tensors. 68 | """ 69 | 70 | def __init__(self, 71 | input_size: int, 72 | hidden_size: int, 73 | num_layers: int, 74 | recurrent_dropout_probability: float = 0.0, 75 | use_highway: bool = True, 76 | use_input_projection_bias: bool = True, 77 | go_forward: bool = True) -> None: 78 | super(StackedLstm, self).__init__() 79 | 80 | # Required to be wrapped with a :class:`PytorchSeq2SeqWrapper`. 81 | self.input_size = input_size 82 | self.hidden_size = hidden_size 83 | self.num_layers = num_layers 84 | 85 | layers = [] 86 | lstm_input_size = input_size 87 | for layer_index in range(num_layers): 88 | layer = AugmentedLstm(lstm_input_size, hidden_size, go_forward, 89 | recurrent_dropout_probability=recurrent_dropout_probability, 90 | use_highway=use_highway, 91 | use_input_projection_bias=use_input_projection_bias) 92 | lstm_input_size = hidden_size 93 | self.add_module('layer_{}'.format(layer_index), layer) 94 | layers.append(layer) 95 | self.lstm_layers = layers 96 | 97 | def forward(self, # pylint: disable=arguments-differ 98 | inputs: PackedSequence, 99 | initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): 100 | """ 101 | Parameters 102 | ---------- 103 | inputs : ``PackedSequence``, required. 104 | A batch first ``PackedSequence`` to run the stacked LSTM over. 105 | initial_state : Tuple[torch.Tensor, torch.Tensor], optional, (default = None) 106 | A tuple (state, memory) representing the initial hidden state and memory 107 | of the LSTM. Each tensor has shape (1, batch_size, output_dimension). 108 | 109 | Returns 110 | ------- 111 | output_sequence : PackedSequence 112 | The encoded sequence of shape (batch_size, sequence_length, hidden_size) 113 | final_states: torch.Tensor 114 | The per-layer final (state, memory) states of the LSTM, each with shape 115 | (num_layers, batch_size, hidden_size). 116 | """ 117 | if not initial_state: 118 | hidden_states = [None] * len(self.lstm_layers) 119 | elif initial_state[0].size()[0] != len(self.lstm_layers): 120 | raise ConfigurationError("Initial states were passed to forward() but the number of " 121 | "initial states does not match the number of layers.") 122 | else: 123 | hidden_states = list(zip(initial_state[0].split(1, 0), 124 | initial_state[1].split(1, 0))) 125 | 126 | output_sequence = inputs 127 | final_states = [] 128 | for i, state in enumerate(hidden_states): 129 | layer = getattr(self, 'layer_{}'.format(i)) 130 | # The state is duplicated to mirror the Pytorch API for LSTMs. 131 | output_sequence, final_state = layer(output_sequence, state) 132 | final_states.append(final_state) 133 | 134 | final_state_tuple = tuple(torch.cat(state_list, 0) for state_list in zip(*final_states)) 135 | return output_sequence, final_state_tuple 136 | 137 | 138 | class SimpleBiLM(torch.nn.Module): 139 | def __init__(self, 140 | vocab: Vocabulary, 141 | recurrent_dropout_probability: float = 0.0, 142 | embedding_dropout_probability: float = 0.0, 143 | input_size=512, 144 | hidden_size=512) -> None: 145 | """ 146 | :param options_file: for initializing elmo BiLM 147 | :param weight_file: for initializing elmo BiLM 148 | :param requires_grad: Whether or not to finetune the LSTM layers 149 | :param recurrent_dropout_probability: recurrent dropout to add to LSTM layers 150 | """ 151 | super(SimpleBiLM, self).__init__() 152 | 153 | self.forward_lm = PytorchSeq2SeqWrapper(StackedLstm( 154 | input_size=input_size, hidden_size=hidden_size, num_layers=2, go_forward=True, 155 | recurrent_dropout_probability=recurrent_dropout_probability, 156 | use_input_projection_bias=False, use_highway=True), stateful=True) 157 | self.reverse_lm = PytorchSeq2SeqWrapper(StackedLstm( 158 | input_size=input_size, hidden_size=hidden_size, num_layers=2, go_forward=False, 159 | recurrent_dropout_probability=recurrent_dropout_probability, 160 | use_input_projection_bias=False, use_highway=True), stateful=True) 161 | 162 | # This will also be the encoder 163 | self.decoder = torch.nn.Linear(512, vocab.get_vocab_size(namespace='tokens')) 164 | 165 | self.vocab = vocab 166 | self.register_buffer('eos_tokens', torch.LongTensor([vocab.get_token_index(tok) for tok in 167 | ['.', '!', '?', '@@UNKNOWN@@', '@@PADDING@@', '@@bos@@', 168 | '@@eos@@']])) 169 | self.register_buffer('invalid_tokens', torch.LongTensor([vocab.get_token_index(tok) for tok in 170 | ['@@UNKNOWN@@', '@@PADDING@@', '@@bos@@', '@@eos@@', 171 | '@@NEWLINE@@']])) 172 | self.embedding_dropout_probability = embedding_dropout_probability 173 | 174 | def embed_words(self, words): 175 | # assert words.dim() == 2 176 | return F.embedding(words, self.decoder.weight) 177 | # if not self.training: 178 | # return F.embedding(words, self.decoder.weight) 179 | # Embedding dropout 180 | # vocab_size = self.decoder.weight.size(0) 181 | # mask = Variable( 182 | # self.decoder.weight.data.new(vocab_size, 1).bernoulli_(1 - self.embedding_dropout_probability).expand_as( 183 | # self.decoder.weight) / (1 - self.embedding_dropout_probability)) 184 | 185 | # padding_idx = 0 186 | # embeds = self.decoder._backend.Embedding.apply(words, mask * self.decoder.weight, padding_idx, None, 187 | # 2, False, False) 188 | # return embeds 189 | 190 | def timestep_to_ids(self, timestep_tokenized: List[str]): 191 | """ Just a single timestep (so dont add BOS or EOS""" 192 | return torch.tensor([self.vocab.get_token_index(x) for x in timestep_tokenized])[:, None] 193 | 194 | def batch_to_ids(self, stories_tokenized: List[List[str]]): 195 | """ 196 | Simple wrapper around _elmo_batch_to_ids 197 | :param batch: A list of tokenized sentences. 198 | :return: A tensor of padded character ids. 199 | """ 200 | batch = Batch([Instance( 201 | {'story': TextField([Token('@@bos@@')] + [Token(x) for x in story] + [Token('@@eos@@')], 202 | token_indexers={ 203 | 'tokens': SingleIdTokenIndexer(namespace='tokens', lowercase_tokens=True)})}) 204 | for story in stories_tokenized]) 205 | batch.index_instances(self.vocab) 206 | words = {k: v['tokens'] for k, v in batch.as_tensor_dict().items()}['story'] 207 | return words 208 | 209 | def conditional_generation(self, context: List[str], gt_completion: List[str], 210 | batch_size: int = 128, max_gen_length: int = 25, 211 | same_length_as_gt: bool = False, first_is_gold: bool = False): 212 | """ 213 | Generate conditoned on the context. While we're at it we'll score the GT going forwards 214 | :param context: List of tokens to condition on. We'll add the BOS marker to it 215 | :param gt_completion: The gold truth completion 216 | :param batch_size: Number of sentences to generate 217 | :param max_gen_length: Max length for genertaed sentences (irrelvant if same_length_as_gt=True) 218 | :param same_length_as_gt: set to True if you want all the sents to have the same length as the gt_completion 219 | :param first_is_gold: set to True if you want the first sample to be the gt_completion 220 | :return: 221 | """ 222 | # Forward condition on context, then repeat to be the right batch size: 223 | # (layer_index, batch_size, fwd hidden dim) 224 | log_probs = self(self.batch_to_ids([context]), use_forward=True, 225 | use_reverse=False, compute_logprobs=True) 226 | forward_logprobs = log_probs['forward_logprobs'] 227 | self.forward_lm._states = tuple(x.repeat(1, batch_size, 1).contiguous() for x in self.forward_lm._states) 228 | # Each item will be (token, score) 229 | generations = [[(context[-1], 0.0)] for i in range(batch_size)] 230 | mask = forward_logprobs.new(batch_size).long().fill_(1) 231 | 232 | gt_completion_padded = [self.vocab.get_token_index(gt_token) for gt_token in 233 | [x.lower() for x in gt_completion] + ['@@PADDING@@'] * ( 234 | max_gen_length - len(gt_completion))] 235 | 236 | for index, gt_token_ind in enumerate(gt_completion_padded): 237 | embeds = self.embed_words(self.timestep_to_ids([gen[-1][0] for gen in generations])) 238 | next_dists = F.softmax(self.decoder(self.forward_lm(embeds, mask[:, None]))[:, 0], dim=1) 239 | 240 | # Perform hacky stuff on the distribution (disallowing BOS, EOS, that sorta thing 241 | sampling_probs = next_dists.clone() 242 | sampling_probs[:, self.invalid_tokens] = 0.0 243 | 244 | if first_is_gold: 245 | # Gold truth is first row 246 | sampling_probs[0].zero_() 247 | sampling_probs[0, gt_token_ind] = 1 248 | 249 | if same_length_as_gt: 250 | if index == (len(gt_completion) - 1): 251 | sampling_probs.zero_() 252 | sampling_probs[:, gt_token_ind] = 1 253 | else: 254 | sampling_probs[:, self.eos_tokens] = 0.0 255 | 256 | sampling_probs = sampling_probs / sampling_probs.sum(1, keepdim=True) 257 | 258 | next_preds = torch.multinomial(sampling_probs, 1).squeeze(1) 259 | next_scores = np.log(next_dists[ 260 | torch.arange(0, next_dists.size(0), 261 | out=mask.data.new(next_dists.size(0))), 262 | next_preds, 263 | ].cpu().detach().numpy()) 264 | for i, (gen_list, pred_id, score_i, mask_i) in enumerate( 265 | zip(generations, next_preds.cpu().detach().numpy(), next_scores, mask.data.cpu().detach().numpy())): 266 | if mask_i: 267 | gen_list.append((self.vocab.get_token_from_index(pred_id), score_i)) 268 | is_eos = (next_preds[:, None] == self.eos_tokens[None]).max(1)[0] 269 | mask[is_eos] = 0 270 | if mask.sum().item() == 0: 271 | break 272 | generation_scores = np.zeros((len(generations), max([len(g) - 1 for g in generations])), dtype=np.float32) 273 | for i, gen in enumerate(generations): 274 | for j, (_, v) in enumerate(gen[1:]): 275 | generation_scores[i, j] = v 276 | 277 | generation_toks, idx = _de_duplicate_generations([[tok for (tok, score) in gen[1:]] for gen in generations]) 278 | return generation_toks, generation_scores[idx], forward_logprobs.cpu().detach().numpy() 279 | 280 | def _chunked_logsoftmaxes(self, activation, word_targets, chunk_size=256): 281 | """ 282 | do the softmax in chunks so memory doesnt explode 283 | :param activation: [batch, T, dim] 284 | :param targets: [batch, T] indices 285 | :param chunk_size: you might need to tune this based on GPU specs 286 | :return: 287 | """ 288 | all_logprobs = [] 289 | num_chunks = (activation.size(0) - 1) // chunk_size + 1 290 | for activation_chunk, target_chunk in zip(torch.chunk(activation, num_chunks, dim=0), 291 | torch.chunk(word_targets, num_chunks, dim=0)): 292 | assert activation_chunk.size()[:2] == target_chunk.size()[:2] 293 | targets_flat = target_chunk.view(-1) 294 | time_indexer = torch.arange(0, targets_flat.size(0), 295 | out=target_chunk.data.new(targets_flat.size(0))) % target_chunk.size(1) 296 | batch_indexer = torch.arange(0, targets_flat.size(0), 297 | out=target_chunk.data.new(targets_flat.size(0))) / target_chunk.size(1) 298 | all_logprobs.append(F.log_softmax(self.decoder(activation_chunk), 2)[ 299 | batch_indexer, time_indexer, targets_flat].view(*target_chunk.size())) 300 | return torch.cat(all_logprobs, 0) 301 | 302 | def forward(self, words: torch.Tensor, use_forward: bool = True, use_reverse: bool = True, 303 | compute_logprobs: bool = False) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: 304 | """ 305 | use this for training the LM 306 | :param words: [batch_size, N] words. assuming you're starting with BOS and ending with EOS here 307 | :return: 308 | """ 309 | encoded_inputs = self.embed_words(words) 310 | mask = (words != 0).long()[:, 2:] 311 | word_targets = words[:, 1:-1].contiguous() 312 | 313 | result_dict = { 314 | 'mask': mask, 315 | 'word_targets': word_targets, 316 | } 317 | # TODO: try to reduce duplicate code here 318 | if use_forward: 319 | self.forward_lm.reset_states() 320 | forward_activation = self.forward_lm(encoded_inputs[:, :-2], mask) 321 | 322 | if compute_logprobs: 323 | # being memory efficient here is critical if the input tensors are large 324 | result_dict['forward_logprobs'] = self._chunked_logsoftmaxes(forward_activation, 325 | word_targets) * mask.float() 326 | else: 327 | 328 | result_dict['forward_logits'] = self.decoder(forward_activation) 329 | result_dict['forward_loss'] = sequence_cross_entropy_with_logits(result_dict['forward_logits'], 330 | word_targets, 331 | mask) 332 | if use_reverse: 333 | self.reverse_lm.reset_states() 334 | reverse_activation = self.reverse_lm(encoded_inputs[:, 2:], mask) 335 | if compute_logprobs: 336 | result_dict['reverse_logprobs'] = self._chunked_logsoftmaxes(reverse_activation, 337 | word_targets) * mask.float() 338 | else: 339 | result_dict['reverse_logits'] = self.decoder(reverse_activation) 340 | result_dict['reverse_loss'] = sequence_cross_entropy_with_logits(result_dict['reverse_logits'], 341 | word_targets, 342 | mask) 343 | return result_dict 344 | -------------------------------------------------------------------------------- /adversarialnlp/generators/swag/swag_generator.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name,arguments-differ 2 | from typing import List, Iterable, Tuple 3 | import logging 4 | 5 | 6 | import torch 7 | from allennlp.common.util import JsonDict 8 | from allennlp.common.file_utils import cached_path 9 | from allennlp.data import Instance, Token, Vocabulary 10 | from allennlp.data.fields import TextField 11 | from allennlp.pretrained import PretrainedModel 12 | 13 | from adversarialnlp.common.file_utils import download_files, DATA_ROOT 14 | from adversarialnlp.generators import Generator 15 | from adversarialnlp.generators.swag.simple_bilm import SimpleBiLM 16 | from adversarialnlp.generators.swag.utils import optimistic_restore 17 | from adversarialnlp.generators.swag.activitynet_captions_reader import ActivityNetCaptionsDatasetReader 18 | 19 | BATCH_SIZE = 1 20 | BEAM_SIZE = 8 * BATCH_SIZE 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | class SwagGenerator(Generator): 25 | """ 26 | ``SwagGenerator`` inherit from the ``Generator`` class. 27 | This ``Generator`` generate adversarial examples from seeds using 28 | the method described in 29 | `SWAG: A Large-Scale Adversarial Dataset for Grounded Commonsense Inference `_. 30 | 31 | This method goes schematically as follows: 32 | - In a seed sample containing a pair of sequential sentence (ex: video captions), 33 | the second sentence is split into noun and verb phrases. 34 | - A language model generates several possible endings from the sencond sentence noun phrase. 35 | 36 | Args, input and yield: 37 | See the ``Generator`` class. 38 | 39 | Seeds: 40 | AllenNLP ``Instance`` containing two ``TextField``: 41 | `first_sentence` and `first_sentence`, respectively containing 42 | first and the second consecutive sentences. 43 | 44 | default_seeds: 45 | If no seeds are provided, the default_seeds are the training set 46 | of the 47 | `ActivityNet Captions dataset `_. 48 | 49 | """ 50 | def __init__(self, 51 | default_seeds: Iterable = None, 52 | quiet: bool = False): 53 | super().__init__(default_seeds, quiet) 54 | 55 | lm_files = download_files(fnames=['vocabulary.zip', 56 | 'lm-fold-0.bin'], 57 | local_folder='swag_lm') 58 | 59 | activity_data_files = download_files(fnames=['captions.zip'], 60 | paths='https://cs.stanford.edu/people/ranjaykrishna/densevid/', 61 | local_folder='activitynet_captions') 62 | 63 | const_parser_files = cached_path('https://s3-us-west-2.amazonaws.com/allennlp/models/elmo-constituency-parser-2018.03.14.tar.gz', 64 | cache_dir=str(DATA_ROOT / 'allennlp_constituency_parser')) 65 | 66 | self.const_parser = PretrainedModel(const_parser_files, 'constituency-parser').predictor() 67 | vocab = Vocabulary.from_files(lm_files[0]) 68 | self.language_model = SimpleBiLM(vocab=vocab, recurrent_dropout_probability=0.2, 69 | embedding_dropout_probability=0.2) 70 | optimistic_restore(self.language_model, torch.load(lm_files[1], map_location='cpu')['state_dict']) 71 | 72 | if default_seeds is None: 73 | self.default_seeds = ActivityNetCaptionsDatasetReader().read(activity_data_files[0] + '/train.json') 74 | else: 75 | self.default_seeds = default_seeds 76 | 77 | def _find_VP(self, tree: JsonDict) -> List[Tuple[str, any]]: 78 | r"""Recurse on a constituency parse tree until we find verb phrases""" 79 | 80 | # Recursion is annoying because we need to check whether each is a list or not 81 | def _recurse_on_children(): 82 | assert 'children' in tree 83 | result = [] 84 | for child in tree['children']: 85 | res = self._find_VP(child) 86 | if isinstance(res, tuple): 87 | result.append(res) 88 | else: 89 | result.extend(res) 90 | return result 91 | 92 | if 'VP' in tree['attributes']: 93 | # # Now we'll get greedy and see if we can find something better 94 | # if 'children' in tree and len(tree['children']) > 1: 95 | # recurse_result = _recurse_on_children() 96 | # if all([x[1] in ('VP', 'NP', 'CC') for x in recurse_result]): 97 | # return recurse_result 98 | return [(tree['word'], 'VP')] 99 | # base cases 100 | if 'NP' in tree['attributes']: 101 | return [(tree['word'], 'NP')] 102 | # No children 103 | if not 'children' in tree: 104 | return [(tree['word'], tree['attributes'][0])] 105 | 106 | # If a node only has 1 child then we'll have to stick with that 107 | if len(tree['children']) == 1: 108 | return _recurse_on_children() 109 | # try recursing on everything 110 | return _recurse_on_children() 111 | 112 | def _split_on_final_vp(self, sentence: Instance) -> (List[str], List[str]): 113 | r"""Splits a sentence on the final verb phrase """ 114 | sentence_txt = ' '.join(t.text for t in sentence.tokens) 115 | res = self.const_parser.predict(sentence_txt) 116 | res_chunked = self._find_VP(res['hierplane_tree']['root']) 117 | is_vp: List[int] = [i for i, (word, pos) in enumerate(res_chunked) if pos == 'VP'] 118 | if not is_vp: 119 | return None, None 120 | vp_ind = max(is_vp) 121 | not_vp = [token for x in res_chunked[:vp_ind] for token in x[0].split(' ')] 122 | is_vp = [token for x in res_chunked[vp_ind:] for token in x[0].split(' ')] 123 | return not_vp, is_vp 124 | 125 | def generate_from_seed(self, seed: Tuple): 126 | """Edit a seed example. 127 | """ 128 | first_sentence: TextField = seed.fields["first_sentence"] 129 | second_sentence: TextField = seed.fields["second_sentence"] 130 | eos_bounds = [i + 1 for i, x in enumerate(first_sentence.tokens) if x.text in ('.', '?', '!')] 131 | if not eos_bounds: 132 | first_sentence = TextField(tokens=first_sentence.tokens + [Token(text='.')], 133 | token_indexers=first_sentence.token_indexers) 134 | context_len = len(first_sentence.tokens) 135 | if context_len < 6 or context_len > 100: 136 | print("skipping on {} (too short or long)".format( 137 | ' '.join(first_sentence.tokens + second_sentence.tokens))) 138 | return 139 | # Something I should have done: 140 | # make sure that there aren't multiple periods, etc. in s2 or in the middle 141 | eos_bounds_s2 = [i + 1 for i, x in enumerate(second_sentence.tokens) if x.text in ('.', '?', '!')] 142 | if len(eos_bounds_s2) > 1 or max(eos_bounds_s2) != len(second_sentence.tokens): 143 | return 144 | elif not eos_bounds_s2: 145 | second_sentence = TextField(tokens=second_sentence.tokens + [Token(text='.')], 146 | token_indexers=second_sentence.token_indexers) 147 | 148 | # Now split on the VP 149 | startphrase, endphrase = self._split_on_final_vp(second_sentence) 150 | if startphrase is None or not startphrase or len(endphrase) < 5 or len(endphrase) > 25: 151 | print("skipping on {}->{},{}".format(' '.join(first_sentence.tokens + second_sentence.tokens), 152 | startphrase, endphrase), flush=True) 153 | return 154 | 155 | # if endphrase contains unk then it's hopeless 156 | # if any(vocab.get_token_index(tok.lower()) == vocab.get_token_index(vocab._oov_token) 157 | # for tok in endphrase): 158 | # print("skipping on {} (unk!)".format(' '.join(s1_toks + s2_toks))) 159 | # return 160 | 161 | context = [token.text for token in first_sentence.tokens] + startphrase 162 | 163 | lm_out = self.language_model.conditional_generation(context, gt_completion=endphrase, 164 | batch_size=BEAM_SIZE, 165 | max_gen_length=25) 166 | gens0, fwd_scores, ctx_scores = lm_out 167 | if len(gens0) < BATCH_SIZE: 168 | print("Couldn't generate enough candidates so skipping") 169 | return 170 | gens0 = gens0[:BATCH_SIZE] 171 | yield gens0 172 | # fwd_scores = fwd_scores[:BATCH_SIZE] 173 | 174 | # # Now get the backward scores. 175 | # full_sents = [context + gen for gen in gens0] # NOTE: #1 is GT 176 | # result_dict = self.language_model(self.language_model.batch_to_ids(full_sents), 177 | # use_forward=False, use_reverse=True, compute_logprobs=True) 178 | # ending_lengths = (fwd_scores < 0).sum(1) 179 | # ending_lengths_float = ending_lengths.astype(np.float32) 180 | # rev_scores = result_dict['reverse_logprobs'].cpu().detach().numpy() 181 | 182 | # forward_logperp_ending = -fwd_scores.sum(1) / ending_lengths_float 183 | # reverse_logperp_ending = -rev_scores[:, context_len:].sum(1) / ending_lengths_float 184 | # forward_logperp_begin = -ctx_scores.mean() 185 | # reverse_logperp_begin = -rev_scores[:, :context_len].mean(1) 186 | # eos_logperp = -fwd_scores[np.arange(fwd_scores.shape[0]), ending_lengths - 1] 187 | # # print("Time elapsed {:.3f}".format(time() - tic), flush=True) 188 | 189 | # scores = np.exp(np.column_stack(( 190 | # forward_logperp_ending, 191 | # reverse_logperp_ending, 192 | # reverse_logperp_begin, 193 | # eos_logperp, 194 | # np.ones(forward_logperp_ending.shape[0], dtype=np.float32) * forward_logperp_begin, 195 | # ))) 196 | 197 | # PRINTOUT 198 | # low2high = scores[:, 2].argsort() 199 | # print("\n\n Dataset={} ctx: {} (perp={:.3f})\n~~~\n".format(item['dataset'], ' '.join(context), 200 | # np.exp(forward_logperp_begin)), flush=True) 201 | # for i, ind in enumerate(low2high.tolist()): 202 | # gen_i = ' '.join(gens0[ind]) 203 | # if (ind == 0) or (i < 128): 204 | # print("{:3d}/{:4d}) ({}, end|ctx:{:5.1f} end:{:5.1f} ctx|end:{:5.1f} EOS|(ctx, end):{:5.1f}) {}".format( 205 | # i, len(gens0), 'GOLD' if ind == 0 else ' ', *scores[ind][:-1], gen_i), flush=True) 206 | # gt_score = low2high.argsort()[0] 207 | 208 | # item_full = deepcopy(item) 209 | # item_full['sent1'] = first_sentence 210 | # item_full['startphrase'] = startphrase 211 | # item_full['context'] = context 212 | # item_full['generations'] = gens0 213 | # item_full['postags'] = [ # parse real fast 214 | # [x.orth_.lower() if pos_vocab.get_token_index(x.orth_.lower()) != 1 else x.pos_ for x in y] 215 | # for y in spacy_model.pipe([startphrase + gen for gen in gens0], batch_size=BATCH_SIZE)] 216 | # item_full['scores'] = pd.DataFrame(data=scores, index=np.arange(scores.shape[0]), 217 | # columns=['end-from-ctx', 'end', 'ctx-from-end', 'eos-from-ctxend', 'ctx']) 218 | 219 | # generated_examples.append(gens0) 220 | # if len(generated_examples) > 0: 221 | # yield generated_examples 222 | # generated_examples = [] 223 | 224 | # from adversarialnlp.common.file_utils import FIXTURES_ROOT 225 | # generator = SwagGenerator() 226 | # test_instances = ActivityNetCaptionsDatasetReader().read(FIXTURES_ROOT / 'activitynet_captions.json') 227 | # batches = list(generator(test_instances, num_epochs=1)) 228 | # assert len(batches) != 0 229 | -------------------------------------------------------------------------------- /adversarialnlp/generators/swag/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from itertools import tee 3 | 4 | from num2words import num2words 5 | 6 | def optimistic_restore(network, state_dict): 7 | mismatch = False 8 | own_state = network.state_dict() 9 | for name, param in state_dict.items(): 10 | if name not in own_state: 11 | print("Unexpected key {} in state_dict with size {}".format(name, param.size())) 12 | mismatch = True 13 | elif param.size() == own_state[name].size(): 14 | own_state[name].copy_(param) 15 | else: 16 | print("Network has {} with size {}, ckpt has {}".format(name, 17 | own_state[name].size(), 18 | param.size())) 19 | mismatch = True 20 | 21 | missing = set(own_state.keys()) - set(state_dict.keys()) 22 | if len(missing) > 0: 23 | print("We couldn't find {}".format(','.join(missing))) 24 | mismatch = True 25 | return not mismatch 26 | 27 | def pairwise(iterable): 28 | "s -> (s0,s1), (s1,s2), (s2, s3), ..." 29 | a, b = tee(iterable) 30 | next(b, None) 31 | return zip(a, b) 32 | 33 | def n2w_1k(num, use_ordinal=False): 34 | if num > 1000: 35 | return '' 36 | return num2words(num, to='ordinal' if use_ordinal else 'cardinal') 37 | 38 | def postprocess(sentence): 39 | """ 40 | make sure punctuation is followed by a space 41 | :param sentence: 42 | :return: 43 | """ 44 | sentence = remove_allcaps(sentence) 45 | # Aggressively get rid of some punctuation markers 46 | sent0 = re.sub(r'^.*(\\|/|!!!|~|=|#|@|\*|¡|©|¿|«|»|¬|{|}|\||\(|\)|\+|\]|\[).*$', 47 | ' ', sentence, flags=re.MULTILINE|re.IGNORECASE) 48 | 49 | # Less aggressively get rid of quotes, apostrophes 50 | sent1 = re.sub(r'"', ' ', sent0) 51 | sent2 = re.sub(r'`', '\'', sent1) 52 | 53 | # match ordinals 54 | sent3 = re.sub(r'(\d+(?:rd|st|nd))', 55 | lambda x: n2w_1k(int(x.group(0)[:-2]), use_ordinal=True), sent2) 56 | 57 | #These things all need to be followed by spaces or else we'll run into problems 58 | sent4 = re.sub(r'[:;,\"\!\.\-\?](?! )', lambda x: x.group(0) + ' ', sent3) 59 | 60 | #These things all need to be preceded by spaces or else we'll run into problems 61 | sent5 = re.sub(r'(?! )[-]', lambda x: ' ' + x.group(0), sent4) 62 | 63 | # Several spaces 64 | sent6 = re.sub(r'\s\s+', ' ', sent5) 65 | 66 | sent7 = sent6.strip() 67 | return sent7 68 | 69 | def remove_allcaps(sent): 70 | """ 71 | Given a sentence, filter it so that it doesn't contain some words that are ALLcaps 72 | :param sent: string, like SOMEONE wheels SOMEONE on, mouthing silent words of earnest prayer. 73 | :return: Someone wheels someone on, mouthing silent words of earnest prayer. 74 | """ 75 | # Remove all caps 76 | def _sanitize(word, is_first): 77 | if word == "I": 78 | return word 79 | num_capitals = len([x for x in word if not x.islower()]) 80 | if num_capitals > len(word) // 2: 81 | # We have an all caps word here. 82 | if is_first: 83 | return word[0] + word[1:].lower() 84 | return word.lower() 85 | return word 86 | 87 | return ' '.join([_sanitize(word, i == 0) for i, word in enumerate(sent.split(' '))]) 88 | -------------------------------------------------------------------------------- /adversarialnlp/pruners/__init__.py: -------------------------------------------------------------------------------- 1 | from .pruner import Pruner -------------------------------------------------------------------------------- /adversarialnlp/pruners/pruner.py: -------------------------------------------------------------------------------- 1 | from allennlp.common import Registrable 2 | 3 | class Pruner(Registrable): 4 | """ 5 | ``Pruner`` is used to fil potential adversarial samples 6 | 7 | Parameters 8 | ---------- 9 | dataset_reader : ``DatasetReader`` 10 | The ``DatasetReader`` object that will be used to sample training examples. 11 | 12 | """ 13 | def __init__(self, ) -> None: 14 | super().__init__() 15 | -------------------------------------------------------------------------------- /adversarialnlp/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import logging 3 | import os 4 | import sys 5 | 6 | if os.environ.get("ALLENNLP_DEBUG"): 7 | LEVEL = logging.DEBUG 8 | else: 9 | LEVEL = logging.INFO 10 | 11 | sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))) 12 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 13 | level=LEVEL) 14 | 15 | from adversarialnlp.commands import main # pylint: disable=wrong-import-position 16 | 17 | if __name__ == "__main__": 18 | main(prog="adversarialnlp") 19 | -------------------------------------------------------------------------------- /adversarialnlp/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/adversarialnlp/543c02111c57bf245f2aa145c0e5a4879d151001/adversarialnlp/tests/__init__.py -------------------------------------------------------------------------------- /adversarialnlp/tests/dataset_readers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/adversarialnlp/543c02111c57bf245f2aa145c0e5a4879d151001/adversarialnlp/tests/dataset_readers/__init__.py -------------------------------------------------------------------------------- /adversarialnlp/tests/dataset_readers/activitynet_captions_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-self-use,invalid-name 2 | import pytest 3 | 4 | from allennlp.common.util import ensure_list 5 | 6 | from adversarialnlp.dataset_readers import ActivityNetCaptionsDatasetReader 7 | from adversarialnlp.tests.utils import FIXTURES_ROOT 8 | 9 | class TestActivityNetCaptionsReader(): 10 | @pytest.mark.parametrize("lazy", (True, False)) 11 | def test_read_from_file(self, lazy): 12 | reader = ActivityNetCaptionsDatasetReader(lazy=lazy) 13 | instances = reader.read(FIXTURES_ROOT / 'activitynet_captions.json') 14 | instances = ensure_list(instances) 15 | 16 | instance1 = {"video_id": "v_uqiMw7tQ1Cc", 17 | "first_sentence": "A weight lifting tutorial is given .".split(), 18 | "second_sentence": "The coach helps the guy in red with the proper body placement and lifting technique .".split()} 19 | 20 | instance2 = {"video_id": "v_bXdq2zI1Ms0", 21 | "first_sentence": "A man is seen speaking to the camera and pans out into more men standing behind him .".split(), 22 | "second_sentence": "The first man then begins performing martial arts moves while speaking to he camera .".split()} 23 | 24 | instance3 = {"video_id": "v_bXdq2zI1Ms0", 25 | "first_sentence": "The first man then begins performing martial arts moves while speaking to he camera .".split(), 26 | "second_sentence": "He continues moving around and looking to the camera .".split()} 27 | 28 | assert len(instances) == 3 29 | 30 | for instance, expected_instance in zip(instances, [instance1, instance2, instance3]): 31 | fields = instance.fields 32 | assert [t.text for t in fields["first_sentence"].tokens] == expected_instance["first_sentence"] 33 | assert [t.text for t in fields["second_sentence"].tokens] == expected_instance["second_sentence"] 34 | assert fields["video_id"].metadata == expected_instance["video_id"] 35 | -------------------------------------------------------------------------------- /adversarialnlp/tests/fixtures/activitynet_captions.json: -------------------------------------------------------------------------------- 1 | {"v_uqiMw7tQ1Cc": {"duration": 55.15, "timestamps": [[0.28, 55.15], [13.79, 54.32]], "sentences": ["A weight lifting tutorial is given.", " The coach helps the guy in red with the proper body placement and lifting technique."]}, "v_bXdq2zI1Ms0": {"duration": 73.1, "timestamps": [[0, 10.23], [10.6, 39.84], [38.01, 73.1]], "sentences": ["A man is seen speaking to the camera and pans out into more men standing behind him.", " The first man then begins performing martial arts moves while speaking to he camera.", " He continues moving around and looking to the camera."]}} -------------------------------------------------------------------------------- /adversarialnlp/tests/fixtures/squad.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [{ 3 | "title": "Super_Bowl_50", 4 | "paragraphs": [{ 5 | "context": "Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50.", 6 | "qas": [{ 7 | "answers": [{ 8 | "answer_start": 177, 9 | "text": "Denver Broncos" 10 | }, { 11 | "answer_start": 177, 12 | "text": "Denver Broncos" 13 | }, { 14 | "answer_start": 177, 15 | "text": "Denver Broncos" 16 | }], 17 | "question": "Which NFL team represented the AFC at Super Bowl 50?", 18 | "id": "56be4db0acb8001400a502ec" 19 | }, { 20 | "answers": [{ 21 | "answer_start": 177, 22 | "text": "Denver Broncos" 23 | }, { 24 | "answer_start": 177, 25 | "text": "Denver Broncos" 26 | }, { 27 | "answer_start": 177, 28 | "text": "Denver Broncos" 29 | }], 30 | "question": "Which NFL team won Super Bowl 50?", 31 | "id": "56be4db0acb8001400a502ef" 32 | }] 33 | }, { 34 | "context": "The Panthers finished the regular season with a 15\u20131 record, and quarterback Cam Newton was named the NFL Most Valuable Player (MVP). They defeated the Arizona Cardinals 49\u201315 in the NFC Championship Game and advanced to their second Super Bowl appearance since the franchise was founded in 1995. The Broncos finished the regular season with a 12\u20134 record, and denied the New England Patriots a chance to defend their title from Super Bowl XLIX by defeating them 20\u201318 in the AFC Championship Game. They joined the Patriots, Dallas Cowboys, and Pittsburgh Steelers as one of four teams that have made eight appearances in the Super Bowl.", 35 | "qas": [{ 36 | "answers": [{ 37 | "answer_start": 77, 38 | "text": "Cam Newton" 39 | }, { 40 | "answer_start": 77, 41 | "text": "Cam Newton" 42 | }, { 43 | "answer_start": 77, 44 | "text": "Cam Newton" 45 | }], 46 | "question": "Which Carolina Panthers player was named Most Valuable Player?", 47 | "id": "56be4e1facb8001400a502f6" 48 | }, { 49 | "answers": [{ 50 | "answer_start": 467, 51 | "text": "8" 52 | }, { 53 | "answer_start": 601, 54 | "text": "eight" 55 | }, { 56 | "answer_start": 601, 57 | "text": "eight" 58 | }], 59 | "question": "How many appearances have the Denver Broncos made in the Super Bowl?", 60 | "id": "56be4e1facb8001400a502f9" 61 | }] 62 | }] 63 | },{ 64 | "title": "Warsaw", 65 | "paragraphs": [{ 66 | "context": "One of the most famous people born in Warsaw was Maria Sk\u0142odowska-Curie, who achieved international recognition for her research on radioactivity and was the first female recipient of the Nobel Prize. Famous musicians include W\u0142adys\u0142aw Szpilman and Fr\u00e9d\u00e9ric Chopin. Though Chopin was born in the village of \u017belazowa Wola, about 60 km (37 mi) from Warsaw, he moved to the city with his family when he was seven months old. Casimir Pulaski, a Polish general and hero of the American Revolutionary War, was born here in 1745.", 67 | "qas": [{ 68 | "answers": [{ 69 | "answer_start": 188, 70 | "text": "Nobel Prize" 71 | }, { 72 | "answer_start": 188, 73 | "text": "Nobel Prize" 74 | }, { 75 | "answer_start": 188, 76 | "text": "Nobel Prize" 77 | }], 78 | "question": "What was Maria Curie the first female recipient of?", 79 | "id": "5733a5f54776f41900660f45" 80 | }, { 81 | "answers": [{ 82 | "answer_start": 517, 83 | "text": "1745" 84 | }, { 85 | "answer_start": 517, 86 | "text": "1745" 87 | }, { 88 | "answer_start": 517, 89 | "text": "1745" 90 | }], 91 | "question": "What year was Casimir Pulaski born in Warsaw?", 92 | "id": "5733a5f54776f41900660f48" 93 | }] 94 | }, { 95 | "context": "The Saxon Garden, covering the area of 15.5 ha, was formally a royal garden. There are over 100 different species of trees and the avenues are a place to sit and relax. At the east end of the park, the Tomb of the Unknown Soldier is situated. In the 19th century the Krasi\u0144ski Palace Garden was remodelled by Franciszek Szanior. Within the central area of the park one can still find old trees dating from that period: maidenhair tree, black walnut, Turkish hazel and Caucasian wingnut trees. With its benches, flower carpets, a pond with ducks on and a playground for kids, the Krasi\u0144ski Palace Garden is a popular strolling destination for the Varsovians. The Monument of the Warsaw Ghetto Uprising is also situated here. The \u0141azienki Park covers the area of 76 ha. The unique character and history of the park is reflected in its landscape architecture (pavilions, sculptures, bridges, cascades, ponds) and vegetation (domestic and foreign species of trees and bushes). What makes this park different from other green spaces in Warsaw is the presence of peacocks and pheasants, which can be seen here walking around freely, and royal carps in the pond. The Wilan\u00f3w Palace Park, dates back to the second half of the 17th century. It covers the area of 43 ha. Its central French-styled area corresponds to the ancient, baroque forms of the palace. The eastern section of the park, closest to the Palace, is the two-level garden with a terrace facing the pond. The park around the Kr\u00f3likarnia Palace is situated on the old escarpment of the Vistula. The park has lanes running on a few levels deep into the ravines on both sides of the palace.", 96 | "qas": [{ 97 | "answers": [{ 98 | "answer_start": 92, 99 | "text": "100" 100 | }, { 101 | "answer_start": 87, 102 | "text": "over 100" 103 | }, { 104 | "answer_start": 92, 105 | "text": "100" 106 | }], 107 | "question": "Over how many species of trees can be found in the Saxon Garden?", 108 | "id": "57336755d058e614000b5a3d" 109 | }] 110 | }] 111 | }], 112 | "version": "1.1" 113 | } -------------------------------------------------------------------------------- /adversarialnlp/tests/generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/adversarialnlp/543c02111c57bf245f2aa145c0e5a4879d151001/adversarialnlp/tests/generators/__init__.py -------------------------------------------------------------------------------- /adversarialnlp/tests/generators/addsent_generator_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-self-use,invalid-name 2 | from typing import List 3 | import pytest 4 | 5 | from adversarialnlp.generators.addsent.addsent_generator import AddSentGenerator 6 | from adversarialnlp.generators.addsent.squad_reader import squad_reader 7 | from adversarialnlp.common.file_utils import FIXTURES_ROOT 8 | 9 | 10 | # class GeneratorTest(AllenNlpTestCase): 11 | # def setUp(self): 12 | # super(GeneratorTest, self).setUp() 13 | # self.token_indexers = {"tokens": SingleIdTokenIndexer()} 14 | # self.vocab = Vocabulary() 15 | # self.this_index = self.vocab.add_token_to_namespace('this') 16 | # self.is_index = self.vocab.add_token_to_namespace('is') 17 | # self.a_index = self.vocab.add_token_to_namespace('a') 18 | # self.sentence_index = self.vocab.add_token_to_namespace('sentence') 19 | # self.another_index = self.vocab.add_token_to_namespace('another') 20 | # self.yet_index = self.vocab.add_token_to_namespace('yet') 21 | # self.very_index = self.vocab.add_token_to_namespace('very') 22 | # self.long_index = self.vocab.add_token_to_namespace('long') 23 | # instances = [ 24 | # self.create_instance(["this", "is", "a", "sentence"], ["this", "is", "another", "sentence"]), 25 | # self.create_instance(["yet", "another", "sentence"], 26 | # ["this", "is", "a", "very", "very", "very", "very", "long", "sentence"]), 27 | # ] 28 | 29 | # class LazyIterable: 30 | # def __iter__(self): 31 | # return (instance for instance in instances) 32 | 33 | # self.instances = instances 34 | # self.lazy_instances = LazyIterable() 35 | 36 | # def create_instance(self, first_sentence: List[str], second_sentence: List[str]): 37 | # first_tokens = [Token(t) for t in first_sentence] 38 | # second_tokens = [Token(t) for t in second_sentence] 39 | # instance = Instance({'first_sentence': TextField(first_tokens, self.token_indexers), 40 | # 'second_sentence': TextField(second_tokens, self.token_indexers)}) 41 | # return instance 42 | 43 | # def assert_instances_are_correct(self, candidate_instances): 44 | # # First we need to remove padding tokens from the candidates. 45 | # # pylint: disable=protected-access 46 | # candidate_instances = [tuple(w for w in instance if w != 0) for instance in candidate_instances] 47 | # expected_instances = [tuple(instance.fields["first_sentence"]._indexed_tokens["tokens"]) 48 | # for instance in self.instances] 49 | # assert set(candidate_instances) == set(expected_instances) 50 | 51 | 52 | class TestSwagGenerator(): 53 | # The Generator should work the same for lazy and non lazy datasets, 54 | # so each remaining test runs over both. 55 | def test_yield_one_epoch_generation_over_the_data_once(self): 56 | generator = AddSentGenerator() 57 | test_instances = squad_reader(FIXTURES_ROOT / 'squad.json') 58 | batches = list(generator(test_instances, num_epochs=1)) 59 | # We just want to get the single-token array for the text field in the instance. 60 | # instances = [tuple(instance.detach().cpu().numpy()) 61 | # for batch in batches 62 | # for instance in batch['text']["tokens"]] 63 | assert len(batches) == 5 64 | # self.assert_instances_are_correct(instances) 65 | -------------------------------------------------------------------------------- /adversarialnlp/tests/generators/swag_generator_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-self-use,invalid-name 2 | from typing import List 3 | import pytest 4 | 5 | from allennlp.data.fields import TextField 6 | from allennlp.common.util import ensure_list 7 | from allennlp.common.testing import AllenNlpTestCase 8 | from allennlp.data import Instance, Token, Vocabulary 9 | from allennlp.data.iterators import BasicIterator 10 | from allennlp.data.token_indexers import SingleIdTokenIndexer 11 | from allennlp.data.dataset_readers.dataset_reader import _LazyInstances 12 | 13 | from adversarialnlp.generators.swag.swag_generator import SwagGenerator 14 | from adversarialnlp.generators.swag.activitynet_captions import ActivityNetCaptionsDatasetReader 15 | from adversarialnlp.tests.utils import FIXTURES_ROOT 16 | 17 | 18 | class GeneratorTest(AllenNlpTestCase): 19 | def setUp(self): 20 | super(GeneratorTest, self).setUp() 21 | self.token_indexers = {"tokens": SingleIdTokenIndexer()} 22 | self.vocab = Vocabulary() 23 | self.this_index = self.vocab.add_token_to_namespace('this') 24 | self.is_index = self.vocab.add_token_to_namespace('is') 25 | self.a_index = self.vocab.add_token_to_namespace('a') 26 | self.sentence_index = self.vocab.add_token_to_namespace('sentence') 27 | self.another_index = self.vocab.add_token_to_namespace('another') 28 | self.yet_index = self.vocab.add_token_to_namespace('yet') 29 | self.very_index = self.vocab.add_token_to_namespace('very') 30 | self.long_index = self.vocab.add_token_to_namespace('long') 31 | instances = [ 32 | self.create_instance(["this", "is", "a", "sentence"], ["this", "is", "another", "sentence"]), 33 | self.create_instance(["yet", "another", "sentence"], 34 | ["this", "is", "a", "very", "very", "very", "very", "long", "sentence"]), 35 | ] 36 | 37 | class LazyIterable: 38 | def __iter__(self): 39 | return (instance for instance in instances) 40 | 41 | self.instances = instances 42 | self.lazy_instances = LazyIterable() 43 | 44 | def create_instance(self, first_sentence: List[str], second_sentence: List[str]): 45 | first_tokens = [Token(t) for t in first_sentence] 46 | second_tokens = [Token(t) for t in second_sentence] 47 | instance = Instance({'first_sentence': TextField(first_tokens, self.token_indexers), 48 | 'second_sentence': TextField(second_tokens, self.token_indexers)}) 49 | return instance 50 | 51 | def assert_instances_are_correct(self, candidate_instances): 52 | # First we need to remove padding tokens from the candidates. 53 | # pylint: disable=protected-access 54 | candidate_instances = [tuple(w for w in instance if w != 0) for instance in candidate_instances] 55 | expected_instances = [tuple(instance.fields["first_sentence"]._indexed_tokens["tokens"]) 56 | for instance in self.instances] 57 | assert set(candidate_instances) == set(expected_instances) 58 | 59 | 60 | class TestSwagGenerator(GeneratorTest): 61 | # The Generator should work the same for lazy and non lazy datasets, 62 | # so each remaining test runs over both. 63 | def test_yield_one_epoch_generation_over_the_data_once(self): 64 | for test_instances in (self.instances, self.lazy_instances): 65 | generator = SwagGenerator(num_examples=1) 66 | test_instances = ActivityNetCaptionsDatasetReader().read(FIXTURES_ROOT / 'activitynet_captions.json') 67 | batches = list(generator(test_instances)) 68 | # We just want to get the single-token array for the text field in the instance. 69 | instances = [tuple(instance.detach().cpu().numpy()) 70 | for batch in batches 71 | for instance in batch['text']["tokens"]] 72 | assert len(instances) == 5 73 | self.assert_instances_are_correct(instances) 74 | -------------------------------------------------------------------------------- /adversarialnlp/version.py: -------------------------------------------------------------------------------- 1 | _MAJOR = "0" 2 | _MINOR = "1" 3 | _REVISION = "1-unreleased" 4 | 5 | VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) 6 | VERSION = "{0}.{1}.{2}".format(_MAJOR, _MINOR, _REVISION) 7 | -------------------------------------------------------------------------------- /bin/adversarialnlp: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python -m adversarialnlp.run "$@" -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = AdversarialNLP 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/common.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | Common 5 | ====== 6 | 7 | .. automodule:: adversarialnlp.common 8 | .. currentmodule:: adversarialnlp.common 9 | 10 | 11 | Files 12 | ----- 13 | 14 | .. autofunction:: adversarialnlp.common.file_utils.download_files 15 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # AdversarialNLP documentation build configuration file, created by 5 | # sphinx-quickstart on Wed Oct 24 11:35:14 2018. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | # import os 21 | # import sys 22 | # sys.path.insert(0, os.path.abspath('.')) 23 | 24 | 25 | # -- General configuration ------------------------------------------------ 26 | 27 | # If your documentation needs a minimal Sphinx version, state it here. 28 | # 29 | # needs_sphinx = '1.0' 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = ['sphinx.ext.autodoc', 35 | 'sphinx.ext.intersphinx', 36 | 'sphinx.ext.mathjax', 37 | 'sphinx.ext.ifconfig', 38 | 'sphinx.ext.viewcode', 39 | 'sphinx.ext.napoleon'] 40 | 41 | # Add any paths that contain templates here, relative to this directory. 42 | templates_path = ['_templates'] 43 | 44 | # The suffix(es) of source filenames. 45 | # You can specify multiple suffix as a list of string: 46 | # 47 | # source_suffix = ['.rst', '.md'] 48 | source_suffix = '.rst' 49 | 50 | # The master toctree document. 51 | master_doc = 'index' 52 | 53 | # General information about the project. 54 | project = 'AdversarialNLP' 55 | copyright = '2018, Mohit Iyyer, Pasquale Minervini, Victor Sanh, Thomas Wolf, Rowan Zellers' 56 | author = 'Mohit Iyyer, Pasquale Minervini, Victor Sanh, Thomas Wolf, Rowan Zellers' 57 | 58 | # The version info for the project you're documenting, acts as replacement for 59 | # |version| and |release|, also used in various other places throughout the 60 | # built documents. 61 | # 62 | # The short X.Y version. 63 | version = '0.1' 64 | # The full version, including alpha/beta/rc tags. 65 | release = '0.1' 66 | 67 | # The language for content autogenerated by Sphinx. Refer to documentation 68 | # for a list of supported languages. 69 | # 70 | # This is also used if you do content translation via gettext catalogs. 71 | # Usually you set "language" from the command line for these cases. 72 | language = 'en' 73 | 74 | # List of patterns, relative to source directory, that match files and 75 | # directories to ignore when looking for source files. 76 | # This patterns also effect to html_static_path and html_extra_path 77 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 78 | 79 | # The name of the Pygments (syntax highlighting) style to use. 80 | pygments_style = 'sphinx' 81 | 82 | # If true, `todo` and `todoList` produce output, else they produce nothing. 83 | todo_include_todos = False 84 | 85 | 86 | # -- Options for HTML output ---------------------------------------------- 87 | 88 | # The theme to use for HTML and HTML Help pages. See the documentation for 89 | # a list of builtin themes. 90 | # 91 | html_theme = 'alabaster' 92 | 93 | # Theme options are theme-specific and customize the look and feel of a theme 94 | # further. For a list of options available for each theme, see the 95 | # documentation. 96 | # 97 | # html_theme_options = {} 98 | 99 | # Add any paths that contain custom static files (such as style sheets) here, 100 | # relative to this directory. They are copied after the builtin static files, 101 | # so a file named "default.css" will overwrite the builtin "default.css". 102 | html_static_path = ['_static'] 103 | 104 | 105 | # -- Options for HTMLHelp output ------------------------------------------ 106 | 107 | # Output file base name for HTML help builder. 108 | htmlhelp_basename = 'AdversarialNLPdoc' 109 | 110 | 111 | # -- Options for LaTeX output --------------------------------------------- 112 | 113 | latex_elements = { 114 | # The paper size ('letterpaper' or 'a4paper'). 115 | # 116 | # 'papersize': 'letterpaper', 117 | 118 | # The font size ('10pt', '11pt' or '12pt'). 119 | # 120 | # 'pointsize': '10pt', 121 | 122 | # Additional stuff for the LaTeX preamble. 123 | # 124 | # 'preamble': '', 125 | 126 | # Latex figure (float) alignment 127 | # 128 | # 'figure_align': 'htbp', 129 | } 130 | 131 | # Grouping the document tree into LaTeX files. List of tuples 132 | # (source start file, target name, title, 133 | # author, documentclass [howto, manual, or own class]). 134 | latex_documents = [ 135 | (master_doc, 'AdversarialNLP.tex', 'AdversarialNLP Documentation', 136 | 'Mohit Iyyer, Pasquale Minervini, Victor Sanh, Thomas Wolf, Rowan Zellers', 'manual'), 137 | ] 138 | 139 | 140 | # -- Options for manual page output --------------------------------------- 141 | 142 | # One entry per manual page. List of tuples 143 | # (source start file, name, description, authors, manual section). 144 | man_pages = [ 145 | (master_doc, 'adversarialnlp', 'AdversarialNLP Documentation', 146 | [author], 1) 147 | ] 148 | 149 | 150 | # -- Options for Texinfo output ------------------------------------------- 151 | 152 | # Grouping the document tree into Texinfo files. List of tuples 153 | # (source start file, target name, title, author, 154 | # dir menu entry, description, category) 155 | texinfo_documents = [ 156 | (master_doc, 'AdversarialNLP', 'AdversarialNLP Documentation', 157 | author, 'AdversarialNLP', 'One line description of project.', 158 | 'Miscellaneous'), 159 | ] 160 | 161 | 162 | 163 | 164 | # Example configuration for intersphinx: refer to the Python standard library. 165 | intersphinx_mapping = {'https://docs.python.org/': None} 166 | -------------------------------------------------------------------------------- /docs/generators.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | Generators 5 | ========== 6 | 7 | .. automodule:: adversarialnlp.generators 8 | .. currentmodule:: adversarialnlp.generators 9 | 10 | Generator 11 | ---------- 12 | 13 | .. autoclass:: adversarialnlp.generators.Generator 14 | 15 | AddSent 16 | ---------- 17 | 18 | .. autoclass:: adversarialnlp.generators.addsent.AddSentGenerator 19 | .. autofunction:: adversarialnlp.generators.addsent.squad_reader 20 | 21 | SWAG 22 | ---- 23 | 24 | .. autoclass:: adversarialnlp.generators.swag.SwagGenerator 25 | .. autoclass:: adversarialnlp.generators.swag.ActivityNetCaptionsDatasetReader 26 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. AdversarialNLP documentation master file, created by 2 | sphinx-quickstart on Wed Oct 24 11:35:14 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/pytorch/pytorch 7 | 8 | AdversarialNLP documentation 9 | ============================ 10 | 11 | AdversarialNLP is a generic library for crafting and using Adversarial NLP examples. 12 | 13 | .. toctree:: 14 | :maxdepth: 1 15 | :caption: Contents 16 | 17 | readme 18 | 19 | common 20 | generators 21 | 22 | 23 | Indices and tables 24 | ================== 25 | 26 | * :ref:`genindex` 27 | * :ref:`modindex` 28 | * :ref:`search` 29 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=AdversarialNLP 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/readme.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.md -------------------------------------------------------------------------------- /docs/readthedoc_requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | typing 3 | pytest 4 | PyYAML==3.13 -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests/ 3 | python_paths = ./ 4 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | 3 | build: 4 | image: latest 5 | 6 | python: 7 | version: 3.6 8 | setup_py_install: true 9 | 10 | requirements_file: docs/readthedoc_requirements.txt 11 | 12 | formats: [] 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Library dependencies for the python code. You need to install these with 2 | # `pip install -r requirements.txt` before you can run this. 3 | # NOTE: all essential packages must be placed under a section named 'ESSENTIAL ...' 4 | # so that the script `./scripts/check_requirements_and_setup.py` can find them. 5 | 6 | #### ESSENTIAL LIBRARIES FOR MAIN FUNCTIONALITY #### 7 | 8 | # This installs Pytorch for CUDA 8 only. If you are using a newer version, 9 | # please visit http://pytorch.org/ and install the relevant version. 10 | torch>=0.4.1,<0.5.0 11 | 12 | # Parameter parsing (but not on Windows). 13 | jsonnet==0.10.0 ; sys.platform != 'win32' 14 | 15 | # Adds an @overrides decorator for better documentation and error checking when using subclasses. 16 | overrides 17 | 18 | # Used by some old code. We moved away from it because it's too slow, but some old code still 19 | # imports this. 20 | nltk 21 | 22 | # Mainly used for the faster tokenizer. 23 | spacy>=2.0,<2.1 24 | 25 | # Used by span prediction models. 26 | numpy 27 | 28 | # Used for reading configuration info out of numpy-style docstrings. 29 | numpydoc==0.8.0 30 | 31 | # Used in coreference resolution evaluation metrics. 32 | scipy 33 | scikit-learn 34 | 35 | # Write logs for training visualisation with the Tensorboard application 36 | # Install the Tensorboard application separately (part of tensorflow) to view them. 37 | tensorboardX==1.2 38 | 39 | # Required by torch.utils.ffi 40 | cffi==1.11.2 41 | 42 | # aws commandline tools for running on Docker remotely. 43 | # second requirement is to get botocore < 1.11, to avoid the below bug 44 | awscli>=1.11.91 45 | 46 | # Accessing files from S3 directly. 47 | boto3 48 | 49 | # REST interface for models 50 | flask==0.12.4 51 | flask-cors==3.0.3 52 | gevent==1.3.6 53 | 54 | # Used by semantic parsing code to strip diacritics from unicode strings. 55 | unidecode 56 | 57 | # Used by semantic parsing code to parse SQL 58 | parsimonious==0.8.0 59 | 60 | # Used by semantic parsing code to format and postprocess SQL 61 | sqlparse==0.2.4 62 | 63 | # For text normalization 64 | ftfy 65 | 66 | #### ESSENTIAL LIBRARIES USED IN SCRIPTS #### 67 | 68 | # Plot graphs for learning rate finder 69 | matplotlib==2.2.3 70 | 71 | # Used for downloading datasets over HTTP 72 | requests>=2.18 73 | 74 | # progress bars in data cleaning scripts 75 | tqdm>=4.19 76 | 77 | # In SQuAD eval script, we use this to see if we likely have some tokenization problem. 78 | editdistance 79 | 80 | # For pretrained model weights 81 | h5py 82 | 83 | # For timezone utilities 84 | pytz==2017.3 85 | 86 | # Reads Universal Dependencies files. 87 | conllu==0.11 88 | 89 | #### ESSENTIAL TESTING-RELATED PACKAGES #### 90 | 91 | # We'll use pytest to run our tests; this isn't really necessary to run the code, but it is to run 92 | # the tests. With this here, you can run the tests with `py.test` from the base directory. 93 | pytest 94 | 95 | # Allows marking tests as flaky, to be rerun if they fail 96 | flaky 97 | 98 | # Required to mock out `requests` calls 99 | responses>=0.7 100 | 101 | # For mocking s3. 102 | moto==1.3.4 103 | 104 | #### TESTING-RELATED PACKAGES #### 105 | 106 | # Checks style, syntax, and other useful errors. 107 | pylint==1.8.1 108 | 109 | # Tutorial notebooks 110 | # see: https://github.com/jupyter/jupyter/issues/370 for ipykernel 111 | ipykernel<5.0.0 112 | jupyter 113 | 114 | # Static type checking 115 | mypy==0.521 116 | 117 | # Allows generation of coverage reports with pytest. 118 | pytest-cov 119 | 120 | # Allows codecov to generate coverage reports 121 | coverage 122 | codecov 123 | 124 | # Required to run sanic tests 125 | aiohttp 126 | 127 | #### DOC-RELATED PACKAGES #### 128 | 129 | # Builds our documentation. 130 | sphinx==1.5.3 131 | 132 | # Watches the documentation directory and rebuilds on changes. 133 | sphinx-autobuild 134 | 135 | # doc theme 136 | sphinx_rtd_theme 137 | 138 | # Only used to convert our readme to reStructuredText on Pypi. 139 | pypandoc 140 | 141 | # Pypi uploads 142 | twine==1.11.0 143 | 144 | #### GENERATOR-RELATED PACKAGES #### 145 | 146 | # Used by AddSent. 147 | psutil 148 | pattern 149 | 150 | # Used by SWAG. 151 | allennlp 152 | num2words 153 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from setuptools import setup, find_packages 3 | 4 | # PEP0440 compatible formatted version, see: 5 | # https://www.python.org/dev/peps/pep-0440/ 6 | # 7 | # release markers: 8 | # X.Y 9 | # X.Y.Z # For bugfix releases 10 | # 11 | # pre-release markers: 12 | # X.YaN # Alpha release 13 | # X.YbN # Beta release 14 | # X.YrcN # Release Candidate 15 | # X.Y # Final release 16 | 17 | # version.py defines the VERSION and VERSION_SHORT variables. 18 | # We use exec here so we don't import allennlp whilst setting up. 19 | VERSION = {} 20 | with open("adversarialnlp/version.py", "r") as version_file: 21 | exec(version_file.read(), VERSION) 22 | 23 | # make pytest-runner a conditional requirement, 24 | # per: https://github.com/pytest-dev/pytest-runner#considerations 25 | needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv) 26 | pytest_runner = ['pytest-runner'] if needs_pytest else [] 27 | 28 | with open('requirements.txt', 'r') as f: 29 | install_requires = [l for l in f.readlines() if not l.startswith('# ')] 30 | 31 | setup_requirements = [ 32 | # add other setup requirements as necessary 33 | ] + pytest_runner 34 | 35 | setup(name='adversarialnlp', 36 | version=VERSION["VERSION"], 37 | description='A generice library for crafting adversarial NLP examples, built on AllenNLP and PyTorch.', 38 | long_description=open("README.md").read(), 39 | long_description_content_type="text/markdown", 40 | classifiers=[ 41 | 'Intended Audience :: Science/Research', 42 | 'Development Status :: 3 - Alpha', 43 | 'License :: OSI Approved :: Apache Software License', 44 | 'Programming Language :: Python :: 3.6', 45 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 46 | ], 47 | keywords='adversarialnlp allennlp NLP deep learning machine reading', 48 | url='https://github.com/huggingface/adversarialnlp', 49 | author='Thomas WOLF', 50 | author_email='thomas@huggingface.co', 51 | license='Apache', 52 | packages=find_packages(exclude=["*.tests", "*.tests.*", 53 | "tests.*", "tests"]), 54 | install_requires=install_requires, 55 | scripts=["bin/adversarialnlp"], 56 | setup_requires=setup_requirements, 57 | tests_require=[ 58 | 'pytest', 59 | ], 60 | include_package_data=True, 61 | python_requires='>=3.6.1', 62 | zip_safe=False) 63 | -------------------------------------------------------------------------------- /tutorials/usage.py: -------------------------------------------------------------------------------- 1 | from adversarialnlp import Adversarial 2 | from allennlp.data.dataset_readers.reading_comprehension.squad import SquadReader 3 | 4 | adversarial = Adversarial(dataset_reader=SquadReader, editor='lstm_lm', num_samples=10) 5 | examples = adversarial.generate() --------------------------------------------------------------------------------