├── .gitignore
├── .pylintrc
├── README.md
├── adversarialnlp
├── __init__.py
├── commands
│ ├── __init__.py
│ └── test_install.py
├── common
│ ├── __init__.py
│ └── file_utils.py
├── generators
│ ├── __init__.py
│ ├── addsent
│ │ ├── __init__.py
│ │ ├── addsent_generator.py
│ │ ├── corenlp.py
│ │ ├── rules
│ │ │ ├── __init__.py
│ │ │ ├── alteration_rules.py
│ │ │ ├── answer_rules.py
│ │ │ └── conversion_rules.py
│ │ ├── squad_reader.py
│ │ └── utils.py
│ ├── generator.py
│ └── swag
│ │ ├── __init__.py
│ │ ├── activitynet_captions_reader.py
│ │ ├── openai_transformer_model.py
│ │ ├── simple_bilm.py
│ │ ├── swag_generator.py
│ │ └── utils.py
├── pruners
│ ├── __init__.py
│ └── pruner.py
├── run.py
├── tests
│ ├── __init__.py
│ ├── dataset_readers
│ │ ├── __init__.py
│ │ └── activitynet_captions_test.py
│ ├── fixtures
│ │ ├── activitynet_captions.json
│ │ └── squad.json
│ └── generators
│ │ ├── __init__.py
│ │ ├── addsent_generator_test.py
│ │ └── swag_generator_test.py
└── version.py
├── bin
└── adversarialnlp
├── docs
├── Makefile
├── common.rst
├── conf.py
├── generators.rst
├── index.rst
├── make.bat
├── readme.rst
└── readthedoc_requirements.txt
├── pytest.ini
├── readthedocs.yml
├── requirements.txt
├── setup.cfg
├── setup.py
└── tutorials
└── usage.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # build artifacts
2 |
3 | .eggs/
4 | .mypy_cache
5 | adversarialnlp.egg-info/
6 | build/
7 | dist/
8 | data/*
9 | lib/*
10 |
11 | # dev tools
12 |
13 | .envrc
14 | .python-version
15 |
16 |
17 | # jupyter notebooks
18 |
19 | .ipynb_checkpoints
20 |
21 |
22 | # miscellaneous
23 |
24 | .cache/
25 | .vscode/
26 | docs/_build/
27 |
28 |
29 | # python
30 |
31 | *.pyc
32 | *.pyo
33 | __pycache__
34 |
35 |
36 | # testing and continuous integration
37 |
38 | .coverage
39 | .pytest_cache/
40 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 |
3 | # Specify a configuration file.
4 | #rcfile=
5 |
6 | # Python code to execute, usually for sys.path manipulation such as
7 | # pygtk.require().
8 | init-hook='import sys; sys.path.append("./")'
9 |
10 | # Add files or directories to the blacklist. They should be base names, not
11 | # paths.
12 | ignore=CVS,custom_extensions
13 |
14 | # Add files or directories matching the regex patterns to the blacklist. The
15 | # regex matches against base names, not paths.
16 | ignore-patterns=
17 |
18 | # Pickle collected data for later comparisons.
19 | persistent=yes
20 |
21 | # List of plugins (as comma separated values of python modules names) to load,
22 | # usually to register additional checkers.
23 | load-plugins=
24 |
25 | # Use multiple processes to speed up Pylint.
26 | jobs=4
27 |
28 | # Allow loading of arbitrary C extensions. Extensions are imported into the
29 | # active Python interpreter and may run arbitrary code.
30 | unsafe-load-any-extension=no
31 |
32 | # A comma-separated list of package or module names from where C extensions may
33 | # be loaded. Extensions are loading into the active Python interpreter and may
34 | # run arbitrary code
35 | extension-pkg-whitelist=numpy,torch,spacy,_jsonnet
36 |
37 | # Allow optimization of some AST trees. This will activate a peephole AST
38 | # optimizer, which will apply various small optimizations. For instance, it can
39 | # be used to obtain the result of joining multiple strings with the addition
40 | # operator. Joining a lot of strings can lead to a maximum recursion error in
41 | # Pylint and this flag can prevent that. It has one side effect, the resulting
42 | # AST will be different than the one from reality. This option is deprecated
43 | # and it will be removed in Pylint 2.0.
44 | optimize-ast=no
45 |
46 |
47 | [MESSAGES CONTROL]
48 |
49 | # Only show warnings with the listed confidence levels. Leave empty to show
50 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
51 | confidence=
52 |
53 | # Enable the message, report, category or checker with the given id(s). You can
54 | # either give multiple identifier separated by comma (,) or put this option
55 | # multiple time (only on the command line, not in the configuration file where
56 | # it should appear only once). See also the "--disable" option for examples.
57 | #enable=
58 |
59 | # Disable the message, report, category or checker with the given id(s). You
60 | # can either give multiple identifiers separated by comma (,) or put this
61 | # option multiple times (only on the command line, not in the configuration
62 | # file where it should appear only once).You can also use "--disable=all" to
63 | # disable everything first and then reenable specific checks. For example, if
64 | # you want to run only the similarities checker, you can use "--disable=all
65 | # --enable=similarities". If you want to run only the classes checker, but have
66 | # no Warning level messages displayed, use"--disable=all --enable=classes
67 | # --disable=W"
68 | disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating,missing-docstring,too-many-arguments,too-many-locals,too-many-statements,too-many-branches,too-many-nested-blocks,too-many-instance-attributes,fixme,too-few-public-methods,no-else-return
69 |
70 |
71 | [REPORTS]
72 |
73 | # Set the output format. Available formats are text, parseable, colorized, msvs
74 | # (visual studio) and html. You can also give a reporter class, eg
75 | # mypackage.mymodule.MyReporterClass.
76 | output-format=text
77 |
78 | # Put messages in a separate file for each module / package specified on the
79 | # command line instead of printing them on stdout. Reports (if any) will be
80 | # written in a file name "pylint_global.[txt|html]". This option is deprecated
81 | # and it will be removed in Pylint 2.0.
82 | files-output=no
83 |
84 | # Tells whether to display a full report or only the messages
85 | reports=yes
86 |
87 | # Python expression which should return a note less than 10 (10 is the highest
88 | # note). You have access to the variables errors warning, statement which
89 | # respectively contain the number of errors / warnings messages and the total
90 | # number of statements analyzed. This is used by the global evaluation report
91 | # (RP0004).
92 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
93 |
94 | # Template used to display messages. This is a python new-style format string
95 | # used to format the message information. See doc for all details
96 | #msg-template=
97 |
98 |
99 | [LOGGING]
100 |
101 | # Logging modules to check that the string format arguments are in logging
102 | # function parameter format
103 | logging-modules=logging
104 |
105 |
106 | [TYPECHECK]
107 |
108 | # Tells whether missing members accessed in mixin class should be ignored. A
109 | # mixin class is detected if its name ends with "mixin" (case insensitive).
110 | ignore-mixin-members=yes
111 |
112 | # List of module names for which member attributes should not be checked
113 | # (useful for modules/projects where namespaces are manipulated during runtime
114 | # and thus existing member attributes cannot be deduced by static analysis. It
115 | # supports qualified module names, as well as Unix pattern matching.
116 | ignored-modules=
117 |
118 | # List of class names for which member attributes should not be checked (useful
119 | # for classes with dynamically set attributes). This supports the use of
120 | # qualified names.
121 | ignored-classes=optparse.Values,thread._local,_thread._local,responses
122 |
123 | # List of members which are set dynamically and missed by pylint inference
124 | # system, and so shouldn't trigger E1101 when accessed. Python regular
125 | # expressions are accepted.
126 | generated-members=torch.*
127 |
128 | # List of decorators that produce context managers, such as
129 | # contextlib.contextmanager. Add to this list to register other decorators that
130 | # produce valid context managers.
131 | contextmanager-decorators=contextlib.contextmanager
132 |
133 |
134 | [SIMILARITIES]
135 |
136 | # Minimum lines number of a similarity.
137 | min-similarity-lines=4
138 |
139 | # Ignore comments when computing similarities.
140 | ignore-comments=yes
141 |
142 | # Ignore docstrings when computing similarities.
143 | ignore-docstrings=yes
144 |
145 | # Ignore imports when computing similarities.
146 | ignore-imports=no
147 |
148 |
149 | [FORMAT]
150 |
151 | # Maximum number of characters on a single line. Ideally, lines should be under 100 characters,
152 | # but we allow some leeway before calling it an error.
153 | max-line-length=115
154 |
155 | # Regexp for a line that is allowed to be longer than the limit.
156 | ignore-long-lines=^\s*(# )??$
157 |
158 | # Allow the body of an if to be on the same line as the test if there is no
159 | # else.
160 | single-line-if-stmt=no
161 |
162 | # List of optional constructs for which whitespace checking is disabled. `dict-
163 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
164 | # `trailing-comma` allows a space between comma and closing bracket: (a, ).
165 | # `empty-line` allows space-only lines.
166 | no-space-check=trailing-comma,dict-separator
167 |
168 | # Maximum number of lines in a module
169 | max-module-lines=1000
170 |
171 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
172 | # tab).
173 | indent-string=' '
174 |
175 | # Number of spaces of indent required inside a hanging or continued line.
176 | indent-after-paren=8
177 |
178 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
179 | expected-line-ending-format=
180 |
181 |
182 | [BASIC]
183 |
184 | # Good variable names which should always be accepted, separated by a comma
185 | good-names=i,j,k,ex,Run,_
186 |
187 | # Bad variable names which should always be refused, separated by a comma
188 | bad-names=foo,bar,baz,toto,tutu,tata
189 |
190 | # Colon-delimited sets of names that determine each other's naming style when
191 | # the name regexes allow several styles.
192 | name-group=
193 |
194 | # Include a hint for the correct naming format with invalid-name
195 | include-naming-hint=no
196 |
197 | # List of decorators that produce properties, such as abc.abstractproperty. Add
198 | # to this list to register other decorators that produce valid properties.
199 | property-classes=abc.abstractproperty
200 |
201 | # Regular expression matching correct function names
202 | function-rgx=[a-z_][a-z0-9_]{2,40}$
203 |
204 | # Naming hint for function names
205 | function-name-hint=[a-z_][a-z0-9_]{2,40}$
206 |
207 | # Regular expression matching correct variable names
208 | variable-rgx=[a-z_][a-z0-9_]{2,40}$
209 |
210 | # Naming hint for variable names
211 | variable-name-hint=[a-z_][a-z0-9_]{2,40}$
212 |
213 | # Regular expression matching correct constant names
214 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
215 |
216 | # Naming hint for constant names
217 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
218 |
219 | # Regular expression matching correct attribute names
220 | attr-rgx=[a-z_][a-z0-9_]{2,40}$
221 |
222 | # Naming hint for attribute names
223 | attr-name-hint=[a-z_][a-z0-9_]{2,40}$
224 |
225 | # Regular expression matching correct argument names
226 | argument-rgx=[a-z_][a-z0-9_]{2,40}$
227 |
228 | # Naming hint for argument names
229 | argument-name-hint=[a-z_][a-z0-9_]{2,40}$
230 |
231 | # Regular expression matching correct class attribute names
232 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,40}|(__.*__))$
233 |
234 | # Naming hint for class attribute names
235 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,40}|(__.*__))$
236 |
237 | # Regular expression matching correct inline iteration names
238 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
239 |
240 | # Naming hint for inline iteration names
241 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
242 |
243 | # Regular expression matching correct class names
244 | class-rgx=[A-Z_][a-zA-Z0-9]+$
245 |
246 | # Naming hint for class names
247 | class-name-hint=[A-Z_][a-zA-Z0-9]+$
248 |
249 | # Regular expression matching correct module names
250 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
251 |
252 | # Naming hint for module names
253 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
254 |
255 | # Regular expression matching correct method names
256 | method-rgx=[a-z_][a-z0-9_]{2,40}$
257 |
258 | # Naming hint for method names
259 | method-name-hint=[a-z_][a-z0-9_]{2,40}$
260 |
261 | # Regular expression which should only match function or class names that do
262 | # not require a docstring.
263 | no-docstring-rgx=^_
264 |
265 | # Minimum line length for functions/classes that require docstrings, shorter
266 | # ones are exempt.
267 | docstring-min-length=-1
268 |
269 |
270 | [ELIF]
271 |
272 | # Maximum number of nested blocks for function / method body
273 | max-nested-blocks=5
274 |
275 |
276 | [VARIABLES]
277 |
278 | # Tells whether we should check for unused import in __init__ files.
279 | init-import=no
280 |
281 | # A regular expression matching the name of dummy variables (i.e. expectedly
282 | # not used).
283 | dummy-variables-rgx=(_+[a-zA-Z0-9]*?$)|dummy
284 |
285 | # List of additional names supposed to be defined in builtins. Remember that
286 | # you should avoid to define new builtins when possible.
287 | additional-builtins=
288 |
289 | # List of strings which can identify a callback function by name. A callback
290 | # name must start or end with one of those strings.
291 | callbacks=cb_,_cb
292 |
293 | # List of qualified module names which can have objects that can redefine
294 | # builtins.
295 | redefining-builtins-modules=six.moves,future.builtins
296 |
297 |
298 | [SPELLING]
299 |
300 | # Spelling dictionary name. Available dictionaries: none. To make it working
301 | # install python-enchant package.
302 | spelling-dict=
303 |
304 | # List of comma separated words that should not be checked.
305 | spelling-ignore-words=
306 |
307 | # A path to a file that contains private dictionary; one word per line.
308 | spelling-private-dict-file=
309 |
310 | # Tells whether to store unknown words to indicated private dictionary in
311 | # --spelling-private-dict-file option instead of raising a message.
312 | spelling-store-unknown-words=no
313 |
314 |
315 | [MISCELLANEOUS]
316 |
317 | # List of note tags to take in consideration, separated by a comma.
318 | notes=FIXME,XXX,TODO
319 |
320 |
321 | [DESIGN]
322 |
323 | # Maximum number of arguments for function / method
324 | max-args=5
325 |
326 | # Argument names that match this expression will be ignored. Default to name
327 | # with leading underscore
328 | ignored-argument-names=_.*
329 |
330 | # Maximum number of locals for function / method body
331 | max-locals=15
332 |
333 | # Maximum number of return / yield for function / method body
334 | max-returns=6
335 |
336 | # Maximum number of branch for function / method body
337 | max-branches=12
338 |
339 | # Maximum number of statements in function / method body
340 | max-statements=50
341 |
342 | # Maximum number of parents for a class (see R0901).
343 | max-parents=7
344 |
345 | # Maximum number of attributes for a class (see R0902).
346 | max-attributes=7
347 |
348 | # Minimum number of public methods for a class (see R0903).
349 | min-public-methods=2
350 |
351 | # Maximum number of public methods for a class (see R0904).
352 | max-public-methods=20
353 |
354 | # Maximum number of boolean expressions in a if statement
355 | max-bool-expr=5
356 |
357 |
358 | [CLASSES]
359 |
360 | # List of method names used to declare (i.e. assign) instance attributes.
361 | defining-attr-methods=__init__,__new__,setUp
362 |
363 | # List of valid names for the first argument in a class method.
364 | valid-classmethod-first-arg=cls
365 |
366 | # List of valid names for the first argument in a metaclass class method.
367 | valid-metaclass-classmethod-first-arg=mcs
368 |
369 | # List of member names, which should be excluded from the protected access
370 | # warning.
371 | exclude-protected=_asdict,_fields,_replace,_source,_make
372 |
373 |
374 | [IMPORTS]
375 |
376 | # Deprecated modules which should not be used, separated by a comma
377 | deprecated-modules=regsub,TERMIOS,Bastion,rexec
378 |
379 | # Create a graph of every (i.e. internal and external) dependencies in the
380 | # given file (report RP0402 must not be disabled)
381 | import-graph=
382 |
383 | # Create a graph of external dependencies in the given file (report RP0402 must
384 | # not be disabled)
385 | ext-import-graph=
386 |
387 | # Create a graph of internal dependencies in the given file (report RP0402 must
388 | # not be disabled)
389 | int-import-graph=
390 |
391 | # Force import order to recognize a module as part of the standard
392 | # compatibility libraries.
393 | known-standard-library=
394 |
395 | # Force import order to recognize a module as part of a third party library.
396 | known-third-party=enchant
397 |
398 | # Analyse import fallback blocks. This can be used to support both Python 2 and
399 | # 3 compatible code, which means that the block might have code that exists
400 | # only in one or another interpreter, leading to false positives when analysed.
401 | analyse-fallback-blocks=no
402 |
403 |
404 | [EXCEPTIONS]
405 |
406 | # Exceptions that will emit a warning when being caught. Defaults to
407 | # "Exception"
408 | overgeneral-exceptions=Exception
409 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AdversarialNLP - WIP
2 |
3 | AdversarialNLP is a generic library for crafting and using Adversarial NLP examples.
4 |
5 | Work in Progress
6 |
7 | ## Installation
8 |
9 | AdversarialNLP requires Python 3.6.1 or later. The preferred way to install AdversarialNLP is via `pip`. Just run `pip install adversarialnlp` in your Python environment and you're good to go!
10 |
--------------------------------------------------------------------------------
/adversarialnlp/__init__.py:
--------------------------------------------------------------------------------
1 | from adversarialnlp.version import VERSION as __version__
2 |
--------------------------------------------------------------------------------
/adversarialnlp/commands/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import argparse
3 | import logging
4 |
5 | from allennlp.commands.subcommand import Subcommand
6 | from allennlp.common.util import import_submodules
7 |
8 | from adversarialnlp import __version__
9 | from adversarialnlp.commands.test_install import TestInstall
10 |
11 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
12 |
13 |
14 | def main(prog: str = None,
15 | subcommand_overrides: Dict[str, Subcommand] = {}) -> None:
16 | """
17 | :mod:`~adversarialnlp.run` command.
18 | """
19 | # pylint: disable=dangerous-default-value
20 | parser = argparse.ArgumentParser(description="Run AdversarialNLP", usage='%(prog)s', prog=prog)
21 | parser.add_argument('--version', action='version', version='%(prog)s ' + __version__)
22 |
23 | subparsers = parser.add_subparsers(title='Commands', metavar='')
24 |
25 | subcommands = {
26 | # Default commands
27 | "test-install": TestInstall(),
28 |
29 | # Superseded by overrides
30 | **subcommand_overrides
31 | }
32 |
33 | for name, subcommand in subcommands.items():
34 | subparser = subcommand.add_subparser(name, subparsers)
35 | # configure doesn't need include-package because it imports
36 | # whatever classes it needs.
37 | if name != "configure":
38 | subparser.add_argument('--include-package',
39 | type=str,
40 | action='append',
41 | default=[],
42 | help='additional packages to include')
43 |
44 | args = parser.parse_args()
45 |
46 | # If a subparser is triggered, it adds its work as `args.func`.
47 | # So if no such attribute has been added, no subparser was triggered,
48 | # so give the user some help.
49 | if 'func' in dir(args):
50 | # Import any additional modules needed (to register custom classes).
51 | for package_name in getattr(args, 'include_package', ()):
52 | import_submodules(package_name)
53 | args.func(args)
54 | else:
55 | parser.print_help()
56 |
--------------------------------------------------------------------------------
/adversarialnlp/commands/test_install.py:
--------------------------------------------------------------------------------
1 | """
2 | The ``test-install`` subcommand verifies
3 | an installation by running the unit tests.
4 | .. code-block:: bash
5 | $ adversarialnlp test-install --help
6 | usage: adversarialnlp test-install [-h] [--run-all]
7 | [--include-package INCLUDE_PACKAGE]
8 | Test that installation works by running the unit tests.
9 | optional arguments:
10 | -h, --help show this help message and exit
11 | --run-all By default, we skip tests that are slow or download
12 | large files. This flag will run all tests.
13 | --include-package INCLUDE_PACKAGE
14 | additional packages to include
15 | """
16 |
17 | import argparse
18 | import logging
19 | import os
20 | import pathlib
21 |
22 | import pytest
23 |
24 | from allennlp.commands.subcommand import Subcommand
25 |
26 | import adversarialnlp
27 |
28 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
29 |
30 | class TestInstall(Subcommand):
31 | def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
32 | # pylint: disable=protected-access
33 | description = '''Test that installation works by running the unit tests.'''
34 | subparser = parser.add_parser(
35 | name, description=description, help='Run the unit tests.')
36 |
37 | subparser.add_argument('--run-all', action="store_true",
38 | help="By default, we skip tests that are slow "
39 | "or download large files. This flag will run all tests.")
40 |
41 | subparser.set_defaults(func=_run_test)
42 |
43 | return subparser
44 |
45 |
46 | def _get_module_root():
47 | return pathlib.Path(adversarialnlp.__file__).parent
48 |
49 |
50 | def _run_test(args: argparse.Namespace):
51 | initial_working_dir = os.getcwd()
52 | module_parent = _get_module_root().parent
53 | logger.info("Changing directory to %s", module_parent)
54 | os.chdir(module_parent)
55 | test_dir = os.path.join(module_parent, "adversarialnlp")
56 | logger.info("Running tests at %s", test_dir)
57 | if args.run_all:
58 | # TODO(nfliu): remove this when notebooks have been rewritten as markdown.
59 | exit_code = pytest.main([test_dir, '--color=no', '-k', 'not notebooks_test'])
60 | else:
61 | exit_code = pytest.main([test_dir, '--color=no', '-k', 'not sniff_test and not notebooks_test',
62 | '-m', 'not java'])
63 | # Change back to original working directory after running tests
64 | os.chdir(initial_working_dir)
65 | exit(exit_code)
66 |
--------------------------------------------------------------------------------
/adversarialnlp/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/adversarialnlp/543c02111c57bf245f2aa145c0e5a4879d151001/adversarialnlp/common/__init__.py
--------------------------------------------------------------------------------
/adversarialnlp/common/file_utils.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=invalid-name,protected-access
2 | #!/usr/bin/env python3
3 | # Copyright (c) 2017-present, Facebook, Inc.
4 | # All rights reserved.
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree. An additional grant
7 | # of patent rights can be found in the PATENTS file in the same directory.
8 | """
9 | Utilities for downloading and building data.
10 | These can be replaced if your particular file system does not support them.
11 | """
12 | from typing import Union, List
13 | from pathlib import Path
14 | import time
15 | import datetime
16 | import os
17 | import shutil
18 | import requests
19 |
20 | MODULE_ROOT = Path(__file__).parent.parent
21 | FIXTURES_ROOT = (MODULE_ROOT / "tests" / "fixtures").resolve()
22 | PACKAGE_ROOT = MODULE_ROOT.parent
23 | DATA_ROOT = (PACKAGE_ROOT / "data").resolve()
24 |
25 | class ProgressLogger(object):
26 | """Throttles and display progress in human readable form."""
27 |
28 | def __init__(self, throttle=1, should_humanize=True):
29 | """Initialize Progress logger.
30 | :param throttle: default 1, number in seconds to use as throttle rate
31 | :param should_humanize: default True, whether to humanize data units
32 | """
33 | self.latest = time.time()
34 | self.throttle_speed = throttle
35 | self.should_humanize = should_humanize
36 |
37 | def humanize(self, num, suffix='B'):
38 | """Convert units to more human-readable format."""
39 | if num < 0:
40 | return num
41 | for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
42 | if abs(num) < 1024.0:
43 | return "%3.1f%s%s" % (num, unit, suffix)
44 | num /= 1024.0
45 | return "%.1f%s%s" % (num, 'Yi', suffix)
46 |
47 | def log(self, curr, total, width=40, force=False):
48 | """Display a bar showing the current progress."""
49 | if curr == 0 and total == -1:
50 | print('[ no data received for this file ]', end='\r')
51 | return
52 | curr_time = time.time()
53 | if not force and curr_time - self.latest < self.throttle_speed:
54 | return
55 | else:
56 | self.latest = curr_time
57 |
58 | self.latest = curr_time
59 | done = min(curr * width // total, width)
60 | remain = width - done
61 |
62 | if self.should_humanize:
63 | curr = self.humanize(curr)
64 | total = self.humanize(total)
65 |
66 | progress = '[{}{}] {} / {}'.format(
67 | ''.join(['|'] * done),
68 | ''.join(['.'] * remain),
69 | curr,
70 | total
71 | )
72 | print(progress, end='\r')
73 |
74 | def built(path, version_string=None):
75 | """Checks if '.built' flag has been set for that task.
76 | If a version_string is provided, this has to match, or the version
77 | is regarded as not built.
78 | """
79 | built_file_path = os.path.join(path, '.built')
80 | if not os.path.isfile(built_file_path):
81 | return False
82 | else:
83 | with open(built_file_path, 'r') as built_file:
84 | text = built_file.read().split('\n')
85 | if len(text) <= 2:
86 | return False
87 | for fname in text[1:-1]:
88 | if not os.path.isfile(os.path.join(path, fname)) and not os.path.isdir(os.path.join(path, fname)):
89 | return False
90 | return text[-1] == version_string if version_string else True
91 |
92 |
93 | def mark_done(path, fnames, version_string='vXX'):
94 | """Marks the path as done by adding a '.built' file with the current
95 | timestamp plus a version description string if specified.
96 | """
97 | with open(os.path.join(path, '.built'), 'w') as built_file:
98 | built_file.write(str(datetime.datetime.today()))
99 | for fname in fnames:
100 | fname = fname.replace('.tar.gz', '').replace('.tgz', '').replace('.gz', '').replace('.zip', '')
101 | built_file.write('\n' + fname)
102 | built_file.write('\n' + version_string)
103 |
104 |
105 | def download(url, path, fname, redownload=False):
106 | """Downloads file using `requests`. If ``redownload`` is set to false, then
107 | will not download tar file again if it is present (default ``True``)."""
108 | outfile = os.path.join(path, fname)
109 | curr_download = not os.path.isfile(outfile) or redownload
110 | print("[ downloading: " + url + " to " + outfile + " ]")
111 | retry = 5
112 | exp_backoff = [2 ** r for r in reversed(range(retry))]
113 |
114 | logger = ProgressLogger()
115 |
116 | while curr_download and retry >= 0:
117 | resume_file = outfile + '.part'
118 | resume = os.path.isfile(resume_file)
119 | if resume:
120 | resume_pos = os.path.getsize(resume_file)
121 | mode = 'ab'
122 | else:
123 | resume_pos = 0
124 | mode = 'wb'
125 | response = None
126 |
127 | with requests.Session() as session:
128 | try:
129 | header = {'Range': 'bytes=%d-' % resume_pos,
130 | 'Accept-Encoding': 'identity'} if resume else {}
131 | response = session.get(url, stream=True, timeout=5, headers=header)
132 |
133 | # negative reply could be 'none' or just missing
134 | if resume and response.headers.get('Accept-Ranges', 'none') == 'none':
135 | resume_pos = 0
136 | mode = 'wb'
137 |
138 | CHUNK_SIZE = 32768
139 | total_size = int(response.headers.get('Content-Length', -1))
140 | # server returns remaining size if resuming, so adjust total
141 | total_size += resume_pos
142 | done = resume_pos
143 |
144 | with open(resume_file, mode) as f:
145 | for chunk in response.iter_content(CHUNK_SIZE):
146 | if chunk: # filter out keep-alive new chunks
147 | f.write(chunk)
148 | if total_size > 0:
149 | done += len(chunk)
150 | if total_size < done:
151 | # don't freak out if content-length was too small
152 | total_size = done
153 | logger.log(done, total_size)
154 | break
155 | except requests.exceptions.ConnectionError:
156 | retry -= 1
157 | # TODO Better way to clean progress bar?
158 | print(''.join([' '] * 60), end='\r')
159 | if retry >= 0:
160 | print('Connection error, retrying. (%d retries left)' % retry)
161 | time.sleep(exp_backoff[retry])
162 | else:
163 | print('Retried too many times, stopped retrying.')
164 | finally:
165 | if response:
166 | response.close()
167 | if retry < 0:
168 | raise RuntimeWarning('Connection broken too many times. Stopped retrying.')
169 |
170 | if curr_download and retry > 0:
171 | logger.log(done, total_size, force=True)
172 | print()
173 | if done < total_size:
174 | raise RuntimeWarning('Received less data than specified in ' +
175 | 'Content-Length header for ' + url + '.' +
176 | ' There may be a download problem.')
177 | move(resume_file, outfile)
178 |
179 |
180 | def make_dir(path):
181 | """Makes the directory and any nonexistent parent directories."""
182 | # the current working directory is a fine path
183 | if path != '':
184 | os.makedirs(path, exist_ok=True)
185 |
186 |
187 | def move(path1, path2):
188 | """Renames the given file."""
189 | shutil.move(path1, path2)
190 |
191 |
192 | def remove_dir(path):
193 | """Removes the given directory, if it exists."""
194 | shutil.rmtree(path, ignore_errors=True)
195 |
196 |
197 | def untar(path, fname, deleteTar=True):
198 | """Unpacks the given archive file to the same directory, then (by default)
199 | deletes the archive file.
200 | """
201 | print('unpacking ' + fname)
202 | fullpath = os.path.join(path, fname)
203 | if '.tar.gz' in fname:
204 | shutil.unpack_archive(fullpath, path, format='gztar')
205 | else:
206 | shutil.unpack_archive(fullpath, path)
207 | if deleteTar:
208 | os.remove(fullpath)
209 |
210 |
211 | def cat(file1, file2, outfile, deleteFiles=True):
212 | with open(outfile, 'wb') as wfd:
213 | for f in [file1, file2]:
214 | with open(f, 'rb') as fd:
215 | shutil.copyfileobj(fd, wfd, 1024 * 1024 * 10)
216 | # 10MB per writing chunk to avoid reading big file into memory.
217 | if deleteFiles:
218 | os.remove(file1)
219 | os.remove(file2)
220 |
221 |
222 | def _get_confirm_token(response):
223 | for key, value in response.cookies.items():
224 | if key.startswith('download_warning'):
225 | return value
226 | return None
227 |
228 | def download_from_google_drive(gd_id, destination):
229 | """Uses the requests package to download a file from Google Drive."""
230 | URL = 'https://docs.google.com/uc?export=download'
231 |
232 | with requests.Session() as session:
233 | response = session.get(URL, params={'id': gd_id}, stream=True)
234 | token = _get_confirm_token(response)
235 |
236 | if token:
237 | response.close()
238 | params = {'id': gd_id, 'confirm': token}
239 | response = session.get(URL, params=params, stream=True)
240 |
241 | CHUNK_SIZE = 32768
242 | with open(destination, 'wb') as f:
243 | for chunk in response.iter_content(CHUNK_SIZE):
244 | if chunk: # filter out keep-alive new chunks
245 | f.write(chunk)
246 | response.close()
247 |
248 |
249 | def download_files(fnames: List[Union[str, Path]],
250 | local_folder: str,
251 | version: str = 'v1.0',
252 | paths: Union[List[str], str] = 'aws') -> List[str]:
253 | r"""Download model/data files from a url.
254 |
255 | Args:
256 | fnames: List of filenames to download
257 | local_folder: Sub-folder of `./data` where models/data will
258 | be downloaded.
259 | version: Version of the model
260 | path: url or respective urls for downloading filenames.
261 |
262 | Return:
263 | List[str]: List of downloaded file path.
264 | If the downloaded file was a compressed file (`.tar.gz`,
265 | `.zip`, `.tgz`, `.gz`), return the path of the folder
266 | containing the extracted files.
267 | """
268 |
269 | dpath = str(DATA_ROOT / local_folder)
270 | out_paths = list(dpath + '/' + fname.replace('.tar.gz', '').replace('.tgz', '').replace('.gz', '').replace('.zip', '')
271 | for fname in fnames)
272 |
273 | if not built(dpath, version):
274 | for fname in fnames:
275 | print('[building data: ' + dpath + '/' + fname + ']')
276 | if built(dpath):
277 | # An older version exists, so remove these outdated files.
278 | remove_dir(dpath)
279 | make_dir(dpath)
280 |
281 | if isinstance(paths, str):
282 | paths = [paths] * len(fnames)
283 | # Download the data.
284 | for fname, path in zip(fnames, paths):
285 | if path == 'aws':
286 | url = 'http://huggingface.co/downloads/models/'
287 | url += local_folder + '/'
288 | url += fname
289 | else:
290 | url = path + '/' + fname
291 | download(url, dpath, fname)
292 | if '.tar.gz' in fname or '.tgz' in fname or '.gz' in fname or '.zip' in fname:
293 | untar(dpath, fname)
294 | # Mark the data as built.
295 | mark_done(dpath, fnames, version)
296 | return out_paths
297 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/__init__.py:
--------------------------------------------------------------------------------
1 | from .generator import Generator
2 | from .swag import SwagGenerator
3 | from .addsent import AddSentGenerator
4 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/addsent/__init__.py:
--------------------------------------------------------------------------------
1 | from .addsent_generator import AddSentGenerator
2 | from .squad_reader import squad_reader
3 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/addsent/addsent_generator.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import json
3 | import itertools
4 | from typing import Iterable, Dict, Tuple
5 | from collections import defaultdict
6 |
7 | from adversarialnlp.common.file_utils import download_files
8 | from adversarialnlp.generators import Generator
9 | from adversarialnlp.generators.addsent.rules import (ANSWER_RULES, HIGH_CONF_ALTER_RULES, ALL_ALTER_RULES,
10 | DO_NOT_ALTER, BAD_ALTERATIONS, CONVERSION_RULES)
11 | from adversarialnlp.generators.addsent.utils import (rejoin, ConstituencyParse, get_tokens_for_answers,
12 | get_determiner_for_answers, read_const_parse)
13 | from adversarialnlp.generators.addsent.squad_reader import squad_reader
14 | from adversarialnlp.generators.addsent.corenlp import StanfordCoreNLP
15 |
16 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
17 |
18 | SQUAD_FILE = 'data/squad/train-v1.1.json'
19 | NEARBY_GLOVE_FILE = 'data/addsent/nearby_n100_glove_6B_100d.json'
20 | POSTAG_FILE = 'data/addsent/postag_dict.json'
21 |
22 | class AddSentGenerator(Generator):
23 | r"""Adversarial examples generator based on AddSent.
24 |
25 | AddSent is described in the paper `Adversarial Examples for
26 | Evaluating Reading Comprehension Systems`_
27 | by Robin Jia & Percy Liang
28 |
29 | Args, input and yield:
30 | See the ``Generator`` class.
31 |
32 | Additional arguments:
33 | alteration_strategy: Alteration strategy. Options:
34 |
35 | - `separate`: Do best alteration for each word separately.
36 | - `best`: Generate exactly one best alteration
37 | (may over-alter).
38 | - `high-conf`: Do all possible high-confidence alterations.
39 | - `high-conf-separate`: Do best high-confidence alteration
40 | for each word separately.
41 | - `all`: Do all possible alterations (very conservative)
42 |
43 | prepend: Insert adversarial example at the beginning
44 | or end of the context.
45 | use_answer_placeholder: Use and answer placeholder.
46 |
47 | Seeds:
48 | Tuple of SQuAD-like instances containing
49 | - question-answer-span, and
50 | - context paragraph.
51 |
52 | default_seeds:
53 | If no seeds are provided, the default_seeds are the training
54 | set of the
55 | `SQuAD V1.0 dataset `_.
56 |
57 | """
58 | def __init__(self,
59 | alteration_strategy: str = 'high-conf',
60 | prepend: bool = False,
61 | use_answer_placeholder: bool = False,
62 | default_seeds: Iterable = None,
63 | quiet: bool = False):
64 | super(AddSentGenerator).__init__(default_seeds, quiet)
65 | model_files = download_files(fnames=['nearby_n100_glove_6B_100d.json',
66 | 'postag_dict.json'],
67 | local_folder='addsent')
68 | corenlp_path = download_files(fnames=['stanford-corenlp-full-2018-02-27.zip'],
69 | paths='http://nlp.stanford.edu/software/',
70 | local_folder='corenlp')
71 |
72 | self.nlp: StanfordCoreNLP = StanfordCoreNLP(corenlp_path[0])
73 | with open(model_files[0], 'r') as data_file:
74 | self.nearby_word_dict: Dict = json.load(data_file)
75 | with open(model_files[1], 'r') as data_file:
76 | self.postag_dict: Dict = json.load(data_file)
77 |
78 | self.alteration_strategy: str = alteration_strategy
79 | self.prepend: bool = prepend
80 | self.use_answer_placeholder: bool = use_answer_placeholder
81 | if default_seeds is None:
82 | self.default_seeds = squad_reader(SQUAD_FILE)
83 | else:
84 | self.default_seeds = default_seeds
85 |
86 | def close(self):
87 | self.nlp.close()
88 |
89 | def _annotate(self, text: str, annotators: str):
90 | r"""Wrapper to call CoreNLP. """
91 | props = {'annotators': annotators,
92 | 'ssplit.newlineIsSentenceBreak': 'always',
93 | 'outputFormat':'json'}
94 | return json.loads(self.nlp.annotate(text, properties=props))
95 |
96 | def _alter_question(self, question, tokens, const_parse):
97 | r"""Alter the question to make it ask something else. """
98 | used_words = [tok['word'].lower() for tok in tokens]
99 | new_qs = []
100 | toks_all = []
101 | if self.alteration_strategy.startswith('high-conf'):
102 | rules = HIGH_CONF_ALTER_RULES
103 | else:
104 | rules = ALL_ALTER_RULES
105 | for i, tok in enumerate(tokens):
106 | if tok['word'].lower() in DO_NOT_ALTER:
107 | if self.alteration_strategy in ('high-conf', 'all'):
108 | toks_all.append(tok)
109 | continue
110 | begin = tokens[:i]
111 | end = tokens[i+1:]
112 | found = False
113 | for rule_name in rules:
114 | rule = rules[rule_name]
115 | new_words = rule(tok, nearby_word_dict=self.nearby_word_dict, postag_dict=self.postag_dict)
116 | if new_words:
117 | for word in new_words:
118 | if word.lower() in used_words:
119 | continue
120 | if word.lower() in BAD_ALTERATIONS:
121 | continue
122 | # Match capitzliation
123 | if tok['word'] == tok['word'].upper():
124 | word = word.upper()
125 | elif tok['word'] == tok['word'].title():
126 | word = word.title()
127 | new_tok = dict(tok)
128 | new_tok['word'] = new_tok['lemma'] = new_tok['originalText'] = word
129 | new_tok['altered'] = True
130 | # NOTE: obviously this is approximate
131 | if self.alteration_strategy.endswith('separate'):
132 | new_tokens = begin + [new_tok] + end
133 | new_q = rejoin(new_tokens)
134 | tag = '%s-%d-%s' % (rule_name, i, word)
135 | new_const_parse = ConstituencyParse.replace_words(
136 | const_parse, [tok['word'] for tok in new_tokens])
137 | new_qs.append((new_q, new_tokens, new_const_parse, tag))
138 | break
139 | elif self.alteration_strategy in ('high-conf', 'all'):
140 | toks_all.append(new_tok)
141 | found = True
142 | break
143 | if self.alteration_strategy in ('high-conf', 'all') and found:
144 | break
145 | if self.alteration_strategy in ('high-conf', 'all') and not found:
146 | toks_all.append(tok)
147 | if self.alteration_strategy in ('high-conf', 'all'):
148 | new_q = rejoin(toks_all)
149 | new_const_parse = ConstituencyParse.replace_words(
150 | const_parse, [tok['word'] for tok in toks_all])
151 | if new_q != question:
152 | new_qs.append((rejoin(toks_all), toks_all, new_const_parse, self.alteration_strategy))
153 | return new_qs
154 |
155 | def generate_from_seed(self, seed: Tuple):
156 | r"""Edit a SQuAD example using rules. """
157 | qas, paragraph = seed
158 | question = qas['question'].strip()
159 | if not self.quiet:
160 | print(f"Question: {question}")
161 | if self.use_answer_placeholder:
162 | answer = 'ANSWER'
163 | determiner = ''
164 | else:
165 | p_parse = self._annotate(paragraph, 'tokenize,ssplit,pos,ner,entitymentions')
166 | ind, a_toks = get_tokens_for_answers(qas['answers'], p_parse)
167 | determiner = get_determiner_for_answers(qas['answers'])
168 | answer_obj = qas['answers'][ind]
169 | for _, func in ANSWER_RULES:
170 | answer = func(answer_obj, a_toks, question, determiner=determiner)
171 | if answer:
172 | break
173 | else:
174 | raise ValueError('Missing answer')
175 | q_parse = self._annotate(question, 'tokenize,ssplit,pos,parse,ner')
176 | q_parse = q_parse['sentences'][0]
177 | q_tokens = q_parse['tokens']
178 | q_const_parse = read_const_parse(q_parse['parse'])
179 | if self.alteration_strategy:
180 | # Easiest to alter the question before converting
181 | q_list = self._alter_question(question, q_tokens, q_const_parse)
182 | else:
183 | q_list = [(question, q_tokens, q_const_parse, 'unaltered')]
184 | for q_str, q_tokens, q_const_parse, tag in q_list:
185 | for rule in CONVERSION_RULES:
186 | sent = rule.convert(q_str, answer, q_tokens, q_const_parse)
187 | if sent:
188 | if not self.quiet:
189 | print(f" Sent ({tag}): {sent}'")
190 | cur_qa = {
191 | 'question': qas['question'],
192 | 'id': '%s-%s' % (qas['id'], tag),
193 | 'answers': qas['answers']
194 | }
195 | if self.prepend:
196 | cur_text = '%s %s' % (sent, paragraph)
197 | new_answers = []
198 | for ans in qas['answers']:
199 | new_answers.append({
200 | 'text': ans['text'],
201 | 'answer_start': ans['answer_start'] + len(sent) + 1
202 | })
203 | cur_qa['answers'] = new_answers
204 | else:
205 | cur_text = '%s %s' % (paragraph, sent)
206 | out_example = {'title': title,
207 | 'seed_context': paragraph,
208 | 'seed_qas': qas,
209 | 'context': cur_text,
210 | 'qas': [cur_qa]}
211 | yield out_example
212 |
213 | # from adversarialnlp.common.file_utils import FIXTURES_ROOT
214 | # generator = AddSentGenerator()
215 | # test_instances = squad_reader(FIXTURES_ROOT / 'squad.json')
216 | # batches = list(generator(test_instances, num_epochs=1))
217 | # assert len(batches) != 0
218 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/addsent/corenlp.py:
--------------------------------------------------------------------------------
1 | # Python wrapper for Stanford CoreNLP
2 | # Copyright (c) 2017 Lynten Guo, 2018 Thomas Wolf
3 | # Extracted and adapted from https://github.com/Lynten/stanford-corenlp
4 |
5 | from __future__ import print_function
6 |
7 | import glob
8 | import json
9 | import logging
10 | import os
11 | import re
12 | import socket
13 | import subprocess
14 | import sys
15 | import time
16 |
17 | import psutil
18 |
19 | try:
20 | from urlparse import urlparse
21 | except ImportError:
22 | from urllib.parse import urlparse
23 |
24 | import requests
25 |
26 |
27 | class StanfordCoreNLP:
28 | def __init__(self, path_or_host, port=None, memory='4g', lang='en', timeout=1500, quiet=True,
29 | logging_level=logging.WARNING, max_retries=5):
30 | self.path_or_host = path_or_host
31 | self.port = port
32 | self.memory = memory
33 | self.lang = lang
34 | self.timeout = timeout
35 | self.quiet = quiet
36 | self.logging_level = logging_level
37 |
38 | logging.basicConfig(level=self.logging_level)
39 |
40 | # Check args
41 | self._check_args()
42 |
43 | if path_or_host.startswith('http'):
44 | self.url = path_or_host + ':' + str(port)
45 | logging.info('Using an existing server {}'.format(self.url))
46 | else:
47 |
48 | # Check Java
49 | if not subprocess.call(['java', '-version'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) == 0:
50 | raise RuntimeError('Java not found.')
51 |
52 | # Check if the dir exists
53 | if not os.path.isdir(self.path_or_host):
54 | raise IOError(str(self.path_or_host) + ' is not a directory.')
55 | directory = os.path.normpath(self.path_or_host) + os.sep
56 | self.class_path_dir = directory
57 |
58 | # Check if the language specific model file exists
59 | switcher = {
60 | 'en': 'stanford-corenlp-[0-9].[0-9].[0-9]-models.jar',
61 | 'zh': 'stanford-chinese-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar',
62 | 'ar': 'stanford-arabic-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar',
63 | 'fr': 'stanford-french-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar',
64 | 'de': 'stanford-german-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar',
65 | 'es': 'stanford-spanish-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar'
66 | }
67 | jars = {
68 | 'en': 'stanford-corenlp-x.x.x-models.jar',
69 | 'zh': 'stanford-chinese-corenlp-yyyy-MM-dd-models.jar',
70 | 'ar': 'stanford-arabic-corenlp-yyyy-MM-dd-models.jar',
71 | 'fr': 'stanford-french-corenlp-yyyy-MM-dd-models.jar',
72 | 'de': 'stanford-german-corenlp-yyyy-MM-dd-models.jar',
73 | 'es': 'stanford-spanish-corenlp-yyyy-MM-dd-models.jar'
74 | }
75 | if len(glob.glob(directory + switcher.get(self.lang))) <= 0:
76 | raise IOError(jars.get(
77 | self.lang) + ' not exists. You should download and place it in the ' + directory + ' first.')
78 |
79 | # If port not set, auto select
80 | # Commenting: see https://github.com/Lynten/stanford-corenlp/issues/26
81 | # if self.port is None:
82 | # for port_candidate in range(9000, 65535):
83 | # if port_candidate not in [conn.laddr[1] for conn in psutil.net_connections()]:
84 | # self.port = port_candidate
85 | # break
86 | self.port = 9999
87 |
88 | # Check if the port is in use
89 | # Also commenting: see https://github.com/Lynten/stanford-corenlp/issues/26
90 | # if self.port in [conn.laddr[1] for conn in psutil.net_connections()]:
91 | # raise IOError('Port ' + str(self.port) + ' is already in use.')
92 |
93 | # Start native server
94 | logging.info('Initializing native server...')
95 | cmd = "java"
96 | java_args = "-Xmx{}".format(self.memory)
97 | java_class = "edu.stanford.nlp.pipeline.StanfordCoreNLPServer"
98 | class_path = '"{}*"'.format(directory)
99 |
100 | args = [cmd, java_args, '-cp', class_path, java_class, '-port', str(self.port)]
101 |
102 | args = ' '.join(args)
103 |
104 | logging.info(args)
105 |
106 | # Silence
107 | with open(os.devnull, 'w') as null_file:
108 | out_file = None
109 | if self.quiet:
110 | out_file = null_file
111 |
112 | self.p = subprocess.Popen(args, shell=True, stdout=out_file, stderr=subprocess.STDOUT)
113 | logging.info('Server shell PID: {}'.format(self.p.pid))
114 |
115 | self.url = 'http://localhost:' + str(self.port)
116 |
117 | # Wait until server starts
118 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
119 | host_name = urlparse(self.url).hostname
120 | time.sleep(1) # OSX, not tested
121 | trial = 1
122 | while sock.connect_ex((host_name, self.port)):
123 | if trial > max_retries:
124 | raise ValueError('Corenlp server is not available')
125 | logging.info('Waiting until the server is available.')
126 | trial += 1
127 | time.sleep(1)
128 | logging.info('The server is available.')
129 |
130 | def __enter__(self):
131 | return self
132 |
133 | def __exit__(self, exc_type, exc_val, exc_tb):
134 | self.close()
135 |
136 | def close(self):
137 | logging.info('Cleanup...')
138 | if hasattr(self, 'p'):
139 | try:
140 | parent = psutil.Process(self.p.pid)
141 | except psutil.NoSuchProcess:
142 | logging.info('No process: {}'.format(self.p.pid))
143 | return
144 |
145 | if self.class_path_dir not in ' '.join(parent.cmdline()):
146 | logging.info('Process not in: {}'.format(parent.cmdline()))
147 | return
148 |
149 | children = parent.children(recursive=True)
150 | for process in children:
151 | logging.info('Killing pid: {}, cmdline: {}'.format(process.pid, process.cmdline()))
152 | # process.send_signal(signal.SIGTERM)
153 | process.kill()
154 |
155 | logging.info('Killing shell pid: {}, cmdline: {}'.format(parent.pid, parent.cmdline()))
156 | # parent.send_signal(signal.SIGTERM)
157 | parent.kill()
158 |
159 | def annotate(self, text, properties=None):
160 | if sys.version_info.major >= 3:
161 | text = text.encode('utf-8')
162 |
163 | r = requests.post(self.url, params={'properties': str(properties)}, data=text,
164 | headers={'Connection': 'close'})
165 | return r.text
166 |
167 | def tregex(self, sentence, pattern):
168 | tregex_url = self.url + '/tregex'
169 | r_dict = self._request(tregex_url, "tokenize,ssplit,depparse,parse", sentence, pattern=pattern)
170 | return r_dict
171 |
172 | def tokensregex(self, sentence, pattern):
173 | tokensregex_url = self.url + '/tokensregex'
174 | r_dict = self._request(tokensregex_url, "tokenize,ssplit,depparse", sentence, pattern=pattern)
175 | return r_dict
176 |
177 | def semgrex(self, sentence, pattern):
178 | semgrex_url = self.url + '/semgrex'
179 | r_dict = self._request(semgrex_url, "tokenize,ssplit,depparse", sentence, pattern=pattern)
180 | return r_dict
181 |
182 | def word_tokenize(self, sentence, span=False):
183 | r_dict = self._request('ssplit,tokenize', sentence)
184 | tokens = [token['originalText'] for s in r_dict['sentences'] for token in s['tokens']]
185 |
186 | # Whether return token span
187 | if span:
188 | spans = [(token['characterOffsetBegin'], token['characterOffsetEnd']) for s in r_dict['sentences'] for token
189 | in s['tokens']]
190 | return tokens, spans
191 | else:
192 | return tokens
193 |
194 | def pos_tag(self, sentence):
195 | r_dict = self._request(self.url, 'pos', sentence)
196 | words = []
197 | tags = []
198 | for s in r_dict['sentences']:
199 | for token in s['tokens']:
200 | words.append(token['originalText'])
201 | tags.append(token['pos'])
202 | return list(zip(words, tags))
203 |
204 | def ner(self, sentence):
205 | r_dict = self._request(self.url, 'ner', sentence)
206 | words = []
207 | ner_tags = []
208 | for s in r_dict['sentences']:
209 | for token in s['tokens']:
210 | words.append(token['originalText'])
211 | ner_tags.append(token['ner'])
212 | return list(zip(words, ner_tags))
213 |
214 | def parse(self, sentence):
215 | r_dict = self._request(self.url, 'pos,parse', sentence)
216 | return [s['parse'] for s in r_dict['sentences']][0]
217 |
218 | def dependency_parse(self, sentence):
219 | r_dict = self._request(self.url, 'depparse', sentence)
220 | return [(dep['dep'], dep['governor'], dep['dependent']) for s in r_dict['sentences'] for dep in
221 | s['basicDependencies']]
222 |
223 | def coref(self, text):
224 | r_dict = self._request('coref', text)
225 |
226 | corefs = []
227 | for k, mentions in r_dict['corefs'].items():
228 | simplified_mentions = []
229 | for m in mentions:
230 | simplified_mentions.append((m['sentNum'], m['startIndex'], m['endIndex'], m['text']))
231 | corefs.append(simplified_mentions)
232 | return corefs
233 |
234 | def switch_language(self, language="en"):
235 | self._check_language(language)
236 | self.lang = language
237 |
238 | def _request(self, url, annotators=None, data=None, *args, **kwargs):
239 | if sys.version_info.major >= 3:
240 | data = data.encode('utf-8')
241 |
242 | properties = {'annotators': annotators, 'outputFormat': 'json'}
243 | params = {'properties': str(properties), 'pipelineLanguage': self.lang}
244 | if 'pattern' in kwargs:
245 | params = {"pattern": kwargs['pattern'], 'properties': str(properties), 'pipelineLanguage': self.lang}
246 |
247 | logging.info(params)
248 | r = requests.post(url, params=params, data=data, headers={'Connection': 'close'})
249 | r_dict = json.loads(r.text)
250 |
251 | return r_dict
252 |
253 | def _check_args(self):
254 | self._check_language(self.lang)
255 | if not re.match('\dg', self.memory):
256 | raise ValueError('memory=' + self.memory + ' not supported. Use 4g, 6g, 8g and etc. ')
257 |
258 | def _check_language(self, lang):
259 | if lang not in ['en', 'zh', 'ar', 'fr', 'de', 'es']:
260 | raise ValueError('lang=' + self.lang + ' not supported. Use English(en), Chinese(zh), Arabic(ar), '
261 | 'French(fr), German(de), Spanish(es).')
262 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/addsent/rules/__init__.py:
--------------------------------------------------------------------------------
1 | from .answer_rules import ANSWER_RULES
2 | from .alteration_rules import (HIGH_CONF_ALTER_RULES, ALL_ALTER_RULES,
3 | DO_NOT_ALTER, BAD_ALTERATIONS)
4 | from .conversion_rules import CONVERSION_RULES
5 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/addsent/rules/alteration_rules.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | import nltk
4 | nltk.download('wordnet')
5 |
6 | from nltk.corpus import wordnet as wn
7 | from nltk.stem.lancaster import LancasterStemmer
8 |
9 | STEMMER = LancasterStemmer()
10 |
11 | POS_TO_WORDNET = {
12 | 'NN': wn.NOUN,
13 | 'JJ': wn.ADJ,
14 | 'JJR': wn.ADJ,
15 | 'JJS': wn.ADJ,
16 | }
17 |
18 | def alter_special(token, **kwargs):
19 | w = token['originalText']
20 | if w in SPECIAL_ALTERATIONS:
21 | return [SPECIAL_ALTERATIONS[w]]
22 | return None
23 |
24 | def alter_nearby(pos_list, ignore_pos=False, is_ner=False):
25 | def func(token, nearby_word_dict=None, postag_dict=None, **kwargs):
26 | if token['pos'] not in pos_list: return None
27 | if is_ner and token['ner'] not in ('PERSON', 'LOCATION', 'ORGANIZATION', 'MISC'):
28 | return None
29 | w = token['word'].lower()
30 | if w in ('war'): return None
31 | if w not in nearby_word_dict: return None
32 | new_words = []
33 | w_stem = STEMMER.stem(w.replace('.', ''))
34 | for x in nearby_word_dict[w][1:]:
35 | new_word = x['word']
36 | # Make sure words aren't too similar (e.g. same stem)
37 | new_stem = STEMMER.stem(new_word.replace('.', ''))
38 | if w_stem.startswith(new_stem) or new_stem.startswith(w_stem): continue
39 | if not ignore_pos:
40 | # Check for POS tag match
41 | if new_word not in postag_dict: continue
42 | new_postag = postag_dict[new_word]
43 | if new_postag != token['pos']: continue
44 | new_words.append(new_word)
45 | return new_words
46 | return func
47 |
48 | def alter_entity_glove(token, nearby_word_dict=None, **kwargs):
49 | # NOTE: Deprecated
50 | if token['ner'] not in ('PERSON', 'LOCATION', 'ORGANIZATION', 'MISC'): return None
51 | w = token['word'].lower()
52 | if w == token['word']: return None # Only do capitalized words
53 | if w not in nearby_word_dict: return None
54 | new_words = []
55 | for x in nearby_word_dict[w][1:3]:
56 | if token['word'] == w.upper():
57 | new_words.append(x['word'].upper())
58 | else:
59 | new_words.append(x['word'].title())
60 | return new_words
61 |
62 | def alter_entity_type(token, **kwargs):
63 | pos = token['pos']
64 | ner = token['ner']
65 | word = token['word']
66 | is_abbrev = word == word.upper() and not word == word.lower()
67 | if token['pos'] not in (
68 | 'JJ', 'JJR', 'JJS', 'NN', 'NNS', 'NNP', 'NNPS', 'RB', 'RBR', 'RBS',
69 | 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ'):
70 | # Don't alter non-content words
71 | return None
72 | if ner == 'PERSON':
73 | return ['Jackson']
74 | elif ner == 'LOCATION':
75 | return ['Berlin']
76 | elif ner == 'ORGANIZATION':
77 | if is_abbrev: return ['UNICEF']
78 | return ['Acme']
79 | elif ner == 'MISC':
80 | return ['Neptune']
81 | elif ner == 'NNP':
82 | if is_abbrev: return ['XKCD']
83 | return ['Dalek']
84 | elif pos == 'NNPS':
85 | return ['Daleks']
86 | return None
87 |
88 | def alter_wordnet_antonyms(token, **kwargs):
89 | if token['pos'] not in POS_TO_WORDNET: return None
90 | w = token['word'].lower()
91 | wn_pos = POS_TO_WORDNET[token['pos']]
92 | synsets = wn.synsets(w, wn_pos)
93 | if not synsets: return None
94 | synset = synsets[0]
95 | antonyms = []
96 | for lem in synset.lemmas():
97 | if lem.antonyms():
98 | for a in lem.antonyms():
99 | new_word = a.name()
100 | if '_' in a.name(): continue
101 | antonyms.append(new_word)
102 | return antonyms
103 |
104 | SPECIAL_ALTERATIONS = {
105 | 'States': 'Kingdom',
106 | 'US': 'UK',
107 | 'U.S': 'U.K.',
108 | 'U.S.': 'U.K.',
109 | 'UK': 'US',
110 | 'U.K.': 'U.S.',
111 | 'U.K': 'U.S.',
112 | 'largest': 'smallest',
113 | 'smallest': 'largest',
114 | 'highest': 'lowest',
115 | 'lowest': 'highest',
116 | 'May': 'April',
117 | 'Peyton': 'Trevor',
118 | }
119 |
120 | DO_NOT_ALTER = ['many', 'such', 'few', 'much', 'other', 'same', 'general',
121 | 'type', 'record', 'kind', 'sort', 'part', 'form', 'terms', 'use',
122 | 'place', 'way', 'old', 'young', 'bowl', 'united', 'one',
123 | 'likely', 'different', 'square', 'war', 'republic', 'doctor', 'color']
124 |
125 | BAD_ALTERATIONS = ['mx2004', 'planet', 'u.s.', 'Http://Www.Co.Mo.Md.Us']
126 |
127 | HIGH_CONF_ALTER_RULES = collections.OrderedDict([
128 | ('special', alter_special),
129 | ('wn_antonyms', alter_wordnet_antonyms),
130 | ('nearbyNum', alter_nearby(['CD'], ignore_pos=True)),
131 | ('nearbyProperNoun', alter_nearby(['NNP', 'NNPS'])),
132 | ('nearbyProperNoun', alter_nearby(['NNP', 'NNPS'], ignore_pos=True)),
133 | ('nearbyEntityNouns', alter_nearby(['NN', 'NNS'], is_ner=True)),
134 | ('nearbyEntityJJ', alter_nearby(['JJ', 'JJR', 'JJS'], is_ner=True)),
135 | ('entityType', alter_entity_type),
136 | #('entity_glove', alter_entity_glove),
137 | ])
138 | ALL_ALTER_RULES = collections.OrderedDict(list(HIGH_CONF_ALTER_RULES.items()) + [
139 | ('nearbyAdj', alter_nearby(['JJ', 'JJR', 'JJS'])),
140 | ('nearbyNoun', alter_nearby(['NN', 'NNS'])),
141 | #('nearbyNoun', alter_nearby(['NN', 'NNS'], ignore_pos=True)),
142 | ])
143 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/addsent/rules/answer_rules.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from adversarialnlp.generators.addsent.utils import rejoin
4 |
5 | MONTHS = ['january', 'february', 'march', 'april', 'may', 'june', 'july',
6 | 'august', 'september', 'october', 'november', 'december']
7 |
8 | def ans_number(a, tokens, q, **kwargs):
9 | out_toks = []
10 | seen_num = False
11 | for t in tokens:
12 | ner = t['ner']
13 | pos = t['pos']
14 | w = t['word']
15 | out_tok = {'before': t['before']}
16 |
17 | # Split on dashes
18 | leftover = ''
19 | dash_toks = w.split('-')
20 | if len(dash_toks) > 1:
21 | w = dash_toks[0]
22 | leftover = '-'.join(dash_toks[1:])
23 |
24 | # Try to get a number out
25 | value = None
26 | if w != '%':
27 | # Percent sign should just pass through
28 | try:
29 | value = float(w.replace(',', ''))
30 | except:
31 | try:
32 | norm_ner = t['normalizedNER']
33 | if norm_ner[0] in ('%', '>', '<'):
34 | norm_ner = norm_ner[1:]
35 | value = float(norm_ner)
36 | except:
37 | pass
38 | if not value and (
39 | ner == 'NUMBER' or
40 | (ner == 'PERCENT' and pos == 'CD')):
41 | # Force this to be a number anyways
42 | value = 10
43 | if value:
44 | if math.isinf(value) or math.isnan(value): value = 9001
45 | seen_num = True
46 | if w in ('thousand', 'million', 'billion', 'trillion'):
47 | if w == 'thousand':
48 | new_val = 'million'
49 | else:
50 | new_val = 'thousand'
51 | else:
52 | if value < 2500 and value > 1000:
53 | new_val = str(value - 75)
54 | else:
55 | # Change leading digit
56 | if value == int(value):
57 | val_chars = list('%d' % value)
58 | else:
59 | val_chars = list('%g' % value)
60 | c = val_chars[0]
61 | for i in range(len(val_chars)):
62 | c = val_chars[i]
63 | if c >= '0' and c <= '9':
64 | val_chars[i] = str(max((int(c) + 5) % 10, 1))
65 | break
66 | new_val = ''.join(val_chars)
67 | if leftover:
68 | new_val = '%s-%s' % (new_val, leftover)
69 | out_tok['originalText'] = new_val
70 | else:
71 | out_tok['originalText'] = t['originalText']
72 | out_toks.append(out_tok)
73 | if seen_num:
74 | return rejoin(out_toks).strip()
75 | else:
76 | return None
77 |
78 | def ans_date(a, tokens, q, **kwargs):
79 | out_toks = []
80 | if not all(t['ner'] == 'DATE' for t in tokens):
81 | return None
82 | for t in tokens:
83 | if t['pos'] == 'CD' or t['word'].isdigit():
84 | try:
85 | value = int(t['word'])
86 | except:
87 | value = 10 # fallback
88 | if value > 50: new_val = str(value - 25) # Year
89 | else: # Day of month
90 | if value > 15: new_val = str(value - 11)
91 | else: new_val = str(value + 11)
92 | else:
93 | if t['word'].lower() in MONTHS:
94 | m_ind = MONTHS.index(t['word'].lower())
95 | new_val = MONTHS[(m_ind + 6) % 12].title()
96 | else:
97 | # Give up
98 | new_val = t['originalText']
99 | out_toks.append({'before': t['before'], 'originalText': new_val})
100 | new_ans = rejoin(out_toks).strip()
101 | if new_ans == a['text']: return None
102 | return new_ans
103 |
104 | def ans_entity_full(ner_tag, new_ans):
105 | """Returns a function that yields new_ans iff every token has |ner_tag|."""
106 | def func(a, tokens, q, **kwargs):
107 | for t in tokens:
108 | if t['ner'] != ner_tag: return None
109 | return new_ans
110 | return func
111 |
112 | def ans_abbrev(new_ans):
113 | def func(a, tokens, q, **kwargs):
114 | s = a['text']
115 | if s == s.upper() and s != s.lower():
116 | return new_ans
117 | return None
118 | return func
119 |
120 | def ans_match_wh(wh_word, new_ans):
121 | """Returns a function that yields new_ans if the question starts with |wh_word|."""
122 | def func(a, tokens, q, **kwargs):
123 | if q.lower().startswith(wh_word + ' '):
124 | return new_ans
125 | return None
126 | return func
127 |
128 | def ans_pos(pos, new_ans, end=False, add_dt=False):
129 | """Returns a function that yields new_ans if the first/last token has |pos|."""
130 | def func(a, tokens, q, determiner, **kwargs):
131 | if end:
132 | t = tokens[-1]
133 | else:
134 | t = tokens[0]
135 | if t['pos'] != pos: return None
136 | if add_dt and determiner:
137 | return '%s %s' % (determiner, new_ans)
138 | return new_ans
139 | return func
140 |
141 |
142 | def ans_catch_all(new_ans):
143 | def func(a, tokens, q, **kwargs):
144 | return new_ans
145 | return func
146 |
147 | ANSWER_RULES = [
148 | ('date', ans_date),
149 | ('number', ans_number),
150 | ('ner_person', ans_entity_full('PERSON', 'Jeff Dean')),
151 | ('ner_location', ans_entity_full('LOCATION', 'Chicago')),
152 | ('ner_organization', ans_entity_full('ORGANIZATION', 'Stark Industries')),
153 | ('ner_misc', ans_entity_full('MISC', 'Jupiter')),
154 | ('abbrev', ans_abbrev('LSTM')),
155 | ('wh_who', ans_match_wh('who', 'Jeff Dean')),
156 | ('wh_when', ans_match_wh('when', '1956')),
157 | ('wh_where', ans_match_wh('where', 'Chicago')),
158 | ('wh_where', ans_match_wh('how many', '42')),
159 | # Starts with verb
160 | ('pos_begin_vb', ans_pos('VB', 'learn')),
161 | ('pos_end_vbd', ans_pos('VBD', 'learned')),
162 | ('pos_end_vbg', ans_pos('VBG', 'learning')),
163 | ('pos_end_vbp', ans_pos('VBP', 'learns')),
164 | ('pos_end_vbz', ans_pos('VBZ', 'learns')),
165 | # Ends with some POS tag
166 | ('pos_end_nn', ans_pos('NN', 'hamster', end=True, add_dt=True)),
167 | ('pos_end_nnp', ans_pos('NNP', 'Central Park', end=True, add_dt=True)),
168 | ('pos_end_nns', ans_pos('NNS', 'hamsters', end=True, add_dt=True)),
169 | ('pos_end_nnps', ans_pos('NNPS', 'Kew Gardens', end=True, add_dt=True)),
170 | ('pos_end_jj', ans_pos('JJ', 'deep', end=True)),
171 | ('pos_end_jjr', ans_pos('JJR', 'deeper', end=True)),
172 | ('pos_end_jjs', ans_pos('JJS', 'deepest', end=True)),
173 | ('pos_end_rb', ans_pos('RB', 'silently', end=True)),
174 | ('pos_end_vbg', ans_pos('VBG', 'learning', end=True)),
175 | ('catch_all', ans_catch_all('aliens')),
176 | ]
177 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/addsent/rules/conversion_rules.py:
--------------------------------------------------------------------------------
1 | from pattern import en as patten
2 |
3 | CONST_PARSE_MACROS = {
4 | '$Noun': '$NP/$NN/$NNS/$NNP/$NNPS',
5 | '$Verb': '$VB/$VBD/$VBP/$VBZ',
6 | '$Part': '$VBN/$VG',
7 | '$Be': 'is/are/was/were',
8 | '$Do': "do/did/does/don't/didn't/doesn't",
9 | '$WHP': '$WHADJP/$WHADVP/$WHNP/$WHPP',
10 | }
11 |
12 | # Map to pattern.en aliases
13 | # http://www.clips.ua.ac.be/pages/pattern-en#conjugation
14 | POS_TO_PATTERN = {
15 | 'vb': 'inf', # Infinitive
16 | 'vbp': '1sg', # non-3rd-person singular present
17 | 'vbz': '3sg', # 3rd-person singular present
18 | 'vbg': 'part', # gerund or present participle
19 | 'vbd': 'p', # past
20 | 'vbn': 'ppart', # past participle
21 | }
22 | # Tenses prioritized by likelihood of arising
23 | PATTERN_TENSES = ['inf', '3sg', 'p', 'part', 'ppart', '1sg']
24 |
25 | def _check_match(node, pattern_tok):
26 | if pattern_tok in CONST_PARSE_MACROS:
27 | pattern_tok = CONST_PARSE_MACROS[pattern_tok]
28 | if ':' in pattern_tok:
29 | # ':' means you match the LHS category and start with something on the right
30 | lhs, rhs = pattern_tok.split(':')
31 | match_lhs = _check_match(node, lhs)
32 | if not match_lhs: return False
33 | phrase = node.get_phrase().lower()
34 | retval = any(phrase.startswith(w) for w in rhs.split('/'))
35 | return retval
36 | elif '/' in pattern_tok:
37 | return any(_check_match(node, t) for t in pattern_tok.split('/'))
38 | return ((pattern_tok.startswith('$') and pattern_tok[1:] == node.tag) or
39 | (node.word and pattern_tok.lower() == node.word.lower()))
40 |
41 | def _recursive_match_pattern(pattern_toks, stack, matches):
42 | """Recursively try to match a pattern, greedily."""
43 | if len(matches) == len(pattern_toks):
44 | # We matched everything in the pattern; also need stack to be empty
45 | return len(stack) == 0
46 | if len(stack) == 0: return False
47 | cur_tok = pattern_toks[len(matches)]
48 | node = stack.pop()
49 | # See if we match the current token at this level
50 | is_match = _check_match(node, cur_tok)
51 | if is_match:
52 | cur_num_matches = len(matches)
53 | matches.append(node)
54 | new_stack = list(stack)
55 | success = _recursive_match_pattern(pattern_toks, new_stack, matches)
56 | if success: return True
57 | # Backtrack
58 | while len(matches) > cur_num_matches:
59 | matches.pop()
60 | # Recurse to children
61 | if not node.children: return False # No children to recurse on, we failed
62 | stack.extend(node.children[::-1]) # Leftmost children should be popped first
63 | return _recursive_match_pattern(pattern_toks, stack, matches)
64 |
65 | def match_pattern(pattern, const_parse):
66 | pattern_toks = pattern.split(' ')
67 | whole_phrase = const_parse.get_phrase()
68 | if whole_phrase.endswith('?') or whole_phrase.endswith('.'):
69 | # Match trailing punctuation as needed
70 | pattern_toks.append(whole_phrase[-1])
71 | matches = []
72 | success = _recursive_match_pattern(pattern_toks, [const_parse], matches)
73 | if success:
74 | return matches
75 | else:
76 | return None
77 |
78 | def run_postprocessing(s, rules, all_args):
79 | rule_list = rules.split(',')
80 | for rule in rule_list:
81 | if rule == 'lower':
82 | s = s.lower()
83 | elif rule.startswith('tense-'):
84 | ind = int(rule[6:])
85 | orig_vb = all_args[ind]
86 | tenses = patten.tenses(orig_vb)
87 | for tense in PATTERN_TENSES: # Prioritize by PATTERN_TENSES
88 | if tense in tenses:
89 | break
90 | else: # Default to first tense
91 | tense = PATTERN_TENSES[0]
92 | s = patten.conjugate(s, tense)
93 | elif rule in POS_TO_PATTERN:
94 | s = patten.conjugate(s, POS_TO_PATTERN[rule])
95 | return s
96 |
97 | def convert_whp(node, q, a, tokens, quiet=False):
98 | if node.tag in ('WHNP', 'WHADJP', 'WHADVP', 'WHPP'):
99 | # Apply WHP rules
100 | cur_phrase = node.get_phrase()
101 | cur_tokens = tokens[node.get_start_index():node.get_end_index()]
102 | for r in WHP_RULES:
103 | phrase = r.convert(cur_phrase, a, cur_tokens, node, run_fix_style=False)
104 | if phrase:
105 | if not quiet:
106 | print(f" WHP Rule '{r.name}': {phrase}")
107 | return phrase
108 | return None
109 |
110 | ### Rules for converting questions into declarative sentences
111 | def fix_style(s):
112 | """Minor, general style fixes for questions."""
113 | s = s.replace('?', '') # Delete question marks anywhere in sentence.
114 | s = s.strip(' .')
115 | if s[0] == s[0].lower():
116 | s = s[0].upper() + s[1:]
117 | return s + '.'
118 |
119 | class ConversionRule(object):
120 | def convert(self, q, a, tokens, const_parse, run_fix_style=True):
121 | raise NotImplementedError
122 |
123 | class ConstituencyRule(ConversionRule):
124 | """A rule for converting question to sentence based on constituency parse."""
125 | def __init__(self, in_pattern, out_pattern, postproc=None):
126 | self.in_pattern = in_pattern # e.g. "where did $NP $VP"
127 | self.out_pattern = out_pattern #unicode(out_pattern)
128 | # e.g. "{1} did {2} at {0}." Answer is always 0
129 | self.name = in_pattern
130 | if postproc:
131 | self.postproc = postproc
132 | else:
133 | self.postproc = {}
134 |
135 | def convert(self, q, a, tokens, const_parse, run_fix_style=True) -> str:
136 | pattern_toks = self.in_pattern.split(' ') # Don't care about trailing punctuation
137 | match = match_pattern(self.in_pattern, const_parse)
138 | appended_clause = False
139 | if not match:
140 | # Try adding a PP at the beginning
141 | appended_clause = True
142 | new_pattern = '$PP , ' + self.in_pattern
143 | pattern_toks = new_pattern.split(' ')
144 | match = match_pattern(new_pattern, const_parse)
145 | if not match:
146 | # Try adding an SBAR at the beginning
147 | new_pattern = '$SBAR , ' + self.in_pattern
148 | pattern_toks = new_pattern.split(' ')
149 | match = match_pattern(new_pattern, const_parse)
150 | if not match: return None
151 | appended_clause_match = None
152 | fmt_args = [a]
153 | for t, m in zip(pattern_toks, match):
154 | if t.startswith('$') or '/' in t:
155 | # First check if it's a WHP
156 | phrase = convert_whp(m, q, a, tokens)
157 | if not phrase:
158 | phrase = m.get_phrase()
159 | fmt_args.append(phrase)
160 | if appended_clause:
161 | appended_clause_match = fmt_args[1]
162 | fmt_args = [a] + fmt_args[2:]
163 | for i in range(len(fmt_args)):
164 | if i in self.postproc:
165 | # Run postprocessing filters
166 | fmt_args[i] = run_postprocessing(fmt_args[i], self.postproc[i], fmt_args)
167 | output = self.gen_output(fmt_args)
168 | if appended_clause:
169 | output = appended_clause_match + ', ' + output
170 | if run_fix_style:
171 | output = fix_style(output)
172 | return output
173 |
174 |
175 | def gen_output(self, fmt_args):
176 | """By default, use self.out_pattern. Can be overridden."""
177 | return self.out_pattern.format(*fmt_args)
178 |
179 | class ReplaceRule(ConversionRule):
180 | """A simple rule that replaces some tokens with the answer."""
181 | def __init__(self, target, replacement='{}', start=False):
182 | self.target = target
183 | self.replacement = replacement #unicode(replacement)
184 | self.name = 'replace(%s)' % target
185 | self.start = start
186 |
187 | def convert(self, q, a, tokens, const_parse, run_fix_style=True):
188 | t_toks = self.target.split(' ')
189 | q_toks = q.rstrip('?.').split(' ')
190 | replacement_text = self.replacement.format(a)
191 | for i in range(len(q_toks)):
192 | if self.start and i != 0: continue
193 | if ' '.join(q_toks[i:i + len(t_toks)]).rstrip(',').lower() == self.target:
194 | begin = q_toks[:i]
195 | end = q_toks[i + len(t_toks):]
196 | output = ' '.join(begin + [replacement_text] + end)
197 | if run_fix_style:
198 | output = fix_style(output)
199 | return output
200 | return None
201 |
202 | class FindWHPRule(ConversionRule):
203 | """A rule that looks for $WHP's from right to left and does replacements."""
204 | name = 'FindWHP'
205 | def _recursive_convert(self, node, q, a, tokens, found_whp):
206 | if node.word:
207 | return node.word, found_whp
208 | if not found_whp:
209 | whp_phrase = convert_whp(node, q, a, tokens)
210 | if whp_phrase:
211 | return whp_phrase, True
212 | child_phrases = []
213 | for c in node.children[::-1]:
214 | c_phrase, found_whp = self._recursive_convert(c, q, a, tokens, found_whp)
215 | child_phrases.append(c_phrase)
216 | out_toks = []
217 | for i, p in enumerate(child_phrases[::-1]):
218 | if i == 0 or p.startswith("'"):
219 | out_toks.append(p)
220 | else:
221 | out_toks.append(' ' + p)
222 | return ''.join(out_toks), found_whp
223 |
224 | def convert(self, q, a, tokens, const_parse, run_fix_style=True):
225 | out_phrase, found_whp = self._recursive_convert(const_parse, q, a, tokens, False)
226 | if found_whp:
227 | if run_fix_style:
228 | out_phrase = fix_style(out_phrase)
229 | return out_phrase
230 | return None
231 |
232 | class AnswerRule(ConversionRule):
233 | """Just return the answer."""
234 | name = 'AnswerRule'
235 | def convert(self, q, a, tokens, const_parse, run_fix_style=True):
236 | return a
237 |
238 | CONVERSION_RULES = [
239 | # Special rules
240 | ConstituencyRule('$WHP:what $Be $NP called that $VP', '{2} that {3} {1} called {1}'),
241 |
242 | # What type of X
243 | #ConstituencyRule("$WHP:what/which type/sort/kind/group of $NP/$Noun $Be $NP", '{5} {4} a {1} {3}'),
244 | #ConstituencyRule("$WHP:what/which type/sort/kind/group of $NP/$Noun $Be $VP", '{1} {3} {4} {5}'),
245 | #ConstituencyRule("$WHP:what/which type/sort/kind/group of $NP $VP", '{1} {3} {4}'),
246 |
247 | # How $JJ
248 | ConstituencyRule('how $JJ $Be $NP $IN $NP', '{3} {2} {0} {1} {4} {5}'),
249 | ConstituencyRule('how $JJ $Be $NP $SBAR', '{3} {2} {0} {1} {4}'),
250 | ConstituencyRule('how $JJ $Be $NP', '{3} {2} {0} {1}'),
251 |
252 | # When/where $Verb
253 | ConstituencyRule('$WHP:when/where $Do $NP', '{3} occurred in {1}'),
254 | ConstituencyRule('$WHP:when/where $Do $NP $Verb', '{3} {4} in {1}', {4: 'tense-2'}),
255 | ConstituencyRule('$WHP:when/where $Do $NP $Verb $NP/$PP', '{3} {4} {5} in {1}', {4: 'tense-2'}),
256 | ConstituencyRule('$WHP:when/where $Do $NP $Verb $NP $PP', '{3} {4} {5} {6} in {1}', {4: 'tense-2'}),
257 | ConstituencyRule('$WHP:when/where $Be $NP', '{3} {2} in {1}'),
258 | ConstituencyRule('$WHP:when/where $Verb $NP $VP/$ADJP', '{3} {2} {4} in {1}'),
259 |
260 | # What/who/how $Do
261 | ConstituencyRule("$WHP:what/which/who $Do $NP do", '{3} {1}', {0: 'tense-2'}),
262 | ConstituencyRule("$WHP:what/which/who/how $Do $NP $Verb", '{3} {4} {1}', {4: 'tense-2'}),
263 | ConstituencyRule("$WHP:what/which/who $Do $NP $Verb $IN/$NP", '{3} {4} {5} {1}', {4: 'tense-2', 0: 'vbg'}),
264 | ConstituencyRule("$WHP:what/which/who $Do $NP $Verb $PP", '{3} {4} {1} {5}', {4: 'tense-2', 0: 'vbg'}),
265 | ConstituencyRule("$WHP:what/which/who $Do $NP $Verb $NP $VP", '{3} {4} {5} {6} {1}', {4: 'tense-2'}),
266 | ConstituencyRule("$WHP:what/which/who $Do $NP $Verb to $VB", '{3} {4} to {5} {1}', {4: 'tense-2'}),
267 | ConstituencyRule("$WHP:what/which/who $Do $NP $Verb to $VB $VP", '{3} {4} to {5} {1} {6}', {4: 'tense-2'}),
268 | ConstituencyRule("$WHP:what/which/who/how $Do $NP $Verb $NP $IN $VP", '{3} {4} {5} {6} {1} {7}', {4: 'tense-2'}),
269 | ConstituencyRule("$WHP:what/which/who/how $Do $NP $Verb $PP/$S/$VP/$SBAR/$SQ", '{3} {4} {1} {5}', {4: 'tense-2'}),
270 | ConstituencyRule("$WHP:what/which/who/how $Do $NP $Verb $PP $PP/$S/$VP/$SBAR", '{3} {4} {1} {5} {6}', {4: 'tense-2'}),
271 |
272 | # What/who/how $Be
273 | # Watch out for things that end in a preposition
274 | ConstituencyRule("$WHP:what/which/who $Be/$MD $NP of $NP $Verb/$Part $IN", '{3} of {4} {2} {5} {6} {1}'),
275 | ConstituencyRule("$WHP:what/which/who $Be/$MD $NP $NP $IN", '{3} {2} {4} {5} {1}'),
276 | ConstituencyRule("$WHP:what/which/who $Be/$MD $NP $VP/$IN", '{3} {2} {4} {1}'),
277 | ConstituencyRule("$WHP:what/which/who $Be/$MD $NP $IN $NP/$VP", '{1} {2} {3} {4} {5}'),
278 | ConstituencyRule('$WHP:what/which/who $Be/$MD $NP $Verb $PP', '{3} {2} {4} {1} {5}'),
279 | ConstituencyRule('$WHP:what/which/who $Be/$MD $NP/$VP/$PP', '{1} {2} {3}'),
280 | ConstituencyRule("$WHP:how $Be/$MD $NP $VP", '{3} {2} {4} by {1}'),
281 |
282 | # What/who $Verb
283 | ConstituencyRule("$WHP:what/which/who $VP", '{1} {2}'),
284 |
285 | # $IN what/which $NP
286 | ConstituencyRule('$IN what/which $NP $Do $NP $Verb $NP', '{5} {6} {7} {1} the {3} of {0}',
287 | {1: 'lower', 6: 'tense-4'}),
288 | ConstituencyRule('$IN what/which $NP $Be $NP $VP/$ADJP', '{5} {4} {6} {1} the {3} of {0}',
289 | {1: 'lower'}),
290 | ConstituencyRule('$IN what/which $NP $Verb $NP/$ADJP $VP', '{5} {4} {6} {1} the {3} of {0}',
291 | {1: 'lower'}),
292 | FindWHPRule(),
293 | ]
294 |
295 | # Rules for going from WHP to an answer constituent
296 | WHP_RULES = [
297 | # WHPP rules
298 | ConstituencyRule('$IN what/which type/sort/kind/group of $NP/$Noun', '{1} {0} {4}'),
299 | ConstituencyRule('$IN what/which type/sort/kind/group of $NP/$Noun $PP', '{1} {0} {4} {5}'),
300 | ConstituencyRule('$IN what/which $NP', '{1} the {3} of {0}'),
301 | ConstituencyRule('$IN $WP/$WDT', '{1} {0}'),
302 |
303 | # what/which
304 | ConstituencyRule('what/which type/sort/kind/group of $NP/$Noun', '{0} {3}'),
305 | ConstituencyRule('what/which type/sort/kind/group of $NP/$Noun $PP', '{0} {3} {4}'),
306 | ConstituencyRule('what/which $NP', 'the {2} of {0}'),
307 |
308 | # How many
309 | ConstituencyRule('how many/much $NP', '{0} {2}'),
310 |
311 | # Replace
312 | ReplaceRule('what'),
313 | ReplaceRule('who'),
314 | ReplaceRule('how many'),
315 | ReplaceRule('how much'),
316 | ReplaceRule('which'),
317 | ReplaceRule('where'),
318 | ReplaceRule('when'),
319 | ReplaceRule('why'),
320 | ReplaceRule('how'),
321 |
322 | # Just give the answer
323 | AnswerRule(),
324 | ]
325 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/addsent/squad_reader.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | from typing import Iterator, List, Tuple
4 |
5 | from adversarialnlp.common.file_utils import download_files
6 |
7 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
8 |
9 |
10 | def squad_reader(file_path: str = None) -> Iterator[List[Tuple[str, str]]]:
11 | r""" Reads a JSON-formatted SQuAD file and returns an Iterator.
12 |
13 | Args:
14 | file_path: Path to a JSON-formatted SQuAD file.
15 | If no path is provided, download and use SQuAD v1.0 training dataset.
16 |
17 | Return:
18 | list of tuple (question_answer, paragraph).
19 | """
20 | if file_path is None:
21 | file_path = download_files(fnames=['train-v1.1.json'],
22 | paths='https://rajpurkar.github.io/SQuAD-explorer/dataset/',
23 | local_folder='squad')
24 | file_path = file_path[0]
25 |
26 | logger.info("Reading file at %s", file_path)
27 | with open(file_path) as dataset_file:
28 | dataset_json = json.load(dataset_file)
29 | dataset = dataset_json['data']
30 | logger.info("Reading the dataset")
31 | out_data = []
32 | for article in dataset:
33 | for paragraph_json in article['paragraphs']:
34 | paragraph = paragraph_json["context"]
35 | for question_answer in paragraph_json['qas']:
36 | question_answer["question"] = question_answer["question"].strip().replace("\n", "")
37 | out_data.append((question_answer, paragraph))
38 | return out_data
39 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/addsent/utils.py:
--------------------------------------------------------------------------------
1 | """Utilities for AddSent generator."""
2 | from typing import List, Dict, Tuple, Optional
3 |
4 | class ConstituencyParse(object):
5 | """A CoreNLP constituency parse (or a node in a parse tree).
6 |
7 | Word-level constituents have |word| and |index| set and no children.
8 | Phrase-level constituents have no |word| or |index| and have at least one child.
9 | """
10 | def __init__(self, tag, children=None, word=None, index=None):
11 | self.tag = tag
12 | if children:
13 | self.children = children
14 | else:
15 | self.children = None
16 | self.word = word
17 | self.index = index
18 |
19 | @classmethod
20 | def _recursive_parse_corenlp(cls, tokens, i, j):
21 | orig_i = i
22 | if tokens[i] == '(':
23 | tag = tokens[i + 1]
24 | children = []
25 | i = i + 2
26 | while True:
27 | child, i, j = cls._recursive_parse_corenlp(tokens, i, j)
28 | if isinstance(child, cls):
29 | children.append(child)
30 | if tokens[i] == ')':
31 | return cls(tag, children), i + 1, j
32 | else:
33 | if tokens[i] != ')':
34 | raise ValueError('Expected ")" following leaf')
35 | return cls(tag, word=child, index=j), i + 1, j + 1
36 | else:
37 | # Only other possibility is it's a word
38 | return tokens[i], i + 1, j
39 |
40 | @classmethod
41 | def from_corenlp(cls, s):
42 | """Parses the "parse" attribute returned by CoreNLP parse annotator."""
43 | # "parse": "(ROOT\n (SBARQ\n (WHNP (WDT What)\n (NP (NN portion)\n (PP (IN of)\n (NP\n (NP (NNS households))\n (PP (IN in)\n (NP (NNP Jacksonville)))))))\n (SQ\n (VP (VBP have)\n (NP (RB only) (CD one) (NN person))))\n (. ? )))",
44 | s_spaced = s.replace('\n', ' ').replace('(', ' ( ').replace(')', ' ) ')
45 | tokens = [t for t in s_spaced.split(' ') if t]
46 | tree, index, num_words = cls._recursive_parse_corenlp(tokens, 0, 0)
47 | if index != len(tokens):
48 | raise ValueError('Only parsed %d of %d tokens' % (index, len(tokens)))
49 | return tree
50 |
51 | def is_singleton(self):
52 | if self.word:
53 | return True
54 | if len(self.children) > 1:
55 | return False
56 | return self.children[0].is_singleton()
57 |
58 | def print_tree(self, indent=0):
59 | spaces = ' ' * indent
60 | if self.word:
61 | print(f"{spaces}{self.tag}: {self.word} ({self.index})")
62 | else:
63 | print(f"{spaces}{self.tag}")
64 | for c in self.children:
65 | c.print_tree(indent=indent + 1)
66 |
67 | def get_phrase(self):
68 | if self.word:
69 | return self.word
70 | toks = []
71 | for i, c in enumerate(self.children):
72 | p = c.get_phrase()
73 | if i == 0 or p.startswith("'"):
74 | toks.append(p)
75 | else:
76 | toks.append(' ' + p)
77 | return ''.join(toks)
78 |
79 | def get_start_index(self):
80 | if self.index is not None:
81 | return self.index
82 | return self.children[0].get_start_index()
83 |
84 | def get_end_index(self):
85 | if self.index is not None:
86 | return self.index + 1
87 | return self.children[-1].get_end_index()
88 |
89 | @classmethod
90 | def _recursive_replace_words(cls, tree, new_words, i):
91 | if tree.word:
92 | new_word = new_words[i]
93 | return (cls(tree.tag, word=new_word, index=tree.index), i + 1)
94 | new_children = []
95 | for c in tree.children:
96 | new_child, i = cls._recursive_replace_words(c, new_words, i)
97 | new_children.append(new_child)
98 | return cls(tree.tag, children=new_children), i
99 |
100 | @classmethod
101 | def replace_words(cls, tree, new_words):
102 | """Return a new tree, with new words replacing old ones."""
103 | new_tree, i = cls._recursive_replace_words(tree, new_words, 0)
104 | if i != len(new_words):
105 | raise ValueError('len(new_words) == %d != i == %d' % (len(new_words), i))
106 | return new_tree
107 |
108 | def rejoin(tokens: List[Dict[str, str]], sep: str = None) -> str:
109 | """Rejoin tokens into the original sentence.
110 |
111 | Args:
112 | tokens: a list of dicts containing 'originalText' and 'before' fields.
113 | All other fields will be ignored.
114 | sep: if provided, use the given character as a separator instead of
115 | the 'before' field (e.g. if you want to preserve where tokens are).
116 | Returns: the original sentence that generated this CoreNLP token list.
117 | """
118 | if sep is None:
119 | return ''.join('%s%s' % (t['before'], t['originalText']) for t in tokens)
120 | else:
121 | # Use the given separator instead
122 | return sep.join(t['originalText'] for t in tokens)
123 |
124 |
125 | def get_tokens_for_answers(answer_objs: List[Tuple[int, Dict]], corenlp_obj: Dict) -> Tuple[int, List]:
126 | """Get CoreNLP tokens corresponding to a SQuAD answer object."""
127 | first_a_toks = None
128 | for i, a_obj in enumerate(answer_objs):
129 | a_toks = []
130 | answer_start = a_obj['answer_start']
131 | answer_end = answer_start + len(a_obj['text'])
132 | for sent in corenlp_obj['sentences']:
133 | for tok in sent['tokens']:
134 | if tok['characterOffsetBegin'] >= answer_end:
135 | continue
136 | if tok['characterOffsetEnd'] <= answer_start:
137 | continue
138 | a_toks.append(tok)
139 | if rejoin(a_toks).strip() == a_obj['text']:
140 | # Make sure that the tokens reconstruct the answer
141 | return i, a_toks
142 | if i == 0:
143 | first_a_toks = a_toks
144 | # None of the extracted token lists reconstruct the answer
145 | # Default to the first
146 | return 0, first_a_toks
147 |
148 | def get_determiner_for_answers(answer_objs: List[Dict]) -> Optional[str]:
149 | for ans in answer_objs:
150 | words = ans['text'].split(' ')
151 | if words[0].lower() == 'the':
152 | return 'the'
153 | if words[0].lower() in ('a', 'an'):
154 | return 'a'
155 | return None
156 |
157 | def compress_whnp(tree, inside_whnp=False):
158 | if not tree.children: return tree # Reached leaf
159 | # Compress all children
160 | for i, c in enumerate(tree.children):
161 | tree.children[i] = compress_whnp(c, inside_whnp=inside_whnp or tree.tag == 'WHNP')
162 | if tree.tag != 'WHNP':
163 | if inside_whnp:
164 | # Wrap everything in an NP
165 | return ConstituencyParse('NP', children=[tree])
166 | return tree
167 | wh_word = None
168 | new_np_children = []
169 | new_siblings = []
170 | for i, c in enumerate(tree.children):
171 | if i == 0:
172 | if c.tag in ('WHNP', 'WHADJP', 'WHAVP', 'WHPP'):
173 | wh_word = c.children[0]
174 | new_np_children.extend(c.children[1:])
175 | elif c.tag in ('WDT', 'WP', 'WP$', 'WRB'):
176 | wh_word = c
177 | else:
178 | # No WH-word at start of WHNP
179 | return tree
180 | else:
181 | if c.tag == 'SQ': # Due to bad parse, SQ may show up here
182 | new_siblings = tree.children[i:]
183 | break
184 | # Wrap everything in an NP
185 | new_np_children.append(ConstituencyParse('NP', children=[c]))
186 | if new_np_children:
187 | new_np = ConstituencyParse('NP', children=new_np_children)
188 | new_tree = ConstituencyParse('WHNP', children=[wh_word, new_np])
189 | else:
190 | new_tree = tree
191 | if new_siblings:
192 | new_tree = ConstituencyParse('SBARQ', children=[new_tree] + new_siblings)
193 | return new_tree
194 |
195 | def read_const_parse(parse_str):
196 | tree = ConstituencyParse.from_corenlp(parse_str)
197 | new_tree = compress_whnp(tree)
198 | return new_tree
199 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/generator.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Dict, Union, Iterable, List
3 | from collections import defaultdict
4 | import itertools
5 |
6 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
7 |
8 | class Generator():
9 | r"""An abstract ``Generator`` class.
10 |
11 | A ``Generator`` takes as inputs an iterable of seeds (for examples
12 | samples from a training dataset) and edit them to generate
13 | potential adversarial examples.
14 |
15 | This class is an abstract class. To implement a ``Generator``, you
16 | should override the `generate_from_seed(self, seed: any)` method
17 | with a specific method to use for yielding adversarial samples
18 | from a seed sample.
19 |
20 | Optionally, you should also:
21 | - define a typing class for the ``seed`` objects
22 | - define a default seed source in the ``__init__`` class, for
23 | examples by downloading an appropriate dataset. See examples
24 | in the ``AddSentGenerator`` class.
25 |
26 | Args:
27 | default_seeds: Default Iterable to use as source of seeds.
28 | quiet: Output debuging information.
29 |
30 | Inputs:
31 | **seed_instances** (optional): Instances to use as seed
32 | for adversarial example generation. If None uses the
33 | default_seeds providing at class instantiation.
34 | Default to None
35 | **num_epochs** (optional): How many times should we iterate
36 | over the seeds? If None, we will iterate over it forever.
37 | Default to None.
38 | **shuffle** (optional): Shuffle the instances before iteration.
39 | If True, we will shuffle the instances before iterating.
40 | Default to False.
41 |
42 | Yields:
43 | **adversarial_examples** (Iterable): Adversarial examples
44 | generated from the seeds.
45 |
46 | Examples::
47 |
48 | >> generator = Generator()
49 | >> examples = generator(num_epochs=1)
50 | """
51 |
52 | def __init__(self,
53 | default_seeds: Iterable = None,
54 | quiet: bool = False):
55 | self.default_seeds = default_seeds
56 | self.quiet: bool = quiet
57 |
58 | self._epochs: Dict[int, int] = defaultdict(int)
59 |
60 | def generate_from_seed(self, seed: any):
61 | r"""Generate an adversarial example from a seed.
62 | """
63 | raise NotImplementedError
64 |
65 | def __call__(self,
66 | seeds: Iterable = None,
67 | num_epochs: int = None,
68 | shuffle: bool = True) -> Iterable:
69 | r"""Generate adversarial examples using _generate_from_seed.
70 |
71 | Args:
72 | seeds: Instances to use as seed for adversarial
73 | example generation.
74 | num_epochs: How many times should we iterate over the seeds?
75 | If None, we will iterate over it forever.
76 | shuffle: Shuffle the instances before iteration.
77 | If True, we will shuffle the instances before iterating.
78 |
79 | Yields: adversarial_examples
80 | adversarial_examples: Adversarial examples generated
81 | from the seeds.
82 | """
83 | if seeds is None:
84 | if self.default_seeds is not None:
85 | seeds = self.default_seeds
86 | else:
87 | return
88 | # Instances is likely to be a list, which cannot be used as a key,
89 | # so we take the object id instead.
90 | key = id(seeds)
91 | starting_epoch = self._epochs[key]
92 |
93 | if num_epochs is None:
94 | epochs: Iterable[int] = itertools.count(starting_epoch)
95 | else:
96 | epochs = range(starting_epoch, starting_epoch + num_epochs)
97 |
98 | for epoch in epochs:
99 | self._epochs[key] = epoch
100 | for seed in seeds:
101 | yield from self.generate_from_seed(seed)
102 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/swag/__init__.py:
--------------------------------------------------------------------------------
1 | from .swag_generator import SwagGenerator
2 | from .activitynet_captions_reader import ActivityNetCaptionsDatasetReader
3 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/swag/activitynet_captions_reader.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import json
3 | import logging
4 | from overrides import overrides
5 | from unidecode import unidecode
6 |
7 | from allennlp.common.file_utils import cached_path
8 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader
9 | from allennlp.data.fields import TextField, MetadataField
10 | from allennlp.data.instance import Instance
11 | from allennlp.data.tokenizers import Tokenizer, WordTokenizer
12 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
13 |
14 | from adversarialnlp.generators.swag.utils import pairwise, postprocess
15 |
16 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
17 |
18 |
19 | @DatasetReader.register("activitynet_captions")
20 | class ActivityNetCaptionsDatasetReader(DatasetReader):
21 | r""" Reads ActivityNet Captions JSON files and creates a dataset suitable for crafting
22 | adversarial examples with swag using these captions.
23 |
24 | Expected format:
25 | JSON dict[video_id, video_obj] where
26 | video_id: str,
27 | video_obj: {
28 | "duration": float,
29 | "timestamps": list of pairs of float,
30 | "sentences": list of strings
31 | }
32 |
33 | The output of ``read`` is a list of ``Instance`` s with the fields:
34 | video_id: ``MetadataField``
35 | first_sentence: ``TextField``
36 | second_sentence: ``TextField``
37 |
38 | The instances are created from all consecutive pair of sentences
39 | associated to each video.
40 | Ex: if a video has three associated sentences: s1, s2, s3 read will
41 | generate two instances:
42 |
43 | 1. Instance("first_sentence" = s1, "second_sentence" = s2)
44 | 2. Instance("first_sentence" = s2, "second_sentence" = s3)
45 |
46 | Args:
47 | lazy : If True, training will start sooner, but will take
48 | longer per batch. This allows training with datasets that
49 | are too large to fit in memory. Passed to DatasetReader.
50 | tokenizer : Tokenizer to use to split the title and abstract
51 | into words or other kinds of tokens.
52 | token_indexers : Indexers used to define input token
53 | representations.
54 | """
55 | def __init__(self,
56 | lazy: bool = False,
57 | tokenizer: Tokenizer = None,
58 | token_indexers: Dict[str, TokenIndexer] = None) -> None:
59 | super().__init__(lazy)
60 | self._tokenizer = tokenizer or WordTokenizer()
61 | self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
62 |
63 | @overrides
64 | def _read(self, file_path):
65 | with open(cached_path(file_path), "r") as data_file:
66 | logger.info("Reading instances from: %s", file_path)
67 | json_data = json.load(data_file)
68 | for video_id, value in json_data.items():
69 | sentences = [postprocess(unidecode(x.strip()))
70 | for x in value['sentences']]
71 | for first_sentence, second_sentence in pairwise(sentences):
72 | yield self.text_to_instance(video_id, first_sentence, second_sentence)
73 |
74 | @overrides
75 | def text_to_instance(self,
76 | video_id: str,
77 | first_sentence: str,
78 | second_sentence: str) -> Instance: # type: ignore
79 | # pylint: disable=arguments-differ
80 | tokenized_first_sentence = self._tokenizer.tokenize(first_sentence)
81 | tokenized_second_sentence = self._tokenizer.tokenize(second_sentence)
82 | first_sentence_field = TextField(tokenized_first_sentence, self._token_indexers)
83 | second_sentence_field = TextField(tokenized_second_sentence, self._token_indexers)
84 | fields = {'video_id': MetadataField(video_id),
85 | 'first_sentence': first_sentence_field,
86 | 'second_sentence': second_sentence_field}
87 | return Instance(fields)
88 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/swag/openai_transformer_model.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple, Union, Optional
2 |
3 | import torch
4 | import numpy as np
5 |
6 | from allennlp.common.checks import ConfigurationError
7 | from allennlp.data.vocabulary import Vocabulary
8 | from allennlp.models.model import Model
9 | from allennlp.modules.openai_transformer import OpenaiTransformer
10 | from allennlp.modules.token_embedders import OpenaiTransformerEmbedder
11 | from allennlp.nn.util import get_text_field_mask, remove_sentence_boundaries
12 |
13 |
14 | @Model.register('openai-transformer-language-model')
15 | class OpenAITransformerLanguageModel(Model):
16 | """
17 | The ``OpenAITransformerLanguageModel`` is a wrapper around ``OpenATransformerModule``.
18 |
19 | Parameters
20 | ----------
21 | vocab: ``Vocabulary``
22 | remove_bos_eos: ``bool``, optional (default: True)
23 | Typically the provided token indexes will be augmented with
24 | begin-sentence and end-sentence tokens. If this flag is True
25 | the corresponding embeddings will be removed from the return values.
26 | """
27 | def __init__(self,
28 | vocab: Vocabulary,
29 | openai_token_embedder: OpenaiTransformerEmbedder,
30 | remove_bos_eos: bool = True) -> None:
31 | super().__init__(vocab)
32 | model_path = "https://s3-us-west-2.amazonaws.com/allennlp/models/openai-transformer-lm-2018.07.23.tar.gz"
33 | indexer = OpenaiTransformerBytePairIndexer(model_path=model_path)
34 | transformer = OpenaiTransformer(model_path=model_path)
35 | self._token_embedders = OpenaiTransformerEmbedder(transformer=transformer, top_layer_only=True)
36 | self._remove_bos_eos = remove_bos_eos
37 |
38 | def _get_target_token_embedding(self,
39 | token_embeddings: torch.Tensor,
40 | mask: torch.Tensor,
41 | direction: int) -> torch.Tensor:
42 | # Need to shift the mask in the correct direction
43 | zero_col = token_embeddings.new_zeros(mask.size(0), 1).byte()
44 | if direction == 0:
45 | # forward direction, get token to right
46 | shifted_mask = torch.cat([zero_col, mask[:, 0:-1]], dim=1)
47 | else:
48 | shifted_mask = torch.cat([mask[:, 1:], zero_col], dim=1)
49 | return token_embeddings.masked_select(shifted_mask.unsqueeze(-1)).view(-1, self._forward_dim)
50 |
51 | def _compute_loss(self,
52 | lm_embeddings: torch.Tensor,
53 | token_embeddings: torch.Tensor,
54 | forward_targets: torch.Tensor,
55 | backward_targets: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
56 | # lm_embeddings is shape (batch_size, timesteps, dim * 2)
57 | # forward_targets, backward_targets are shape (batch_size, timesteps)
58 | # masked with 0
59 | forward_embeddings, backward_embeddings = lm_embeddings.chunk(2, -1)
60 | losses: List[torch.Tensor] = []
61 | for idx, embedding, targets in ((0, forward_embeddings, forward_targets),
62 | (1, backward_embeddings, backward_targets)):
63 | mask = targets > 0
64 | # we need to subtract 1 to undo the padding id since the softmax
65 | # does not include a padding dimension
66 | non_masked_targets = targets.masked_select(mask) - 1
67 | non_masked_embedding = embedding.masked_select(
68 | mask.unsqueeze(-1)
69 | ).view(-1, self._forward_dim)
70 | # note: need to return average loss across forward and backward
71 | # directions, but total sum loss across all batches.
72 | # Assuming batches include full sentences, forward and backward
73 | # directions have the same number of samples, so sum up loss
74 | # here then divide by 2 just below
75 | if not self._softmax_loss.tie_embeddings or not self._use_character_inputs:
76 | losses.append(self._softmax_loss(non_masked_embedding, non_masked_targets))
77 | else:
78 | # we also need the token embeddings corresponding to the
79 | # the targets
80 | raise NotImplementedError("This requires SampledSoftmaxLoss, which isn't implemented yet.")
81 | # pylint: disable=unreachable
82 | non_masked_token_embedding = self._get_target_token_embedding(token_embeddings, mask, idx)
83 | losses.append(self._softmax(non_masked_embedding,
84 | non_masked_targets,
85 | non_masked_token_embedding))
86 |
87 | return losses[0], losses[1]
88 |
89 | def forward(self, # type: ignore
90 | source: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]:
91 | """
92 | Computes the averaged forward and backward LM loss from the batch.
93 |
94 | By convention, the input dict is required to have at least a ``"tokens"``
95 | entry that's the output of a ``SingleIdTokenIndexer``, which is used
96 | to compute the language model targets.
97 |
98 | If the model was instantatiated with ``remove_bos_eos=True``,
99 | then it is expected that each of the input sentences was augmented with
100 | begin-sentence and end-sentence tokens.
101 |
102 | Parameters
103 | ----------
104 | tokens: ``torch.Tensor``, required.
105 | The output of ``Batch.as_tensor_dict()`` for a batch of sentences.
106 |
107 | Returns
108 | -------
109 | Dict with keys:
110 |
111 | ``'loss'``: ``torch.Tensor``
112 | averaged forward/backward negative log likelihood
113 | ``'forward_loss'``: ``torch.Tensor``
114 | forward direction negative log likelihood
115 | ``'backward_loss'``: ``torch.Tensor``
116 | backward direction negative log likelihood
117 | ``'lm_embeddings'``: ``torch.Tensor``
118 | (batch_size, timesteps, embed_dim) tensor of top layer contextual representations
119 | ``'mask'``: ``torch.Tensor``
120 | (batch_size, timesteps) mask for the embeddings
121 | """
122 | # pylint: disable=arguments-differ
123 | mask = get_text_field_mask(source)
124 |
125 | # We must have token_ids so that we can compute targets
126 | token_ids = source.get("tokens")
127 | if token_ids is None:
128 | raise ConfigurationError("Your data must have a 'tokens': SingleIdTokenIndexer() "
129 | "in order to use the BidirectionalLM")
130 |
131 | # Use token_ids to compute targets
132 | forward_targets = torch.zeros_like(token_ids)
133 | backward_targets = torch.zeros_like(token_ids)
134 | forward_targets[:, 0:-1] = token_ids[:, 1:]
135 | backward_targets[:, 1:] = token_ids[:, 0:-1]
136 |
137 | embeddings = self._text_field_embedder(source)
138 |
139 | # Apply LayerNorm if appropriate.
140 | embeddings = self._layer_norm(embeddings)
141 |
142 | contextual_embeddings = self._contextualizer(embeddings, mask)
143 |
144 | # add dropout
145 | contextual_embeddings = self._dropout(contextual_embeddings)
146 |
147 | # compute softmax loss
148 | forward_loss, backward_loss = self._compute_loss(contextual_embeddings,
149 | embeddings,
150 | forward_targets,
151 | backward_targets)
152 |
153 | num_targets = torch.sum((forward_targets > 0).long())
154 | if num_targets > 0:
155 | average_loss = 0.5 * (forward_loss + backward_loss) / num_targets.float()
156 | else:
157 | average_loss = torch.tensor(0.0).to(forward_targets.device) # pylint: disable=not-callable
158 | # this is stored to compute perplexity if needed
159 | self._last_average_loss[0] = average_loss.detach().item()
160 |
161 | if num_targets > 0:
162 | # loss is directly minimized
163 | if self._loss_scale == 'n_samples':
164 | scale_factor = num_targets.float()
165 | else:
166 | scale_factor = self._loss_scale
167 |
168 | return_dict = {
169 | 'loss': average_loss * scale_factor,
170 | 'forward_loss': forward_loss * scale_factor / num_targets.float(),
171 | 'backward_loss': backward_loss * scale_factor / num_targets.float()
172 | }
173 | else:
174 | # average_loss zero tensor, return it for all
175 | return_dict = {
176 | 'loss': average_loss,
177 | 'forward_loss': average_loss,
178 | 'backward_loss': average_loss
179 | }
180 |
181 | if self._remove_bos_eos:
182 | contextual_embeddings, mask = remove_sentence_boundaries(contextual_embeddings, mask)
183 |
184 | return_dict.update({
185 | 'lm_embeddings': contextual_embeddings,
186 | 'mask': mask
187 | })
188 |
189 | return return_dict
190 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/swag/simple_bilm.py:
--------------------------------------------------------------------------------
1 | """
2 | A wrapper around ai2s elmo LM to allow for an lm objective...
3 | """
4 |
5 | from typing import Optional, Tuple
6 | from typing import Union, List, Dict
7 |
8 | import numpy as np
9 | import torch
10 | from allennlp.common.checks import ConfigurationError
11 | from allennlp.data import Token, Vocabulary, Instance
12 | from allennlp.data.dataset import Batch
13 | from allennlp.data.fields import TextField
14 | from allennlp.data.token_indexers import SingleIdTokenIndexer
15 | from allennlp.modules.augmented_lstm import AugmentedLstm
16 | from allennlp.modules.seq2seq_encoders.pytorch_seq2seq_wrapper import PytorchSeq2SeqWrapper
17 | from allennlp.nn.util import sequence_cross_entropy_with_logits
18 | from torch.autograd import Variable
19 | from torch.nn import functional as F
20 | from torch.nn.utils.rnn import PackedSequence
21 |
22 |
23 | def _de_duplicate_generations(generations):
24 | """
25 | Given a list of list of strings, filter out the ones that are duplicates. and return an idx corresponding
26 | to the good ones
27 | :param generations:
28 | :return:
29 | """
30 | dup_set = set()
31 | unique_idx = []
32 | for i, gen_i in enumerate(generations):
33 | gen_i_str = ' '.join(gen_i)
34 | if gen_i_str not in dup_set:
35 | unique_idx.append(i)
36 | dup_set.add(gen_i_str)
37 | return [generations[i] for i in unique_idx], np.array(unique_idx)
38 |
39 |
40 | class StackedLstm(torch.nn.Module):
41 | """
42 | A stacked LSTM.
43 |
44 | Parameters
45 | ----------
46 | input_size : int, required
47 | The dimension of the inputs to the LSTM.
48 | hidden_size : int, required
49 | The dimension of the outputs of the LSTM.
50 | num_layers : int, required
51 | The number of stacked LSTMs to use.
52 | recurrent_dropout_probability: float, optional (default = 0.0)
53 | The dropout probability to be used in a dropout scheme as stated in
54 | `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks
55 | `_ .
56 | use_input_projection_bias : bool, optional (default = True)
57 | Whether or not to use a bias on the input projection layer. This is mainly here
58 | for backwards compatibility reasons and will be removed (and set to False)
59 | in future releases.
60 |
61 | Returns
62 | -------
63 | output_accumulator : PackedSequence
64 | The outputs of the interleaved LSTMs per timestep. A tensor of shape
65 | (batch_size, max_timesteps, hidden_size) where for a given batch
66 | element, all outputs past the sequence length for that batch are
67 | zero tensors.
68 | """
69 |
70 | def __init__(self,
71 | input_size: int,
72 | hidden_size: int,
73 | num_layers: int,
74 | recurrent_dropout_probability: float = 0.0,
75 | use_highway: bool = True,
76 | use_input_projection_bias: bool = True,
77 | go_forward: bool = True) -> None:
78 | super(StackedLstm, self).__init__()
79 |
80 | # Required to be wrapped with a :class:`PytorchSeq2SeqWrapper`.
81 | self.input_size = input_size
82 | self.hidden_size = hidden_size
83 | self.num_layers = num_layers
84 |
85 | layers = []
86 | lstm_input_size = input_size
87 | for layer_index in range(num_layers):
88 | layer = AugmentedLstm(lstm_input_size, hidden_size, go_forward,
89 | recurrent_dropout_probability=recurrent_dropout_probability,
90 | use_highway=use_highway,
91 | use_input_projection_bias=use_input_projection_bias)
92 | lstm_input_size = hidden_size
93 | self.add_module('layer_{}'.format(layer_index), layer)
94 | layers.append(layer)
95 | self.lstm_layers = layers
96 |
97 | def forward(self, # pylint: disable=arguments-differ
98 | inputs: PackedSequence,
99 | initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
100 | """
101 | Parameters
102 | ----------
103 | inputs : ``PackedSequence``, required.
104 | A batch first ``PackedSequence`` to run the stacked LSTM over.
105 | initial_state : Tuple[torch.Tensor, torch.Tensor], optional, (default = None)
106 | A tuple (state, memory) representing the initial hidden state and memory
107 | of the LSTM. Each tensor has shape (1, batch_size, output_dimension).
108 |
109 | Returns
110 | -------
111 | output_sequence : PackedSequence
112 | The encoded sequence of shape (batch_size, sequence_length, hidden_size)
113 | final_states: torch.Tensor
114 | The per-layer final (state, memory) states of the LSTM, each with shape
115 | (num_layers, batch_size, hidden_size).
116 | """
117 | if not initial_state:
118 | hidden_states = [None] * len(self.lstm_layers)
119 | elif initial_state[0].size()[0] != len(self.lstm_layers):
120 | raise ConfigurationError("Initial states were passed to forward() but the number of "
121 | "initial states does not match the number of layers.")
122 | else:
123 | hidden_states = list(zip(initial_state[0].split(1, 0),
124 | initial_state[1].split(1, 0)))
125 |
126 | output_sequence = inputs
127 | final_states = []
128 | for i, state in enumerate(hidden_states):
129 | layer = getattr(self, 'layer_{}'.format(i))
130 | # The state is duplicated to mirror the Pytorch API for LSTMs.
131 | output_sequence, final_state = layer(output_sequence, state)
132 | final_states.append(final_state)
133 |
134 | final_state_tuple = tuple(torch.cat(state_list, 0) for state_list in zip(*final_states))
135 | return output_sequence, final_state_tuple
136 |
137 |
138 | class SimpleBiLM(torch.nn.Module):
139 | def __init__(self,
140 | vocab: Vocabulary,
141 | recurrent_dropout_probability: float = 0.0,
142 | embedding_dropout_probability: float = 0.0,
143 | input_size=512,
144 | hidden_size=512) -> None:
145 | """
146 | :param options_file: for initializing elmo BiLM
147 | :param weight_file: for initializing elmo BiLM
148 | :param requires_grad: Whether or not to finetune the LSTM layers
149 | :param recurrent_dropout_probability: recurrent dropout to add to LSTM layers
150 | """
151 | super(SimpleBiLM, self).__init__()
152 |
153 | self.forward_lm = PytorchSeq2SeqWrapper(StackedLstm(
154 | input_size=input_size, hidden_size=hidden_size, num_layers=2, go_forward=True,
155 | recurrent_dropout_probability=recurrent_dropout_probability,
156 | use_input_projection_bias=False, use_highway=True), stateful=True)
157 | self.reverse_lm = PytorchSeq2SeqWrapper(StackedLstm(
158 | input_size=input_size, hidden_size=hidden_size, num_layers=2, go_forward=False,
159 | recurrent_dropout_probability=recurrent_dropout_probability,
160 | use_input_projection_bias=False, use_highway=True), stateful=True)
161 |
162 | # This will also be the encoder
163 | self.decoder = torch.nn.Linear(512, vocab.get_vocab_size(namespace='tokens'))
164 |
165 | self.vocab = vocab
166 | self.register_buffer('eos_tokens', torch.LongTensor([vocab.get_token_index(tok) for tok in
167 | ['.', '!', '?', '@@UNKNOWN@@', '@@PADDING@@', '@@bos@@',
168 | '@@eos@@']]))
169 | self.register_buffer('invalid_tokens', torch.LongTensor([vocab.get_token_index(tok) for tok in
170 | ['@@UNKNOWN@@', '@@PADDING@@', '@@bos@@', '@@eos@@',
171 | '@@NEWLINE@@']]))
172 | self.embedding_dropout_probability = embedding_dropout_probability
173 |
174 | def embed_words(self, words):
175 | # assert words.dim() == 2
176 | return F.embedding(words, self.decoder.weight)
177 | # if not self.training:
178 | # return F.embedding(words, self.decoder.weight)
179 | # Embedding dropout
180 | # vocab_size = self.decoder.weight.size(0)
181 | # mask = Variable(
182 | # self.decoder.weight.data.new(vocab_size, 1).bernoulli_(1 - self.embedding_dropout_probability).expand_as(
183 | # self.decoder.weight) / (1 - self.embedding_dropout_probability))
184 |
185 | # padding_idx = 0
186 | # embeds = self.decoder._backend.Embedding.apply(words, mask * self.decoder.weight, padding_idx, None,
187 | # 2, False, False)
188 | # return embeds
189 |
190 | def timestep_to_ids(self, timestep_tokenized: List[str]):
191 | """ Just a single timestep (so dont add BOS or EOS"""
192 | return torch.tensor([self.vocab.get_token_index(x) for x in timestep_tokenized])[:, None]
193 |
194 | def batch_to_ids(self, stories_tokenized: List[List[str]]):
195 | """
196 | Simple wrapper around _elmo_batch_to_ids
197 | :param batch: A list of tokenized sentences.
198 | :return: A tensor of padded character ids.
199 | """
200 | batch = Batch([Instance(
201 | {'story': TextField([Token('@@bos@@')] + [Token(x) for x in story] + [Token('@@eos@@')],
202 | token_indexers={
203 | 'tokens': SingleIdTokenIndexer(namespace='tokens', lowercase_tokens=True)})})
204 | for story in stories_tokenized])
205 | batch.index_instances(self.vocab)
206 | words = {k: v['tokens'] for k, v in batch.as_tensor_dict().items()}['story']
207 | return words
208 |
209 | def conditional_generation(self, context: List[str], gt_completion: List[str],
210 | batch_size: int = 128, max_gen_length: int = 25,
211 | same_length_as_gt: bool = False, first_is_gold: bool = False):
212 | """
213 | Generate conditoned on the context. While we're at it we'll score the GT going forwards
214 | :param context: List of tokens to condition on. We'll add the BOS marker to it
215 | :param gt_completion: The gold truth completion
216 | :param batch_size: Number of sentences to generate
217 | :param max_gen_length: Max length for genertaed sentences (irrelvant if same_length_as_gt=True)
218 | :param same_length_as_gt: set to True if you want all the sents to have the same length as the gt_completion
219 | :param first_is_gold: set to True if you want the first sample to be the gt_completion
220 | :return:
221 | """
222 | # Forward condition on context, then repeat to be the right batch size:
223 | # (layer_index, batch_size, fwd hidden dim)
224 | log_probs = self(self.batch_to_ids([context]), use_forward=True,
225 | use_reverse=False, compute_logprobs=True)
226 | forward_logprobs = log_probs['forward_logprobs']
227 | self.forward_lm._states = tuple(x.repeat(1, batch_size, 1).contiguous() for x in self.forward_lm._states)
228 | # Each item will be (token, score)
229 | generations = [[(context[-1], 0.0)] for i in range(batch_size)]
230 | mask = forward_logprobs.new(batch_size).long().fill_(1)
231 |
232 | gt_completion_padded = [self.vocab.get_token_index(gt_token) for gt_token in
233 | [x.lower() for x in gt_completion] + ['@@PADDING@@'] * (
234 | max_gen_length - len(gt_completion))]
235 |
236 | for index, gt_token_ind in enumerate(gt_completion_padded):
237 | embeds = self.embed_words(self.timestep_to_ids([gen[-1][0] for gen in generations]))
238 | next_dists = F.softmax(self.decoder(self.forward_lm(embeds, mask[:, None]))[:, 0], dim=1)
239 |
240 | # Perform hacky stuff on the distribution (disallowing BOS, EOS, that sorta thing
241 | sampling_probs = next_dists.clone()
242 | sampling_probs[:, self.invalid_tokens] = 0.0
243 |
244 | if first_is_gold:
245 | # Gold truth is first row
246 | sampling_probs[0].zero_()
247 | sampling_probs[0, gt_token_ind] = 1
248 |
249 | if same_length_as_gt:
250 | if index == (len(gt_completion) - 1):
251 | sampling_probs.zero_()
252 | sampling_probs[:, gt_token_ind] = 1
253 | else:
254 | sampling_probs[:, self.eos_tokens] = 0.0
255 |
256 | sampling_probs = sampling_probs / sampling_probs.sum(1, keepdim=True)
257 |
258 | next_preds = torch.multinomial(sampling_probs, 1).squeeze(1)
259 | next_scores = np.log(next_dists[
260 | torch.arange(0, next_dists.size(0),
261 | out=mask.data.new(next_dists.size(0))),
262 | next_preds,
263 | ].cpu().detach().numpy())
264 | for i, (gen_list, pred_id, score_i, mask_i) in enumerate(
265 | zip(generations, next_preds.cpu().detach().numpy(), next_scores, mask.data.cpu().detach().numpy())):
266 | if mask_i:
267 | gen_list.append((self.vocab.get_token_from_index(pred_id), score_i))
268 | is_eos = (next_preds[:, None] == self.eos_tokens[None]).max(1)[0]
269 | mask[is_eos] = 0
270 | if mask.sum().item() == 0:
271 | break
272 | generation_scores = np.zeros((len(generations), max([len(g) - 1 for g in generations])), dtype=np.float32)
273 | for i, gen in enumerate(generations):
274 | for j, (_, v) in enumerate(gen[1:]):
275 | generation_scores[i, j] = v
276 |
277 | generation_toks, idx = _de_duplicate_generations([[tok for (tok, score) in gen[1:]] for gen in generations])
278 | return generation_toks, generation_scores[idx], forward_logprobs.cpu().detach().numpy()
279 |
280 | def _chunked_logsoftmaxes(self, activation, word_targets, chunk_size=256):
281 | """
282 | do the softmax in chunks so memory doesnt explode
283 | :param activation: [batch, T, dim]
284 | :param targets: [batch, T] indices
285 | :param chunk_size: you might need to tune this based on GPU specs
286 | :return:
287 | """
288 | all_logprobs = []
289 | num_chunks = (activation.size(0) - 1) // chunk_size + 1
290 | for activation_chunk, target_chunk in zip(torch.chunk(activation, num_chunks, dim=0),
291 | torch.chunk(word_targets, num_chunks, dim=0)):
292 | assert activation_chunk.size()[:2] == target_chunk.size()[:2]
293 | targets_flat = target_chunk.view(-1)
294 | time_indexer = torch.arange(0, targets_flat.size(0),
295 | out=target_chunk.data.new(targets_flat.size(0))) % target_chunk.size(1)
296 | batch_indexer = torch.arange(0, targets_flat.size(0),
297 | out=target_chunk.data.new(targets_flat.size(0))) / target_chunk.size(1)
298 | all_logprobs.append(F.log_softmax(self.decoder(activation_chunk), 2)[
299 | batch_indexer, time_indexer, targets_flat].view(*target_chunk.size()))
300 | return torch.cat(all_logprobs, 0)
301 |
302 | def forward(self, words: torch.Tensor, use_forward: bool = True, use_reverse: bool = True,
303 | compute_logprobs: bool = False) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
304 | """
305 | use this for training the LM
306 | :param words: [batch_size, N] words. assuming you're starting with BOS and ending with EOS here
307 | :return:
308 | """
309 | encoded_inputs = self.embed_words(words)
310 | mask = (words != 0).long()[:, 2:]
311 | word_targets = words[:, 1:-1].contiguous()
312 |
313 | result_dict = {
314 | 'mask': mask,
315 | 'word_targets': word_targets,
316 | }
317 | # TODO: try to reduce duplicate code here
318 | if use_forward:
319 | self.forward_lm.reset_states()
320 | forward_activation = self.forward_lm(encoded_inputs[:, :-2], mask)
321 |
322 | if compute_logprobs:
323 | # being memory efficient here is critical if the input tensors are large
324 | result_dict['forward_logprobs'] = self._chunked_logsoftmaxes(forward_activation,
325 | word_targets) * mask.float()
326 | else:
327 |
328 | result_dict['forward_logits'] = self.decoder(forward_activation)
329 | result_dict['forward_loss'] = sequence_cross_entropy_with_logits(result_dict['forward_logits'],
330 | word_targets,
331 | mask)
332 | if use_reverse:
333 | self.reverse_lm.reset_states()
334 | reverse_activation = self.reverse_lm(encoded_inputs[:, 2:], mask)
335 | if compute_logprobs:
336 | result_dict['reverse_logprobs'] = self._chunked_logsoftmaxes(reverse_activation,
337 | word_targets) * mask.float()
338 | else:
339 | result_dict['reverse_logits'] = self.decoder(reverse_activation)
340 | result_dict['reverse_loss'] = sequence_cross_entropy_with_logits(result_dict['reverse_logits'],
341 | word_targets,
342 | mask)
343 | return result_dict
344 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/swag/swag_generator.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=invalid-name,arguments-differ
2 | from typing import List, Iterable, Tuple
3 | import logging
4 |
5 |
6 | import torch
7 | from allennlp.common.util import JsonDict
8 | from allennlp.common.file_utils import cached_path
9 | from allennlp.data import Instance, Token, Vocabulary
10 | from allennlp.data.fields import TextField
11 | from allennlp.pretrained import PretrainedModel
12 |
13 | from adversarialnlp.common.file_utils import download_files, DATA_ROOT
14 | from adversarialnlp.generators import Generator
15 | from adversarialnlp.generators.swag.simple_bilm import SimpleBiLM
16 | from adversarialnlp.generators.swag.utils import optimistic_restore
17 | from adversarialnlp.generators.swag.activitynet_captions_reader import ActivityNetCaptionsDatasetReader
18 |
19 | BATCH_SIZE = 1
20 | BEAM_SIZE = 8 * BATCH_SIZE
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 | class SwagGenerator(Generator):
25 | """
26 | ``SwagGenerator`` inherit from the ``Generator`` class.
27 | This ``Generator`` generate adversarial examples from seeds using
28 | the method described in
29 | `SWAG: A Large-Scale Adversarial Dataset for Grounded Commonsense Inference `_.
30 |
31 | This method goes schematically as follows:
32 | - In a seed sample containing a pair of sequential sentence (ex: video captions),
33 | the second sentence is split into noun and verb phrases.
34 | - A language model generates several possible endings from the sencond sentence noun phrase.
35 |
36 | Args, input and yield:
37 | See the ``Generator`` class.
38 |
39 | Seeds:
40 | AllenNLP ``Instance`` containing two ``TextField``:
41 | `first_sentence` and `first_sentence`, respectively containing
42 | first and the second consecutive sentences.
43 |
44 | default_seeds:
45 | If no seeds are provided, the default_seeds are the training set
46 | of the
47 | `ActivityNet Captions dataset `_.
48 |
49 | """
50 | def __init__(self,
51 | default_seeds: Iterable = None,
52 | quiet: bool = False):
53 | super().__init__(default_seeds, quiet)
54 |
55 | lm_files = download_files(fnames=['vocabulary.zip',
56 | 'lm-fold-0.bin'],
57 | local_folder='swag_lm')
58 |
59 | activity_data_files = download_files(fnames=['captions.zip'],
60 | paths='https://cs.stanford.edu/people/ranjaykrishna/densevid/',
61 | local_folder='activitynet_captions')
62 |
63 | const_parser_files = cached_path('https://s3-us-west-2.amazonaws.com/allennlp/models/elmo-constituency-parser-2018.03.14.tar.gz',
64 | cache_dir=str(DATA_ROOT / 'allennlp_constituency_parser'))
65 |
66 | self.const_parser = PretrainedModel(const_parser_files, 'constituency-parser').predictor()
67 | vocab = Vocabulary.from_files(lm_files[0])
68 | self.language_model = SimpleBiLM(vocab=vocab, recurrent_dropout_probability=0.2,
69 | embedding_dropout_probability=0.2)
70 | optimistic_restore(self.language_model, torch.load(lm_files[1], map_location='cpu')['state_dict'])
71 |
72 | if default_seeds is None:
73 | self.default_seeds = ActivityNetCaptionsDatasetReader().read(activity_data_files[0] + '/train.json')
74 | else:
75 | self.default_seeds = default_seeds
76 |
77 | def _find_VP(self, tree: JsonDict) -> List[Tuple[str, any]]:
78 | r"""Recurse on a constituency parse tree until we find verb phrases"""
79 |
80 | # Recursion is annoying because we need to check whether each is a list or not
81 | def _recurse_on_children():
82 | assert 'children' in tree
83 | result = []
84 | for child in tree['children']:
85 | res = self._find_VP(child)
86 | if isinstance(res, tuple):
87 | result.append(res)
88 | else:
89 | result.extend(res)
90 | return result
91 |
92 | if 'VP' in tree['attributes']:
93 | # # Now we'll get greedy and see if we can find something better
94 | # if 'children' in tree and len(tree['children']) > 1:
95 | # recurse_result = _recurse_on_children()
96 | # if all([x[1] in ('VP', 'NP', 'CC') for x in recurse_result]):
97 | # return recurse_result
98 | return [(tree['word'], 'VP')]
99 | # base cases
100 | if 'NP' in tree['attributes']:
101 | return [(tree['word'], 'NP')]
102 | # No children
103 | if not 'children' in tree:
104 | return [(tree['word'], tree['attributes'][0])]
105 |
106 | # If a node only has 1 child then we'll have to stick with that
107 | if len(tree['children']) == 1:
108 | return _recurse_on_children()
109 | # try recursing on everything
110 | return _recurse_on_children()
111 |
112 | def _split_on_final_vp(self, sentence: Instance) -> (List[str], List[str]):
113 | r"""Splits a sentence on the final verb phrase """
114 | sentence_txt = ' '.join(t.text for t in sentence.tokens)
115 | res = self.const_parser.predict(sentence_txt)
116 | res_chunked = self._find_VP(res['hierplane_tree']['root'])
117 | is_vp: List[int] = [i for i, (word, pos) in enumerate(res_chunked) if pos == 'VP']
118 | if not is_vp:
119 | return None, None
120 | vp_ind = max(is_vp)
121 | not_vp = [token for x in res_chunked[:vp_ind] for token in x[0].split(' ')]
122 | is_vp = [token for x in res_chunked[vp_ind:] for token in x[0].split(' ')]
123 | return not_vp, is_vp
124 |
125 | def generate_from_seed(self, seed: Tuple):
126 | """Edit a seed example.
127 | """
128 | first_sentence: TextField = seed.fields["first_sentence"]
129 | second_sentence: TextField = seed.fields["second_sentence"]
130 | eos_bounds = [i + 1 for i, x in enumerate(first_sentence.tokens) if x.text in ('.', '?', '!')]
131 | if not eos_bounds:
132 | first_sentence = TextField(tokens=first_sentence.tokens + [Token(text='.')],
133 | token_indexers=first_sentence.token_indexers)
134 | context_len = len(first_sentence.tokens)
135 | if context_len < 6 or context_len > 100:
136 | print("skipping on {} (too short or long)".format(
137 | ' '.join(first_sentence.tokens + second_sentence.tokens)))
138 | return
139 | # Something I should have done:
140 | # make sure that there aren't multiple periods, etc. in s2 or in the middle
141 | eos_bounds_s2 = [i + 1 for i, x in enumerate(second_sentence.tokens) if x.text in ('.', '?', '!')]
142 | if len(eos_bounds_s2) > 1 or max(eos_bounds_s2) != len(second_sentence.tokens):
143 | return
144 | elif not eos_bounds_s2:
145 | second_sentence = TextField(tokens=second_sentence.tokens + [Token(text='.')],
146 | token_indexers=second_sentence.token_indexers)
147 |
148 | # Now split on the VP
149 | startphrase, endphrase = self._split_on_final_vp(second_sentence)
150 | if startphrase is None or not startphrase or len(endphrase) < 5 or len(endphrase) > 25:
151 | print("skipping on {}->{},{}".format(' '.join(first_sentence.tokens + second_sentence.tokens),
152 | startphrase, endphrase), flush=True)
153 | return
154 |
155 | # if endphrase contains unk then it's hopeless
156 | # if any(vocab.get_token_index(tok.lower()) == vocab.get_token_index(vocab._oov_token)
157 | # for tok in endphrase):
158 | # print("skipping on {} (unk!)".format(' '.join(s1_toks + s2_toks)))
159 | # return
160 |
161 | context = [token.text for token in first_sentence.tokens] + startphrase
162 |
163 | lm_out = self.language_model.conditional_generation(context, gt_completion=endphrase,
164 | batch_size=BEAM_SIZE,
165 | max_gen_length=25)
166 | gens0, fwd_scores, ctx_scores = lm_out
167 | if len(gens0) < BATCH_SIZE:
168 | print("Couldn't generate enough candidates so skipping")
169 | return
170 | gens0 = gens0[:BATCH_SIZE]
171 | yield gens0
172 | # fwd_scores = fwd_scores[:BATCH_SIZE]
173 |
174 | # # Now get the backward scores.
175 | # full_sents = [context + gen for gen in gens0] # NOTE: #1 is GT
176 | # result_dict = self.language_model(self.language_model.batch_to_ids(full_sents),
177 | # use_forward=False, use_reverse=True, compute_logprobs=True)
178 | # ending_lengths = (fwd_scores < 0).sum(1)
179 | # ending_lengths_float = ending_lengths.astype(np.float32)
180 | # rev_scores = result_dict['reverse_logprobs'].cpu().detach().numpy()
181 |
182 | # forward_logperp_ending = -fwd_scores.sum(1) / ending_lengths_float
183 | # reverse_logperp_ending = -rev_scores[:, context_len:].sum(1) / ending_lengths_float
184 | # forward_logperp_begin = -ctx_scores.mean()
185 | # reverse_logperp_begin = -rev_scores[:, :context_len].mean(1)
186 | # eos_logperp = -fwd_scores[np.arange(fwd_scores.shape[0]), ending_lengths - 1]
187 | # # print("Time elapsed {:.3f}".format(time() - tic), flush=True)
188 |
189 | # scores = np.exp(np.column_stack((
190 | # forward_logperp_ending,
191 | # reverse_logperp_ending,
192 | # reverse_logperp_begin,
193 | # eos_logperp,
194 | # np.ones(forward_logperp_ending.shape[0], dtype=np.float32) * forward_logperp_begin,
195 | # )))
196 |
197 | # PRINTOUT
198 | # low2high = scores[:, 2].argsort()
199 | # print("\n\n Dataset={} ctx: {} (perp={:.3f})\n~~~\n".format(item['dataset'], ' '.join(context),
200 | # np.exp(forward_logperp_begin)), flush=True)
201 | # for i, ind in enumerate(low2high.tolist()):
202 | # gen_i = ' '.join(gens0[ind])
203 | # if (ind == 0) or (i < 128):
204 | # print("{:3d}/{:4d}) ({}, end|ctx:{:5.1f} end:{:5.1f} ctx|end:{:5.1f} EOS|(ctx, end):{:5.1f}) {}".format(
205 | # i, len(gens0), 'GOLD' if ind == 0 else ' ', *scores[ind][:-1], gen_i), flush=True)
206 | # gt_score = low2high.argsort()[0]
207 |
208 | # item_full = deepcopy(item)
209 | # item_full['sent1'] = first_sentence
210 | # item_full['startphrase'] = startphrase
211 | # item_full['context'] = context
212 | # item_full['generations'] = gens0
213 | # item_full['postags'] = [ # parse real fast
214 | # [x.orth_.lower() if pos_vocab.get_token_index(x.orth_.lower()) != 1 else x.pos_ for x in y]
215 | # for y in spacy_model.pipe([startphrase + gen for gen in gens0], batch_size=BATCH_SIZE)]
216 | # item_full['scores'] = pd.DataFrame(data=scores, index=np.arange(scores.shape[0]),
217 | # columns=['end-from-ctx', 'end', 'ctx-from-end', 'eos-from-ctxend', 'ctx'])
218 |
219 | # generated_examples.append(gens0)
220 | # if len(generated_examples) > 0:
221 | # yield generated_examples
222 | # generated_examples = []
223 |
224 | # from adversarialnlp.common.file_utils import FIXTURES_ROOT
225 | # generator = SwagGenerator()
226 | # test_instances = ActivityNetCaptionsDatasetReader().read(FIXTURES_ROOT / 'activitynet_captions.json')
227 | # batches = list(generator(test_instances, num_epochs=1))
228 | # assert len(batches) != 0
229 |
--------------------------------------------------------------------------------
/adversarialnlp/generators/swag/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from itertools import tee
3 |
4 | from num2words import num2words
5 |
6 | def optimistic_restore(network, state_dict):
7 | mismatch = False
8 | own_state = network.state_dict()
9 | for name, param in state_dict.items():
10 | if name not in own_state:
11 | print("Unexpected key {} in state_dict with size {}".format(name, param.size()))
12 | mismatch = True
13 | elif param.size() == own_state[name].size():
14 | own_state[name].copy_(param)
15 | else:
16 | print("Network has {} with size {}, ckpt has {}".format(name,
17 | own_state[name].size(),
18 | param.size()))
19 | mismatch = True
20 |
21 | missing = set(own_state.keys()) - set(state_dict.keys())
22 | if len(missing) > 0:
23 | print("We couldn't find {}".format(','.join(missing)))
24 | mismatch = True
25 | return not mismatch
26 |
27 | def pairwise(iterable):
28 | "s -> (s0,s1), (s1,s2), (s2, s3), ..."
29 | a, b = tee(iterable)
30 | next(b, None)
31 | return zip(a, b)
32 |
33 | def n2w_1k(num, use_ordinal=False):
34 | if num > 1000:
35 | return ''
36 | return num2words(num, to='ordinal' if use_ordinal else 'cardinal')
37 |
38 | def postprocess(sentence):
39 | """
40 | make sure punctuation is followed by a space
41 | :param sentence:
42 | :return:
43 | """
44 | sentence = remove_allcaps(sentence)
45 | # Aggressively get rid of some punctuation markers
46 | sent0 = re.sub(r'^.*(\\|/|!!!|~|=|#|@|\*|¡|©|¿|«|»|¬|{|}|\||\(|\)|\+|\]|\[).*$',
47 | ' ', sentence, flags=re.MULTILINE|re.IGNORECASE)
48 |
49 | # Less aggressively get rid of quotes, apostrophes
50 | sent1 = re.sub(r'"', ' ', sent0)
51 | sent2 = re.sub(r'`', '\'', sent1)
52 |
53 | # match ordinals
54 | sent3 = re.sub(r'(\d+(?:rd|st|nd))',
55 | lambda x: n2w_1k(int(x.group(0)[:-2]), use_ordinal=True), sent2)
56 |
57 | #These things all need to be followed by spaces or else we'll run into problems
58 | sent4 = re.sub(r'[:;,\"\!\.\-\?](?! )', lambda x: x.group(0) + ' ', sent3)
59 |
60 | #These things all need to be preceded by spaces or else we'll run into problems
61 | sent5 = re.sub(r'(?! )[-]', lambda x: ' ' + x.group(0), sent4)
62 |
63 | # Several spaces
64 | sent6 = re.sub(r'\s\s+', ' ', sent5)
65 |
66 | sent7 = sent6.strip()
67 | return sent7
68 |
69 | def remove_allcaps(sent):
70 | """
71 | Given a sentence, filter it so that it doesn't contain some words that are ALLcaps
72 | :param sent: string, like SOMEONE wheels SOMEONE on, mouthing silent words of earnest prayer.
73 | :return: Someone wheels someone on, mouthing silent words of earnest prayer.
74 | """
75 | # Remove all caps
76 | def _sanitize(word, is_first):
77 | if word == "I":
78 | return word
79 | num_capitals = len([x for x in word if not x.islower()])
80 | if num_capitals > len(word) // 2:
81 | # We have an all caps word here.
82 | if is_first:
83 | return word[0] + word[1:].lower()
84 | return word.lower()
85 | return word
86 |
87 | return ' '.join([_sanitize(word, i == 0) for i, word in enumerate(sent.split(' '))])
88 |
--------------------------------------------------------------------------------
/adversarialnlp/pruners/__init__.py:
--------------------------------------------------------------------------------
1 | from .pruner import Pruner
--------------------------------------------------------------------------------
/adversarialnlp/pruners/pruner.py:
--------------------------------------------------------------------------------
1 | from allennlp.common import Registrable
2 |
3 | class Pruner(Registrable):
4 | """
5 | ``Pruner`` is used to fil potential adversarial samples
6 |
7 | Parameters
8 | ----------
9 | dataset_reader : ``DatasetReader``
10 | The ``DatasetReader`` object that will be used to sample training examples.
11 |
12 | """
13 | def __init__(self, ) -> None:
14 | super().__init__()
15 |
--------------------------------------------------------------------------------
/adversarialnlp/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import logging
3 | import os
4 | import sys
5 |
6 | if os.environ.get("ALLENNLP_DEBUG"):
7 | LEVEL = logging.DEBUG
8 | else:
9 | LEVEL = logging.INFO
10 |
11 | sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))))
12 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
13 | level=LEVEL)
14 |
15 | from adversarialnlp.commands import main # pylint: disable=wrong-import-position
16 |
17 | if __name__ == "__main__":
18 | main(prog="adversarialnlp")
19 |
--------------------------------------------------------------------------------
/adversarialnlp/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/adversarialnlp/543c02111c57bf245f2aa145c0e5a4879d151001/adversarialnlp/tests/__init__.py
--------------------------------------------------------------------------------
/adversarialnlp/tests/dataset_readers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/adversarialnlp/543c02111c57bf245f2aa145c0e5a4879d151001/adversarialnlp/tests/dataset_readers/__init__.py
--------------------------------------------------------------------------------
/adversarialnlp/tests/dataset_readers/activitynet_captions_test.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=no-self-use,invalid-name
2 | import pytest
3 |
4 | from allennlp.common.util import ensure_list
5 |
6 | from adversarialnlp.dataset_readers import ActivityNetCaptionsDatasetReader
7 | from adversarialnlp.tests.utils import FIXTURES_ROOT
8 |
9 | class TestActivityNetCaptionsReader():
10 | @pytest.mark.parametrize("lazy", (True, False))
11 | def test_read_from_file(self, lazy):
12 | reader = ActivityNetCaptionsDatasetReader(lazy=lazy)
13 | instances = reader.read(FIXTURES_ROOT / 'activitynet_captions.json')
14 | instances = ensure_list(instances)
15 |
16 | instance1 = {"video_id": "v_uqiMw7tQ1Cc",
17 | "first_sentence": "A weight lifting tutorial is given .".split(),
18 | "second_sentence": "The coach helps the guy in red with the proper body placement and lifting technique .".split()}
19 |
20 | instance2 = {"video_id": "v_bXdq2zI1Ms0",
21 | "first_sentence": "A man is seen speaking to the camera and pans out into more men standing behind him .".split(),
22 | "second_sentence": "The first man then begins performing martial arts moves while speaking to he camera .".split()}
23 |
24 | instance3 = {"video_id": "v_bXdq2zI1Ms0",
25 | "first_sentence": "The first man then begins performing martial arts moves while speaking to he camera .".split(),
26 | "second_sentence": "He continues moving around and looking to the camera .".split()}
27 |
28 | assert len(instances) == 3
29 |
30 | for instance, expected_instance in zip(instances, [instance1, instance2, instance3]):
31 | fields = instance.fields
32 | assert [t.text for t in fields["first_sentence"].tokens] == expected_instance["first_sentence"]
33 | assert [t.text for t in fields["second_sentence"].tokens] == expected_instance["second_sentence"]
34 | assert fields["video_id"].metadata == expected_instance["video_id"]
35 |
--------------------------------------------------------------------------------
/adversarialnlp/tests/fixtures/activitynet_captions.json:
--------------------------------------------------------------------------------
1 | {"v_uqiMw7tQ1Cc": {"duration": 55.15, "timestamps": [[0.28, 55.15], [13.79, 54.32]], "sentences": ["A weight lifting tutorial is given.", " The coach helps the guy in red with the proper body placement and lifting technique."]}, "v_bXdq2zI1Ms0": {"duration": 73.1, "timestamps": [[0, 10.23], [10.6, 39.84], [38.01, 73.1]], "sentences": ["A man is seen speaking to the camera and pans out into more men standing behind him.", " The first man then begins performing martial arts moves while speaking to he camera.", " He continues moving around and looking to the camera."]}}
--------------------------------------------------------------------------------
/adversarialnlp/tests/fixtures/squad.json:
--------------------------------------------------------------------------------
1 | {
2 | "data": [{
3 | "title": "Super_Bowl_50",
4 | "paragraphs": [{
5 | "context": "Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50.",
6 | "qas": [{
7 | "answers": [{
8 | "answer_start": 177,
9 | "text": "Denver Broncos"
10 | }, {
11 | "answer_start": 177,
12 | "text": "Denver Broncos"
13 | }, {
14 | "answer_start": 177,
15 | "text": "Denver Broncos"
16 | }],
17 | "question": "Which NFL team represented the AFC at Super Bowl 50?",
18 | "id": "56be4db0acb8001400a502ec"
19 | }, {
20 | "answers": [{
21 | "answer_start": 177,
22 | "text": "Denver Broncos"
23 | }, {
24 | "answer_start": 177,
25 | "text": "Denver Broncos"
26 | }, {
27 | "answer_start": 177,
28 | "text": "Denver Broncos"
29 | }],
30 | "question": "Which NFL team won Super Bowl 50?",
31 | "id": "56be4db0acb8001400a502ef"
32 | }]
33 | }, {
34 | "context": "The Panthers finished the regular season with a 15\u20131 record, and quarterback Cam Newton was named the NFL Most Valuable Player (MVP). They defeated the Arizona Cardinals 49\u201315 in the NFC Championship Game and advanced to their second Super Bowl appearance since the franchise was founded in 1995. The Broncos finished the regular season with a 12\u20134 record, and denied the New England Patriots a chance to defend their title from Super Bowl XLIX by defeating them 20\u201318 in the AFC Championship Game. They joined the Patriots, Dallas Cowboys, and Pittsburgh Steelers as one of four teams that have made eight appearances in the Super Bowl.",
35 | "qas": [{
36 | "answers": [{
37 | "answer_start": 77,
38 | "text": "Cam Newton"
39 | }, {
40 | "answer_start": 77,
41 | "text": "Cam Newton"
42 | }, {
43 | "answer_start": 77,
44 | "text": "Cam Newton"
45 | }],
46 | "question": "Which Carolina Panthers player was named Most Valuable Player?",
47 | "id": "56be4e1facb8001400a502f6"
48 | }, {
49 | "answers": [{
50 | "answer_start": 467,
51 | "text": "8"
52 | }, {
53 | "answer_start": 601,
54 | "text": "eight"
55 | }, {
56 | "answer_start": 601,
57 | "text": "eight"
58 | }],
59 | "question": "How many appearances have the Denver Broncos made in the Super Bowl?",
60 | "id": "56be4e1facb8001400a502f9"
61 | }]
62 | }]
63 | },{
64 | "title": "Warsaw",
65 | "paragraphs": [{
66 | "context": "One of the most famous people born in Warsaw was Maria Sk\u0142odowska-Curie, who achieved international recognition for her research on radioactivity and was the first female recipient of the Nobel Prize. Famous musicians include W\u0142adys\u0142aw Szpilman and Fr\u00e9d\u00e9ric Chopin. Though Chopin was born in the village of \u017belazowa Wola, about 60 km (37 mi) from Warsaw, he moved to the city with his family when he was seven months old. Casimir Pulaski, a Polish general and hero of the American Revolutionary War, was born here in 1745.",
67 | "qas": [{
68 | "answers": [{
69 | "answer_start": 188,
70 | "text": "Nobel Prize"
71 | }, {
72 | "answer_start": 188,
73 | "text": "Nobel Prize"
74 | }, {
75 | "answer_start": 188,
76 | "text": "Nobel Prize"
77 | }],
78 | "question": "What was Maria Curie the first female recipient of?",
79 | "id": "5733a5f54776f41900660f45"
80 | }, {
81 | "answers": [{
82 | "answer_start": 517,
83 | "text": "1745"
84 | }, {
85 | "answer_start": 517,
86 | "text": "1745"
87 | }, {
88 | "answer_start": 517,
89 | "text": "1745"
90 | }],
91 | "question": "What year was Casimir Pulaski born in Warsaw?",
92 | "id": "5733a5f54776f41900660f48"
93 | }]
94 | }, {
95 | "context": "The Saxon Garden, covering the area of 15.5 ha, was formally a royal garden. There are over 100 different species of trees and the avenues are a place to sit and relax. At the east end of the park, the Tomb of the Unknown Soldier is situated. In the 19th century the Krasi\u0144ski Palace Garden was remodelled by Franciszek Szanior. Within the central area of the park one can still find old trees dating from that period: maidenhair tree, black walnut, Turkish hazel and Caucasian wingnut trees. With its benches, flower carpets, a pond with ducks on and a playground for kids, the Krasi\u0144ski Palace Garden is a popular strolling destination for the Varsovians. The Monument of the Warsaw Ghetto Uprising is also situated here. The \u0141azienki Park covers the area of 76 ha. The unique character and history of the park is reflected in its landscape architecture (pavilions, sculptures, bridges, cascades, ponds) and vegetation (domestic and foreign species of trees and bushes). What makes this park different from other green spaces in Warsaw is the presence of peacocks and pheasants, which can be seen here walking around freely, and royal carps in the pond. The Wilan\u00f3w Palace Park, dates back to the second half of the 17th century. It covers the area of 43 ha. Its central French-styled area corresponds to the ancient, baroque forms of the palace. The eastern section of the park, closest to the Palace, is the two-level garden with a terrace facing the pond. The park around the Kr\u00f3likarnia Palace is situated on the old escarpment of the Vistula. The park has lanes running on a few levels deep into the ravines on both sides of the palace.",
96 | "qas": [{
97 | "answers": [{
98 | "answer_start": 92,
99 | "text": "100"
100 | }, {
101 | "answer_start": 87,
102 | "text": "over 100"
103 | }, {
104 | "answer_start": 92,
105 | "text": "100"
106 | }],
107 | "question": "Over how many species of trees can be found in the Saxon Garden?",
108 | "id": "57336755d058e614000b5a3d"
109 | }]
110 | }]
111 | }],
112 | "version": "1.1"
113 | }
--------------------------------------------------------------------------------
/adversarialnlp/tests/generators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/adversarialnlp/543c02111c57bf245f2aa145c0e5a4879d151001/adversarialnlp/tests/generators/__init__.py
--------------------------------------------------------------------------------
/adversarialnlp/tests/generators/addsent_generator_test.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=no-self-use,invalid-name
2 | from typing import List
3 | import pytest
4 |
5 | from adversarialnlp.generators.addsent.addsent_generator import AddSentGenerator
6 | from adversarialnlp.generators.addsent.squad_reader import squad_reader
7 | from adversarialnlp.common.file_utils import FIXTURES_ROOT
8 |
9 |
10 | # class GeneratorTest(AllenNlpTestCase):
11 | # def setUp(self):
12 | # super(GeneratorTest, self).setUp()
13 | # self.token_indexers = {"tokens": SingleIdTokenIndexer()}
14 | # self.vocab = Vocabulary()
15 | # self.this_index = self.vocab.add_token_to_namespace('this')
16 | # self.is_index = self.vocab.add_token_to_namespace('is')
17 | # self.a_index = self.vocab.add_token_to_namespace('a')
18 | # self.sentence_index = self.vocab.add_token_to_namespace('sentence')
19 | # self.another_index = self.vocab.add_token_to_namespace('another')
20 | # self.yet_index = self.vocab.add_token_to_namespace('yet')
21 | # self.very_index = self.vocab.add_token_to_namespace('very')
22 | # self.long_index = self.vocab.add_token_to_namespace('long')
23 | # instances = [
24 | # self.create_instance(["this", "is", "a", "sentence"], ["this", "is", "another", "sentence"]),
25 | # self.create_instance(["yet", "another", "sentence"],
26 | # ["this", "is", "a", "very", "very", "very", "very", "long", "sentence"]),
27 | # ]
28 |
29 | # class LazyIterable:
30 | # def __iter__(self):
31 | # return (instance for instance in instances)
32 |
33 | # self.instances = instances
34 | # self.lazy_instances = LazyIterable()
35 |
36 | # def create_instance(self, first_sentence: List[str], second_sentence: List[str]):
37 | # first_tokens = [Token(t) for t in first_sentence]
38 | # second_tokens = [Token(t) for t in second_sentence]
39 | # instance = Instance({'first_sentence': TextField(first_tokens, self.token_indexers),
40 | # 'second_sentence': TextField(second_tokens, self.token_indexers)})
41 | # return instance
42 |
43 | # def assert_instances_are_correct(self, candidate_instances):
44 | # # First we need to remove padding tokens from the candidates.
45 | # # pylint: disable=protected-access
46 | # candidate_instances = [tuple(w for w in instance if w != 0) for instance in candidate_instances]
47 | # expected_instances = [tuple(instance.fields["first_sentence"]._indexed_tokens["tokens"])
48 | # for instance in self.instances]
49 | # assert set(candidate_instances) == set(expected_instances)
50 |
51 |
52 | class TestSwagGenerator():
53 | # The Generator should work the same for lazy and non lazy datasets,
54 | # so each remaining test runs over both.
55 | def test_yield_one_epoch_generation_over_the_data_once(self):
56 | generator = AddSentGenerator()
57 | test_instances = squad_reader(FIXTURES_ROOT / 'squad.json')
58 | batches = list(generator(test_instances, num_epochs=1))
59 | # We just want to get the single-token array for the text field in the instance.
60 | # instances = [tuple(instance.detach().cpu().numpy())
61 | # for batch in batches
62 | # for instance in batch['text']["tokens"]]
63 | assert len(batches) == 5
64 | # self.assert_instances_are_correct(instances)
65 |
--------------------------------------------------------------------------------
/adversarialnlp/tests/generators/swag_generator_test.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=no-self-use,invalid-name
2 | from typing import List
3 | import pytest
4 |
5 | from allennlp.data.fields import TextField
6 | from allennlp.common.util import ensure_list
7 | from allennlp.common.testing import AllenNlpTestCase
8 | from allennlp.data import Instance, Token, Vocabulary
9 | from allennlp.data.iterators import BasicIterator
10 | from allennlp.data.token_indexers import SingleIdTokenIndexer
11 | from allennlp.data.dataset_readers.dataset_reader import _LazyInstances
12 |
13 | from adversarialnlp.generators.swag.swag_generator import SwagGenerator
14 | from adversarialnlp.generators.swag.activitynet_captions import ActivityNetCaptionsDatasetReader
15 | from adversarialnlp.tests.utils import FIXTURES_ROOT
16 |
17 |
18 | class GeneratorTest(AllenNlpTestCase):
19 | def setUp(self):
20 | super(GeneratorTest, self).setUp()
21 | self.token_indexers = {"tokens": SingleIdTokenIndexer()}
22 | self.vocab = Vocabulary()
23 | self.this_index = self.vocab.add_token_to_namespace('this')
24 | self.is_index = self.vocab.add_token_to_namespace('is')
25 | self.a_index = self.vocab.add_token_to_namespace('a')
26 | self.sentence_index = self.vocab.add_token_to_namespace('sentence')
27 | self.another_index = self.vocab.add_token_to_namespace('another')
28 | self.yet_index = self.vocab.add_token_to_namespace('yet')
29 | self.very_index = self.vocab.add_token_to_namespace('very')
30 | self.long_index = self.vocab.add_token_to_namespace('long')
31 | instances = [
32 | self.create_instance(["this", "is", "a", "sentence"], ["this", "is", "another", "sentence"]),
33 | self.create_instance(["yet", "another", "sentence"],
34 | ["this", "is", "a", "very", "very", "very", "very", "long", "sentence"]),
35 | ]
36 |
37 | class LazyIterable:
38 | def __iter__(self):
39 | return (instance for instance in instances)
40 |
41 | self.instances = instances
42 | self.lazy_instances = LazyIterable()
43 |
44 | def create_instance(self, first_sentence: List[str], second_sentence: List[str]):
45 | first_tokens = [Token(t) for t in first_sentence]
46 | second_tokens = [Token(t) for t in second_sentence]
47 | instance = Instance({'first_sentence': TextField(first_tokens, self.token_indexers),
48 | 'second_sentence': TextField(second_tokens, self.token_indexers)})
49 | return instance
50 |
51 | def assert_instances_are_correct(self, candidate_instances):
52 | # First we need to remove padding tokens from the candidates.
53 | # pylint: disable=protected-access
54 | candidate_instances = [tuple(w for w in instance if w != 0) for instance in candidate_instances]
55 | expected_instances = [tuple(instance.fields["first_sentence"]._indexed_tokens["tokens"])
56 | for instance in self.instances]
57 | assert set(candidate_instances) == set(expected_instances)
58 |
59 |
60 | class TestSwagGenerator(GeneratorTest):
61 | # The Generator should work the same for lazy and non lazy datasets,
62 | # so each remaining test runs over both.
63 | def test_yield_one_epoch_generation_over_the_data_once(self):
64 | for test_instances in (self.instances, self.lazy_instances):
65 | generator = SwagGenerator(num_examples=1)
66 | test_instances = ActivityNetCaptionsDatasetReader().read(FIXTURES_ROOT / 'activitynet_captions.json')
67 | batches = list(generator(test_instances))
68 | # We just want to get the single-token array for the text field in the instance.
69 | instances = [tuple(instance.detach().cpu().numpy())
70 | for batch in batches
71 | for instance in batch['text']["tokens"]]
72 | assert len(instances) == 5
73 | self.assert_instances_are_correct(instances)
74 |
--------------------------------------------------------------------------------
/adversarialnlp/version.py:
--------------------------------------------------------------------------------
1 | _MAJOR = "0"
2 | _MINOR = "1"
3 | _REVISION = "1-unreleased"
4 |
5 | VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
6 | VERSION = "{0}.{1}.{2}".format(_MAJOR, _MINOR, _REVISION)
7 |
--------------------------------------------------------------------------------
/bin/adversarialnlp:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | python -m adversarialnlp.run "$@"
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = AdversarialNLP
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/common.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | Common
5 | ======
6 |
7 | .. automodule:: adversarialnlp.common
8 | .. currentmodule:: adversarialnlp.common
9 |
10 |
11 | Files
12 | -----
13 |
14 | .. autofunction:: adversarialnlp.common.file_utils.download_files
15 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | #
4 | # AdversarialNLP documentation build configuration file, created by
5 | # sphinx-quickstart on Wed Oct 24 11:35:14 2018.
6 | #
7 | # This file is execfile()d with the current directory set to its
8 | # containing dir.
9 | #
10 | # Note that not all possible configuration values are present in this
11 | # autogenerated file.
12 | #
13 | # All configuration values have a default; values that are commented out
14 | # serve to show the default.
15 |
16 | # If extensions (or modules to document with autodoc) are in another directory,
17 | # add these directories to sys.path here. If the directory is relative to the
18 | # documentation root, use os.path.abspath to make it absolute, like shown here.
19 | #
20 | # import os
21 | # import sys
22 | # sys.path.insert(0, os.path.abspath('.'))
23 |
24 |
25 | # -- General configuration ------------------------------------------------
26 |
27 | # If your documentation needs a minimal Sphinx version, state it here.
28 | #
29 | # needs_sphinx = '1.0'
30 |
31 | # Add any Sphinx extension module names here, as strings. They can be
32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
33 | # ones.
34 | extensions = ['sphinx.ext.autodoc',
35 | 'sphinx.ext.intersphinx',
36 | 'sphinx.ext.mathjax',
37 | 'sphinx.ext.ifconfig',
38 | 'sphinx.ext.viewcode',
39 | 'sphinx.ext.napoleon']
40 |
41 | # Add any paths that contain templates here, relative to this directory.
42 | templates_path = ['_templates']
43 |
44 | # The suffix(es) of source filenames.
45 | # You can specify multiple suffix as a list of string:
46 | #
47 | # source_suffix = ['.rst', '.md']
48 | source_suffix = '.rst'
49 |
50 | # The master toctree document.
51 | master_doc = 'index'
52 |
53 | # General information about the project.
54 | project = 'AdversarialNLP'
55 | copyright = '2018, Mohit Iyyer, Pasquale Minervini, Victor Sanh, Thomas Wolf, Rowan Zellers'
56 | author = 'Mohit Iyyer, Pasquale Minervini, Victor Sanh, Thomas Wolf, Rowan Zellers'
57 |
58 | # The version info for the project you're documenting, acts as replacement for
59 | # |version| and |release|, also used in various other places throughout the
60 | # built documents.
61 | #
62 | # The short X.Y version.
63 | version = '0.1'
64 | # The full version, including alpha/beta/rc tags.
65 | release = '0.1'
66 |
67 | # The language for content autogenerated by Sphinx. Refer to documentation
68 | # for a list of supported languages.
69 | #
70 | # This is also used if you do content translation via gettext catalogs.
71 | # Usually you set "language" from the command line for these cases.
72 | language = 'en'
73 |
74 | # List of patterns, relative to source directory, that match files and
75 | # directories to ignore when looking for source files.
76 | # This patterns also effect to html_static_path and html_extra_path
77 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
78 |
79 | # The name of the Pygments (syntax highlighting) style to use.
80 | pygments_style = 'sphinx'
81 |
82 | # If true, `todo` and `todoList` produce output, else they produce nothing.
83 | todo_include_todos = False
84 |
85 |
86 | # -- Options for HTML output ----------------------------------------------
87 |
88 | # The theme to use for HTML and HTML Help pages. See the documentation for
89 | # a list of builtin themes.
90 | #
91 | html_theme = 'alabaster'
92 |
93 | # Theme options are theme-specific and customize the look and feel of a theme
94 | # further. For a list of options available for each theme, see the
95 | # documentation.
96 | #
97 | # html_theme_options = {}
98 |
99 | # Add any paths that contain custom static files (such as style sheets) here,
100 | # relative to this directory. They are copied after the builtin static files,
101 | # so a file named "default.css" will overwrite the builtin "default.css".
102 | html_static_path = ['_static']
103 |
104 |
105 | # -- Options for HTMLHelp output ------------------------------------------
106 |
107 | # Output file base name for HTML help builder.
108 | htmlhelp_basename = 'AdversarialNLPdoc'
109 |
110 |
111 | # -- Options for LaTeX output ---------------------------------------------
112 |
113 | latex_elements = {
114 | # The paper size ('letterpaper' or 'a4paper').
115 | #
116 | # 'papersize': 'letterpaper',
117 |
118 | # The font size ('10pt', '11pt' or '12pt').
119 | #
120 | # 'pointsize': '10pt',
121 |
122 | # Additional stuff for the LaTeX preamble.
123 | #
124 | # 'preamble': '',
125 |
126 | # Latex figure (float) alignment
127 | #
128 | # 'figure_align': 'htbp',
129 | }
130 |
131 | # Grouping the document tree into LaTeX files. List of tuples
132 | # (source start file, target name, title,
133 | # author, documentclass [howto, manual, or own class]).
134 | latex_documents = [
135 | (master_doc, 'AdversarialNLP.tex', 'AdversarialNLP Documentation',
136 | 'Mohit Iyyer, Pasquale Minervini, Victor Sanh, Thomas Wolf, Rowan Zellers', 'manual'),
137 | ]
138 |
139 |
140 | # -- Options for manual page output ---------------------------------------
141 |
142 | # One entry per manual page. List of tuples
143 | # (source start file, name, description, authors, manual section).
144 | man_pages = [
145 | (master_doc, 'adversarialnlp', 'AdversarialNLP Documentation',
146 | [author], 1)
147 | ]
148 |
149 |
150 | # -- Options for Texinfo output -------------------------------------------
151 |
152 | # Grouping the document tree into Texinfo files. List of tuples
153 | # (source start file, target name, title, author,
154 | # dir menu entry, description, category)
155 | texinfo_documents = [
156 | (master_doc, 'AdversarialNLP', 'AdversarialNLP Documentation',
157 | author, 'AdversarialNLP', 'One line description of project.',
158 | 'Miscellaneous'),
159 | ]
160 |
161 |
162 |
163 |
164 | # Example configuration for intersphinx: refer to the Python standard library.
165 | intersphinx_mapping = {'https://docs.python.org/': None}
166 |
--------------------------------------------------------------------------------
/docs/generators.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | Generators
5 | ==========
6 |
7 | .. automodule:: adversarialnlp.generators
8 | .. currentmodule:: adversarialnlp.generators
9 |
10 | Generator
11 | ----------
12 |
13 | .. autoclass:: adversarialnlp.generators.Generator
14 |
15 | AddSent
16 | ----------
17 |
18 | .. autoclass:: adversarialnlp.generators.addsent.AddSentGenerator
19 | .. autofunction:: adversarialnlp.generators.addsent.squad_reader
20 |
21 | SWAG
22 | ----
23 |
24 | .. autoclass:: adversarialnlp.generators.swag.SwagGenerator
25 | .. autoclass:: adversarialnlp.generators.swag.ActivityNetCaptionsDatasetReader
26 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. AdversarialNLP documentation master file, created by
2 | sphinx-quickstart on Wed Oct 24 11:35:14 2018.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | :github_url: https://github.com/pytorch/pytorch
7 |
8 | AdversarialNLP documentation
9 | ============================
10 |
11 | AdversarialNLP is a generic library for crafting and using Adversarial NLP examples.
12 |
13 | .. toctree::
14 | :maxdepth: 1
15 | :caption: Contents
16 |
17 | readme
18 |
19 | common
20 | generators
21 |
22 |
23 | Indices and tables
24 | ==================
25 |
26 | * :ref:`genindex`
27 | * :ref:`modindex`
28 | * :ref:`search`
29 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=AdversarialNLP
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
20 | echo.installed, then set the SPHINXBUILD environment variable to point
21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
22 | echo.may add the Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/readme.rst:
--------------------------------------------------------------------------------
1 | .. include:: ../README.md
--------------------------------------------------------------------------------
/docs/readthedoc_requirements.txt:
--------------------------------------------------------------------------------
1 | requests
2 | typing
3 | pytest
4 | PyYAML==3.13
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | testpaths = tests/
3 | python_paths = ./
4 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml
2 |
3 | build:
4 | image: latest
5 |
6 | python:
7 | version: 3.6
8 | setup_py_install: true
9 |
10 | requirements_file: docs/readthedoc_requirements.txt
11 |
12 | formats: []
13 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Library dependencies for the python code. You need to install these with
2 | # `pip install -r requirements.txt` before you can run this.
3 | # NOTE: all essential packages must be placed under a section named 'ESSENTIAL ...'
4 | # so that the script `./scripts/check_requirements_and_setup.py` can find them.
5 |
6 | #### ESSENTIAL LIBRARIES FOR MAIN FUNCTIONALITY ####
7 |
8 | # This installs Pytorch for CUDA 8 only. If you are using a newer version,
9 | # please visit http://pytorch.org/ and install the relevant version.
10 | torch>=0.4.1,<0.5.0
11 |
12 | # Parameter parsing (but not on Windows).
13 | jsonnet==0.10.0 ; sys.platform != 'win32'
14 |
15 | # Adds an @overrides decorator for better documentation and error checking when using subclasses.
16 | overrides
17 |
18 | # Used by some old code. We moved away from it because it's too slow, but some old code still
19 | # imports this.
20 | nltk
21 |
22 | # Mainly used for the faster tokenizer.
23 | spacy>=2.0,<2.1
24 |
25 | # Used by span prediction models.
26 | numpy
27 |
28 | # Used for reading configuration info out of numpy-style docstrings.
29 | numpydoc==0.8.0
30 |
31 | # Used in coreference resolution evaluation metrics.
32 | scipy
33 | scikit-learn
34 |
35 | # Write logs for training visualisation with the Tensorboard application
36 | # Install the Tensorboard application separately (part of tensorflow) to view them.
37 | tensorboardX==1.2
38 |
39 | # Required by torch.utils.ffi
40 | cffi==1.11.2
41 |
42 | # aws commandline tools for running on Docker remotely.
43 | # second requirement is to get botocore < 1.11, to avoid the below bug
44 | awscli>=1.11.91
45 |
46 | # Accessing files from S3 directly.
47 | boto3
48 |
49 | # REST interface for models
50 | flask==0.12.4
51 | flask-cors==3.0.3
52 | gevent==1.3.6
53 |
54 | # Used by semantic parsing code to strip diacritics from unicode strings.
55 | unidecode
56 |
57 | # Used by semantic parsing code to parse SQL
58 | parsimonious==0.8.0
59 |
60 | # Used by semantic parsing code to format and postprocess SQL
61 | sqlparse==0.2.4
62 |
63 | # For text normalization
64 | ftfy
65 |
66 | #### ESSENTIAL LIBRARIES USED IN SCRIPTS ####
67 |
68 | # Plot graphs for learning rate finder
69 | matplotlib==2.2.3
70 |
71 | # Used for downloading datasets over HTTP
72 | requests>=2.18
73 |
74 | # progress bars in data cleaning scripts
75 | tqdm>=4.19
76 |
77 | # In SQuAD eval script, we use this to see if we likely have some tokenization problem.
78 | editdistance
79 |
80 | # For pretrained model weights
81 | h5py
82 |
83 | # For timezone utilities
84 | pytz==2017.3
85 |
86 | # Reads Universal Dependencies files.
87 | conllu==0.11
88 |
89 | #### ESSENTIAL TESTING-RELATED PACKAGES ####
90 |
91 | # We'll use pytest to run our tests; this isn't really necessary to run the code, but it is to run
92 | # the tests. With this here, you can run the tests with `py.test` from the base directory.
93 | pytest
94 |
95 | # Allows marking tests as flaky, to be rerun if they fail
96 | flaky
97 |
98 | # Required to mock out `requests` calls
99 | responses>=0.7
100 |
101 | # For mocking s3.
102 | moto==1.3.4
103 |
104 | #### TESTING-RELATED PACKAGES ####
105 |
106 | # Checks style, syntax, and other useful errors.
107 | pylint==1.8.1
108 |
109 | # Tutorial notebooks
110 | # see: https://github.com/jupyter/jupyter/issues/370 for ipykernel
111 | ipykernel<5.0.0
112 | jupyter
113 |
114 | # Static type checking
115 | mypy==0.521
116 |
117 | # Allows generation of coverage reports with pytest.
118 | pytest-cov
119 |
120 | # Allows codecov to generate coverage reports
121 | coverage
122 | codecov
123 |
124 | # Required to run sanic tests
125 | aiohttp
126 |
127 | #### DOC-RELATED PACKAGES ####
128 |
129 | # Builds our documentation.
130 | sphinx==1.5.3
131 |
132 | # Watches the documentation directory and rebuilds on changes.
133 | sphinx-autobuild
134 |
135 | # doc theme
136 | sphinx_rtd_theme
137 |
138 | # Only used to convert our readme to reStructuredText on Pypi.
139 | pypandoc
140 |
141 | # Pypi uploads
142 | twine==1.11.0
143 |
144 | #### GENERATOR-RELATED PACKAGES ####
145 |
146 | # Used by AddSent.
147 | psutil
148 | pattern
149 |
150 | # Used by SWAG.
151 | allennlp
152 | num2words
153 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [aliases]
2 | test=pytest
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from setuptools import setup, find_packages
3 |
4 | # PEP0440 compatible formatted version, see:
5 | # https://www.python.org/dev/peps/pep-0440/
6 | #
7 | # release markers:
8 | # X.Y
9 | # X.Y.Z # For bugfix releases
10 | #
11 | # pre-release markers:
12 | # X.YaN # Alpha release
13 | # X.YbN # Beta release
14 | # X.YrcN # Release Candidate
15 | # X.Y # Final release
16 |
17 | # version.py defines the VERSION and VERSION_SHORT variables.
18 | # We use exec here so we don't import allennlp whilst setting up.
19 | VERSION = {}
20 | with open("adversarialnlp/version.py", "r") as version_file:
21 | exec(version_file.read(), VERSION)
22 |
23 | # make pytest-runner a conditional requirement,
24 | # per: https://github.com/pytest-dev/pytest-runner#considerations
25 | needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv)
26 | pytest_runner = ['pytest-runner'] if needs_pytest else []
27 |
28 | with open('requirements.txt', 'r') as f:
29 | install_requires = [l for l in f.readlines() if not l.startswith('# ')]
30 |
31 | setup_requirements = [
32 | # add other setup requirements as necessary
33 | ] + pytest_runner
34 |
35 | setup(name='adversarialnlp',
36 | version=VERSION["VERSION"],
37 | description='A generice library for crafting adversarial NLP examples, built on AllenNLP and PyTorch.',
38 | long_description=open("README.md").read(),
39 | long_description_content_type="text/markdown",
40 | classifiers=[
41 | 'Intended Audience :: Science/Research',
42 | 'Development Status :: 3 - Alpha',
43 | 'License :: OSI Approved :: Apache Software License',
44 | 'Programming Language :: Python :: 3.6',
45 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
46 | ],
47 | keywords='adversarialnlp allennlp NLP deep learning machine reading',
48 | url='https://github.com/huggingface/adversarialnlp',
49 | author='Thomas WOLF',
50 | author_email='thomas@huggingface.co',
51 | license='Apache',
52 | packages=find_packages(exclude=["*.tests", "*.tests.*",
53 | "tests.*", "tests"]),
54 | install_requires=install_requires,
55 | scripts=["bin/adversarialnlp"],
56 | setup_requires=setup_requirements,
57 | tests_require=[
58 | 'pytest',
59 | ],
60 | include_package_data=True,
61 | python_requires='>=3.6.1',
62 | zip_safe=False)
63 |
--------------------------------------------------------------------------------
/tutorials/usage.py:
--------------------------------------------------------------------------------
1 | from adversarialnlp import Adversarial
2 | from allennlp.data.dataset_readers.reading_comprehension.squad import SquadReader
3 |
4 | adversarial = Adversarial(dataset_reader=SquadReader, editor='lstm_lm', num_samples=10)
5 | examples = adversarial.generate()
--------------------------------------------------------------------------------