├── .editorconfig ├── .github └── ISSUE_TEMPLATE.md ├── .gitignore ├── .pylintrc ├── .travis.yml ├── AUTHORS.rst ├── CONTRIBUTING.rst ├── HISTORY.rst ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.rst ├── docs ├── Makefile ├── authors.rst ├── conf.py ├── contributing.rst ├── history.rst ├── index.rst ├── installation.rst ├── make.bat ├── readme.rst └── usage.rst ├── extension ├── .editorconfig ├── bind.cc ├── elmo_character_encoder.cc ├── elmo_character_encoder.h ├── scalar_mix.cc └── scalar_mix.h ├── pytorch_fast_elmo ├── __init__.py ├── factory.py ├── model.py ├── tool │ ├── __init__.py │ ├── cli.py │ ├── inspect.py │ └── profile.py └── utils.py ├── requirements_dev.txt ├── requirements_prod.txt ├── setup.cfg ├── setup.py ├── tests ├── fixtures │ ├── lm_embd.txt │ ├── lm_weights.hdf5 │ ├── options.json │ └── vocab.txt ├── test_elmo.py ├── test_scalar_mix.py └── test_utils.py └── tox.ini /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | * pytorch-fast-elmo version: 2 | * Python version: 3 | * Operating System: 4 | 5 | ### Description 6 | 7 | Describe what you were trying to get done. 8 | Tell us what happened, what went wrong, and what you expected to happen. 9 | 10 | ### What I Did 11 | 12 | ``` 13 | Paste the command(s) you ran and the output. 14 | If there was a crash, please include the traceback here. 15 | ``` 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | .vscode/ 105 | data/ 106 | .*DS_Store 107 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # Specify a configuration file. 4 | #rcfile= 5 | 6 | # Python code to execute, usually for sys.path manipulation such as 7 | # pygtk.require(). 8 | init-hook='import sys; sys.path.append("./")' 9 | 10 | # Add files or directories to the blacklist. They should be base names, not 11 | # paths. 12 | ignore=CVS,custom_extensions 13 | 14 | # Add files or directories matching the regex patterns to the blacklist. The 15 | # regex matches against base names, not paths. 16 | ignore-patterns= 17 | 18 | # Pickle collected data for later comparisons. 19 | persistent=yes 20 | 21 | # List of plugins (as comma separated values of python modules names) to load, 22 | # usually to register additional checkers. 23 | load-plugins= 24 | 25 | # Use multiple processes to speed up Pylint. 26 | jobs=4 27 | 28 | # Allow loading of arbitrary C extensions. Extensions are imported into the 29 | # active Python interpreter and may run arbitrary code. 30 | unsafe-load-any-extension=no 31 | 32 | # A comma-separated list of package or module names from where C extensions may 33 | # be loaded. Extensions are loading into the active Python interpreter and may 34 | # run arbitrary code 35 | extension-pkg-whitelist=numpy,torch,spacy,_jsonnet 36 | 37 | # Allow optimization of some AST trees. This will activate a peephole AST 38 | # optimizer, which will apply various small optimizations. For instance, it can 39 | # be used to obtain the result of joining multiple strings with the addition 40 | # operator. Joining a lot of strings can lead to a maximum recursion error in 41 | # Pylint and this flag can prevent that. It has one side effect, the resulting 42 | # AST will be different than the one from reality. This option is deprecated 43 | # and it will be removed in Pylint 2.0. 44 | optimize-ast=no 45 | 46 | 47 | [MESSAGES CONTROL] 48 | 49 | # Only show warnings with the listed confidence levels. Leave empty to show 50 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 51 | confidence= 52 | 53 | # Enable the message, report, category or checker with the given id(s). You can 54 | # either give multiple identifier separated by comma (,) or put this option 55 | # multiple time (only on the command line, not in the configuration file where 56 | # it should appear only once). See also the "--disable" option for examples. 57 | #enable= 58 | 59 | # Disable the message, report, category or checker with the given id(s). You 60 | # can either give multiple identifiers separated by comma (,) or put this 61 | # option multiple times (only on the command line, not in the configuration 62 | # file where it should appear only once).You can also use "--disable=all" to 63 | # disable everything first and then reenable specific checks. For example, if 64 | # you want to run only the similarities checker, you can use "--disable=all 65 | # --enable=similarities". If you want to run only the classes checker, but have 66 | # no Warning level messages displayed, use"--disable=all --enable=classes 67 | # --disable=W" 68 | disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating,missing-docstring,too-many-arguments,too-many-locals,too-many-statements,too-many-branches,too-many-nested-blocks,too-many-instance-attributes,fixme,too-few-public-methods,no-else-return 69 | 70 | 71 | [REPORTS] 72 | 73 | # Set the output format. Available formats are text, parseable, colorized, msvs 74 | # (visual studio) and html. You can also give a reporter class, eg 75 | # mypackage.mymodule.MyReporterClass. 76 | output-format=text 77 | 78 | # Put messages in a separate file for each module / package specified on the 79 | # command line instead of printing them on stdout. Reports (if any) will be 80 | # written in a file name "pylint_global.[txt|html]". This option is deprecated 81 | # and it will be removed in Pylint 2.0. 82 | files-output=no 83 | 84 | # Tells whether to display a full report or only the messages 85 | reports=yes 86 | 87 | # Python expression which should return a note less than 10 (10 is the highest 88 | # note). You have access to the variables errors warning, statement which 89 | # respectively contain the number of errors / warnings messages and the total 90 | # number of statements analyzed. This is used by the global evaluation report 91 | # (RP0004). 92 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 93 | 94 | # Template used to display messages. This is a python new-style format string 95 | # used to format the message information. See doc for all details 96 | #msg-template= 97 | 98 | 99 | [LOGGING] 100 | 101 | # Logging modules to check that the string format arguments are in logging 102 | # function parameter format 103 | logging-modules=logging 104 | 105 | 106 | [TYPECHECK] 107 | 108 | # Tells whether missing members accessed in mixin class should be ignored. A 109 | # mixin class is detected if its name ends with "mixin" (case insensitive). 110 | ignore-mixin-members=yes 111 | 112 | # List of module names for which member attributes should not be checked 113 | # (useful for modules/projects where namespaces are manipulated during runtime 114 | # and thus existing member attributes cannot be deduced by static analysis. It 115 | # supports qualified module names, as well as Unix pattern matching. 116 | ignored-modules= 117 | 118 | # List of class names for which member attributes should not be checked (useful 119 | # for classes with dynamically set attributes). This supports the use of 120 | # qualified names. 121 | ignored-classes=optparse.Values,thread._local,_thread._local,responses 122 | 123 | # List of members which are set dynamically and missed by pylint inference 124 | # system, and so shouldn't trigger E1101 when accessed. Python regular 125 | # expressions are accepted. 126 | generated-members=torch.* 127 | 128 | # List of decorators that produce context managers, such as 129 | # contextlib.contextmanager. Add to this list to register other decorators that 130 | # produce valid context managers. 131 | contextmanager-decorators=contextlib.contextmanager 132 | 133 | 134 | [SIMILARITIES] 135 | 136 | # Minimum lines number of a similarity. 137 | min-similarity-lines=4 138 | 139 | # Ignore comments when computing similarities. 140 | ignore-comments=yes 141 | 142 | # Ignore docstrings when computing similarities. 143 | ignore-docstrings=yes 144 | 145 | # Ignore imports when computing similarities. 146 | ignore-imports=no 147 | 148 | 149 | [FORMAT] 150 | 151 | # Maximum number of characters on a single line. Ideally, lines should be under 100 characters, 152 | # but we allow some leeway before calling it an error. 153 | max-line-length=115 154 | 155 | # Regexp for a line that is allowed to be longer than the limit. 156 | ignore-long-lines=^\s*(# )??$ 157 | 158 | # Allow the body of an if to be on the same line as the test if there is no 159 | # else. 160 | single-line-if-stmt=no 161 | 162 | # List of optional constructs for which whitespace checking is disabled. `dict- 163 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 164 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 165 | # `empty-line` allows space-only lines. 166 | no-space-check=trailing-comma,dict-separator 167 | 168 | # Maximum number of lines in a module 169 | max-module-lines=1000 170 | 171 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 172 | # tab). 173 | indent-string=' ' 174 | 175 | # Number of spaces of indent required inside a hanging or continued line. 176 | indent-after-paren=8 177 | 178 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 179 | expected-line-ending-format= 180 | 181 | 182 | [BASIC] 183 | 184 | # Good variable names which should always be accepted, separated by a comma 185 | good-names=i,j,k,ex,Run,_ 186 | 187 | # Bad variable names which should always be refused, separated by a comma 188 | bad-names=foo,bar,baz,toto,tutu,tata 189 | 190 | # Colon-delimited sets of names that determine each other's naming style when 191 | # the name regexes allow several styles. 192 | name-group= 193 | 194 | # Include a hint for the correct naming format with invalid-name 195 | include-naming-hint=no 196 | 197 | # List of decorators that produce properties, such as abc.abstractproperty. Add 198 | # to this list to register other decorators that produce valid properties. 199 | property-classes=abc.abstractproperty 200 | 201 | # Regular expression matching correct function names 202 | function-rgx=[a-z_][a-z0-9_]{2,40}$ 203 | 204 | # Naming hint for function names 205 | function-name-hint=[a-z_][a-z0-9_]{2,40}$ 206 | 207 | # Regular expression matching correct variable names 208 | variable-rgx=[a-z_][a-z0-9_]{2,40}$ 209 | 210 | # Naming hint for variable names 211 | variable-name-hint=[a-z_][a-z0-9_]{2,40}$ 212 | 213 | # Regular expression matching correct constant names 214 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 215 | 216 | # Naming hint for constant names 217 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 218 | 219 | # Regular expression matching correct attribute names 220 | attr-rgx=[a-z_][a-z0-9_]{2,40}$ 221 | 222 | # Naming hint for attribute names 223 | attr-name-hint=[a-z_][a-z0-9_]{2,40}$ 224 | 225 | # Regular expression matching correct argument names 226 | argument-rgx=[a-z_][a-z0-9_]{2,40}$ 227 | 228 | # Naming hint for argument names 229 | argument-name-hint=[a-z_][a-z0-9_]{2,40}$ 230 | 231 | # Regular expression matching correct class attribute names 232 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,40}|(__.*__))$ 233 | 234 | # Naming hint for class attribute names 235 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,40}|(__.*__))$ 236 | 237 | # Regular expression matching correct inline iteration names 238 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ 239 | 240 | # Naming hint for inline iteration names 241 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ 242 | 243 | # Regular expression matching correct class names 244 | class-rgx=[A-Z_][a-zA-Z0-9]+$ 245 | 246 | # Naming hint for class names 247 | class-name-hint=[A-Z_][a-zA-Z0-9]+$ 248 | 249 | # Regular expression matching correct module names 250 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 251 | 252 | # Naming hint for module names 253 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 254 | 255 | # Regular expression matching correct method names 256 | method-rgx=[a-z_][a-z0-9_]{2,40}$ 257 | 258 | # Naming hint for method names 259 | method-name-hint=[a-z_][a-z0-9_]{2,40}$ 260 | 261 | # Regular expression which should only match function or class names that do 262 | # not require a docstring. 263 | no-docstring-rgx=^_ 264 | 265 | # Minimum line length for functions/classes that require docstrings, shorter 266 | # ones are exempt. 267 | docstring-min-length=-1 268 | 269 | 270 | [ELIF] 271 | 272 | # Maximum number of nested blocks for function / method body 273 | max-nested-blocks=5 274 | 275 | 276 | [VARIABLES] 277 | 278 | # Tells whether we should check for unused import in __init__ files. 279 | init-import=no 280 | 281 | # A regular expression matching the name of dummy variables (i.e. expectedly 282 | # not used). 283 | dummy-variables-rgx=(_+[a-zA-Z0-9]*?$)|dummy 284 | 285 | # List of additional names supposed to be defined in builtins. Remember that 286 | # you should avoid to define new builtins when possible. 287 | additional-builtins= 288 | 289 | # List of strings which can identify a callback function by name. A callback 290 | # name must start or end with one of those strings. 291 | callbacks=cb_,_cb 292 | 293 | # List of qualified module names which can have objects that can redefine 294 | # builtins. 295 | redefining-builtins-modules=six.moves,future.builtins 296 | 297 | 298 | [SPELLING] 299 | 300 | # Spelling dictionary name. Available dictionaries: none. To make it working 301 | # install python-enchant package. 302 | spelling-dict= 303 | 304 | # List of comma separated words that should not be checked. 305 | spelling-ignore-words= 306 | 307 | # A path to a file that contains private dictionary; one word per line. 308 | spelling-private-dict-file= 309 | 310 | # Tells whether to store unknown words to indicated private dictionary in 311 | # --spelling-private-dict-file option instead of raising a message. 312 | spelling-store-unknown-words=no 313 | 314 | 315 | [MISCELLANEOUS] 316 | 317 | # List of note tags to take in consideration, separated by a comma. 318 | notes=FIXME,XXX,TODO 319 | 320 | 321 | [DESIGN] 322 | 323 | # Maximum number of arguments for function / method 324 | max-args=5 325 | 326 | # Argument names that match this expression will be ignored. Default to name 327 | # with leading underscore 328 | ignored-argument-names=_.* 329 | 330 | # Maximum number of locals for function / method body 331 | max-locals=15 332 | 333 | # Maximum number of return / yield for function / method body 334 | max-returns=6 335 | 336 | # Maximum number of branch for function / method body 337 | max-branches=12 338 | 339 | # Maximum number of statements in function / method body 340 | max-statements=50 341 | 342 | # Maximum number of parents for a class (see R0901). 343 | max-parents=7 344 | 345 | # Maximum number of attributes for a class (see R0902). 346 | max-attributes=7 347 | 348 | # Minimum number of public methods for a class (see R0903). 349 | min-public-methods=2 350 | 351 | # Maximum number of public methods for a class (see R0904). 352 | max-public-methods=20 353 | 354 | # Maximum number of boolean expressions in a if statement 355 | max-bool-expr=5 356 | 357 | 358 | [CLASSES] 359 | 360 | # List of method names used to declare (i.e. assign) instance attributes. 361 | defining-attr-methods=__init__,__new__,setUp 362 | 363 | # List of valid names for the first argument in a class method. 364 | valid-classmethod-first-arg=cls 365 | 366 | # List of valid names for the first argument in a metaclass class method. 367 | valid-metaclass-classmethod-first-arg=mcs 368 | 369 | # List of member names, which should be excluded from the protected access 370 | # warning. 371 | exclude-protected=_asdict,_fields,_replace,_source,_make 372 | 373 | 374 | [IMPORTS] 375 | 376 | # Deprecated modules which should not be used, separated by a comma 377 | deprecated-modules=regsub,TERMIOS,Bastion,rexec 378 | 379 | # Create a graph of every (i.e. internal and external) dependencies in the 380 | # given file (report RP0402 must not be disabled) 381 | import-graph= 382 | 383 | # Create a graph of external dependencies in the given file (report RP0402 must 384 | # not be disabled) 385 | ext-import-graph= 386 | 387 | # Create a graph of internal dependencies in the given file (report RP0402 must 388 | # not be disabled) 389 | int-import-graph= 390 | 391 | # Force import order to recognize a module as part of the standard 392 | # compatibility libraries. 393 | known-standard-library= 394 | 395 | # Force import order to recognize a module as part of a third party library. 396 | known-third-party=enchant 397 | 398 | # Analyse import fallback blocks. This can be used to support both Python 2 and 399 | # 3 compatible code, which means that the block might have code that exists 400 | # only in one or another interpreter, leading to false positives when analysed. 401 | analyse-fallback-blocks=no 402 | 403 | 404 | [EXCEPTIONS] 405 | 406 | # Exceptions that will emit a warning when being caught. Defaults to 407 | # "Exception" 408 | overgeneral-exceptions=Exception 409 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | dist: xenial 3 | python: 4 | - "3.6" 5 | install: 6 | - pip install -U pip 7 | - pip install -U torch thinc spacy pluggy 8 | - pip install -U -r requirements_dev.txt --progress-bar off 9 | - python setup.py develop 10 | script: 11 | - py.test 12 | - pylint pytorch_fast_elmo 13 | - yapf -d -r pytorch_fast_elmo 14 | - mypy pytorch_fast_elmo --strict --ignore-missing-imports 15 | deploy: 16 | provider: pypi 17 | distributions: sdist 18 | user: huntzhan 19 | password: 20 | secure: fifYVP1JpGUQTsToNDvReKBbS1o/54GHdi13I/bgU+l3AQq5yulChvukGlwJNsGTYL4DT6exWdIGv6o6gZbZgecAdDikVJnih8oxYFLqJyLEWWZ7nrJEO6OvXfTflNaRSabIduGNRxgknt66FCKLmolcSEpMht6+FSnBy/Ig3VP8R581CHKqUusMCDd3eBNA2Ag86GdLrVEMDh7t3xWzXXjRNnA+MHrG1SUROIR0qEOBz+vtd7elgoZ7ciV7G4FgAlAzSzFsicfDCaAmpu6yCcA1gqS2zf+8x8hIWAFjA0glvkB7243WMVPM6GxN+SPpD2zhgtQRvT5+PNZzvFTE9Gtw76Jvs+5IDpqaPQ7ugeKktIwu3XPgo+dRblx75AINLvQLi+dekU92sh4u/ilLEMaKmEF5WNmgSTMrFbX+mbVF4d5d2tig6Ni3/QA5MNM73p6e5YmAoRGYEB15rdBQTDGrtCOW21LVOFS3BfuQ4Hg5W3izNi2Vm9Kw5qutYfTa5t1SVXdJPRnOx/pT5ipf60UOuz6AUcYNbS348Yd7vjX12/LjsH52pPWn9m7z5xN6G+YKVFy9sfA2sBhehkiQYj5DGp3Hzmeg4p7ufmA4+lrm3Dkb9dHQ48oT/S0hCdb9iKHK06rwH7xzVRXb2GDcn7B79yqsDgG63W4Vb8cNoRM= 21 | on: 22 | tags: true 23 | repo: cnt-dev/pytorch-fast-elmo 24 | python: 3.6 25 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Credits 3 | ======= 4 | 5 | Development Lead 6 | ---------------- 7 | 8 | * Hunt Zhan 9 | 10 | Contributors 11 | ------------ 12 | 13 | None yet. Why not be the first? 14 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Contributing 5 | ============ 6 | 7 | Contributions are welcome, and they are greatly appreciated! Every little bit 8 | helps, and credit will always be given. 9 | 10 | You can contribute in many ways: 11 | 12 | Types of Contributions 13 | ---------------------- 14 | 15 | Report Bugs 16 | ~~~~~~~~~~~ 17 | 18 | Report bugs at https://github.com/cnt-dev/pytorch-fast-elmo/issues. 19 | 20 | If you are reporting a bug, please include: 21 | 22 | * Your operating system name and version. 23 | * Any details about your local setup that might be helpful in troubleshooting. 24 | * Detailed steps to reproduce the bug. 25 | 26 | Fix Bugs 27 | ~~~~~~~~ 28 | 29 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help 30 | wanted" is open to whoever wants to implement it. 31 | 32 | Implement Features 33 | ~~~~~~~~~~~~~~~~~~ 34 | 35 | Look through the GitHub issues for features. Anything tagged with "enhancement" 36 | and "help wanted" is open to whoever wants to implement it. 37 | 38 | Write Documentation 39 | ~~~~~~~~~~~~~~~~~~~ 40 | 41 | pytorch-fast-elmo could always use more documentation, whether as part of the 42 | official pytorch-fast-elmo docs, in docstrings, or even on the web in blog posts, 43 | articles, and such. 44 | 45 | Submit Feedback 46 | ~~~~~~~~~~~~~~~ 47 | 48 | The best way to send feedback is to file an issue at https://github.com/cnt-dev/pytorch-fast-elmo/issues. 49 | 50 | If you are proposing a feature: 51 | 52 | * Explain in detail how it would work. 53 | * Keep the scope as narrow as possible, to make it easier to implement. 54 | * Remember that this is a volunteer-driven project, and that contributions 55 | are welcome :) 56 | 57 | Get Started! 58 | ------------ 59 | 60 | Ready to contribute? Here's how to set up `pytorch_fast_elmo` for local development. 61 | 62 | 1. Fork the `pytorch_fast_elmo` repo on GitHub. 63 | 2. Clone your fork locally:: 64 | 65 | $ git clone git@github.com:your_name_here/pytorch_fast_elmo.git 66 | 67 | 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: 68 | 69 | $ mkvirtualenv pytorch_fast_elmo 70 | $ cd pytorch_fast_elmo/ 71 | $ python setup.py develop 72 | 73 | 4. Create a branch for local development:: 74 | 75 | $ git checkout -b name-of-your-bugfix-or-feature 76 | 77 | Now you can make your changes locally. 78 | 79 | 5. When you're done making changes, check that your changes pass flake8 and the 80 | tests, including testing other Python versions with tox:: 81 | 82 | $ flake8 pytorch_fast_elmo tests 83 | $ python setup.py test or py.test 84 | $ tox 85 | 86 | To get flake8 and tox, just pip install them into your virtualenv. 87 | 88 | 6. Commit your changes and push your branch to GitHub:: 89 | 90 | $ git add . 91 | $ git commit -m "Your detailed description of your changes." 92 | $ git push origin name-of-your-bugfix-or-feature 93 | 94 | 7. Submit a pull request through the GitHub website. 95 | 96 | Pull Request Guidelines 97 | ----------------------- 98 | 99 | Before you submit a pull request, check that it meets these guidelines: 100 | 101 | 1. The pull request should include tests. 102 | 2. If the pull request adds functionality, the docs should be updated. Put 103 | your new functionality into a function with a docstring, and add the 104 | feature to the list in README.rst. 105 | 3. The pull request should work for Python 2.7, 3.4, 3.5 and 3.6, and for PyPy. Check 106 | https://travis-ci.org/cnt-dev/pytorch-fast-elmo/pull_requests 107 | and make sure that the tests pass for all supported Python versions. 108 | 109 | Tips 110 | ---- 111 | 112 | To run a subset of tests:: 113 | 114 | $ py.test tests.test_pytorch_fast_elmo 115 | 116 | 117 | Deploying 118 | --------- 119 | 120 | A reminder for the maintainers on how to deploy. 121 | Make sure all your changes are committed (including an entry in HISTORY.rst). 122 | Then run:: 123 | 124 | $ bumpversion patch # possible: major / minor / patch 125 | $ git push 126 | $ git push --tags 127 | 128 | Travis will then deploy to PyPI if tests pass. 129 | -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | History 3 | ======= 4 | 5 | 0.1.0 (2019-01-02) 6 | ------------------ 7 | 8 | * First release on PyPI. 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019, Hunt Zhan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include AUTHORS.rst 2 | include CONTRIBUTING.rst 3 | include HISTORY.rst 4 | include LICENSE 5 | include README.rst 6 | include requirements_*.txt 7 | 8 | recursive-include extension * 9 | recursive-include tests * 10 | recursive-exclude * __pycache__ 11 | recursive-exclude * *.py[co] 12 | 13 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | 4 | define BROWSER_PYSCRIPT 5 | import os, webbrowser, sys 6 | 7 | try: 8 | from urllib import pathname2url 9 | except: 10 | from urllib.request import pathname2url 11 | 12 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 13 | endef 14 | export BROWSER_PYSCRIPT 15 | 16 | define PRINT_HELP_PYSCRIPT 17 | import re, sys 18 | 19 | for line in sys.stdin: 20 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 21 | if match: 22 | target, help = match.groups() 23 | print("%-20s %s" % (target, help)) 24 | endef 25 | export PRINT_HELP_PYSCRIPT 26 | 27 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 28 | 29 | help: 30 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 31 | 32 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 33 | 34 | clean-build: ## remove build artifacts 35 | rm -fr build/ 36 | rm -fr dist/ 37 | rm -fr .eggs/ 38 | find . -name '*.egg-info' -exec rm -fr {} + 39 | find . -name '*.egg' -exec rm -f {} + 40 | 41 | clean-pyc: ## remove Python file artifacts 42 | find . -name '*.pyc' -exec rm -f {} + 43 | find . -name '*.pyo' -exec rm -f {} + 44 | find . -name '*~' -exec rm -f {} + 45 | find . -name '__pycache__' -exec rm -fr {} + 46 | 47 | clean-test: ## remove test and coverage artifacts 48 | rm -fr .tox/ 49 | rm -f .coverage 50 | rm -fr htmlcov/ 51 | rm -fr .pytest_cache 52 | 53 | lint: ## check style with flake8 54 | flake8 pytorch_fast_elmo tests 55 | 56 | test: ## run tests quickly with the default Python 57 | py.test 58 | 59 | test-all: ## run tests on every Python version with tox 60 | tox 61 | 62 | coverage: ## check code coverage quickly with the default Python 63 | coverage run --source pytorch_fast_elmo -m pytest 64 | coverage report -m 65 | coverage html 66 | $(BROWSER) htmlcov/index.html 67 | 68 | docs: ## generate Sphinx HTML documentation, including API docs 69 | rm -f docs/pytorch_fast_elmo.rst 70 | rm -f docs/modules.rst 71 | sphinx-apidoc -o docs/ pytorch_fast_elmo 72 | $(MAKE) -C docs clean 73 | $(MAKE) -C docs html 74 | $(BROWSER) docs/_build/html/index.html 75 | 76 | servedocs: docs ## compile the docs watching for changes 77 | watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . 78 | 79 | release: dist ## package and upload a release 80 | twine upload dist/* 81 | 82 | dist: clean ## builds source and wheel package 83 | python setup.py sdist 84 | python setup.py bdist_wheel 85 | ls -l dist 86 | 87 | install: clean ## install the package to the active Python's site-packages 88 | python setup.py install 89 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | pytorch-fast-elmo 3 | ================= 4 | 5 | 6 | .. image:: https://img.shields.io/pypi/v/pytorch_fast_elmo.svg 7 | :target: https://pypi.python.org/pypi/pytorch_fast_elmo 8 | 9 | .. image:: https://img.shields.io/travis/cnt-dev/pytorch-fast-elmo.svg 10 | :target: https://travis-ci.org/cnt-dev/pytorch-fast-elmo 11 | 12 | .. image:: https://img.shields.io/badge/License-MIT-yellow.svg 13 | :target: https://travis-ci.org/cnt-dev/pytorch-fast-elmo 14 | 15 | 16 | Introduction 17 | ------------ 18 | 19 | A fast ELMo implementation with features: 20 | 21 | - **Lower execution overhead.** The core components are reimplemented in Libtorch in order to reduce the Python execution overhead (**45%** speedup). 22 | - **A more flexible design.** By redesigning the workflow, the user could extend or change the ELMo behavior easily. 23 | 24 | Benchmark 25 | --------- 26 | 27 | Hardware: 28 | 29 | - CPU: i7-7800X 30 | - GPU: 1080Ti 31 | 32 | Options: 33 | 34 | - Batch size: 32 35 | - Warm up iterations: 20 36 | - Test iterations: 1000 37 | - Word length: [1, 20] 38 | - Sentence length: [1, 30] 39 | - Random seed: 10000 40 | 41 | +--------------------------------------+------------------------+------------------------+ 42 | | Item | Mean Of Durations (ms) | cumtime(synchronize)% | 43 | +======================================+========================+========================+ 44 | | Fast ELMo (CUDA, no synchronize) | **31** | N/A | 45 | +--------------------------------------+------------------------+------------------------+ 46 | | AllenNLP ELMo (CUDA, no synchronize) | 56 | N/A | 47 | +--------------------------------------+------------------------+------------------------+ 48 | | Fast ELMo (CUDA, synchronize) | 47 | **26.13%** | 49 | +--------------------------------------+------------------------+------------------------+ 50 | | AllenNLP ELMo (CUDA, synchronize) | 57 | 0.02% | 51 | +--------------------------------------+------------------------+------------------------+ 52 | | Fast ELMo (CPU) | 1277 | N/A | 53 | +--------------------------------------+------------------------+------------------------+ 54 | | AllenNLP ELMo (CPU) | 1453 | N/A | 55 | +--------------------------------------+------------------------+------------------------+ 56 | 57 | Usage 58 | ----- 59 | 60 | Please install **torch==1.0.0** first. Then, simply run this command to install. 61 | 62 | .. code-block:: bash 63 | 64 | pip install pytorch-fast-elmo 65 | 66 | 67 | ``FastElmo`` should have the same behavior as AllenNLP's ``ELMo``. 68 | 69 | .. code-block:: python 70 | 71 | from pytorch_fast_elmo import FastElmo, batch_to_char_ids 72 | 73 | options_file = '/path/to/elmo_2x4096_512_2048cnn_2xhighway_options.json' 74 | weight_file = '/path/to/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5' 75 | 76 | elmo = FastElmo(options_file, weight_file) 77 | 78 | sentences = [['First', 'sentence', '.'], ['Another', '.']] 79 | character_ids = batch_to_ids(sentences) 80 | 81 | embeddings = elmo(character_ids) 82 | 83 | 84 | Use ``FastElmoWordEmbedding`` if you have disabled ``char_cnn`` in ``bilm-tf``, or have exported the Char CNN representation to a weight file. 85 | 86 | .. code-block:: python 87 | 88 | from pytorch_fast_elmo import FastElmoWordEmbedding, load_and_build_vocab2id, batch_to_word_ids 89 | 90 | options_file = '/path/to/elmo_2x4096_512_2048cnn_2xhighway_options.json' 91 | weight_file = '/path/to/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5' 92 | 93 | vocab_file = '/path/to/vocab.txt' 94 | embedding_file = '/path/to/cached_elmo_embedding.hdf5' 95 | 96 | elmo = FastElmoWordEmbedding( 97 | options_file, 98 | weight_file, 99 | # Could be omitted if the embedding weight is in `weight_file`. 100 | word_embedding_weight_file=embedding_file, 101 | ) 102 | vocab2id = load_and_build_vocab2id(vocab_file) 103 | 104 | sentences = [['First', 'sentence', '.'], ['Another', '.']] 105 | word_ids = batch_to_word_ids(sentences, vocab2id) 106 | 107 | embeddings = elmo(word_ids) 108 | 109 | 110 | CLI commands: 111 | 112 | .. code-block:: bash 113 | 114 | # Cache the Char CNN representation. 115 | fast-elmo cache-char-cnn ./vocab.txt ./options.json ./lm_weights.hdf5 ./lm_embd.hdf5 116 | 117 | # Export word embedding. 118 | fast-elmo export-word-embd ./vocab.txt ./no-char-cnn.hdf5 ./embd.txt 119 | 120 | 121 | Credits 122 | ------- 123 | 124 | This package was created with Cookiecutter_ and the `audreyr/cookiecutter-pypackage`_ project template. 125 | 126 | .. _Cookiecutter: https://github.com/audreyr/cookiecutter 127 | .. _`audreyr/cookiecutter-pypackage`: https://github.com/audreyr/cookiecutter-pypackage 128 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = pytorch_fast_elmo 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../AUTHORS.rst 2 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # pytorch_fast_elmo documentation build configuration file, created by 5 | # sphinx-quickstart on Fri Jun 9 13:47:02 2017. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another 17 | # directory, add these directories to sys.path here. If the directory is 18 | # relative to the documentation root, use os.path.abspath to make it 19 | # absolute, like shown here. 20 | # 21 | import os 22 | import sys 23 | sys.path.insert(0, os.path.abspath('..')) 24 | 25 | import pytorch_fast_elmo 26 | 27 | # -- General configuration --------------------------------------------- 28 | 29 | # If your documentation needs a minimal Sphinx version, state it here. 30 | # 31 | # needs_sphinx = '1.0' 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 35 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode'] 36 | 37 | # Add any paths that contain templates here, relative to this directory. 38 | templates_path = ['_templates'] 39 | 40 | # The suffix(es) of source filenames. 41 | # You can specify multiple suffix as a list of string: 42 | # 43 | # source_suffix = ['.rst', '.md'] 44 | source_suffix = '.rst' 45 | 46 | # The master toctree document. 47 | master_doc = 'index' 48 | 49 | # General information about the project. 50 | project = u'pytorch-fast-elmo' 51 | copyright = u"2019, Hunt Zhan" 52 | author = u"Hunt Zhan" 53 | 54 | # The version info for the project you're documenting, acts as replacement 55 | # for |version| and |release|, also used in various other places throughout 56 | # the built documents. 57 | # 58 | # The short X.Y version. 59 | version = pytorch_fast_elmo.__version__ 60 | # The full version, including alpha/beta/rc tags. 61 | release = pytorch_fast_elmo.__version__ 62 | 63 | # The language for content autogenerated by Sphinx. Refer to documentation 64 | # for a list of supported languages. 65 | # 66 | # This is also used if you do content translation via gettext catalogs. 67 | # Usually you set "language" from the command line for these cases. 68 | language = None 69 | 70 | # List of patterns, relative to source directory, that match files and 71 | # directories to ignore when looking for source files. 72 | # This patterns also effect to html_static_path and html_extra_path 73 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 74 | 75 | # The name of the Pygments (syntax highlighting) style to use. 76 | pygments_style = 'sphinx' 77 | 78 | # If true, `todo` and `todoList` produce output, else they produce nothing. 79 | todo_include_todos = False 80 | 81 | 82 | # -- Options for HTML output ------------------------------------------- 83 | 84 | # The theme to use for HTML and HTML Help pages. See the documentation for 85 | # a list of builtin themes. 86 | # 87 | html_theme = 'alabaster' 88 | 89 | # Theme options are theme-specific and customize the look and feel of a 90 | # theme further. For a list of options available for each theme, see the 91 | # documentation. 92 | # 93 | # html_theme_options = {} 94 | 95 | # Add any paths that contain custom static files (such as style sheets) here, 96 | # relative to this directory. They are copied after the builtin static files, 97 | # so a file named "default.css" will overwrite the builtin "default.css". 98 | html_static_path = ['_static'] 99 | 100 | 101 | # -- Options for HTMLHelp output --------------------------------------- 102 | 103 | # Output file base name for HTML help builder. 104 | htmlhelp_basename = 'pytorch_fast_elmodoc' 105 | 106 | 107 | # -- Options for LaTeX output ------------------------------------------ 108 | 109 | latex_elements = { 110 | # The paper size ('letterpaper' or 'a4paper'). 111 | # 112 | # 'papersize': 'letterpaper', 113 | 114 | # The font size ('10pt', '11pt' or '12pt'). 115 | # 116 | # 'pointsize': '10pt', 117 | 118 | # Additional stuff for the LaTeX preamble. 119 | # 120 | # 'preamble': '', 121 | 122 | # Latex figure (float) alignment 123 | # 124 | # 'figure_align': 'htbp', 125 | } 126 | 127 | # Grouping the document tree into LaTeX files. List of tuples 128 | # (source start file, target name, title, author, documentclass 129 | # [howto, manual, or own class]). 130 | latex_documents = [ 131 | (master_doc, 'pytorch_fast_elmo.tex', 132 | u'pytorch-fast-elmo Documentation', 133 | u'Hunt Zhan', 'manual'), 134 | ] 135 | 136 | 137 | # -- Options for manual page output ------------------------------------ 138 | 139 | # One entry per manual page. List of tuples 140 | # (source start file, name, description, authors, manual section). 141 | man_pages = [ 142 | (master_doc, 'pytorch_fast_elmo', 143 | u'pytorch-fast-elmo Documentation', 144 | [author], 1) 145 | ] 146 | 147 | 148 | # -- Options for Texinfo output ---------------------------------------- 149 | 150 | # Grouping the document tree into Texinfo files. List of tuples 151 | # (source start file, target name, title, author, 152 | # dir menu entry, description, category) 153 | texinfo_documents = [ 154 | (master_doc, 'pytorch_fast_elmo', 155 | u'pytorch-fast-elmo Documentation', 156 | author, 157 | 'pytorch_fast_elmo', 158 | 'One line description of project.', 159 | 'Miscellaneous'), 160 | ] 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /docs/history.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../HISTORY.rst 2 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to pytorch-fast-elmo's documentation! 2 | ====================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Contents: 7 | 8 | readme 9 | installation 10 | usage 11 | modules 12 | contributing 13 | authors 14 | history 15 | 16 | Indices and tables 17 | ================== 18 | * :ref:`genindex` 19 | * :ref:`modindex` 20 | * :ref:`search` 21 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Installation 5 | ============ 6 | 7 | 8 | Stable release 9 | -------------- 10 | 11 | To install pytorch-fast-elmo, run this command in your terminal: 12 | 13 | .. code-block:: console 14 | 15 | $ pip install pytorch_fast_elmo 16 | 17 | This is the preferred method to install pytorch-fast-elmo, as it will always install the most recent stable release. 18 | 19 | If you don't have `pip`_ installed, this `Python installation guide`_ can guide 20 | you through the process. 21 | 22 | .. _pip: https://pip.pypa.io 23 | .. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ 24 | 25 | 26 | From sources 27 | ------------ 28 | 29 | The sources for pytorch-fast-elmo can be downloaded from the `Github repo`_. 30 | 31 | You can either clone the public repository: 32 | 33 | .. code-block:: console 34 | 35 | $ git clone git://github.com/cnt-dev/pytorch-fast-elmo 36 | 37 | Or download the `tarball`_: 38 | 39 | .. code-block:: console 40 | 41 | $ curl -OL https://github.com/cnt-dev/pytorch-fast-elmo/tarball/master 42 | 43 | Once you have a copy of the source, you can install it with: 44 | 45 | .. code-block:: console 46 | 47 | $ python setup.py install 48 | 49 | 50 | .. _Github repo: https://github.com/cnt-dev/pytorch-fast-elmo 51 | .. _tarball: https://github.com/cnt-dev/pytorch-fast-elmo/tarball/master 52 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=pytorch_fast_elmo 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/readme.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | -------------------------------------------------------------------------------- /docs/usage.rst: -------------------------------------------------------------------------------- 1 | ===== 2 | Usage 3 | ===== 4 | 5 | To use pytorch-fast-elmo in a project:: 6 | 7 | import pytorch_fast_elmo 8 | -------------------------------------------------------------------------------- /extension/.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 2 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /extension/bind.cc: -------------------------------------------------------------------------------- 1 | #include "extension/elmo_character_encoder.h" 2 | #include "extension/scalar_mix.h" 3 | 4 | template 5 | py::class_ patch_methods( 6 | py::class_ module) { 7 | module.attr("cuda") = nullptr; 8 | module.def( 9 | "cuda", 10 | [](ModuleType& module, torch::optional device) { 11 | if (device.has_value()) { 12 | module.to("cuda:" + std::to_string(device.value())); 13 | } else { 14 | module.to(at::kCUDA); 15 | } 16 | return module; 17 | }); 18 | module.def( 19 | "cuda", 20 | [](ModuleType& module) { 21 | module.to(at::kCUDA); 22 | return module; 23 | }); 24 | 25 | module.attr("cpu") = nullptr; 26 | module.def( 27 | "cpu", 28 | [](ModuleType& module) { 29 | module.to(at::kCPU); 30 | return module; 31 | }); 32 | 33 | return module; 34 | } 35 | 36 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 37 | patch_methods( 38 | torch::python::bind_module( 39 | m, "ElmoCharacterEncoder")) 40 | 41 | .def( 42 | py::init< 43 | int64_t, int64_t, 44 | cnt::ElmoCharacterEncoderFiltersType, 45 | std::string, 46 | int64_t, int64_t>(), 47 | // Required. 48 | py::arg("char_embedding_cnt"), 49 | py::arg("char_embedding_dim"), 50 | py::arg("filters"), 51 | py::arg("activation"), 52 | py::arg("num_highway_layers"), 53 | py::arg("output_dim")) 54 | 55 | .def( 56 | "__call__", 57 | &cnt::ElmoCharacterEncoderImpl::forward); 58 | 59 | patch_methods( 60 | torch::python::bind_module( 61 | m, "ScalarMix")) 62 | 63 | .def( 64 | py::init< 65 | int64_t, 66 | bool, 67 | std::vector, 68 | bool>(), 69 | // Required. 70 | py::arg("mixture_size"), 71 | // Optional. 72 | py::arg("do_layer_norm") = false, 73 | py::arg("initial_scalar_parameters") = std::vector(), 74 | py::arg("trainable") = true) 75 | 76 | .def( 77 | "__call__", 78 | &cnt::ScalarMixImpl::forward) 79 | 80 | .def( 81 | "__call__", 82 | [](cnt::ScalarMixImpl& module, 83 | const std::vector &tensors) { 84 | return module.forward(tensors, torch::Tensor()); 85 | }); 86 | } 87 | -------------------------------------------------------------------------------- /extension/elmo_character_encoder.cc: -------------------------------------------------------------------------------- 1 | #include "extension/elmo_character_encoder.h" 2 | #include 3 | 4 | namespace cnt { 5 | 6 | HighwayImpl::HighwayImpl( 7 | int64_t input_dim, 8 | int64_t num_layers, 9 | TorchActivationType activation) 10 | : 11 | input_dim_(input_dim), 12 | activation_(activation) { 13 | // Build layers. 14 | for (int64_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) { 15 | // Ouptut: concat(H, T). 16 | auto layer = torch::nn::Linear(input_dim_, input_dim_ * 2); 17 | // Initially biased towards carry behavior. 18 | layer->bias 19 | .detach() 20 | .narrow(0, input_dim_, input_dim_) 21 | .fill_(-1); 22 | // Note: Libtorch 1.0 doesn't support module list, 23 | // so we need to register all layers explictly. 24 | register_module( 25 | "layers_" + std::to_string(layer_idx), 26 | layer); 27 | layers_.push_back(layer); 28 | } 29 | } 30 | 31 | torch::Tensor HighwayImpl::forward(torch::Tensor inputs) { 32 | auto cur_inputs = inputs; 33 | 34 | for (auto layer : layers_) { 35 | auto proj_inputs = layer(cur_inputs); 36 | 37 | auto transform = proj_inputs.narrow(-1, 0, input_dim_); 38 | auto transform_gate = proj_inputs.narrow(-1, input_dim_, input_dim_); 39 | 40 | transform = activation_(transform); 41 | transform_gate = torch::sigmoid(transform_gate); 42 | cur_inputs = transform_gate * transform + (1 - transform_gate) * cur_inputs; 43 | } 44 | return cur_inputs; 45 | } 46 | 47 | ElmoCharacterEncoderImpl::ElmoCharacterEncoderImpl( 48 | int64_t char_embedding_cnt, 49 | int64_t char_embedding_dim, 50 | ElmoCharacterEncoderFiltersType filters, 51 | std::string activation, 52 | int64_t num_highway_layers, 53 | int64_t output_dim) { 54 | // Build char embedding. 55 | char_embedding_ = torch::nn::Embedding( 56 | // Add offset 1 for padding. 57 | char_embedding_cnt + 1, 58 | char_embedding_dim); 59 | register_module("char_embedding", char_embedding_); 60 | 61 | // Build CNN. 62 | int64_t total_out_channels = 0; 63 | for (int64_t conv_idx = 0; 64 | conv_idx < static_cast(filters.size()); 65 | ++conv_idx) { 66 | // Config. 67 | auto kernel_size = std::get<0>(filters[conv_idx]); 68 | auto out_channels = std::get<1>(filters[conv_idx]); 69 | 70 | total_out_channels += out_channels; 71 | 72 | auto conv_options = 73 | torch::nn::Conv1dOptions( 74 | char_embedding_dim, 75 | out_channels, 76 | kernel_size) 77 | // Explicitly set bias. 78 | .with_bias(true); 79 | 80 | // Build. 81 | auto conv = torch::nn::Conv1d(conv_options); 82 | register_module( 83 | "char_conv_" + std::to_string(conv_idx), 84 | conv); 85 | convolutions_.push_back(conv); 86 | } 87 | 88 | // Bind CNN activation. 89 | if (activation == "tanh") { 90 | activation_ = &torch::tanh; 91 | } else if (activation == "relu") { 92 | activation_ = &torch::relu; 93 | } else { 94 | throw std::invalid_argument("Invalid activation."); 95 | } 96 | 97 | // Build highway layers. 98 | highway_ = Highway( 99 | total_out_channels, 100 | num_highway_layers, 101 | // hardcoded as bilm-tf. 102 | &torch::relu); 103 | register_module("highway", highway_); 104 | 105 | // Build projection. 106 | output_proj_ = torch::nn::Linear(total_out_channels, output_dim); 107 | register_module("output_proj", output_proj_); 108 | } 109 | 110 | torch::Tensor ElmoCharacterEncoderImpl::forward(torch::Tensor inputs) { 111 | // Of shape `(*, char_embedding_dim, max_chars_per_token)`. 112 | auto char_embds = char_embedding_(inputs).transpose(1, 2); 113 | 114 | // Apply CNN. 115 | std::vector conv_outputs(convolutions_.size()); 116 | for (int64_t conv_idx = 0; 117 | conv_idx < static_cast(convolutions_.size()); 118 | ++conv_idx) { 119 | // `(*, C_out, L_out)` 120 | auto convolved = convolutions_[conv_idx](char_embds); 121 | // `(*, C_out)` 122 | convolved = std::get<0>(torch::max(convolved, -1)); // NOLINT 123 | convolved = activation_(convolved); 124 | 125 | conv_outputs[conv_idx] = convolved; 126 | } 127 | // `(*, total_out_channels)` 128 | auto char_repr = torch::cat(conv_outputs, -1); 129 | 130 | // Apply highway. 131 | // `(*, total_out_channels)` 132 | char_repr = highway_(char_repr); 133 | 134 | // Apply output projection. 135 | // `(*, output_dim)` 136 | char_repr = output_proj_(char_repr); 137 | 138 | return char_repr; 139 | } 140 | 141 | } // namespace cnt 142 | -------------------------------------------------------------------------------- /extension/elmo_character_encoder.h: -------------------------------------------------------------------------------- 1 | #ifndef EXTENSION_ELMO_CHARACTER_ENCODER_H_ 2 | #define EXTENSION_ELMO_CHARACTER_ENCODER_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace cnt { 10 | 11 | using TorchActivationType = 12 | torch::Tensor (*)(const torch::Tensor &); 13 | using ElmoCharacterEncoderFiltersType = 14 | std::vector>; 15 | 16 | struct HighwayImpl : torch::nn::Module { 17 | // https://arxiv.org/abs/1505.00387 18 | HighwayImpl( 19 | int64_t input_dim, 20 | int64_t num_layers, 21 | TorchActivationType activation); 22 | 23 | torch::Tensor forward(torch::Tensor inputs); 24 | 25 | int64_t input_dim_ = -1; 26 | std::vector layers_ = {}; 27 | TorchActivationType activation_ = nullptr; 28 | }; 29 | 30 | TORCH_MODULE(Highway); 31 | 32 | struct ElmoCharacterEncoderImpl : torch::nn::Module { 33 | ElmoCharacterEncoderImpl( 34 | // Char embedding. 35 | int64_t char_embedding_cnt, 36 | int64_t char_embedding_dim, 37 | 38 | // CNN filters: [[, ], ...] 39 | // Example: [[1, 4], [2, 8], [3, 16], [4, 32], [5, 64]] 40 | ElmoCharacterEncoderFiltersType filters, 41 | // CNN activation (supports "relu", "tanh"). 42 | std::string activation, 43 | 44 | // The number of highways. 45 | int64_t num_highway_layers, 46 | // The final projection size. 47 | int64_t output_dim); 48 | 49 | // Inputs: inputs 50 | // - **inputs** of shape `(*, max_characters_per_token)`: 51 | // tensor of `PackedSequence.data`. 52 | // 53 | // Outputs: output 54 | // - **output** of shape `(*, output_dim)`: 55 | // tensor containing the representations of character. 56 | // 57 | // Note: Different to AllenNLP's implementation, 58 | // BOS/EOS will not be injected here. 59 | torch::Tensor forward(torch::Tensor inputs); 60 | 61 | // Char embedding. 62 | torch::nn::Embedding char_embedding_ = nullptr; 63 | 64 | // CNN. 65 | std::vector convolutions_ = {}; 66 | TorchActivationType activation_ = nullptr; 67 | 68 | // Highway. 69 | Highway highway_ = nullptr; 70 | 71 | // Output projection. 72 | torch::nn::Linear output_proj_ = nullptr; 73 | }; 74 | 75 | TORCH_MODULE(ElmoCharacterEncoder); 76 | 77 | } // namespace cnt 78 | 79 | #endif // EXTENSION_ELMO_CHARACTER_ENCODER_H_ 80 | -------------------------------------------------------------------------------- /extension/scalar_mix.cc: -------------------------------------------------------------------------------- 1 | #include "extension/scalar_mix.h" 2 | #include 3 | 4 | namespace cnt { 5 | 6 | ScalarMixImpl::ScalarMixImpl( 7 | int64_t mixture_size, 8 | bool do_layer_norm, 9 | std::vector initial_scalar_parameters, 10 | bool trainable) 11 | : 12 | mixture_size_(mixture_size), 13 | do_layer_norm_(do_layer_norm) { 14 | if (initial_scalar_parameters.empty()) { 15 | // Initialize with 1/n. 16 | initial_scalar_parameters.insert( 17 | initial_scalar_parameters.end(), 18 | mixture_size, 19 | 1.0 / static_cast(mixture_size)); 20 | } else if ( 21 | static_cast(initial_scalar_parameters.size()) != \ 22 | mixture_size) { 23 | throw std::invalid_argument( 24 | "initial_scalar_parameters & mixture_size not match."); 25 | } 26 | 27 | // Build scalar_parameters & gamma. 28 | // scalar_parameters. 29 | for (int64_t idx = 0; 30 | idx < static_cast(initial_scalar_parameters.size()); 31 | ++idx) { 32 | auto scalar = torch::zeros({1}, torch::dtype(torch::kFloat32)); 33 | scalar.detach().fill_(initial_scalar_parameters[idx]); 34 | 35 | register_parameter("scalar_" + std::to_string(idx), scalar, trainable); 36 | scalar_parameters_.push_back(scalar); 37 | } 38 | // gamma. 39 | gamma_ = torch::ones({1}, torch::dtype(torch::kFloat32)); 40 | register_parameter("gamma", gamma_, trainable); 41 | } 42 | 43 | // Inputs: tensor, broadcast_mask, num_elements_not_masked 44 | // - **tensor** of shape `(*, features)`: 45 | // where * means any number of additional dimensions 46 | // - **broadcast_mask** of shape `(*, 1)`: 47 | // where * means the dimensions of **tensor**. 48 | // - **num_elements_not_masked** of shape `(1,)`: 49 | // the number of valid elements 50 | // 51 | // Outputs: output 52 | // - **output** of shape `(*, features)`: 53 | // normalized **tensor**. 54 | // 55 | // Note: Masked elements in output tensor won't be zeros. 56 | inline torch::Tensor apply_layer_norm( 57 | torch::Tensor tensor, 58 | torch::Tensor broadcast_mask, 59 | torch::Tensor num_elements_not_masked) { 60 | auto tensor_masked = tensor * broadcast_mask; 61 | auto mean = torch::sum(tensor_masked) / num_elements_not_masked; 62 | auto variance = 63 | torch::sum( 64 | torch::pow( 65 | (tensor_masked - mean) * broadcast_mask, 66 | 2)) / num_elements_not_masked; 67 | return (tensor - mean) / torch::sqrt(variance + 1E-12); 68 | } 69 | 70 | // We assume 1. the shapes of `tensors` are identical. 71 | // 2. the shape of `mask`, if `mask` is provided, 72 | // should match the prefix of the shape of `tensors`. 73 | torch::Tensor ScalarMixImpl::forward( 74 | const std::vector &tensors, 75 | torch::Tensor mask) { 76 | // Check the length of `tensors`. 77 | if (static_cast(tensors.size()) != mixture_size_) { 78 | throw std::invalid_argument( 79 | "tensors & mixture_size not match."); 80 | } 81 | // Check the mask. 82 | if (do_layer_norm_ && !mask.defined()) { 83 | if (tensors[0].dim() == 2) { 84 | // To handle the packed sequences. 85 | mask = torch::ones({tensors[0].size(0)}); 86 | } else { 87 | throw std::invalid_argument( 88 | "do_layer_norm but mask is not defined."); 89 | } 90 | } 91 | 92 | auto normed_weights = torch::split( 93 | torch::softmax( 94 | torch::cat(scalar_parameters_), 95 | 0), 96 | 1); 97 | 98 | torch::Tensor broadcast_mask; 99 | torch::Tensor num_elements_not_masked; 100 | if (do_layer_norm_) { 101 | auto mask_float = mask.to(torch::kFloat32); 102 | broadcast_mask = mask_float.unsqueeze(-1); 103 | auto input_dim = tensors[0].size(-1); 104 | num_elements_not_masked = torch::sum(mask_float) * input_dim; 105 | } 106 | 107 | torch::Tensor total; 108 | for (int64_t idx = 0; idx < mixture_size_; ++idx) { 109 | auto tensor = tensors[idx]; 110 | if (do_layer_norm_) { 111 | tensor = apply_layer_norm( 112 | tensor, 113 | broadcast_mask, 114 | num_elements_not_masked); 115 | } 116 | auto weighted_tensor = normed_weights[idx] * tensor; 117 | if (idx == 0) { 118 | total = weighted_tensor; 119 | } else { 120 | total += weighted_tensor; 121 | } 122 | } 123 | return gamma_ * total; 124 | } 125 | 126 | } // namespace cnt 127 | -------------------------------------------------------------------------------- /extension/scalar_mix.h: -------------------------------------------------------------------------------- 1 | #ifndef EXTENSION_SCALAR_MIX_H_ 2 | #define EXTENSION_SCALAR_MIX_H_ 3 | 4 | #include 5 | #include 6 | 7 | namespace cnt { 8 | 9 | struct ScalarMixImpl : torch::nn::Module { 10 | ScalarMixImpl( 11 | int64_t mixture_size, 12 | bool do_layer_norm, 13 | std::vector initial_scalar_parameters, 14 | bool trainable); 15 | 16 | torch::Tensor forward( 17 | const std::vector &tensors, 18 | torch::Tensor mask); 19 | 20 | int64_t mixture_size_ = -1; 21 | bool do_layer_norm_ = false; 22 | 23 | std::vector scalar_parameters_ = {}; 24 | torch::Tensor gamma_; 25 | }; 26 | 27 | TORCH_MODULE(ScalarMix); 28 | 29 | } // namespace cnt 30 | 31 | #endif // EXTENSION_SCALAR_MIX_H_ 32 | -------------------------------------------------------------------------------- /pytorch_fast_elmo/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Top-level package for pytorch-fast-elmo.""" 3 | 4 | __author__ = """Hunt Zhan""" 5 | __email__ = 'huntzhan.dev@gmail.com' 6 | __version__ = '0.6.12' 7 | 8 | # To avoid `undefined symbol` error. 9 | import torch 10 | 11 | # pylint: disable=no-name-in-module 12 | from pytorch_stateful_lstm import StatefulUnidirectionalLstm 13 | # from _pytorch_fast_elmo import ElmoCharacterEncoder, ScalarMix 14 | from _pytorch_fast_elmo import ElmoCharacterEncoder 15 | 16 | from pytorch_fast_elmo.utils import ( 17 | batch_to_char_ids, 18 | load_and_build_vocab2id, 19 | batch_to_word_ids, 20 | ) 21 | 22 | from pytorch_fast_elmo.factory import ( 23 | ElmoCharacterEncoderFactory, 24 | ElmoWordEmbeddingFactory, 25 | ElmoLstmFactory, 26 | ElmoVocabProjectionFactory, 27 | ) 28 | 29 | from pytorch_fast_elmo.model import ( 30 | ScalarMix, 31 | FastElmoBase, 32 | FastElmo, 33 | FastElmoWordEmbedding, 34 | FastElmoPlainEncoder, 35 | FastElmoWordEmbeddingPlainEncoder, 36 | FastElmoForwardVocabDistrib, 37 | FastElmoBackwardVocabDistrib, 38 | FastElmoWordEmbeddingForwardVocabDistrib, 39 | FastElmoWordEmbeddingBackwardVocabDistrib, 40 | ) 41 | -------------------------------------------------------------------------------- /pytorch_fast_elmo/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Follows AllenNLP. 3 | """ 4 | # pylint: disable=attribute-defined-outside-init 5 | 6 | from typing import Dict, Tuple, Any, Optional, List 7 | import json 8 | import math 9 | 10 | import torch 11 | import h5py 12 | import numpy as np 13 | 14 | from pytorch_stateful_lstm import StatefulUnidirectionalLstm 15 | from _pytorch_fast_elmo import ElmoCharacterEncoder # pylint: disable=no-name-in-module 16 | 17 | 18 | def load_options(options_file: Optional[str]): # type: ignore 19 | if options_file is None: 20 | return None 21 | else: 22 | with open(options_file) as fin: 23 | return json.load(fin) 24 | 25 | 26 | def freeze_parameters(named_parameters: Dict[str, torch.Tensor]) -> None: 27 | for param in named_parameters.values(): 28 | param.requires_grad = False 29 | 30 | 31 | class FactoryBase: 32 | 33 | def __init__( 34 | self, 35 | options_file: Optional[str], 36 | weight_file: Optional[str], 37 | ) -> None: 38 | self.options = load_options(options_file) 39 | self.weight_file = weight_file 40 | 41 | 42 | class ElmoCharacterEncoderFactory(FactoryBase): 43 | 44 | @staticmethod 45 | def from_scratch( 46 | char_embedding_cnt: int, 47 | char_embedding_dim: int, 48 | filters: List[Tuple[int, int]], 49 | activation: str, 50 | num_highway_layers: int, 51 | output_dim: int, 52 | ) -> 'ElmoCharacterEncoderFactory': 53 | factory = ElmoCharacterEncoderFactory(None, None) 54 | factory.options = { 55 | 'n_characters': char_embedding_cnt, 56 | 'char_cnn': { 57 | 'embedding': { 58 | 'dim': char_embedding_dim 59 | }, 60 | 'filters': filters, 61 | 'activation': activation, 62 | 'n_highway': num_highway_layers, 63 | }, 64 | 'lstm': { 65 | 'projection_dim': output_dim 66 | }, 67 | } 68 | return factory 69 | 70 | def create(self, requires_grad: bool = False) -> ElmoCharacterEncoder: 71 | assert self.options and 'char_cnn' in self.options 72 | 73 | # Collect parameters for the construction of `ElmoCharacterEncoder`. 74 | self.char_embedding_cnt = self.options.get('n_characters', 261) 75 | self.char_embedding_dim = self.options['char_cnn']['embedding']['dim'] 76 | self.filters = self.options['char_cnn']['filters'] 77 | self.activation = self.options['char_cnn']['activation'] 78 | self.num_highway_layers = self.options['char_cnn']['n_highway'] 79 | self.output_dim = self.options['lstm']['projection_dim'] 80 | 81 | self.named_parameters: Dict[str, torch.Tensor] = {} 82 | 83 | module = ElmoCharacterEncoder( 84 | self.char_embedding_cnt, 85 | self.char_embedding_dim, 86 | self.filters, 87 | self.activation, 88 | self.num_highway_layers, 89 | self.output_dim, 90 | ) 91 | self.named_parameters.update(module.named_parameters()) 92 | 93 | if self.weight_file: 94 | self._load_char_embedding() 95 | self._load_cnn_weights() 96 | self._load_highway() 97 | self._load_projection() 98 | else: 99 | assert requires_grad 100 | 101 | if not requires_grad: 102 | freeze_parameters(self.named_parameters) 103 | 104 | return module 105 | 106 | def _load_char_embedding(self) -> None: 107 | with h5py.File(self.weight_file, 'r') as fin: 108 | char_embed_weights = fin['char_embed'][...] 109 | 110 | if char_embed_weights.shape != \ 111 | (self.char_embedding_cnt, self.char_embedding_dim): 112 | raise ValueError('Char embd shape not match. ' 113 | f'Loaded shape: {char_embed_weights.shape}') 114 | 115 | self.named_parameters['char_embedding.weight'].data[0:, :] = \ 116 | torch.zeros(1, self.char_embedding_dim, dtype=torch.float) 117 | self.named_parameters['char_embedding.weight'].data[1:, :] = \ 118 | torch.FloatTensor(char_embed_weights) 119 | 120 | def _load_cnn_weights(self) -> None: 121 | for conv_idx, (kernel_size, out_channels) in enumerate(self.filters): 122 | with h5py.File(self.weight_file, 'r') as fin: 123 | weight = fin['CNN'][f'W_cnn_{conv_idx}'][...] 124 | bias = fin['CNN'][f'b_cnn_{conv_idx}'][...] 125 | 126 | w_reshaped = np.transpose(weight.squeeze(axis=0), axes=(2, 1, 0)) 127 | if w_reshaped.shape != (out_channels, self.char_embedding_dim, kernel_size): 128 | raise ValueError("Invalid weight file") 129 | 130 | weight_name = f'char_conv_{conv_idx}.weight' 131 | self.named_parameters[weight_name].data.copy_(torch.FloatTensor(w_reshaped)) 132 | 133 | if bias.shape != (out_channels,): 134 | raise ValueError("Invalid weight file") 135 | 136 | bias_name = f'char_conv_{conv_idx}.bias' 137 | self.named_parameters[bias_name].data.copy_(torch.FloatTensor(bias)) 138 | 139 | def _load_highway(self) -> None: 140 | """ 141 | Note: `W_carry` and `b_carry` in bilm-tf are weights of transform gate. 142 | """ 143 | total_out_channels = sum(out_channels for _, out_channels in self.filters) 144 | 145 | for layer_idx in range(self.num_highway_layers): 146 | with h5py.File(self.weight_file, 'r') as fin: 147 | w_transform = np.transpose(fin[f'CNN_high_{layer_idx}']['W_transform'][...]) 148 | w_transform_gate = np.transpose(fin[f'CNN_high_{layer_idx}']['W_carry'][...]) 149 | weight = np.concatenate([w_transform, w_transform_gate], axis=0) 150 | 151 | if weight.shape != (total_out_channels * 2, total_out_channels): 152 | raise ValueError("Invalid weight file") 153 | 154 | weight_name = f'highway.layers_{layer_idx}.weight' 155 | self.named_parameters[weight_name].data.copy_(torch.FloatTensor(weight)) 156 | 157 | b_transform = fin[f'CNN_high_{layer_idx}']['b_transform'][...] 158 | b_transform_gate = fin[f'CNN_high_{layer_idx}']['b_carry'][...] 159 | bias = np.concatenate([b_transform, b_transform_gate], axis=0) 160 | 161 | if bias.shape != (total_out_channels * 2,): 162 | raise ValueError("Invalid weight file") 163 | 164 | bias_name = f'highway.layers_{layer_idx}.bias' 165 | self.named_parameters[bias_name].data.copy_(torch.FloatTensor(bias)) 166 | 167 | def _load_projection(self) -> None: 168 | total_out_channels = sum(out_channels for _, out_channels in self.filters) 169 | 170 | with h5py.File(self.weight_file, 'r') as fin: 171 | weight = fin['CNN_proj']['W_proj'][...] 172 | bias = fin['CNN_proj']['b_proj'][...] 173 | 174 | weight = np.transpose(weight) 175 | if weight.shape != (self.output_dim, total_out_channels): 176 | raise ValueError("Invalid weight file") 177 | 178 | weight_name = 'output_proj.weight' 179 | self.named_parameters[weight_name].data.copy_(torch.FloatTensor(weight)) 180 | 181 | if bias.shape != (self.output_dim,): 182 | raise ValueError("Invalid weight file") 183 | 184 | bias_name = f'output_proj.bias' 185 | self.named_parameters[bias_name].data.copy_(torch.FloatTensor(bias)) 186 | 187 | 188 | class ElmoWordEmbeddingFactory(FactoryBase): 189 | 190 | @staticmethod 191 | def from_scratch( 192 | cnt: int, 193 | dim: int, 194 | ) -> 'ElmoWordEmbeddingFactory': 195 | factory = ElmoWordEmbeddingFactory(None, None) 196 | factory.options = { 197 | 'n_tokens_vocab': cnt, 198 | 'word_embedding_dim': dim, 199 | } 200 | return factory 201 | 202 | def weight_file_is_hdf5(self) -> bool: 203 | try: 204 | with h5py.File(self.weight_file, 'r') as fin: 205 | return True 206 | except OSError as ex: 207 | return False 208 | 209 | def create( 210 | self, 211 | requires_grad: bool = False, 212 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 213 | """ 214 | Returns (embedding, lstm_bos, lstm_eos) 215 | """ 216 | if self.weight_file: 217 | # Load `embd_weight` from hdf5 or txt. 218 | if self.weight_file_is_hdf5(): 219 | # HDF5 format. 220 | with h5py.File(self.weight_file, 'r') as fin: 221 | assert 'embedding' in fin 222 | embd_weight = fin['embedding'][...] 223 | 224 | else: 225 | # TXT format. 226 | loaded_cnt = 0 227 | loaded_dim = 0 228 | loaded_embds = [] 229 | with open(self.weight_file) as fin: 230 | for idx, line in enumerate(fin): 231 | fields = line.split() 232 | if not fields: 233 | continue 234 | 235 | # L0: 236 | if idx == 0 and len(fields) == 2: 237 | loaded_dim = int(fields[1]) 238 | continue 239 | 240 | token = fields[0] 241 | embd = fields[1:] 242 | loaded_cnt += 1 243 | 244 | if loaded_dim == 0: 245 | loaded_dim = len(embd) 246 | elif loaded_dim != len(embd): 247 | raise ValueError(f'Dimension not match on L{idx}: {token}, ' 248 | f'should be {loaded_dim}-D.') 249 | 250 | vec = np.expand_dims( 251 | np.asarray(list(map(float, embd))), 252 | 0, 253 | ) 254 | loaded_embds.append(vec) 255 | 256 | embd_weight = np.concatenate(loaded_embds) 257 | 258 | # Since bilm-tf doesn't include padding, 259 | # we need to prepend a padding row in index 0. 260 | self.word_embedding_cnt = embd_weight.shape[0] 261 | self.word_embedding_dim = embd_weight.shape[1] 262 | 263 | # Check with options if `n_tokens_vocab` exists. 264 | if 'n_tokens_vocab' in self.options \ 265 | and self.options['n_tokens_vocab'] != self.word_embedding_cnt: 266 | raise ValueError('n_tokens_vocab not match') 267 | 268 | embd = torch.zeros( 269 | (self.word_embedding_cnt + 1, self.word_embedding_dim), 270 | dtype=torch.float, 271 | ) 272 | 273 | embd.data[1:, :].copy_(torch.FloatTensor(embd_weight)) 274 | embd.requires_grad = requires_grad 275 | 276 | lstm_bos_repr = embd.data[1] 277 | lstm_eos_repr = embd.data[2] 278 | 279 | else: 280 | assert requires_grad 281 | assert self.options['n_tokens_vocab'] > 0 282 | 283 | self.word_embedding_cnt = self.options['n_tokens_vocab'] 284 | self.word_embedding_dim = self.options['word_embedding_dim'] 285 | 286 | embd = torch.zeros( 287 | (self.word_embedding_cnt + 1, self.word_embedding_dim), 288 | dtype=torch.float, 289 | ) 290 | torch.nn.init.normal_(embd) 291 | embd.requires_grad = True 292 | 293 | # `exec_managed_lstm_bos_eos` should be disabled in this case. 294 | lstm_bos_repr = None 295 | lstm_eos_repr = None 296 | 297 | return embd, lstm_bos_repr, lstm_eos_repr 298 | 299 | 300 | class ElmoLstmFactory(FactoryBase): 301 | 302 | @staticmethod 303 | def from_scratch( 304 | num_layers: int, 305 | input_size: int, 306 | hidden_size: int, 307 | cell_size: int, 308 | cell_clip: float, 309 | proj_clip: float, 310 | truncated_bptt: int, 311 | ) -> 'ElmoLstmFactory': 312 | factory = ElmoLstmFactory(None, None) 313 | factory.options = { 314 | 'lstm': { 315 | 'n_layers': num_layers, 316 | 'projection_dim': input_size, 317 | '_hidden_size': hidden_size, 318 | 'dim': cell_size, 319 | 'cell_clip': cell_clip, 320 | 'proj_clip': proj_clip, 321 | }, 322 | 'unroll_steps': truncated_bptt, 323 | } 324 | return factory 325 | 326 | def create( 327 | self, 328 | enable_forward: bool = False, 329 | forward_requires_grad: bool = False, 330 | enable_backward: bool = False, 331 | backward_requires_grad: bool = False, 332 | ) -> Tuple[StatefulUnidirectionalLstm, StatefulUnidirectionalLstm]: 333 | assert self.options and 'lstm' in self.options 334 | 335 | self.num_layers = self.options['lstm']['n_layers'] 336 | self.input_size = self.options['lstm']['projection_dim'] 337 | self.cell_size = self.options['lstm']['dim'] 338 | self.cell_clip = self.options['lstm']['cell_clip'] 339 | self.proj_clip = self.options['lstm']['proj_clip'] 340 | self.truncated_bptt = self.options.get('unroll_steps', 20) 341 | self.use_skip_connections = True 342 | 343 | if self.options['lstm'].get('_hidden_size', 0) > 0: 344 | self.hidden_size = self.options['lstm']['_hidden_size'] 345 | else: 346 | self.hidden_size = self.input_size 347 | 348 | self.named_parameters: Dict[str, torch.Tensor] = {} 349 | 350 | fwd_lstm = None 351 | if enable_forward: 352 | fwd_lstm = StatefulUnidirectionalLstm( 353 | go_forward=True, 354 | num_layers=self.num_layers, 355 | input_size=self.input_size, 356 | hidden_size=self.hidden_size, 357 | cell_size=self.cell_size, 358 | cell_clip=self.cell_clip, 359 | proj_clip=self.proj_clip, 360 | truncated_bptt=self.truncated_bptt, 361 | use_skip_connections=self.use_skip_connections, 362 | ) 363 | fwd_lstm_named_parameters = fwd_lstm.named_parameters() 364 | self.named_parameters.update(fwd_lstm_named_parameters) 365 | 366 | bwd_lstm = None 367 | if enable_backward: 368 | bwd_lstm = StatefulUnidirectionalLstm( 369 | go_forward=False, 370 | num_layers=self.num_layers, 371 | input_size=self.input_size, 372 | hidden_size=self.hidden_size, 373 | cell_size=self.cell_size, 374 | cell_clip=self.cell_clip, 375 | proj_clip=self.proj_clip, 376 | truncated_bptt=self.truncated_bptt, 377 | use_skip_connections=self.use_skip_connections, 378 | ) 379 | bwd_lstm_named_parameters = bwd_lstm.named_parameters() 380 | self.named_parameters.update(bwd_lstm_named_parameters) 381 | 382 | if enable_forward and enable_backward: 383 | if set(fwd_lstm_named_parameters.keys()) & \ 384 | set(bwd_lstm_named_parameters.keys()): 385 | raise ValueError('key conflict.') 386 | 387 | # Load weights. 388 | if self.weight_file: 389 | with h5py.File(self.weight_file, 'r') as fin: 390 | for layer_idx in range(self.num_layers): 391 | for direction, prefix in enumerate([ 392 | 'uni_lstm.forward_layer_', 393 | 'uni_lstm.backward_layer_', 394 | ]): 395 | good_forward = (direction == 0 and enable_forward) 396 | good_backward = (direction == 1 and enable_backward) 397 | if good_forward or good_backward: 398 | dataset = fin[f'RNN_{direction}']\ 399 | ['RNN']\ 400 | ['MultiRNNCell']\ 401 | [f'Cell{layer_idx}']\ 402 | ['LSTMCell'] 403 | self._load_lstm(prefix + str(layer_idx), dataset) 404 | else: 405 | if enable_forward: 406 | assert forward_requires_grad 407 | if enable_backward: 408 | assert backward_requires_grad 409 | 410 | if enable_forward and not forward_requires_grad: 411 | freeze_parameters(fwd_lstm_named_parameters) 412 | 413 | if enable_backward and not backward_requires_grad: 414 | freeze_parameters(bwd_lstm_named_parameters) 415 | 416 | return fwd_lstm, bwd_lstm 417 | 418 | def _load_lstm(self, prefix: str, dataset: Any) -> None: 419 | cell_size = self.cell_size 420 | input_size = self.input_size 421 | 422 | tf_weights = np.transpose(dataset['W_0'][...]) 423 | torch_weights = tf_weights.copy() 424 | 425 | input_weights = torch_weights[:, :input_size] 426 | recurrent_weights = torch_weights[:, input_size:] 427 | tf_input_weights = tf_weights[:, :input_size] 428 | tf_recurrent_weights = tf_weights[:, input_size:] 429 | 430 | for torch_w, tf_w in [[input_weights, tf_input_weights], 431 | [recurrent_weights, tf_recurrent_weights]]: 432 | torch_w[(1 * cell_size):(2 * cell_size), :] = tf_w[(2 * cell_size):(3 * cell_size), :] 433 | torch_w[(2 * cell_size):(3 * cell_size), :] = tf_w[(1 * cell_size):(2 * cell_size), :] 434 | 435 | self.named_parameters[prefix + '.input_linearity_weight'].data.copy_( 436 | torch.FloatTensor(input_weights),) 437 | self.named_parameters[prefix + '.hidden_linearity_weight'].data.copy_( 438 | torch.FloatTensor(recurrent_weights),) 439 | 440 | tf_bias = dataset['B'][...] 441 | tf_bias[(2 * cell_size):(3 * cell_size)] += 1 442 | torch_bias = tf_bias.copy() 443 | torch_bias[(1 * cell_size):(2 * cell_size)] = tf_bias[(2 * cell_size):(3 * cell_size)] 444 | torch_bias[(2 * cell_size):(3 * cell_size)] = tf_bias[(1 * cell_size):(2 * cell_size)] 445 | 446 | self.named_parameters[prefix + '.hidden_linearity_bias'].data.copy_( 447 | torch.FloatTensor(torch_bias),) 448 | 449 | proj_weights = np.transpose(dataset['W_P_0'][...]) 450 | 451 | self.named_parameters[prefix + '.proj_linearity_weight'].data.copy_( 452 | torch.FloatTensor(proj_weights),) 453 | 454 | 455 | class ElmoVocabProjectionFactory(FactoryBase): 456 | 457 | @staticmethod 458 | def from_scratch( 459 | input_size: int, 460 | proj_size: int, 461 | ) -> 'ElmoVocabProjectionFactory': 462 | factory = ElmoVocabProjectionFactory(None, None) 463 | factory.options = { 464 | 'lstm': { 465 | 'projection_dim': input_size 466 | }, 467 | 'n_tokens_vocab': proj_size, 468 | } 469 | return factory 470 | 471 | def create( 472 | self, 473 | requires_grad: bool = False, 474 | ) -> Tuple[torch.Tensor, torch.Tensor]: 475 | """ 476 | Returns (weight, bias) for affine transformation. 477 | """ 478 | assert self.options \ 479 | and 'n_tokens_vocab' in self.options \ 480 | and 'lstm' in self.options 481 | 482 | self.input_size = self.options['lstm']['projection_dim'] 483 | self.proj_size = self.options['n_tokens_vocab'] 484 | assert self.input_size > 0 and self.proj_size > 0 485 | 486 | # Note: no padding zero. 487 | weight = torch.zeros( 488 | (self.proj_size, self.input_size), 489 | dtype=torch.float, 490 | ) 491 | bias = torch.zeros( 492 | (self.proj_size,), 493 | dtype=torch.float, 494 | ) 495 | 496 | if self.weight_file: 497 | with h5py.File(self.weight_file, 'r') as fin: 498 | if 'softmax' not in fin: 499 | raise ValueError('softmax not in weight file.') 500 | loaded_weight = fin['softmax']['W'][...] 501 | loaded_bias = fin['softmax']['b'][...] 502 | 503 | weight.data.copy_(torch.FloatTensor(loaded_weight)) 504 | weight.requires_grad = requires_grad 505 | 506 | bias.data.copy_(torch.FloatTensor(loaded_bias)) 507 | bias.requires_grad = requires_grad 508 | 509 | else: 510 | assert requires_grad 511 | # init. 512 | torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) 513 | weight.requires_grad = True 514 | 515 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(weight) # pylint: disable=protected-access 516 | bound = 1 / math.sqrt(fan_in) 517 | torch.nn.init.uniform_(bias, -bound, bound) 518 | bias.requires_grad = True 519 | 520 | return weight, bias 521 | -------------------------------------------------------------------------------- /pytorch_fast_elmo/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provide helper classes/functions to execute ELMo. 3 | """ 4 | # pylint: disable=no-self-use,arguments-differ,too-many-public-methods,too-many-lines 5 | from typing import List, Tuple, Optional, Dict, Union, Any, Set 6 | from collections import OrderedDict 7 | 8 | import torch 9 | from torch.nn.utils.rnn import PackedSequence 10 | from torch.nn import ParameterList, Parameter 11 | 12 | from pytorch_fast_elmo.factory import ( 13 | ElmoCharacterEncoderFactory, 14 | ElmoWordEmbeddingFactory, 15 | ElmoLstmFactory, 16 | ElmoVocabProjectionFactory, 17 | ) 18 | from pytorch_fast_elmo import utils 19 | 20 | # TODO: use in inference. 21 | # from _pytorch_fast_elmo import ScalarMix # pylint: disable=no-name-in-module 22 | 23 | 24 | def _raise_if_kwargs_is_invalid(allowed: Set[str], kwargs: Dict[str, Any]) -> None: 25 | invalid_keys = set(kwargs) - allowed 26 | if invalid_keys: 27 | msg = '\n'.join('invalid kwargs: {}'.format(key) for key in invalid_keys) 28 | raise ValueError(msg) 29 | 30 | 31 | # Implement in Python as a temporary solution. 32 | class ScalarMix(torch.nn.Module): # type: ignore 33 | 34 | def __init__( 35 | self, 36 | mixture_size: int, 37 | do_layer_norm: bool = False, 38 | initial_scalar_parameters: Optional[List[float]] = None, 39 | trainable: bool = True, 40 | ) -> None: 41 | super().__init__() 42 | self.mixture_size = mixture_size 43 | self.do_layer_norm = do_layer_norm 44 | 45 | if initial_scalar_parameters is None: 46 | initial_scalar_parameters = [1.0 / mixture_size] * mixture_size 47 | elif len(initial_scalar_parameters) != mixture_size: 48 | raise ValueError("initial_scalar_parameters & mixture_size not match.") 49 | 50 | self.scalar_parameters = ParameterList([ 51 | Parameter( 52 | torch.FloatTensor([val]), 53 | requires_grad=trainable, 54 | ) for val in initial_scalar_parameters 55 | ]) 56 | self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=trainable) 57 | 58 | def forward( 59 | self, 60 | tensors: List[torch.Tensor], # pylint: disable=arguments-differ 61 | mask: torch.Tensor = None, 62 | ) -> torch.Tensor: 63 | 64 | def apply_layer_norm(tensor, broadcast_mask, num_elements_not_masked): # type: ignore 65 | tensor_masked = tensor * broadcast_mask 66 | mean = torch.sum(tensor_masked) / num_elements_not_masked 67 | variance = torch.sum(torch.pow( 68 | (tensor_masked - mean) * broadcast_mask, 69 | 2, 70 | )) / num_elements_not_masked 71 | return (tensor - mean) / torch.sqrt(variance + 1E-12) 72 | 73 | if len(tensors) != self.mixture_size: 74 | raise ValueError("tensors & mixture_size not match.") 75 | if self.do_layer_norm and mask is None: 76 | if tensors[0].ndimension() == 2: 77 | mask = torch.ones((tensors[0].shape[0],)) 78 | else: 79 | raise ValueError("do_layer_norm but mask is not defined.") 80 | 81 | normed_weights = torch.split( 82 | torch.softmax( 83 | torch.cat(list(self.scalar_parameters)), 84 | 0, 85 | ), 86 | 1, 87 | ) 88 | if self.do_layer_norm: 89 | mask_float = mask.float() 90 | broadcast_mask = mask_float.unsqueeze(-1) 91 | input_dim = tensors[0].size(-1) 92 | num_elements_not_masked = torch.sum(mask_float) * input_dim 93 | 94 | pieces = [] 95 | for idx in range(self.mixture_size): 96 | tensor = tensors[idx] 97 | if self.do_layer_norm: 98 | tensor = apply_layer_norm( # type: ignore 99 | tensor, 100 | broadcast_mask, 101 | num_elements_not_masked, 102 | ) 103 | weighted_tensor = normed_weights[idx] * tensor 104 | pieces.append(weighted_tensor) 105 | 106 | return self.gamma * sum(pieces) 107 | 108 | 109 | class FastElmoBase(torch.nn.Module): # type: ignore 110 | 111 | SCALAR_MIX_PARAMS = { 112 | 'disable_scalar_mix', 113 | 'num_output_representations', 114 | 'output_representation_dropout', 115 | 'scalar_mix_parameters', 116 | 'do_layer_norm', 117 | } 118 | EXEC_PARAMS = { 119 | 'exec_managed_lstm_bos_eos', 120 | 'exec_managed_lstm_reset_states', 121 | 'exec_sort_batch', 122 | } 123 | COMMON_PARAMS = SCALAR_MIX_PARAMS | EXEC_PARAMS 124 | 125 | _CHAR_CNN_FILTERS = [ 126 | (1, 32), 127 | (2, 32), 128 | (3, 64), 129 | (4, 128), 130 | (5, 256), 131 | (6, 512), 132 | (7, 1024), 133 | ] 134 | 135 | def __init__( # pylint: disable=dangerous-default-value 136 | self, 137 | 138 | # Generated by bilm-tf. 139 | options_file: Optional[str], 140 | weight_file: Optional[str], 141 | 142 | # Controls the behavior of execution. 143 | exec_managed_lstm_bos_eos: bool = True, 144 | exec_managed_lstm_reset_states: bool = False, 145 | exec_sort_batch: bool = True, 146 | 147 | # Controls the behavior of `ScalarMix`. 148 | disable_scalar_mix: bool = False, 149 | num_output_representations: int = 1, 150 | output_representation_dropout: float = 0.0, 151 | scalar_mix_parameters: Optional[List[float]] = None, 152 | do_layer_norm: bool = False, 153 | 154 | # Controls the behavior of factories. 155 | # Char CNN. 156 | disable_char_cnn: bool = False, 157 | char_cnn_requires_grad: bool = False, 158 | # From scratch. 159 | char_cnn_char_embedding_cnt: int = 261, 160 | char_cnn_char_embedding_dim: int = 16, 161 | char_cnn_filters: List[Tuple[int, int]] = _CHAR_CNN_FILTERS, 162 | char_cnn_activation: str = 'relu', 163 | char_cnn_num_highway_layers: int = 2, 164 | char_cnn_output_dim: int = 512, 165 | 166 | # Word Embedding. 167 | disable_word_embedding: bool = True, 168 | word_embedding_weight_file: Optional[str] = None, 169 | word_embedding_requires_grad: bool = False, 170 | # From scratch. 171 | word_embedding_cnt: int = 0, 172 | word_embedding_dim: int = 512, 173 | 174 | # The Forward LSTM. 175 | disable_forward_lstm: bool = False, 176 | forward_lstm_requires_grad: bool = False, 177 | # The Backward LSTM. 178 | disable_backward_lstm: bool = False, 179 | backward_lstm_requires_grad: bool = False, 180 | # From scratch. 181 | lstm_num_layers: int = 2, 182 | lstm_input_size: int = 512, 183 | lstm_hidden_size: int = 512, 184 | lstm_cell_size: int = 4096, 185 | lstm_cell_clip: float = 3.0, 186 | lstm_proj_clip: float = 3.0, 187 | lstm_truncated_bptt: int = 20, 188 | # Provide the BOS/EOS representations of shape `(projection_dim,)` 189 | # if char CNN is disabled. 190 | lstm_bos_repr: Optional[torch.Tensor] = None, 191 | lstm_eos_repr: Optional[torch.Tensor] = None, 192 | 193 | # The final softmax layer. 194 | disable_vocab_projection: bool = True, 195 | vocab_projection_requires_grad: bool = False, 196 | vocab_projection_input_size: int = 0, 197 | vocab_projection_proj_size: int = 0, 198 | ) -> None: 199 | super().__init__() 200 | 201 | self.disable_char_cnn = disable_char_cnn 202 | self.disable_word_embedding = disable_word_embedding 203 | self.disable_forward_lstm = disable_forward_lstm 204 | self.disable_backward_lstm = disable_backward_lstm 205 | self.disable_scalar_mix = disable_scalar_mix 206 | self.disable_vocab_projection = disable_vocab_projection 207 | 208 | self.exec_managed_lstm_bos_eos = exec_managed_lstm_bos_eos 209 | self.exec_managed_lstm_reset_states = exec_managed_lstm_reset_states 210 | self.exec_sort_batch = exec_sort_batch 211 | 212 | # Char CNN. 213 | if options_file: 214 | self.char_cnn_factory = ElmoCharacterEncoderFactory( 215 | options_file, 216 | weight_file, 217 | ) 218 | else: 219 | # From scratch. 220 | self.char_cnn_factory = ElmoCharacterEncoderFactory.from_scratch( 221 | char_embedding_cnt=char_cnn_char_embedding_cnt, 222 | char_embedding_dim=char_cnn_char_embedding_dim, 223 | filters=char_cnn_filters, 224 | activation=char_cnn_activation, 225 | num_highway_layers=char_cnn_num_highway_layers, 226 | output_dim=char_cnn_output_dim, 227 | ) 228 | 229 | if not disable_char_cnn: 230 | self._add_cpp_module_to_buffer( 231 | 'char_cnn', 232 | self.char_cnn_factory.create(requires_grad=char_cnn_requires_grad), 233 | ) 234 | 235 | # Word Embedding. 236 | if options_file: 237 | self.word_embedding_factory = ElmoWordEmbeddingFactory( 238 | options_file, 239 | word_embedding_weight_file or weight_file, 240 | ) 241 | else: 242 | # From scratch. 243 | self.word_embedding_factory = ElmoWordEmbeddingFactory.from_scratch( 244 | cnt=word_embedding_cnt, 245 | dim=word_embedding_dim, 246 | ) 247 | if exec_managed_lstm_bos_eos: 248 | raise ValueError('exec_managed_lstm_bos_eos should be disabled.') 249 | 250 | if not disable_word_embedding: 251 | # Not a cpp extension. 252 | word_embedding_weight, lstm_bos_repr, lstm_eos_repr = \ 253 | self.word_embedding_factory.create(requires_grad=word_embedding_requires_grad) 254 | self.register_buffer('word_embedding_weight', word_embedding_weight) 255 | 256 | # LSTM. 257 | if options_file: 258 | self.lstm_factory = ElmoLstmFactory( 259 | options_file, 260 | weight_file, 261 | ) 262 | else: 263 | # From scratch. 264 | self.lstm_factory = ElmoLstmFactory.from_scratch( 265 | num_layers=lstm_num_layers, 266 | input_size=lstm_input_size, 267 | hidden_size=lstm_hidden_size, 268 | cell_size=lstm_cell_size, 269 | cell_clip=lstm_cell_clip, 270 | proj_clip=lstm_proj_clip, 271 | truncated_bptt=lstm_truncated_bptt, 272 | ) 273 | 274 | if not (disable_forward_lstm and disable_backward_lstm): 275 | forward_lstm, backward_lstm = self.lstm_factory.create( 276 | enable_forward=not disable_forward_lstm, 277 | forward_requires_grad=forward_lstm_requires_grad, 278 | enable_backward=not disable_backward_lstm, 279 | backward_requires_grad=backward_lstm_requires_grad, 280 | ) 281 | if not disable_forward_lstm: 282 | self._add_cpp_module_to_buffer('forward_lstm', forward_lstm) 283 | if not disable_backward_lstm: 284 | self._add_cpp_module_to_buffer('backward_lstm', backward_lstm) 285 | 286 | # Cache BOS/EOS reprs. 287 | if exec_managed_lstm_bos_eos: 288 | if disable_char_cnn: 289 | if lstm_bos_repr is None or lstm_eos_repr is None: 290 | raise ValueError('BOS/EOS not provided.') 291 | 292 | else: 293 | lstm_bos_repr, lstm_eos_repr = utils.get_bos_eos_token_repr( 294 | self.char_cnn_factory, 295 | self.char_cnn, 296 | ) 297 | self.register_buffer( 298 | 'lstm_bos_repr', 299 | lstm_bos_repr, 300 | ) 301 | self.register_buffer( 302 | 'lstm_eos_repr', 303 | lstm_eos_repr, 304 | ) 305 | 306 | # Vocabulary projection. 307 | if options_file: 308 | self.vocab_projection_factory = ElmoVocabProjectionFactory( 309 | options_file, 310 | weight_file, 311 | ) 312 | else: 313 | self.vocab_projection_factory = ElmoVocabProjectionFactory.from_scratch( 314 | vocab_projection_input_size, 315 | vocab_projection_proj_size, 316 | ) 317 | 318 | if not disable_vocab_projection: 319 | self.vocab_projection_weight, self.vocab_projection_bias = \ 320 | self.vocab_projection_factory.create(requires_grad=vocab_projection_requires_grad) 321 | 322 | # ScalarMix 323 | if not disable_scalar_mix: 324 | self.scalar_mixes: List[ScalarMix] = [] 325 | 326 | for idx in range(num_output_representations): 327 | scalar_mix = ScalarMix( 328 | # char cnn + lstm. 329 | self.lstm_factory.num_layers + 1, 330 | do_layer_norm=do_layer_norm, 331 | initial_scalar_parameters=scalar_mix_parameters, 332 | trainable=not scalar_mix_parameters, 333 | ) 334 | self.add_module(f'scalar_mix_{idx}', scalar_mix) 335 | self.scalar_mixes.append(scalar_mix) 336 | 337 | self.repr_dropout = None 338 | if output_representation_dropout > 0.0: 339 | self.repr_dropout = torch.nn.Dropout(p=output_representation_dropout) 340 | 341 | def state_dict( # type: ignore 342 | self, 343 | destination=None, 344 | prefix='', 345 | keep_vars=False, 346 | ): 347 | tmp_buffers = self._buffers 348 | self._buffers = OrderedDict() # type: ignore 349 | ret = super().state_dict(destination, prefix, keep_vars) 350 | self._buffers = tmp_buffers 351 | return ret 352 | 353 | def _load_from_state_dict( # type: ignore 354 | self, 355 | state_dict, 356 | prefix, 357 | local_metadata, 358 | strict, 359 | missing_keys, 360 | unexpected_keys, 361 | error_msgs, 362 | ): 363 | tmp_buffers = self._buffers 364 | self._buffers = OrderedDict() 365 | super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, 366 | unexpected_keys, error_msgs) 367 | self._buffers = tmp_buffers 368 | 369 | def _add_cpp_module_to_buffer(self, name: str, cpp_module: Any) -> None: 370 | # register_buffer will raise an exception. 371 | self._buffers[name] = cpp_module 372 | 373 | def _get_lstm_device(self) -> int: 374 | cpp_ext = None 375 | if not self.disable_forward_lstm: 376 | cpp_ext = self.forward_lstm 377 | elif not self.disable_backward_lstm: 378 | cpp_ext = self.backward_lstm 379 | 380 | # Assume `cpp_ext` is not None. 381 | assert cpp_ext is not None 382 | tensor = cpp_ext.parameters()[0] 383 | return -1 if not tensor.is_cuda else tensor.get_device() # type: ignore 384 | 385 | def get_batched_lstm_bos_eos_repr(self, attr_name: str, batch_size: int) -> PackedSequence: 386 | tensor = getattr(self, attr_name) 387 | 388 | if not tensor.is_cuda: 389 | # Move to GPU permanently. 390 | device = self._get_lstm_device() 391 | if device >= 0: 392 | tensor = tensor.cuda(device) 393 | setattr(self, attr_name, tensor) 394 | 395 | batched = tensor.unsqueeze(0).expand(batch_size, -1) 396 | return PackedSequence(batched, torch.LongTensor([batch_size])) 397 | 398 | def exec_forward_backword_lstm_bos_eos( 399 | self, 400 | lstm_attr_name: str, 401 | bos_eos_attr_name: str, 402 | batch_size: int, 403 | ) -> torch.Tensor: 404 | lstm = getattr(self, lstm_attr_name) 405 | batched = self.get_batched_lstm_bos_eos_repr(bos_eos_attr_name, batch_size) 406 | with torch.no_grad(): 407 | outputs, _ = lstm(batched.data, batched.batch_sizes) 408 | # Returns the output of last layer. 409 | return outputs[-1] 410 | 411 | def exec_forward_lstm_bos(self, batch_size: int) -> torch.Tensor: 412 | return self.exec_forward_backword_lstm_bos_eos('forward_lstm', 'lstm_bos_repr', batch_size) 413 | 414 | def exec_forward_lstm_eos(self, batch_size: int) -> torch.Tensor: 415 | return self.exec_forward_backword_lstm_bos_eos('forward_lstm', 'lstm_eos_repr', batch_size) 416 | 417 | def exec_backward_lstm_bos(self, batch_size: int) -> torch.Tensor: 418 | return self.exec_forward_backword_lstm_bos_eos('backward_lstm', 'lstm_bos_repr', batch_size) 419 | 420 | def exec_backward_lstm_eos(self, batch_size: int) -> torch.Tensor: 421 | return self.exec_forward_backword_lstm_bos_eos('backward_lstm', 'lstm_eos_repr', batch_size) 422 | 423 | def exec_forward_lstm_permutate_states(self, index: torch.Tensor) -> None: 424 | self.forward_lstm.permutate_states(index) 425 | 426 | def exec_backward_lstm_permutate_states(self, index: torch.Tensor) -> None: 427 | self.backward_lstm.permutate_states(index) 428 | 429 | def exec_bilstm_permutate_states(self, index: torch.Tensor) -> None: 430 | if not self.disable_forward_lstm: 431 | self.exec_forward_lstm_permutate_states(index) 432 | if not self.disable_backward_lstm: 433 | self.exec_backward_lstm_permutate_states(index) 434 | 435 | def exec_char_cnn(self, inputs: PackedSequence) -> PackedSequence: 436 | """ 437 | Char CNN. 438 | """ 439 | output_data = self.char_cnn(inputs.data) 440 | return PackedSequence(output_data, inputs.batch_sizes) 441 | 442 | def exec_word_embedding(self, inputs: PackedSequence) -> PackedSequence: 443 | """ 444 | Word embedding. 445 | """ 446 | output_data = torch.nn.functional.embedding( 447 | inputs.data, 448 | self.word_embedding_weight, 449 | padding_idx=0, 450 | ) 451 | return PackedSequence(output_data, inputs.batch_sizes) 452 | 453 | def exec_forward_lstm( 454 | self, 455 | inputs: PackedSequence, 456 | ) -> List[PackedSequence]: 457 | """ 458 | Forward LSTM. 459 | """ 460 | if self.exec_managed_lstm_bos_eos: 461 | max_batch_size = int(inputs.batch_sizes.data[0]) 462 | # BOS. 463 | self.exec_forward_lstm_bos(max_batch_size) 464 | elif self.exec_managed_lstm_reset_states: 465 | self.forward_lstm.reset_states() 466 | 467 | # Feed inputs. 468 | outputs, _ = self.forward_lstm(inputs.data, inputs.batch_sizes) 469 | 470 | if self.exec_managed_lstm_bos_eos: 471 | # EOS. 472 | self.exec_forward_lstm_eos(max_batch_size) 473 | 474 | # To list of `PackedSequence`. 475 | return [PackedSequence(output, inputs.batch_sizes) for output in outputs] 476 | 477 | def exec_backward_lstm( 478 | self, 479 | inputs: PackedSequence, 480 | ) -> List[PackedSequence]: 481 | """ 482 | Backward LSTM. 483 | """ 484 | if self.exec_managed_lstm_bos_eos: 485 | max_batch_size = int(inputs.batch_sizes.data[0]) 486 | # EOS. 487 | self.exec_backward_lstm_eos(max_batch_size) 488 | elif self.exec_managed_lstm_reset_states: 489 | self.backward_lstm.reset_states() 490 | 491 | # Feed inputs. 492 | outputs, _ = self.backward_lstm(inputs.data, inputs.batch_sizes) 493 | 494 | if self.exec_managed_lstm_bos_eos: 495 | # BOS. 496 | self.exec_backward_lstm_bos(max_batch_size) 497 | 498 | # To list of `PackedSequence`. 499 | return [PackedSequence(output, inputs.batch_sizes) for output in outputs] 500 | 501 | def exec_bilstm( 502 | self, 503 | inputs: PackedSequence, 504 | ) -> List[Tuple[PackedSequence, PackedSequence]]: 505 | """ 506 | BiLSTM. 507 | """ 508 | forward_seqs = self.exec_forward_lstm(inputs) 509 | backward_seqs = self.exec_backward_lstm(inputs) 510 | 511 | return list(zip(forward_seqs, backward_seqs)) 512 | 513 | def concat_packed_sequences( 514 | self, 515 | packed_sequences: List[Tuple[PackedSequence, PackedSequence]], 516 | ) -> List[PackedSequence]: 517 | """ 518 | Concatenate the outputs of fwd/bwd lstms. 519 | """ 520 | return [ 521 | PackedSequence( 522 | torch.cat([fwd.data, bwd.data], dim=-1), 523 | fwd.batch_sizes, 524 | ) for fwd, bwd in packed_sequences 525 | ] 526 | 527 | def combine_char_cnn_and_bilstm_outputs( 528 | self, 529 | char_cnn_packed: PackedSequence, 530 | bilstm_packed: List[PackedSequence], 531 | ) -> List[PackedSequence]: 532 | """ 533 | Combine the outputs of Char CNN & BiLSTM for scalar mix. 534 | """ 535 | # Simply duplicate the output of char cnn. 536 | duplicated_char_cnn_packed = PackedSequence( 537 | torch.cat([char_cnn_packed.data, char_cnn_packed.data], dim=-1), 538 | char_cnn_packed.batch_sizes, 539 | ) 540 | 541 | combined = [duplicated_char_cnn_packed] 542 | combined.extend(bilstm_packed) 543 | return combined 544 | 545 | def exec_vocab_projection(self, context_repr: PackedSequence) -> PackedSequence: 546 | """ 547 | Transform the last layer of LSTM to the probability distributions of vocabulary. 548 | """ 549 | vocab_linear = torch.nn.functional.linear( 550 | context_repr.data, 551 | self.vocab_projection_weight, 552 | self.vocab_projection_bias, 553 | ) 554 | vocab_probs = torch.nn.functional.softmax(vocab_linear, dim=-1) 555 | return PackedSequence(vocab_probs, context_repr.batch_sizes) 556 | 557 | def exec_scalar_mix(self, packed_sequences: List[PackedSequence]) -> List[PackedSequence]: 558 | """ 559 | Scalar Mix. 560 | """ 561 | reprs = [] 562 | for scalar_mix in self.scalar_mixes: 563 | mixed = scalar_mix([inputs.data for inputs in packed_sequences]) 564 | if self.repr_dropout is not None: 565 | mixed = self.repr_dropout(mixed) 566 | reprs.append(PackedSequence(mixed, packed_sequences[0].batch_sizes)) 567 | return reprs 568 | 569 | def exec_bilstm_and_scalar_mix( 570 | self, 571 | token_repr: PackedSequence, 572 | ) -> List[PackedSequence]: 573 | """ 574 | Common combination. 575 | """ 576 | # BiLSTM. 577 | bilstm_repr = self.exec_bilstm(token_repr) 578 | # Scalar Mix. 579 | conbimed_repr = self.combine_char_cnn_and_bilstm_outputs( 580 | token_repr, 581 | self.concat_packed_sequences(bilstm_repr), 582 | ) 583 | mixed_reprs = self.exec_scalar_mix(conbimed_repr) 584 | return mixed_reprs 585 | 586 | def pack_inputs( 587 | self, 588 | inputs: torch.Tensor, 589 | lengths: Optional[torch.Tensor] = None, 590 | ) -> PackedSequence: 591 | return utils.pack_inputs(inputs, lengths=lengths) 592 | 593 | def unpack_output( 594 | self, 595 | output: PackedSequence, 596 | ) -> torch.Tensor: 597 | return utils.unpack_outputs(output) 598 | 599 | def unpack_outputs( 600 | self, 601 | mixed_reprs: List[PackedSequence], 602 | ) -> List[torch.Tensor]: 603 | """ 604 | Unpack the outputs of scalar mixtures. 605 | """ 606 | return [self.unpack_output(mixed_repr) for mixed_repr in mixed_reprs] 607 | 608 | def to_allennlp_elmo_output_format( 609 | self, 610 | unpacks: List[torch.Tensor], 611 | mask: torch.Tensor, 612 | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: 613 | return {'elmo_representations': unpacks, 'mask': mask} 614 | 615 | def preprocess_inputs( 616 | self, 617 | inputs: torch.Tensor, 618 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 619 | lengths = utils.get_lengths_of_zero_padded_batch(inputs) 620 | original_lengths = lengths 621 | restoration_index: Optional[torch.Tensor] = None 622 | 623 | if self.exec_sort_batch: 624 | inputs, permutation_index, restoration_index = \ 625 | utils.sort_batch_by_length(inputs, lengths) 626 | lengths = lengths.index_select(0, permutation_index) 627 | self.exec_bilstm_permutate_states(permutation_index) 628 | 629 | return inputs, lengths, original_lengths, restoration_index 630 | 631 | def postprocess_outputs( 632 | self, 633 | unpacked_tensors: List[torch.Tensor], 634 | restoration_index: Optional[torch.Tensor], 635 | inputs: torch.Tensor, 636 | original_lengths: torch.Tensor, 637 | ) -> Tuple[List[torch.Tensor], torch.Tensor]: 638 | mask = utils.generate_mask_from_lengths( 639 | inputs.shape[0], 640 | inputs.shape[1], 641 | original_lengths, 642 | ) 643 | if self.exec_sort_batch: 644 | assert restoration_index is not None 645 | unpacked_tensors = [ 646 | tensor.index_select(0, restoration_index) for tensor in unpacked_tensors 647 | ] 648 | self.exec_bilstm_permutate_states(restoration_index) 649 | 650 | return unpacked_tensors, mask 651 | 652 | def forward_with_sorting_and_packing( 653 | self, 654 | inputs: torch.Tensor, 655 | ) -> Tuple[List[torch.Tensor], torch.Tensor]: 656 | inputs, lengths, original_lengths, restoration_index = \ 657 | self.preprocess_inputs(inputs) 658 | 659 | packed_inputs = self.pack_inputs(inputs, lengths) 660 | packed_outputs = self.execute(packed_inputs) 661 | 662 | unpacked_outputs = self.unpack_outputs(packed_outputs) 663 | unpacked_outputs, mask = self.postprocess_outputs( 664 | unpacked_outputs, 665 | restoration_index, 666 | inputs, 667 | original_lengths, 668 | ) 669 | return unpacked_outputs, mask 670 | 671 | def execute(self, inputs: PackedSequence) -> List[PackedSequence]: 672 | raise NotImplementedError() 673 | 674 | def forward_like_allennlp( 675 | self, 676 | inputs: torch.Tensor, 677 | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: 678 | outputs, mask = self.forward_with_sorting_and_packing(inputs) 679 | return self.to_allennlp_elmo_output_format(outputs, mask) 680 | 681 | def forward(self): # type: ignore 682 | raise NotImplementedError() 683 | 684 | 685 | class FastElmo(FastElmoBase): 686 | 687 | def __init__( 688 | self, 689 | options_file: Optional[str], 690 | weight_file: str, 691 | **kwargs: Any, 692 | ) -> None: 693 | _raise_if_kwargs_is_invalid( 694 | self.COMMON_PARAMS | set([ 695 | # Fine-tuning is not fully supported by pytorch. 696 | # 'char_cnn_requires_grad', 697 | # 'forward_lstm_requires_grad', 698 | # 'backward_lstm_requires_grad', 699 | ]), 700 | kwargs) 701 | super().__init__(options_file, weight_file, **kwargs) 702 | 703 | def execute(self, inputs: PackedSequence) -> List[PackedSequence]: 704 | token_repr = self.exec_char_cnn(inputs) 705 | mixed_reprs = self.exec_bilstm_and_scalar_mix(token_repr) 706 | return mixed_reprs 707 | 708 | def forward( # type: ignore 709 | self, 710 | inputs: torch.Tensor, 711 | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: 712 | """ 713 | The default workflow (same as AllenNLP). 714 | 715 | `inputs` of shape `(batch_size, max_timesteps, max_characters_per_token) 716 | """ 717 | return self.forward_like_allennlp(inputs) 718 | 719 | 720 | class FastElmoWordEmbedding(FastElmoBase): 721 | 722 | def __init__( 723 | self, 724 | options_file: Optional[str], 725 | weight_file: str, 726 | **kwargs: Any, 727 | ) -> None: 728 | _raise_if_kwargs_is_invalid( 729 | self.COMMON_PARAMS | { 730 | 'word_embedding_weight_file', 731 | # Fine-tuning is not fully supported by pytorch. 732 | # 'word_embedding_requires_grad', 733 | # 'forward_lstm_requires_grad', 734 | # 'backward_lstm_requires_grad', 735 | }, 736 | kwargs) 737 | 738 | kwargs['disable_char_cnn'] = True 739 | kwargs['disable_word_embedding'] = False 740 | super().__init__(options_file, weight_file, **kwargs) 741 | 742 | def execute(self, inputs: PackedSequence) -> List[PackedSequence]: 743 | token_repr = self.exec_word_embedding(inputs) 744 | mixed_reprs = self.exec_bilstm_and_scalar_mix(token_repr) 745 | return mixed_reprs 746 | 747 | def forward( # type: ignore 748 | self, 749 | inputs: torch.Tensor, 750 | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: 751 | """ 752 | `inputs` of shape `(batch_size, max_timesteps) 753 | """ 754 | return self.forward_like_allennlp(inputs) 755 | 756 | 757 | class FastElmoPlainEncoderBase(FastElmoBase): # pylint: disable=abstract-method 758 | 759 | def exec_context_independent_repr(self, inputs: PackedSequence) -> PackedSequence: 760 | raise NotImplementedError() 761 | 762 | def execute(self, inputs: PackedSequence) -> List[PackedSequence]: 763 | token_repr = self.exec_context_independent_repr(inputs) 764 | # BiLSTM. 765 | bilstm_repr = self.exec_bilstm(token_repr) 766 | # Scalar Mix. 767 | conbimed_repr = self.combine_char_cnn_and_bilstm_outputs( 768 | token_repr, 769 | self.concat_packed_sequences(bilstm_repr), 770 | ) 771 | return conbimed_repr 772 | 773 | def forward( # type: ignore 774 | self, 775 | inputs: torch.Tensor, 776 | ) -> Tuple[List[torch.Tensor], torch.Tensor]: 777 | """ 778 | No scalar mix. 779 | """ 780 | return self.forward_with_sorting_and_packing(inputs) 781 | 782 | 783 | class FastElmoPlainEncoder(FastElmoPlainEncoderBase): 784 | 785 | def __init__( 786 | self, 787 | options_file: Optional[str], 788 | weight_file: str, 789 | **kwargs: Any, 790 | ) -> None: 791 | _raise_if_kwargs_is_invalid(self.EXEC_PARAMS, kwargs) 792 | kwargs['disable_scalar_mix'] = True 793 | super().__init__(options_file, weight_file, **kwargs) 794 | 795 | def exec_context_independent_repr(self, inputs: PackedSequence) -> PackedSequence: 796 | return self.exec_char_cnn(inputs) 797 | 798 | 799 | class FastElmoWordEmbeddingPlainEncoder(FastElmoPlainEncoderBase): 800 | 801 | def __init__( 802 | self, 803 | options_file: Optional[str], 804 | weight_file: str, 805 | **kwargs: Any, 806 | ) -> None: 807 | _raise_if_kwargs_is_invalid(self.EXEC_PARAMS, kwargs) 808 | kwargs['disable_char_cnn'] = True 809 | kwargs['disable_word_embedding'] = False 810 | kwargs['disable_scalar_mix'] = True 811 | super().__init__(options_file, weight_file, **kwargs) 812 | 813 | def exec_context_independent_repr(self, inputs: PackedSequence) -> PackedSequence: 814 | return self.exec_word_embedding(inputs) 815 | 816 | 817 | class FastElmoUnidirectionalVocabDistribBase(FastElmoBase): # pylint: disable=abstract-method 818 | 819 | def exec_forward_vocab_prob_distrib(self, token_repr: PackedSequence) -> List[PackedSequence]: 820 | fwd_lstm_last = self.exec_forward_lstm(token_repr)[-1] 821 | fwd_vocab_distrib = self.exec_vocab_projection(fwd_lstm_last) 822 | return [fwd_vocab_distrib] 823 | 824 | def exec_backward_vocab_prob_distrib(self, token_repr: PackedSequence) -> List[PackedSequence]: 825 | bwd_lstm_last = self.exec_backward_lstm(token_repr)[-1] 826 | bwd_vocab_distrib = self.exec_vocab_projection(bwd_lstm_last) 827 | return [bwd_vocab_distrib] 828 | 829 | def forward( # type: ignore 830 | self, 831 | inputs: torch.Tensor, 832 | ) -> Tuple[torch.Tensor, torch.Tensor]: 833 | (vocab_distrib,), mask = self.forward_with_sorting_and_packing(inputs) 834 | return vocab_distrib, mask 835 | 836 | 837 | class FastElmoForwardVocabDistrib(FastElmoUnidirectionalVocabDistribBase): 838 | 839 | def __init__( 840 | self, 841 | options_file: Optional[str], 842 | weight_file: str, 843 | **kwargs: Any, 844 | ) -> None: 845 | _raise_if_kwargs_is_invalid( 846 | {'exec_managed_lstm_bos_eos'}, 847 | kwargs, 848 | ) 849 | kwargs['disable_backward_lstm'] = True 850 | kwargs['disable_vocab_projection'] = False 851 | kwargs['disable_scalar_mix'] = True 852 | super().__init__(options_file, weight_file, **kwargs) 853 | 854 | def execute(self, inputs: PackedSequence) -> List[PackedSequence]: 855 | token_repr = self.exec_char_cnn(inputs) 856 | return self.exec_forward_vocab_prob_distrib(token_repr) 857 | 858 | 859 | class FastElmoBackwardVocabDistrib(FastElmoUnidirectionalVocabDistribBase): 860 | 861 | def __init__( 862 | self, 863 | options_file: Optional[str], 864 | weight_file: str, 865 | **kwargs: Any, 866 | ) -> None: 867 | _raise_if_kwargs_is_invalid( 868 | {'exec_managed_lstm_bos_eos'}, 869 | kwargs, 870 | ) 871 | kwargs['disable_forward_lstm'] = True 872 | kwargs['disable_vocab_projection'] = False 873 | kwargs['disable_scalar_mix'] = True 874 | super().__init__(options_file, weight_file, **kwargs) 875 | 876 | def execute(self, inputs: PackedSequence) -> List[PackedSequence]: 877 | token_repr = self.exec_char_cnn(inputs) 878 | return self.exec_backward_vocab_prob_distrib(token_repr) 879 | 880 | 881 | class FastElmoWordEmbeddingForwardVocabDistrib(FastElmoUnidirectionalVocabDistribBase): 882 | 883 | def __init__( 884 | self, 885 | options_file: Optional[str], 886 | weight_file: str, 887 | **kwargs: Any, 888 | ) -> None: 889 | _raise_if_kwargs_is_invalid( 890 | {'exec_managed_lstm_bos_eos', 'word_embedding_weight_file'}, 891 | kwargs, 892 | ) 893 | kwargs['disable_char_cnn'] = True 894 | kwargs['disable_word_embedding'] = False 895 | kwargs['disable_backward_lstm'] = True 896 | kwargs['disable_vocab_projection'] = False 897 | kwargs['disable_scalar_mix'] = True 898 | super().__init__(options_file, weight_file, **kwargs) 899 | 900 | def execute(self, inputs: PackedSequence) -> List[PackedSequence]: 901 | token_repr = self.exec_word_embedding(inputs) 902 | return self.exec_forward_vocab_prob_distrib(token_repr) 903 | 904 | 905 | class FastElmoWordEmbeddingBackwardVocabDistrib(FastElmoUnidirectionalVocabDistribBase): 906 | 907 | def __init__( 908 | self, 909 | options_file: Optional[str], 910 | weight_file: str, 911 | **kwargs: Any, 912 | ) -> None: 913 | _raise_if_kwargs_is_invalid( 914 | {'exec_managed_lstm_bos_eos', 'word_embedding_weight_file'}, 915 | kwargs, 916 | ) 917 | kwargs['disable_char_cnn'] = True 918 | kwargs['disable_word_embedding'] = False 919 | kwargs['disable_forward_lstm'] = True 920 | kwargs['disable_vocab_projection'] = False 921 | kwargs['disable_scalar_mix'] = True 922 | super().__init__(options_file, weight_file, **kwargs) 923 | 924 | def execute(self, inputs: PackedSequence) -> List[PackedSequence]: 925 | token_repr = self.exec_word_embedding(inputs) 926 | return self.exec_backward_vocab_prob_distrib(token_repr) 927 | -------------------------------------------------------------------------------- /pytorch_fast_elmo/tool/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huntzhan/pytorch-fast-elmo/5800005df3b3341e20e89c6b9e9ca98d4fa26fc5/pytorch_fast_elmo/tool/__init__.py -------------------------------------------------------------------------------- /pytorch_fast_elmo/tool/cli.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-self-use 2 | import cProfile 3 | import pstats 4 | import io 5 | 6 | import fire 7 | from pytorch_fast_elmo import utils 8 | from pytorch_fast_elmo.tool import profile, inspect 9 | 10 | 11 | class Main: 12 | 13 | def cache_char_cnn( # type: ignore 14 | self, 15 | vocab_txt, 16 | options_file, 17 | weight_file, 18 | txt_out, 19 | max_characters_per_token=utils.ElmoCharacterIdsConst.MAX_WORD_LENGTH, 20 | cuda_device=-1, 21 | batch_size=256, 22 | ): 23 | utils.cache_char_cnn_vocab( 24 | vocab_txt, 25 | options_file, 26 | weight_file, 27 | txt_out, 28 | max_characters_per_token, 29 | cuda_device, 30 | batch_size, 31 | ) 32 | 33 | def export_word_embd( # type: ignore 34 | self, 35 | vocab_txt, 36 | weight_file, 37 | txt_out, 38 | ): 39 | utils.export_word_embd( 40 | vocab_txt, 41 | weight_file, 42 | txt_out, 43 | ) 44 | 45 | def profile_full( # type: ignore 46 | self, 47 | mode, 48 | options_file, 49 | weight_file, 50 | cuda_device=-1, 51 | cuda_synchronize=False, 52 | batch_size=32, 53 | warmup_size=20, 54 | iteration_size=1000, 55 | word_min=1, 56 | word_max=20, 57 | sent_min=1, 58 | sent_max=30, 59 | random_seed=10000, 60 | profiler=False, 61 | output_file=None, 62 | ): 63 | sstream = io.StringIO() 64 | 65 | if profiler: 66 | cpr = cProfile.Profile() 67 | cpr.enable() 68 | 69 | mean, median, stdev = profile.profile_full_elmo( 70 | mode, 71 | options_file, 72 | weight_file, 73 | cuda_device, 74 | cuda_synchronize, 75 | batch_size, 76 | warmup_size, 77 | iteration_size, 78 | word_min, 79 | word_max, 80 | sent_min, 81 | sent_max, 82 | random_seed, 83 | ) 84 | 85 | sstream.write(f'Finish {iteration_size} iterations.\n') 86 | sstream.write(f'Mode: {mode}\n') 87 | sstream.write(f'Duration Mean: {mean}\n') 88 | sstream.write(f'Duration Median: {median}\n') 89 | sstream.write(f'Duration Stdev: {stdev}\n\n') 90 | 91 | if profiler: 92 | cpr.disable() 93 | pstats.Stats(cpr, stream=sstream).sort_stats('cumulative').print_stats() 94 | 95 | if output_file: 96 | with open(output_file, 'w') as fout: 97 | fout.write(sstream.getvalue()) 98 | else: 99 | print(sstream.getvalue()) 100 | 101 | def sample_sentence( # type: ignore 102 | self, 103 | options_file, 104 | weight_file, 105 | vocab_txt, 106 | output_json, 107 | enable_trace=False, 108 | go_forward=True, 109 | no_char_cnn=False, 110 | char_cnn_maxlen=0, 111 | next_token_top_k=5, 112 | sample_size=1, 113 | sample_constrain_txt=None, 114 | warm_up_txt=None, 115 | cuda_device=-1, 116 | ): 117 | inspect.sample_sentence( 118 | options_file, 119 | weight_file, 120 | vocab_txt, 121 | output_json, 122 | enable_trace, 123 | no_char_cnn, 124 | char_cnn_maxlen, 125 | go_forward, 126 | next_token_top_k, 127 | sample_size, 128 | sample_constrain_txt, 129 | warm_up_txt, 130 | cuda_device, 131 | ) 132 | 133 | def encode_sentences( # type: ignore 134 | self, 135 | options_file, 136 | weight_file, 137 | vocab_txt, 138 | input_txt, 139 | output_hdf5, 140 | no_char_cnn=False, 141 | char_cnn_maxlen=0, 142 | scalar_mix=None, 143 | warm_up_txt=None, 144 | cuda_device=-1, 145 | ): 146 | inspect.encode_sentences( 147 | options_file, 148 | weight_file, 149 | vocab_txt, 150 | input_txt, 151 | output_hdf5, 152 | no_char_cnn, 153 | char_cnn_maxlen, 154 | scalar_mix, 155 | warm_up_txt, 156 | cuda_device, 157 | ) 158 | 159 | 160 | def main(): # type: ignore 161 | fire.Fire(Main) 162 | -------------------------------------------------------------------------------- /pytorch_fast_elmo/tool/inspect.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, List, Dict, Any, Callable 2 | import json 3 | import logging 4 | import itertools 5 | 6 | import torch 7 | import numpy as np 8 | import h5py 9 | 10 | from pytorch_fast_elmo import ( 11 | batch_to_char_ids, 12 | load_and_build_vocab2id, 13 | batch_to_word_ids, 14 | FastElmoBase, 15 | FastElmo, 16 | FastElmoWordEmbedding, 17 | FastElmoPlainEncoder, 18 | FastElmoWordEmbeddingPlainEncoder, 19 | FastElmoForwardVocabDistrib, 20 | FastElmoBackwardVocabDistrib, 21 | FastElmoWordEmbeddingForwardVocabDistrib, 22 | FastElmoWordEmbeddingBackwardVocabDistrib, 23 | ) 24 | 25 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 26 | 27 | 28 | def _generate_vocab2id_id2vocab(vocab_txt: str,) -> Tuple[Dict[str, int], Dict[int, str]]: 29 | vocab2id = load_and_build_vocab2id(vocab_txt) 30 | id2vocab = {token_id: token for token, token_id in vocab2id.items()} 31 | return vocab2id, id2vocab 32 | 33 | 34 | def _generate_batch_to_ids( 35 | vocab2id: Dict[str, int], 36 | char_cnn_maxlen: int, 37 | no_char_cnn: bool, 38 | cuda_device: int, 39 | ) -> Callable[[List[List[str]]], torch.Tensor]: 40 | if no_char_cnn: 41 | 42 | def batch_to_ids(batch: List[List[str]]) -> torch.Tensor: 43 | tensor = batch_to_word_ids(batch, vocab2id) 44 | if cuda_device >= 0: 45 | tensor = tensor.cuda(cuda_device) 46 | return tensor 47 | else: 48 | 49 | def batch_to_ids(batch: List[List[str]]) -> torch.Tensor: 50 | if char_cnn_maxlen == 0: 51 | tensor = batch_to_char_ids(batch) 52 | else: 53 | tensor = batch_to_char_ids(batch, char_cnn_maxlen) 54 | if cuda_device >= 0: 55 | tensor = tensor.cuda(cuda_device) 56 | return tensor 57 | 58 | return batch_to_ids 59 | 60 | 61 | def _warm_up( 62 | warm_up_txt: str, 63 | batch_to_ids: Callable[[List[List[str]]], torch.Tensor], 64 | elmo: FastElmoBase, 65 | ) -> None: 66 | sentences_token_ids = [] 67 | with open(warm_up_txt) as fin: 68 | for line in fin: 69 | sent = line.split() 70 | if not sent: 71 | continue 72 | token_ids = batch_to_ids([sent]) 73 | sentences_token_ids.append(token_ids) 74 | 75 | for token_ids in sentences_token_ids: 76 | with torch.no_grad(): 77 | elmo(token_ids) 78 | 79 | 80 | def sample_sentence( 81 | options_file: str, 82 | weight_file: str, 83 | vocab_txt: str, 84 | output_json: str, 85 | enable_trace: bool, 86 | no_char_cnn: bool, 87 | char_cnn_maxlen: int, 88 | go_forward: bool, 89 | next_token_top_k: int, 90 | sample_size: int, 91 | sample_constrain_txt: Optional[str], 92 | warm_up_txt: Optional[str], 93 | cuda_device: int, 94 | ) -> None: 95 | if no_char_cnn: 96 | if go_forward: 97 | fast_elmo_cls = FastElmoWordEmbeddingForwardVocabDistrib 98 | else: 99 | fast_elmo_cls = FastElmoWordEmbeddingBackwardVocabDistrib 100 | else: 101 | if go_forward: 102 | fast_elmo_cls = FastElmoForwardVocabDistrib 103 | else: 104 | fast_elmo_cls = FastElmoBackwardVocabDistrib 105 | 106 | vocab2id, id2vocab = _generate_vocab2id_id2vocab(vocab_txt) 107 | batch_to_ids = _generate_batch_to_ids( 108 | vocab2id, 109 | char_cnn_maxlen, 110 | no_char_cnn, 111 | cuda_device, 112 | ) 113 | 114 | elmo = fast_elmo_cls(options_file, weight_file) 115 | if cuda_device >= 0: 116 | elmo = elmo.cuda(cuda_device) 117 | 118 | # Warm up. 119 | if warm_up_txt: 120 | _warm_up(warm_up_txt, batch_to_ids, elmo) 121 | 122 | # Manually deal with BOS/EOS. 123 | elmo.exec_managed_lstm_bos_eos = False 124 | 125 | if sample_constrain_txt: 126 | with open(sample_constrain_txt) as fin: 127 | lines = fin.readlines() 128 | if not lines: 129 | raise ValueError('No content in sample_constrain_txt.') 130 | if len(lines) > 1: 131 | logging.warning('Multiple lines in sample_constrain_txt, only use the 1st line.') 132 | sample_constrain_tokens = lines[0].split() 133 | if not go_forward: 134 | sample_constrain_tokens.reverse() 135 | 136 | infos: List[Any] = [] 137 | for _ in range(sample_size): 138 | if go_forward: 139 | cur_token, end_token = '', '' 140 | else: 141 | cur_token, end_token = '', '' 142 | 143 | if sample_constrain_txt: 144 | for token in itertools.chain([cur_token], sample_constrain_tokens[:-1]): 145 | with torch.no_grad(): 146 | elmo(batch_to_ids([[token]])) 147 | cur_token = sample_constrain_tokens[-1] 148 | 149 | info: List[Any] = [] 150 | while cur_token != end_token: 151 | batched = batch_to_ids([[cur_token]]) 152 | with torch.no_grad(): 153 | output, _ = elmo(batched) 154 | if cuda_device >= 0: 155 | output = output.cpu() 156 | 157 | probs, indices = torch.topk(output.view(-1), next_token_top_k) 158 | probs = probs.numpy() 159 | indices = indices.numpy() 160 | 161 | next_token_id = np.random.choice(indices, p=probs / probs.sum()) + 1 162 | next_token = id2vocab[next_token_id] 163 | 164 | info_probs = sorted( 165 | dict(zip(map(id2vocab.get, indices + 1), probs.tolist())).items(), 166 | key=lambda p: p[1], 167 | reverse=True, 168 | ) 169 | info.append({ 170 | 'cur': cur_token, 171 | 'next': next_token, 172 | 'probs': info_probs, 173 | }) 174 | 175 | cur_token = next_token 176 | 177 | # Ending. 178 | with torch.no_grad(): 179 | elmo(batch_to_ids([[end_token]])) 180 | # Save info. 181 | infos.append({'text': ''.join(step['cur'] for step in info)}) 182 | if enable_trace: 183 | infos[-1]['trace'] = info 184 | if sample_constrain_txt: 185 | infos[-1]['text'] = ''.join(sample_constrain_tokens[:-1]) + infos[-1]['text'] 186 | if not go_forward: 187 | infos[-1]['text'] = infos[-1]['text'][::-1] 188 | 189 | # Output to JSON. 190 | with open(output_json, 'w') as fout: 191 | json.dump(infos, fout, ensure_ascii=False, indent=2) 192 | 193 | 194 | def encode_sentences( 195 | options_file: str, 196 | weight_file: str, 197 | vocab_txt: str, 198 | input_txt: str, 199 | output_hdf5: str, 200 | no_char_cnn: bool, 201 | char_cnn_maxlen: int, 202 | scalar_mix: Optional[Tuple[float]], 203 | warm_up_txt: Optional[str], 204 | cuda_device: int, 205 | ) -> None: 206 | if scalar_mix is None: 207 | if no_char_cnn: 208 | fast_elmo_cls = FastElmoWordEmbeddingPlainEncoder 209 | else: 210 | fast_elmo_cls = FastElmoPlainEncoder 211 | 212 | elmo = fast_elmo_cls( 213 | options_file, 214 | weight_file, 215 | ) 216 | 217 | else: 218 | if no_char_cnn: 219 | fast_elmo_cls = FastElmoWordEmbedding 220 | else: 221 | fast_elmo_cls = FastElmo 222 | 223 | elmo = fast_elmo_cls( 224 | options_file, 225 | weight_file, 226 | scalar_mix_parameters=list(scalar_mix), 227 | ) 228 | 229 | if cuda_device >= 0: 230 | elmo = elmo.cuda(cuda_device) 231 | 232 | vocab2id, _ = _generate_vocab2id_id2vocab(vocab_txt) 233 | batch_to_ids = _generate_batch_to_ids( 234 | vocab2id, 235 | char_cnn_maxlen, 236 | no_char_cnn, 237 | cuda_device, 238 | ) 239 | 240 | # Warm up. 241 | if warm_up_txt: 242 | _warm_up(warm_up_txt, batch_to_ids, elmo) 243 | 244 | sentences: List[Tuple[int, List[str]]] = [] 245 | with open(input_txt) as fin: 246 | for sentence_id, line in enumerate(fin): 247 | tokens = line.split() 248 | if not tokens: 249 | logger.warning('Ignore sentence_id = %s', sentence_id) 250 | continue 251 | sentences.append((sentence_id, tokens)) 252 | 253 | with h5py.File(output_hdf5, 'w') as fout: 254 | for sentence_id, tokens in sentences: 255 | token_ids = batch_to_ids([tokens]) 256 | 257 | if scalar_mix is None: 258 | with torch.no_grad(): 259 | layer_reprs, _ = elmo(token_ids) 260 | # (layers, timesteps, hidden_size) 261 | encoded = torch.cat(layer_reprs, dim=0) 262 | else: 263 | with torch.no_grad(): 264 | out = elmo(token_ids) 265 | # (1, timesteps, hidden_size) 266 | encoded = out['elmo_representations'][0] 267 | 268 | if cuda_device >= 0: 269 | encoded = encoded.cpu() 270 | 271 | fout.create_dataset( 272 | str(sentence_id), 273 | encoded.shape, 274 | dtype='float32', 275 | data=encoded.numpy(), 276 | ) 277 | -------------------------------------------------------------------------------- /pytorch_fast_elmo/tool/profile.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Tuple 2 | import random 3 | import string 4 | import time 5 | import statistics 6 | 7 | import torch 8 | from pytorch_fast_elmo import batch_to_char_ids, FastElmo 9 | 10 | 11 | class SentenceGenerator: 12 | 13 | def __init__( 14 | self, 15 | word_min: int, 16 | word_max: int, 17 | sent_min: int, 18 | sent_max: int, 19 | ) -> None: 20 | self.word_min = word_min 21 | self.word_max = word_max 22 | self.sent_min = sent_min 23 | self.sent_max = sent_max 24 | 25 | def generate_sentence(self) -> List[str]: 26 | return [ 27 | ''.join( 28 | random.choices( 29 | string.ascii_lowercase, 30 | k=random.randint(self.word_min, self.word_max), 31 | )) for _ in range(random.randint(self.sent_min, self.sent_max)) 32 | ] 33 | 34 | def generate_batch(self, batch_size: int) -> List[List[str]]: 35 | return [self.generate_sentence() for _ in range(batch_size)] 36 | 37 | 38 | def load_fast_elmo( 39 | options_file: str, 40 | weight_file: str, 41 | ) -> FastElmo: 42 | return FastElmo( 43 | options_file, 44 | weight_file, 45 | scalar_mix_parameters=[1.0, 1.0, 1.0], 46 | ) 47 | 48 | 49 | def load_allennlp_elmo( 50 | options_file: str, 51 | weight_file: str, 52 | ) -> Any: 53 | from allennlp.modules.elmo import Elmo 54 | return Elmo( 55 | options_file, 56 | weight_file, 57 | num_output_representations=1, 58 | dropout=0.0, 59 | scalar_mix_parameters=[1.0, 1.0, 1.0], 60 | ) 61 | 62 | 63 | def profile_full_elmo( 64 | mode: str, 65 | options_file: str, 66 | weight_file: str, 67 | cuda_device: int, 68 | cuda_synchronize: bool, 69 | batch_size: int, 70 | warmup_size: int, 71 | iteration_size: int, 72 | word_min: int, 73 | word_max: int, 74 | sent_min: int, 75 | sent_max: int, 76 | random_seed: int, 77 | ) -> Tuple[float, float, float]: 78 | random.seed(random_seed) 79 | 80 | module: Any = None 81 | if mode == 'fast-elmo': 82 | module = load_fast_elmo(options_file, weight_file) 83 | elif mode == 'allennlp-elmo': 84 | module = load_allennlp_elmo(options_file, weight_file) 85 | else: 86 | raise ValueError('invalid mode') 87 | 88 | sent_gen = SentenceGenerator( 89 | word_min, 90 | word_max, 91 | sent_min, 92 | sent_max, 93 | ) 94 | 95 | if cuda_device >= 0: 96 | module = module.cuda(cuda_device) 97 | 98 | durations: List[float] = [] 99 | 100 | for idx in range(warmup_size + iteration_size): 101 | batch = sent_gen.generate_batch(batch_size) 102 | char_ids = batch_to_char_ids(batch) 103 | 104 | if cuda_device >= 0: 105 | char_ids = char_ids.cuda(cuda_device) 106 | if cuda_synchronize: 107 | torch.cuda.synchronize() 108 | 109 | start = time.time() 110 | with torch.no_grad(): 111 | module(char_ids) 112 | if cuda_device >= 0 and cuda_synchronize: 113 | torch.cuda.synchronize() 114 | end = time.time() 115 | 116 | if idx >= warmup_size: 117 | durations.append(end - start) 118 | 119 | mean = statistics.mean(durations) 120 | median = statistics.median(durations) 121 | stdev = statistics.stdev(durations) 122 | 123 | return mean, median, stdev 124 | -------------------------------------------------------------------------------- /pytorch_fast_elmo/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Iterable, Dict, Optional 2 | 3 | import torch 4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence 5 | import numpy as np 6 | 7 | from pytorch_fast_elmo.factory import ElmoCharacterEncoderFactory, ElmoWordEmbeddingFactory 8 | from _pytorch_fast_elmo import ElmoCharacterEncoder # pylint: disable=no-name-in-module 9 | 10 | 11 | def load_vocab(vocab_txt: str) -> List[str]: 12 | """ 13 | Use the same format as bilm-tf. 14 | """ 15 | vocab = [] 16 | with open(vocab_txt) as fin: 17 | for line in fin: 18 | word = line.strip() 19 | if word: 20 | vocab.append(word) 21 | return vocab 22 | 23 | 24 | def build_vocab2id(vocab: List[str]) -> Dict[str, int]: 25 | """ 26 | Adding one will be applied for padding. 27 | """ 28 | assert len(vocab) > 3 29 | assert vocab[:3] == ['', '', ''] 30 | return {word: word_id for word_id, word in enumerate(vocab, start=1)} 31 | 32 | 33 | def load_and_build_vocab2id(vocab_txt: str) -> Dict[str, int]: 34 | return build_vocab2id(load_vocab(vocab_txt)) 35 | 36 | 37 | def get_lengths_of_zero_padded_batch(inputs: torch.Tensor) -> torch.Tensor: 38 | if inputs.dim() == 2: 39 | lengths = (inputs > 0).long().sum(dim=-1) 40 | elif inputs.dim() == 3: 41 | lengths = ((inputs > 0).long().sum(dim=-1) > 0).long().sum(dim=-1) 42 | else: 43 | raise ValueError("inputs should be 2D or 3D.") 44 | 45 | return lengths 46 | 47 | 48 | def pack_inputs( 49 | inputs: torch.Tensor, 50 | lengths: Optional[torch.Tensor] = None, 51 | ) -> PackedSequence: 52 | """ 53 | Pack inputs of shape `(batch_size, timesteps, x)` or `(batch_size, timesteps)`. 54 | Padding value should be 0. 55 | """ 56 | if lengths is None: 57 | lengths = get_lengths_of_zero_padded_batch(inputs) 58 | return pack_padded_sequence(inputs, lengths, batch_first=True) 59 | 60 | 61 | def generate_mask_from_lengths( 62 | batch_size: int, 63 | max_timesteps: int, 64 | lengths: torch.Tensor, 65 | ) -> torch.Tensor: 66 | ones = lengths.new_ones(batch_size, max_timesteps, dtype=torch.long) 67 | range_tensor = ones.cumsum(dim=-1) 68 | return (lengths.unsqueeze(1) >= range_tensor).long() 69 | 70 | 71 | def unpack_outputs(inputs: PackedSequence) -> torch.Tensor: 72 | """ 73 | Unpack the final result and return `(tensor, mask)`. 74 | """ 75 | tensor, _ = pad_packed_sequence(inputs, batch_first=True) 76 | return tensor 77 | 78 | 79 | class ElmoCharacterIdsConst: 80 | """ 81 | From Allennlp. 82 | """ 83 | MAX_WORD_LENGTH = 50 84 | 85 | BEGINNING_OF_SENTENCE_CHARACTER = 256 # 86 | END_OF_SENTENCE_CHARACTER = 257 # 87 | BEGINNING_OF_WORD_CHARACTER = 258 # 88 | END_OF_WORD_CHARACTER = 259 # 89 | PADDING_CHARACTER = 260 # 90 | 91 | 92 | def make_padded_char_ids( 93 | char_ids: Iterable[int], # +1 should have been applied. 94 | max_word_length: int = ElmoCharacterIdsConst.MAX_WORD_LENGTH, 95 | padding_character: int = ElmoCharacterIdsConst.PADDING_CHARACTER + 1, 96 | beginning_of_word_character: int = ElmoCharacterIdsConst.BEGINNING_OF_WORD_CHARACTER + 1, 97 | end_of_word_character: int = ElmoCharacterIdsConst.END_OF_WORD_CHARACTER + 1, 98 | ) -> List[int]: 99 | padded = [padding_character] * max_word_length 100 | 101 | padded[0] = beginning_of_word_character 102 | idx = 1 103 | for char_id in char_ids: 104 | if idx >= max_word_length: 105 | break 106 | padded[idx] = char_id 107 | idx += 1 108 | 109 | idx = min(idx, max_word_length - 1) 110 | padded[idx] = end_of_word_character 111 | 112 | return padded 113 | 114 | 115 | def make_bos(max_word_length: int = ElmoCharacterIdsConst.MAX_WORD_LENGTH) -> List[int]: 116 | return make_padded_char_ids( 117 | (ElmoCharacterIdsConst.BEGINNING_OF_SENTENCE_CHARACTER + 1,), 118 | max_word_length, 119 | ) 120 | 121 | 122 | def make_eos(max_word_length: int = ElmoCharacterIdsConst.MAX_WORD_LENGTH) -> List[int]: 123 | return make_padded_char_ids( 124 | (ElmoCharacterIdsConst.END_OF_SENTENCE_CHARACTER + 1,), 125 | max_word_length, 126 | ) 127 | 128 | 129 | def word_to_char_ids(word: str) -> List[int]: 130 | # +1 is applied here. 131 | return [char_id + 1 for char_id in word.encode('utf-8', 'ignore')] 132 | 133 | 134 | _WORD_TO_CHAR_IDS_EXCEPTION = { 135 | '': make_bos, 136 | '': make_eos, 137 | } 138 | 139 | 140 | def batch_to_char_ids( 141 | batch: List[List[str]], 142 | max_characters_per_token: int = ElmoCharacterIdsConst.MAX_WORD_LENGTH, 143 | ) -> torch.Tensor: 144 | """ 145 | From Allennlp. 146 | 147 | Note: 148 | 1. BOS/EOS will be treated specially. 149 | 2. UNK will be treated as normal string, same as bilm-tf. 150 | 151 | Return tensor of shape `(batch_size, max_timesteps, max_characters_per_token)`. 152 | """ 153 | max_timesteps = max(len(row) for row in batch) 154 | zeros = torch.LongTensor([0] * max_characters_per_token) 155 | 156 | rows = [] 157 | for words in batch: 158 | row = [] 159 | for word in words: 160 | special_gen = _WORD_TO_CHAR_IDS_EXCEPTION.get(word) 161 | if special_gen is None: 162 | char_ids = make_padded_char_ids( 163 | word_to_char_ids(word), 164 | max_characters_per_token, 165 | ) 166 | else: 167 | char_ids = special_gen(max_characters_per_token) 168 | 169 | # of shape `(max_characters_per_token,)` 170 | row.append(torch.LongTensor(char_ids)) 171 | 172 | # Add padding. 173 | row.extend([zeros] * (max_timesteps - len(row))) 174 | # Stack to shape `(max_timesteps, max_characters_per_token)` 175 | rows.append(torch.stack(row)) 176 | 177 | # Stack to shape `(batch_size, max_timesteps, max_characters_per_token)` 178 | return torch.stack(rows) 179 | 180 | 181 | def batch_to_word_ids(batch: List[List[str]], vocab2id: Dict[str, int]) -> torch.Tensor: 182 | """ 183 | For word embedding. 184 | 185 | 1. UNK will be mapped to 3 since we assume the vocab starts with `, , `. 186 | 187 | Return tensor of shape `(batch_size, max_timesteps)`. 188 | """ 189 | max_timesteps = max(len(row) for row in batch) 190 | 191 | rows = [] 192 | for words in batch: 193 | row = [vocab2id.get(word, 3) for word in words] 194 | row.extend([0] * (max_timesteps - len(row))) 195 | rows.append(row) 196 | 197 | return torch.LongTensor(rows) 198 | 199 | 200 | def get_bos_eos_token_repr( 201 | char_cnn_factory: ElmoCharacterEncoderFactory, 202 | char_cnn: ElmoCharacterEncoder, 203 | ) -> Tuple[torch.Tensor, torch.Tensor]: 204 | # [, , , max(kernal)...] 205 | max_characters_per_token = max(kernal_size for kernal_size, _ in char_cnn_factory.filters) 206 | max_characters_per_token += 3 207 | 208 | bos_ids = make_bos(max_characters_per_token) 209 | eos_ids = make_eos(max_characters_per_token) 210 | bos_eos = torch.LongTensor([bos_ids, eos_ids]) 211 | 212 | char_cnn_param = char_cnn.parameters()[0] 213 | if char_cnn_param.is_cuda: 214 | bos_eos = bos_eos.cuda(char_cnn_param.get_device()) 215 | 216 | with torch.no_grad(): 217 | bos_eos_reprs = char_cnn(bos_eos) 218 | 219 | return bos_eos_reprs[0], bos_eos_reprs[1] 220 | 221 | 222 | def sort_batch_by_length( 223 | batch: torch.Tensor, 224 | lengths: Optional[torch.Tensor] = None, 225 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 226 | """ 227 | Similar to AllenNLP. 228 | 229 | `batch` of shape `(batch_size, max_timesteps, *)` should be zero padded. 230 | `sequence_lengths` of shape `(batch_size,)` 231 | 232 | Returns (sorted_batch, restoration_index) 233 | """ 234 | if lengths is None: 235 | lengths = get_lengths_of_zero_padded_batch(batch) 236 | _, permutation_index = lengths.sort(0, descending=True) 237 | sorted_batch = batch.index_select(0, permutation_index) 238 | _, restoration_index = permutation_index.sort(0, descending=False) 239 | return sorted_batch, permutation_index, restoration_index 240 | 241 | 242 | def export_word_embedding_to_txt( 243 | vocab: List[str], 244 | embedding_weight: np.array, 245 | txt_out: str, 246 | ) -> None: 247 | if embedding_weight.shape[0] != len(vocab): 248 | raise ValueError('Size not match. ' 249 | f'embd({embedding_weight.shape[0]}) ' 250 | '!= ' 251 | f'vocab({len(vocab)})') 252 | 253 | with open(txt_out, 'w') as fout: 254 | for idx, token in enumerate(vocab): 255 | embd = embedding_weight[idx] 256 | # [token] [dim1] [dim2]... 257 | embd_txt = ' '.join(map(str, embd)) 258 | line = f'{token} {embd_txt}\n' 259 | fout.write(line) 260 | 261 | 262 | def cache_char_cnn_vocab( 263 | vocab_txt: str, 264 | options_file: str, 265 | weight_file: str, 266 | txt_out: str, 267 | max_characters_per_token: int = ElmoCharacterIdsConst.MAX_WORD_LENGTH, 268 | cuda_device: int = -1, 269 | batch_size: int = 256, 270 | ) -> None: 271 | """ 272 | 1. Load vocab. 273 | 2. Feed vocab to Char CNN. 274 | 3. Feed BOS/EOS to Char CNN. 275 | 4. Dump reprs to txt. (will be loaded by `ElmoWordEmbeddingFactory`). 276 | """ 277 | # 1. 278 | vocab = load_vocab(vocab_txt) 279 | 280 | # 2. 281 | char_cnn_factory = ElmoCharacterEncoderFactory( 282 | options_file, 283 | weight_file, 284 | ) 285 | char_cnn = char_cnn_factory.create(requires_grad=False) 286 | if cuda_device >= 0: 287 | char_cnn.cuda(cuda_device) 288 | 289 | cached = [] 290 | for batch_start in range(0, len(vocab), batch_size): 291 | batch = vocab[batch_start:batch_start + batch_size] 292 | # (1, batch_size, max_characters_per_token) 293 | char_ids = batch_to_char_ids([batch], max_characters_per_token) 294 | if cuda_device >= 0: 295 | char_ids = char_ids.cuda(cuda_device) 296 | 297 | inputs = pack_inputs(char_ids) 298 | output_data = char_cnn(inputs.data) 299 | # (1, batch_size, output_dim) 300 | char_reprs = unpack_outputs(PackedSequence(output_data, inputs.batch_sizes)) 301 | # (batch_size, output_dim) 302 | cached.append(char_reprs.squeeze(0)) 303 | 304 | # (vocab, output_dim) 305 | combined = torch.cat(cached, dim=0) 306 | if cuda_device >= 0: 307 | combined = combined.cpu() 308 | embedding_weight = combined.numpy() 309 | 310 | # 3. 311 | lstm_bos_repr, lstm_eos_repr = get_bos_eos_token_repr( 312 | char_cnn_factory, 313 | char_cnn, 314 | ) 315 | if cuda_device >= 0: 316 | lstm_bos_repr = lstm_bos_repr.cpu() 317 | lstm_eos_repr = lstm_eos_repr.cpu() 318 | 319 | lstm_bos_weight = lstm_bos_repr.numpy() 320 | lstm_eos_weight = lstm_eos_repr.numpy() 321 | 322 | embedding_weight[0] = lstm_bos_weight 323 | embedding_weight[1] = lstm_eos_weight 324 | 325 | # 4. 326 | export_word_embedding_to_txt(vocab, embedding_weight, txt_out) 327 | 328 | 329 | def export_word_embd( 330 | vocab_txt: str, 331 | weight_file: str, 332 | txt_out: str, 333 | ) -> None: 334 | vocab = load_vocab(vocab_txt) 335 | word_embedding_factory = ElmoWordEmbeddingFactory(None, weight_file) 336 | embedding_weight, _, _ = word_embedding_factory.create(requires_grad=False) 337 | # remove padding. 338 | embedding_weight = embedding_weight.data[1:, :].numpy() 339 | 340 | export_word_embedding_to_txt(vocab, embedding_weight, txt_out) 341 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | # Package. 2 | pip==19.0.1 3 | wheel==0.32.3 4 | bumpversion==0.5.3 5 | 6 | # Doc. 7 | Sphinx==1.8.3 8 | watchdog==0.9.0 9 | twine==1.12.1 10 | 11 | # CI. 12 | tox==3.7.0 13 | yapf==0.25.0 14 | pylint==2.2.2 15 | mypy==0.660 16 | pytest==4.1.1 17 | flaky==3.5.3 18 | pytest-runner==4.2 19 | coverage==4.5.2 20 | 21 | # Dev. 22 | invoke==1.2.0 23 | ipython==7.2.0 24 | 25 | allennlp==0.8.1 26 | -------------------------------------------------------------------------------- /requirements_prod.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.1.post2 2 | pytorch-stateful-lstm>=1.6.0 3 | h5py==2.9.0 4 | fire==0.1.3 5 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.6.12 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = version='{current_version}' 8 | replace = version='{new_version}' 9 | 10 | [bumpversion:file:pytorch_fast_elmo/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | 14 | [bdist_wheel] 15 | universal = 1 16 | 17 | [yapf] 18 | based_on_style = google 19 | continuation_indent_width = 8 20 | dedent_closing_brackets = false 21 | column_limit = 100 22 | 23 | [aliases] 24 | test = pytest 25 | 26 | [tool:pytest] 27 | collect_ignore = ['setup.py'] 28 | 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """The setup script.""" 5 | 6 | from setuptools import setup, find_packages 7 | from torch.utils.cpp_extension import CppExtension, BuildExtension 8 | from os.path import abspath, dirname 9 | 10 | 11 | def load_requirements(path): 12 | with open(path) as fin: 13 | return [ 14 | line 15 | for line in map(lambda l: l.strip(), fin.readlines()) 16 | if line and not line.startswith('#') 17 | ] 18 | 19 | 20 | with open('README.rst') as readme_file: 21 | readme = readme_file.read() 22 | 23 | with open('HISTORY.rst') as history_file: 24 | history = history_file.read() 25 | 26 | requirements = load_requirements('requirements_prod.txt') 27 | test_requirements = load_requirements('requirements_dev.txt') 28 | 29 | setup( 30 | author="Hunt Zhan", 31 | author_email='huntzhan.dev@gmail.com', 32 | classifiers=[ 33 | 'Development Status :: 2 - Pre-Alpha', 34 | 'Intended Audience :: Developers', 35 | 'License :: OSI Approved :: MIT License', 36 | 'Natural Language :: English', 37 | 'Programming Language :: Python :: 3', 38 | 'Programming Language :: Python :: 3.6', 39 | 'Programming Language :: Python :: 3.7', 40 | ], 41 | description="None", 42 | install_requires=requirements, 43 | license="MIT license", 44 | long_description=readme + '\n\n' + history, 45 | include_package_data=True, 46 | keywords='pytorch_fast_elmo', 47 | name='pytorch_fast_elmo', 48 | packages=find_packages(), 49 | test_suite='tests', 50 | tests_require=test_requirements, 51 | url='https://github.com/cnt-dev/pytorch-fast-elmo', 52 | version='0.6.12', 53 | zip_safe=False, 54 | # Pytorch Cpp Extension. 55 | ext_modules=[ 56 | CppExtension( 57 | '_pytorch_fast_elmo', 58 | [ 59 | 'extension/elmo_character_encoder.cc', 60 | 'extension/scalar_mix.cc', 61 | 'extension/bind.cc', 62 | ], 63 | include_dirs=[dirname(abspath(__file__))], 64 | ), 65 | ], 66 | cmdclass={'build_ext': BuildExtension}, 67 | entry_points={'console_scripts': ['fast-elmo = pytorch_fast_elmo.tool.cli:main']}, 68 | ) 69 | -------------------------------------------------------------------------------- /tests/fixtures/lm_embd.txt: -------------------------------------------------------------------------------- 1 | 0.28980404 -0.6028592 -0.18264832 0.3293563 -0.22317028 0.20625375 0.11600229 0.2848622 0.13669783 -0.00091633946 -0.124367476 -0.5858446 0.3517094 -0.8553426 -0.18970014 0.09590681 2 | -2.3861542 0.8002946 -4.202066 0.6665349 -2.9172802 2.3743784 -4.9812274 3.8177035 0.5781853 -1.8342785 -2.5637474 -0.76124954 4.0692067 -1.6655566 -2.3589456 0.7653032 3 | -0.1412394 0.7715637 -0.06868322 -0.17460269 -0.4982049 -0.53049004 -0.7671114 -0.67538446 -0.16095534 -0.020927057 0.2418477 0.5309684 -0.40139574 0.6756878 0.10029215 -0.30584192 4 | a 1.0359238 -1.696596 -0.23661254 1.2023154 -0.40868482 -0.50711024 0.5740293 -0.092993006 0.6653838 0.5726978 0.4281926 -0.8583081 -0.3748623 -1.0206424 -0.247112 -0.84500986 5 | Another 1.5667292 -1.4926988 0.86566085 1.64292 -0.98098135 -1.9184027 1.0054154 -2.0008366 0.65204847 1.3102181 1.1542757 0.016974151 -2.091839 0.16773838 0.2033734 -2.1522627 6 | one -0.09851174 -0.00082571805 0.02395771 -0.10074097 -0.19198728 0.028538883 -0.2812282 -0.1334887 -0.14111939 -0.14035721 -0.049450208 -0.06812096 0.19856848 -0.1357014 0.19158512 0.3794132 7 | Here -0.17405833 0.6001479 0.13315843 -0.010353893 -0.2604631 -0.26127774 -0.5337093 -0.36444092 -0.06510824 0.040666707 -0.05684178 0.4655811 -0.11033942 0.55485034 0.11241583 -0.1417979 8 | This 0.20986725 -0.53361917 -0.189162 0.40430778 -0.5312984 -0.46517226 0.1680393 -0.6759055 0.13581483 0.24582757 0.33422536 -0.14842531 -0.67372555 -0.014158517 0.18692917 -0.48351508 9 | 's 0.42661577 -1.0821646 -0.039161295 0.3855788 -0.15253794 0.046862587 0.6531919 -0.09456499 0.16414921 0.26111037 0.04874822 -0.8619473 0.008180812 -0.99453473 0.1559616 0.19066153 10 | sentence -0.33722532 1.2612866 0.4395743 -0.8230878 0.15801784 -0.29353058 -0.47678584 -0.8535238 -0.50125676 -0.25310087 0.28450802 0.649899 -0.41487288 0.9483386 0.47737962 0.4027161 11 | is 0.2643016 -1.0500723 -0.04420519 0.25495237 -0.18994996 0.24982716 0.66718554 -0.010221064 0.043384895 0.14484806 -0.12825778 -0.9615593 0.22108398 -1.182286 0.1437658 0.46505222 12 | ELMo -0.24151926 0.41801804 -0.007428486 -0.44236958 -0.18258658 0.11114813 -0.29526627 -0.24564306 -0.2782054 -0.111587785 0.04202531 0.1244072 -0.06128703 0.27814174 0.22751448 0.46913505 13 | . -4.1568537 3.17313 -4.4991755 -0.63724 -3.4168785 4.137848 -7.0601287 4.6883783 -0.5197414 -3.0221457 -3.8914802 -0.735818 6.223864 -1.3600775 -2.6550674 2.0737975 14 | helps -0.5200626 0.6567105 -0.4354482 -0.71946436 -0.3435796 0.18210343 -0.622038 -0.24847429 -0.21154717 -0.35597056 -0.019365553 0.097794384 0.06000697 0.39206824 0.23893628 0.5111789 15 | disambiguate -0.35081345 1.5379206 0.7070806 -1.3555737 0.0134746805 -0.14858353 -0.22784677 -1.197601 -0.5240481 -0.047964945 0.20505252 0.66521466 -0.6952449 1.4304036 1.0336764 0.7010869 16 | The 1.4023353 -1.9006824 -0.37156108 2.3824224 -0.9490761 -1.6122202 0.33879024 -0.5810026 0.7497753 0.81845075 0.76323366 -0.28443652 -1.5327418 -0.7988391 -0.5818471 -2.197631 17 | from -0.42211163 1.1988133 0.46265048 -0.6553217 0.10978282 -0.1263574 -0.3945774 -0.59299177 -0.42469966 -0.10309285 0.030529015 0.59077454 -0.24393405 0.86818707 0.45730734 0.4582367 18 | -------------------------------------------------------------------------------- /tests/fixtures/lm_weights.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huntzhan/pytorch-fast-elmo/5800005df3b3341e20e89c6b9e9ca98d4fa26fc5/tests/fixtures/lm_weights.hdf5 -------------------------------------------------------------------------------- /tests/fixtures/options.json: -------------------------------------------------------------------------------- 1 | { 2 | "lstm": { 3 | "cell_clip": 3, 4 | "use_skip_connections": true, 5 | "n_layers": 2, 6 | "proj_clip": 3, 7 | "projection_dim": 16, 8 | "dim": 64 9 | }, 10 | "char_cnn": { 11 | "embedding": { 12 | "dim": 4 13 | }, 14 | "filters": [ 15 | [1, 4], 16 | [2, 8], 17 | [3, 16], 18 | [4, 32], 19 | [5, 64] 20 | ], 21 | "n_highway": 2, 22 | "n_characters": 261, 23 | "max_characters_per_token": 50, 24 | "activation": "relu" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /tests/fixtures/vocab.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | a 5 | Another 6 | one 7 | Here 8 | This 9 | 's 10 | sentence 11 | is 12 | ELMo 13 | . 14 | helps 15 | disambiguate 16 | The 17 | from 18 | -------------------------------------------------------------------------------- /tests/test_elmo.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, join 2 | 3 | import torch 4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence 5 | import numpy as np 6 | import random 7 | 8 | from allennlp.data import Token, Vocabulary, Instance 9 | from allennlp.data.fields import TextField 10 | from allennlp.data.dataset import Batch 11 | from allennlp.nn.util import add_sentence_boundary_token_ids, remove_sentence_boundaries 12 | from allennlp.data.token_indexers.elmo_indexer import ( 13 | ELMoCharacterMapper, 14 | ELMoTokenCharactersIndexer, 15 | ) 16 | from allennlp.modules.elmo import _ElmoCharacterEncoder, _ElmoBiLm, Elmo 17 | 18 | from pytorch_fast_elmo import ( 19 | ElmoCharacterEncoder, 20 | ElmoCharacterEncoderFactory, 21 | ElmoLstmFactory, 22 | FastElmo, 23 | FastElmoWordEmbedding, 24 | utils, 25 | ) 26 | 27 | FIXTURES_FODLER = join(dirname(__file__), 'fixtures') 28 | ELMO_OPTIONS_FILE = join(FIXTURES_FODLER, 'options.json') 29 | ELMO_WEIGHT_FILE = join(FIXTURES_FODLER, 'lm_weights.hdf5') 30 | 31 | CACHE_VOCAB_FILE = join(FIXTURES_FODLER, 'vocab.txt') 32 | CACHE_EMBD_FILE = join(FIXTURES_FODLER, 'lm_embd.txt') 33 | 34 | 35 | def test_elmo_character_encoder_simple(): 36 | embedder = ElmoCharacterEncoder( 37 | char_embedding_cnt=10, 38 | char_embedding_dim=3, 39 | filters=[ 40 | [1, 5], 41 | [2, 4], 42 | ], 43 | activation='relu', 44 | num_highway_layers=2, 45 | output_dim=3, 46 | ) 47 | inputs = torch.randint(0, 10, (20, 10)) 48 | output = embedder(inputs) 49 | assert list(output.size()) == [20, 3] 50 | 51 | 52 | def _sentences_to_ids(sentences): 53 | indexer = ELMoTokenCharactersIndexer() 54 | 55 | # For each sentence, first create a TextField, then create an instance 56 | instances = [] 57 | for sentence in sentences: 58 | tokens = [Token(token) for token in sentence] 59 | field = TextField(tokens, {'character_ids': indexer}) 60 | instance = Instance({'elmo': field}) 61 | instances.append(instance) 62 | 63 | dataset = Batch(instances) 64 | vocab = Vocabulary() 65 | dataset.index_instances(vocab) 66 | return dataset.as_tensor_dict()['elmo']['character_ids'] 67 | 68 | 69 | def _unpack(tensor, batch_sizes): 70 | tensor, _ = pad_packed_sequence( 71 | PackedSequence(tensor, batch_sizes), 72 | batch_first=True, 73 | ) 74 | return tensor 75 | 76 | 77 | def test_elmo_character_encoder_with_allennlp(): 78 | allennlp_embedder = _ElmoCharacterEncoder( 79 | ELMO_OPTIONS_FILE, 80 | ELMO_WEIGHT_FILE, 81 | ) 82 | embedder = ElmoCharacterEncoderFactory( 83 | ELMO_OPTIONS_FILE, 84 | ELMO_WEIGHT_FILE, 85 | ).create() 86 | 87 | allennlp_parameters = [ 88 | '_char_embedding_weights', 89 | 'char_conv_0.bias', 90 | 'char_conv_0.weight', 91 | 'char_conv_1.bias', 92 | 'char_conv_1.weight', 93 | 'char_conv_2.bias', 94 | 'char_conv_2.weight', 95 | 'char_conv_3.bias', 96 | 'char_conv_3.weight', 97 | 'char_conv_4.bias', 98 | 'char_conv_4.weight', 99 | '_projection.bias', 100 | '_projection.weight', 101 | ] 102 | embedder_parameters = [ 103 | 'char_embedding.weight', 104 | 'char_conv_0.bias', 105 | 'char_conv_0.weight', 106 | 'char_conv_1.bias', 107 | 'char_conv_1.weight', 108 | 'char_conv_2.bias', 109 | 'char_conv_2.weight', 110 | 'char_conv_3.bias', 111 | 'char_conv_3.weight', 112 | 'char_conv_4.bias', 113 | 'char_conv_4.weight', 114 | 'output_proj.bias', 115 | 'output_proj.weight', 116 | ] 117 | allennlp_parameters_diff = [ 118 | '_highways._layers.0.bias', 119 | '_highways._layers.0.weight', 120 | '_highways._layers.1.bias', 121 | '_highways._layers.1.weight', 122 | ] 123 | embedder_parameters_diff = [ 124 | 'highway.layers_0.bias', 125 | 'highway.layers_0.weight', 126 | 'highway.layers_1.bias', 127 | 'highway.layers_1.weight', 128 | ] 129 | assert len(allennlp_parameters) == len(embedder_parameters) 130 | assert len(allennlp_parameters_diff) == len(embedder_parameters_diff) 131 | 132 | allennlp_embedder_named_parameters = dict(allennlp_embedder.named_parameters()) 133 | # Same. 134 | for allennlp_param, embedder_param in zip(allennlp_parameters, embedder_parameters): 135 | allennlp_w = allennlp_embedder_named_parameters[allennlp_param].data 136 | embedder_w = embedder.named_parameters()[embedder_param].data 137 | 138 | np.testing.assert_array_equal(embedder_w.numpy(), allennlp_w.numpy()) 139 | assert embedder_w.dtype == allennlp_w.dtype 140 | # Diff on highway. 141 | for allennlp_param, embedder_param in zip(allennlp_parameters_diff, embedder_parameters_diff): 142 | allennlp_w = allennlp_embedder_named_parameters[allennlp_param].data 143 | embedder_w = embedder.named_parameters()[embedder_param].data 144 | 145 | assert embedder_w.dtype == allennlp_w.dtype 146 | np.testing.assert_raises( 147 | AssertionError, 148 | np.testing.assert_array_equal, 149 | embedder_w.numpy(), 150 | allennlp_w.numpy(), 151 | ) 152 | 153 | sentences = [ 154 | ['ELMo', 'helps', 'disambiguate', 'ELMo', 'from', 'Elmo', '.'], 155 | ['The', 'sentence', '.'], 156 | ] 157 | # `(2, 7, 50)` 158 | character_ids = _sentences_to_ids(sentences) 159 | 160 | # AllenNLP. 161 | out = allennlp_embedder(character_ids) 162 | allennlp_token_embedding, _ = remove_sentence_boundaries(out['token_embedding'], out['mask']) 163 | assert list(allennlp_token_embedding.shape) == [2, 7, 16] 164 | 165 | # Ours. 166 | inputs = pack_padded_sequence(character_ids, [7, 3], batch_first=True) 167 | out = embedder(inputs.data) 168 | ours_token_embedding = _unpack(out, inputs.batch_sizes) 169 | assert list(ours_token_embedding.shape) == [2, 7, 16] 170 | 171 | np.testing.assert_array_almost_equal( 172 | ours_token_embedding.data.numpy(), 173 | allennlp_token_embedding.data.numpy(), 174 | ) 175 | 176 | 177 | def test_elmo_lstm_factory_simple(): 178 | allennlp_elmo_bilm = _ElmoBiLm( 179 | ELMO_OPTIONS_FILE, 180 | ELMO_WEIGHT_FILE, 181 | ) 182 | 183 | embedder = ElmoCharacterEncoderFactory( 184 | ELMO_OPTIONS_FILE, 185 | ELMO_WEIGHT_FILE, 186 | ).create() 187 | fwd_lstm, bwd_lstm = ElmoLstmFactory( 188 | ELMO_OPTIONS_FILE, 189 | ELMO_WEIGHT_FILE, 190 | ).create( 191 | enable_forward=True, enable_backward=True) 192 | 193 | sentences_1 = [ 194 | ['ELMo', 'helps', 'disambiguate', 'ELMo', 'from', 'Elmo', '.'], 195 | ['The', 'sentence', '.'], 196 | ] 197 | sentences_2 = [ 198 | ["This", "is", "a", "sentence"], 199 | ["Here", "'s", "one"], 200 | ["Another", "one"], 201 | ] 202 | 203 | # Internal states should be updated. 204 | for sentences in [sentences_1, sentences_2] * 10: 205 | # `(2, 7, 50)` 206 | character_ids = _sentences_to_ids(sentences) 207 | 208 | # AllenNLP. 209 | allennlp_out = allennlp_elmo_bilm(character_ids) 210 | 211 | # Ours. 212 | inputs = character_ids 213 | _beginning_of_sentence_characters = torch.from_numpy( 214 | np.array(ELMoCharacterMapper.beginning_of_sentence_characters) + 1) 215 | _end_of_sentence_characters = torch.from_numpy( 216 | np.array(ELMoCharacterMapper.end_of_sentence_characters) + 1) 217 | # Add BOS/EOS 218 | mask = ((inputs > 0).long().sum(dim=-1) > 0).long() 219 | character_ids_with_bos_eos, mask_with_bos_eos = add_sentence_boundary_token_ids( 220 | inputs, 221 | mask, 222 | _beginning_of_sentence_characters, 223 | _end_of_sentence_characters, 224 | ) 225 | # Pack input. 226 | lengths = mask_with_bos_eos.sum(dim=-1) 227 | inputs = pack_padded_sequence(character_ids_with_bos_eos, lengths, batch_first=True) 228 | char_repr = embedder(inputs.data) 229 | fwd_lstm_hiddens, _ = fwd_lstm(char_repr, inputs.batch_sizes) 230 | bwd_lstm_hiddens, _ = bwd_lstm(char_repr, inputs.batch_sizes) 231 | lstm_hiddens = [ 232 | torch.cat([fwd, bwd], dim=-1) 233 | for fwd, bwd in zip(fwd_lstm_hiddens, bwd_lstm_hiddens) 234 | ] 235 | # Unpack output. 236 | char_repr = _unpack(char_repr, inputs.batch_sizes) 237 | duplicated_char_repr = torch.cat( 238 | [char_repr, char_repr], 239 | dim=-1, 240 | ) * mask_with_bos_eos.float().unsqueeze(-1) 241 | lstm_hiddens = [_unpack(hx, inputs.batch_sizes) for hx in lstm_hiddens] 242 | 243 | # TODO: Investigate the numerical stability issue. 244 | # np.testing.assert_array_almost_equal( 245 | # duplicated_char_repr.data.numpy(), 246 | # allennlp_out['activations'][0].data.numpy(), 247 | # ) 248 | # np.testing.assert_array_almost_equal( 249 | # lstm_hiddens[0].data.numpy(), 250 | # allennlp_out['activations'][1].data.numpy(), 251 | # ) 252 | np.testing.assert_array_almost_equal( 253 | lstm_hiddens[1].data.numpy(), 254 | allennlp_out['activations'][2].data.numpy(), 255 | ) 256 | 257 | 258 | def test_fast_elmo_with_allennlp(): 259 | fast = FastElmo( 260 | ELMO_OPTIONS_FILE, 261 | ELMO_WEIGHT_FILE, 262 | num_output_representations=2, 263 | scalar_mix_parameters=[1.0, 1.0, 1.0], 264 | ) 265 | 266 | allennlp = Elmo( 267 | ELMO_OPTIONS_FILE, 268 | ELMO_WEIGHT_FILE, 269 | num_output_representations=2, 270 | dropout=0.0, 271 | scalar_mix_parameters=[1.0, 1.0, 1.0], 272 | ) 273 | 274 | sentences_1 = [ 275 | ['ELMo', 'helps', 'disambiguate', 'ELMo', 'from', 'Elmo', '.'], 276 | ['The', 'sentence', '.'], 277 | ] 278 | sentences_2 = [ 279 | ["This", "is", "a", "sentence"], 280 | ["Here", "'s", "one"], 281 | ["Another", "one"], 282 | ] 283 | 284 | for sentences in [sentences_1, sentences_2] * 10: 285 | random.shuffle(sentences) 286 | character_ids = _sentences_to_ids(sentences) 287 | 288 | fast_out = fast(character_ids) 289 | allennlp_out = allennlp(character_ids) 290 | 291 | for repr_idx in range(2): 292 | fast_mixed_repr = fast_out['elmo_representations'][repr_idx] 293 | allennlp_mixed_repr = allennlp_out['elmo_representations'][repr_idx] 294 | np.testing.assert_array_almost_equal( 295 | fast_mixed_repr, 296 | allennlp_mixed_repr, 297 | ) 298 | 299 | np.testing.assert_array_equal( 300 | fast_out['mask'], 301 | allennlp_out['mask'], 302 | ) 303 | 304 | 305 | def test_fast_elmo_with_allennlp_do_layer_norm(): 306 | fast = FastElmo( 307 | ELMO_OPTIONS_FILE, 308 | ELMO_WEIGHT_FILE, 309 | num_output_representations=1, 310 | scalar_mix_parameters=[1.0, 1.0, 1.0], 311 | do_layer_norm=True, 312 | ) 313 | 314 | allennlp = Elmo( 315 | ELMO_OPTIONS_FILE, 316 | ELMO_WEIGHT_FILE, 317 | num_output_representations=1, 318 | dropout=0.0, 319 | scalar_mix_parameters=[1.0, 1.0, 1.0], 320 | do_layer_norm=True, 321 | ) 322 | 323 | sentences = [ 324 | ['ELMo', 'helps', 'disambiguate', 'ELMo', 'from', 'Elmo', '.'], 325 | ['The', 'sentence', '.'], 326 | ] 327 | character_ids = _sentences_to_ids(sentences) 328 | 329 | fast_out = fast(character_ids) 330 | allennlp_out = allennlp(character_ids) 331 | 332 | # Since we don't include the BOS/EOS reprs during layer normalization, 333 | # the result will be different from AllenNLP's implementation. 334 | np.testing.assert_raises( 335 | AssertionError, 336 | np.testing.assert_array_almost_equal, 337 | fast_out['elmo_representations'][0], 338 | allennlp_out['elmo_representations'][0], 339 | ) 340 | 341 | # We can pack BOS/EOS to inputs manually 342 | _beginning_of_sentence_characters = torch.from_numpy( 343 | np.array(ELMoCharacterMapper.beginning_of_sentence_characters) + 1) 344 | _end_of_sentence_characters = torch.from_numpy( 345 | np.array(ELMoCharacterMapper.end_of_sentence_characters) + 1) 346 | 347 | mask = ((character_ids > 0).long().sum(dim=-1) > 0).long() 348 | character_ids_with_bos_eos, mask_with_bos_eos = add_sentence_boundary_token_ids( 349 | character_ids, 350 | mask, 351 | _beginning_of_sentence_characters, 352 | _end_of_sentence_characters, 353 | ) 354 | 355 | # And disable the mock BOS/EOS actions in FastElmo. 356 | fast.exec_managed_lstm_bos_eos = False 357 | fast_out_2 = fast(character_ids_with_bos_eos) 358 | fast_mixed_repr_2, _ = remove_sentence_boundaries( 359 | fast_out_2['elmo_representations'][0], 360 | fast_out_2['mask'], 361 | ) 362 | 363 | allennlp_out_2 = allennlp(character_ids) 364 | 365 | np.testing.assert_array_almost_equal( 366 | fast_mixed_repr_2, 367 | allennlp_out_2['elmo_representations'][0], 368 | ) 369 | 370 | 371 | def test_fast_elmo_save_and_load(): 372 | fast_1 = FastElmo( 373 | ELMO_OPTIONS_FILE, 374 | ELMO_WEIGHT_FILE, 375 | ) 376 | 377 | # Change weight and save. 378 | fast_1.scalar_mix_0.scalar_parameters[0].data.fill_(42.) 379 | fast_1_state_dict = fast_1.state_dict() 380 | 381 | # Load. 382 | fast_2 = FastElmo( 383 | ELMO_OPTIONS_FILE, 384 | ELMO_WEIGHT_FILE, 385 | ) 386 | fast_2.load_state_dict(fast_1_state_dict) 387 | 388 | assert float(fast_2.scalar_mix_0.scalar_parameters[0]) == 42. 389 | # assert float(fast_2.scalar_mixes[0].named_parameters()['scalar_0']) == 42. 390 | 391 | 392 | def test_fast_elmo_word_embedding(): 393 | vocab = utils.load_vocab(CACHE_VOCAB_FILE) 394 | 395 | fast_char_cnn = FastElmo( 396 | ELMO_OPTIONS_FILE, 397 | ELMO_WEIGHT_FILE, 398 | ) 399 | 400 | fast_word_embd = FastElmoWordEmbedding( 401 | ELMO_OPTIONS_FILE, 402 | ELMO_WEIGHT_FILE, 403 | word_embedding_weight_file=CACHE_EMBD_FILE, 404 | ) 405 | 406 | # Test BOS/EOS & other words. 407 | words = vocab 408 | 409 | embd_repr = fast_word_embd(utils.batch_to_word_ids([words], utils.build_vocab2id(vocab))) 410 | char_cnn_repr = fast_char_cnn(utils.batch_to_char_ids([words])) 411 | 412 | np.testing.assert_array_almost_equal( 413 | embd_repr['elmo_representations'][0].data.numpy(), 414 | char_cnn_repr['elmo_representations'][0].data.numpy(), 415 | ) 416 | -------------------------------------------------------------------------------- /tests/test_scalar_mix.py: -------------------------------------------------------------------------------- 1 | """ 2 | All tests are copied from AllenNLP. 3 | """ 4 | import torch 5 | import pytest 6 | import numpy 7 | from pytorch_fast_elmo import ScalarMix 8 | 9 | 10 | def test_scalar_mix_can_run_forward(): 11 | mixture = ScalarMix(3) 12 | tensors = [torch.randn([3, 4, 5]) for _ in range(3)] 13 | for k in range(3): 14 | mixture.scalar_parameters[k].data[0] = 0.1 * (k + 1) 15 | mixture.gamma.data[0] = 0.5 16 | result = mixture(tensors) 17 | 18 | weights = [0.1, 0.2, 0.3] 19 | normed_weights = numpy.exp(weights) / numpy.sum(numpy.exp(weights)) 20 | expected_result = sum(normed_weights[k] * tensors[k].data.numpy() for k in range(3)) 21 | expected_result *= 0.5 22 | numpy.testing.assert_almost_equal(expected_result, result.data.numpy()) 23 | 24 | 25 | def test_scalar_mix_throws_error_on_incorrect_number_of_inputs(): 26 | mixture = ScalarMix(3) 27 | tensors = [torch.randn([3, 4, 5]) for _ in range(5)] 28 | with pytest.raises(ValueError): 29 | _ = mixture(tensors) 30 | 31 | 32 | def test_scalar_mix_throws_error_on_incorrect_initial_scalar_parameters_length(): 33 | with pytest.raises(ValueError): 34 | ScalarMix(3, initial_scalar_parameters=[0.0, 0.0]) 35 | 36 | 37 | def test_scalar_mix_trainable_with_initial_scalar_parameters(): 38 | initial_scalar_parameters = [1.0, 2.0, 3.0] 39 | mixture = ScalarMix(3, initial_scalar_parameters=initial_scalar_parameters, trainable=False) 40 | for i, scalar_mix_parameter in enumerate(mixture.scalar_parameters): 41 | assert scalar_mix_parameter.requires_grad is False 42 | assert scalar_mix_parameter.item() == initial_scalar_parameters[i] 43 | 44 | 45 | def test_scalar_mix_layer_norm(): 46 | mixture = ScalarMix(3, do_layer_norm='scalar_norm_reg') 47 | 48 | tensors = [torch.randn([3, 4, 5]) for _ in range(3)] 49 | numpy_mask = numpy.ones((3, 4), dtype='int32') 50 | numpy_mask[1, 2:] = 0 51 | mask = torch.from_numpy(numpy_mask) 52 | 53 | weights = [0.1, 0.2, 0.3] 54 | for k in range(3): 55 | mixture.scalar_parameters[k].data[0] = weights[k] 56 | mixture.gamma.data[0] = 0.5 57 | result = mixture(tensors, mask) 58 | 59 | normed_weights = numpy.exp(weights) / numpy.sum(numpy.exp(weights)) 60 | expected_result = numpy.zeros((3, 4, 5)) 61 | for k in range(3): 62 | mean = numpy.mean(tensors[k].data.numpy()[numpy_mask == 1]) 63 | std = numpy.std(tensors[k].data.numpy()[numpy_mask == 1]) 64 | normed_tensor = (tensors[k].data.numpy() - mean) / (std + 1E-12) 65 | expected_result += (normed_tensor * normed_weights[k]) 66 | expected_result *= 0.5 67 | 68 | numpy.testing.assert_almost_equal(expected_result, result.data.numpy(), decimal=6) 69 | 70 | 71 | ''' 72 | def access_scalar_parameters(mixture, k): 73 | return mixture.named_parameters()['scalar_' + str(k)] 74 | 75 | 76 | def access_gamma(mixture): 77 | return mixture.named_parameters()['gamma'] 78 | 79 | 80 | def test_scalar_mix_can_run_forward(): 81 | mixture = ScalarMix(3) 82 | tensors = [torch.randn([3, 4, 5]) for _ in range(3)] 83 | for k in range(3): 84 | access_scalar_parameters(mixture, k).data[0] = 0.1 * (k + 1) 85 | access_gamma(mixture).data[0] = 0.5 86 | result = mixture(tensors) 87 | 88 | weights = [0.1, 0.2, 0.3] 89 | normed_weights = numpy.exp(weights) / numpy.sum(numpy.exp(weights)) 90 | expected_result = sum(normed_weights[k] * tensors[k].data.numpy() for k in range(3)) 91 | expected_result *= 0.5 92 | numpy.testing.assert_almost_equal(expected_result, result.data.numpy()) 93 | 94 | 95 | def test_scalar_mix_throws_error_on_incorrect_number_of_inputs(): 96 | mixture = ScalarMix(3) 97 | tensors = [torch.randn([3, 4, 5]) for _ in range(5)] 98 | with pytest.raises(ValueError): 99 | _ = mixture(tensors) 100 | 101 | 102 | def test_scalar_mix_throws_error_on_incorrect_initial_scalar_parameters_length(): 103 | with pytest.raises(ValueError): 104 | ScalarMix(3, initial_scalar_parameters=[0.0, 0.0]) 105 | 106 | 107 | def test_scalar_mix_trainable_with_initial_scalar_parameters(): 108 | initial_scalar_parameters = [1.0, 2.0, 3.0] 109 | mixture = ScalarMix(3, initial_scalar_parameters=initial_scalar_parameters, trainable=False) 110 | for k, initial_scalar_parameter in enumerate(initial_scalar_parameters): 111 | scalar_mix_parameter = access_scalar_parameters(mixture, k) 112 | assert scalar_mix_parameter.requires_grad is False 113 | assert scalar_mix_parameter.item() == initial_scalar_parameter 114 | 115 | 116 | def test_scalar_mix_layer_norm(): 117 | mixture = ScalarMix(3, do_layer_norm=True) 118 | 119 | tensors = [torch.randn([3, 4, 5]) for _ in range(3)] 120 | numpy_mask = numpy.ones((3, 4), dtype='int32') 121 | numpy_mask[1, 2:] = 0 122 | mask = torch.from_numpy(numpy_mask) 123 | 124 | weights = [0.1, 0.2, 0.3] 125 | for k in range(3): 126 | access_scalar_parameters(mixture, k).data[0] = weights[k] 127 | access_gamma(mixture).data[0] = 0.5 128 | result = mixture(tensors, mask) 129 | 130 | normed_weights = numpy.exp(weights) / numpy.sum(numpy.exp(weights)) 131 | expected_result = numpy.zeros((3, 4, 5)) 132 | for k in range(3): 133 | mean = numpy.mean(tensors[k].data.numpy()[numpy_mask == 1]) 134 | std = numpy.std(tensors[k].data.numpy()[numpy_mask == 1]) 135 | normed_tensor = (tensors[k].data.numpy() - mean) / (std + 1E-12) 136 | expected_result += (normed_tensor * normed_weights[k]) 137 | expected_result *= 0.5 138 | 139 | numpy.testing.assert_almost_equal(expected_result, result.data.numpy(), decimal=6) 140 | 141 | 142 | def test_scalar_mix_layer_norm_packed_sequence(): 143 | mixture = ScalarMix(3, do_layer_norm=True) 144 | tensors = [torch.randn([3, 4]) for _ in range(3)] 145 | mixture(tensors) 146 | ''' 147 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, join 2 | 3 | import torch 4 | from allennlp.modules.elmo import batch_to_ids 5 | import numpy as np 6 | 7 | from pytorch_fast_elmo import utils 8 | from pytorch_fast_elmo import FastElmoBase 9 | 10 | FIXTURES_FODLER = join(dirname(__file__), 'fixtures') 11 | ELMO_OPTIONS_FILE = join(FIXTURES_FODLER, 'options.json') 12 | ELMO_WEIGHT_FILE = join(FIXTURES_FODLER, 'lm_weights.hdf5') 13 | 14 | 15 | def test_batch_to_char_ids(): 16 | sentences = [ 17 | ["This", "is", "a", "sentence"], 18 | ["Here", "'s", "one"], 19 | ["Another", "one"], 20 | ] 21 | t1 = utils.batch_to_char_ids(sentences) 22 | t2 = batch_to_ids(sentences) 23 | np.testing.assert_array_equal(t1.numpy(), t2.numpy()) 24 | 25 | sentences = [["one"]] 26 | t1 = utils.batch_to_char_ids(sentences) 27 | t2 = batch_to_ids(sentences) 28 | np.testing.assert_array_equal(t1.numpy(), t2.numpy()) 29 | 30 | 31 | def test_cache_char_cnn_vocab(tmpdir): 32 | vocab = ['', '', '', 'ELMo', 'helps', 'disambiguate', 'ELMo', 'from', 'Elmo', '.'] 33 | vocab_path = tmpdir.join("vocab.txt") 34 | vocab_path.write('\n'.join(vocab)) 35 | 36 | embedding_path = tmpdir.join("embd.txt") 37 | 38 | utils.cache_char_cnn_vocab( 39 | vocab_path.realpath(), 40 | ELMO_OPTIONS_FILE, 41 | ELMO_WEIGHT_FILE, 42 | str(embedding_path.realpath()), 43 | batch_size=2, 44 | max_characters_per_token=15, 45 | ) 46 | 47 | fast_word_embd = FastElmoBase( 48 | ELMO_OPTIONS_FILE, 49 | None, 50 | disable_word_embedding=False, 51 | word_embedding_weight_file=str(embedding_path.realpath()), 52 | # Disable all other components. 53 | disable_char_cnn=True, 54 | disable_forward_lstm=True, 55 | disable_backward_lstm=True, 56 | disable_scalar_mix=True, 57 | ) 58 | fast_char_cnn = FastElmoBase( 59 | ELMO_OPTIONS_FILE, 60 | ELMO_WEIGHT_FILE, 61 | # Disable all other components. 62 | disable_forward_lstm=True, 63 | disable_backward_lstm=True, 64 | disable_scalar_mix=True, 65 | ) 66 | 67 | embd_repr = fast_word_embd.exec_word_embedding( 68 | fast_word_embd.pack_inputs( 69 | utils.batch_to_word_ids( 70 | [['ELMo', 'helps', '!!!UNK!!!']], 71 | utils.load_and_build_vocab2id(vocab_path.realpath()), 72 | ))) 73 | char_cnn_repr = fast_char_cnn.exec_char_cnn( 74 | fast_char_cnn.pack_inputs( 75 | utils.batch_to_char_ids( 76 | [['ELMo', 'helps', '']], 77 | max_characters_per_token=15, 78 | ))) 79 | 80 | np.testing.assert_array_almost_equal( 81 | embd_repr.data.numpy(), 82 | char_cnn_repr.data.numpy(), 83 | ) 84 | np.testing.assert_array_equal(embd_repr.batch_sizes.numpy(), char_cnn_repr.batch_sizes.numpy()) 85 | 86 | 87 | def test_sort_batch_by_length(): 88 | for _ in range(100): 89 | batch = torch.randn(20, 40) 90 | lengths = torch.randint(0, 41, (20,), dtype=torch.long) 91 | mask = utils.generate_mask_from_lengths(20, 40, lengths) 92 | batch[mask != 1] = 0 93 | 94 | sorted_batch, _, restoration_index = utils.sort_batch_by_length(batch) 95 | np.testing.assert_array_equal( 96 | sorted_batch.index_select(0, restoration_index), 97 | batch, 98 | ) 99 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py36 3 | 4 | [travis] 5 | python = 6 | 3.6: py36 7 | 8 | [testenv] 9 | setenv = 10 | PYTHONPATH = {toxinidir} 11 | deps = 12 | -r{toxinidir}/requirements_dev.txt 13 | -r{toxinidir}/requirements_prod.txt 14 | 15 | commands = 16 | pip install -U pip 17 | py.test --basetemp={envtmpdir} 18 | 19 | pylint pytorch_fast_elmo 20 | yapf -d -r pytorch_fast_elmo 21 | mypy pytorch_fast_elmo --strict --ignore-missing-imports 22 | --------------------------------------------------------------------------------