├── .gitignore ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── config ├── cnndm │ ├── transformer_cnndm_baseline.yml │ ├── transformer_cnndm_ctxattn.yml │ └── transformer_cnndm_psa.yml ├── imdb │ ├── transformer_imdb_cond.yml │ ├── transformer_imdb_ctxattn.yml │ └── transformer_imdb_psa.yml └── story_gen │ ├── transformer_story_baseline.yml │ ├── transformer_story_ctxattn.yml │ └── transformer_story_psa.yml ├── docs ├── Makefile ├── requirements.txt └── source │ ├── CONTRIBUTING.md │ ├── FAQ.md │ ├── Library.ipynb │ ├── Library.md │ ├── Summarization.md │ ├── _static │ └── theme_overrides.css │ ├── conf.py │ ├── examples.rst │ ├── extended.md │ ├── im2text.md │ ├── index.md │ ├── index.rst │ ├── main.md │ ├── modules.rst │ ├── onmt.inputters.rst │ ├── onmt.modules.rst │ ├── onmt.rst │ ├── onmt.translate.translation_server.rst │ ├── onmt.translation.rst │ ├── options │ ├── preprocess.rst │ ├── server.rst │ ├── train.rst │ └── translate.rst │ ├── quickstart.md │ ├── ref.rst │ ├── refs.bib │ └── speech2text.md ├── gpt2 ├── .gitignore ├── decode_text.py ├── download_model.py ├── encode_text.py └── vocab.txt ├── onmt ├── __init__.py ├── decoders │ ├── __init__.py │ ├── cnn_decoder.py │ ├── decoder.py │ ├── ensemble.py │ ├── rnn_uncond.py │ └── transformer.py ├── encoders │ ├── __init__.py │ ├── audio_encoder.py │ ├── cnn_encoder.py │ ├── embonly.py │ ├── encoder.py │ ├── image_encoder.py │ ├── imgvec_encoder.py │ ├── mean_encoder.py │ ├── rnn_encoder.py │ └── transformer.py ├── inputters │ ├── __init__.py │ ├── audio_dataset.py │ ├── datareader_base.py │ ├── dataset_base.py │ ├── image_dataset.py │ ├── image_vec_dataset.py │ ├── inputter.py │ ├── none_dataset.py │ └── text_dataset.py ├── model_builder.py ├── models │ ├── __init__.py │ ├── model.py │ ├── model_saver.py │ ├── simple_fusion_model.py │ ├── sru.py │ ├── stacked_rnn.py │ └── uncond_model.py ├── modules │ ├── __init__.py │ ├── average_attn.py │ ├── conv_multi_step_attention.py │ ├── copy_generator.py │ ├── embeddings.py │ ├── gate.py │ ├── global_attention.py │ ├── gpt_mlp.py │ ├── multi_headed_attn.py │ ├── position_ffn.py │ ├── simple_fusion_generator.py │ ├── sparse_activations.py │ ├── sparse_losses.py │ ├── structured_attention.py │ ├── util_class.py │ └── weight_norm.py ├── opts.py ├── tests │ ├── __init__.py │ ├── output_hyp.txt │ ├── pull_request_chk.sh │ ├── rebuild_test_models.sh │ ├── sample_glove.txt │ ├── test_attention.py │ ├── test_audio_dataset.py │ ├── test_beam.py │ ├── test_beam_search.py │ ├── test_copy_generator.py │ ├── test_embeddings.py │ ├── test_image_dataset.py │ ├── test_models.py │ ├── test_models.sh │ ├── test_preprocess.py │ ├── test_random_sampling.py │ ├── test_simple.py │ ├── test_structured_attention.py │ ├── test_text_dataset.py │ ├── test_translation_server.py │ └── utils_for_tests.py ├── train_single.py ├── trainer.py ├── translate │ ├── __init__.py │ ├── beam.py │ ├── beam_search.py │ ├── decode_strategy.py │ ├── penalties.py │ ├── random_sampling.py │ ├── translation.py │ ├── translation_server.py │ └── translator.py └── utils │ ├── __init__.py │ ├── cnn_factory.py │ ├── distributed.py │ ├── logging.py │ ├── loss.py │ ├── misc.py │ ├── optimizers.py │ ├── parse.py │ ├── report_manager.py │ ├── rnn_factory.py │ └── statistics.py ├── preprocess.py ├── requirements.txt ├── setup.py ├── train.py └── translate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # repo-specific stuff 2 | pred.txt 3 | multi-bleu.perl 4 | *.pt 5 | \#*# 6 | *~ 7 | .idea 8 | *.sublime-* 9 | .DS_Store 10 | data/ 11 | data.tar.gz 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | .hypothesis/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | 113 | # Tensorboard 114 | runs/ 115 | 116 | output 117 | models/**/*.pt 118 | models/**/*.txt 119 | *.png 120 | *.pem 121 | *.pkl 122 | *.h5 123 | *.zip 124 | *.bak 125 | 126 | allennlp_test_output_bpe* 127 | 128 | tmp.txt 129 | temp.txt 130 | 131 | bottom-up-summary 132 | slurm 133 | gpt2/models 134 | captioning_tmp 135 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributors 2 | 3 | OpenNMT-py is a community developed project and we love developer contributions. 4 | 5 | ## Guidelines 6 | Before sending a PR, please do this checklist first: 7 | 8 | - Please run `tools/pull_request_chk.sh` and fix any errors. When adding new functionality, also add tests to this script. Included checks: 9 | 1. flake8 check for coding style; 10 | 2. unittest; 11 | 3. continuous integration tests listed in `.travis.yml`. 12 | - When adding/modifying class constructor, please make the arguments as same naming style as its superclass in PyTorch. 13 | - If your change is based on a paper, please include a clear comment and reference in the code (more on that below). 14 | 15 | ### Docstrings 16 | Above all, try to follow the Google docstring format 17 | ([Napoleon example](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html), 18 | [Google styleguide](http://google.github.io/styleguide/pyguide.html)). 19 | This makes it easy to include your contributions in the Sphinx documentation. And, do feel free 20 | to autodoc your contributions in the API ``.rst`` files in the `docs/source` folder! If you do, check that 21 | your additions look right. 22 | 23 | ```bash 24 | cd docs 25 | # install some dependencies if necessary: 26 | # recommonmark, sphinx_rtd_theme, sphinxcontrib-bibtex 27 | make html 28 | firefox build/html/main.html # or your browser of choice 29 | ``` 30 | 31 | Some particular advice: 32 | - Try to follow Python 3 [``typing`` module](https://docs.python.org/3/library/typing.html) conventions when documenting types. 33 | - Exception: use "or" instead of unions for more readability 34 | - For external types, use the full "import name". Common abbreviations (e.g. ``np``) are acceptable. 35 | For ``torch.Tensor`` types, the ``torch.`` is optional. 36 | - Please don't use tics like `` (`str`) `` or rst directives like `` (:obj:`str`) ``. Napoleon handles types 37 | very well without additional help, so avoid the clutter. 38 | - [Google docstrings don't support multiple returns](https://stackoverflow.com/questions/29221551/can-sphinx-napoleon-document-function-returning-multiple-arguments). 39 | For multiple returns, the following works well with Sphinx and is still very readable. 40 | ```python 41 | def foo(a, b): 42 | """This is my docstring. 43 | 44 | Args: 45 | a (object): Something. 46 | b (class): Another thing. 47 | 48 | Returns: 49 | (object, class): 50 | 51 | * a: Something or rather with a long 52 | description that spills over. 53 | * b: And another thing. 54 | """ 55 | 56 | return a, b 57 | ``` 58 | - When citing a paper, avoid directly linking in the docstring! Add a Bibtex entry to `docs/source/refs.bib`. 59 | E.g., to cite "Attention Is All You Need", visit [arXiv](https://arxiv.org/abs/1706.03762), choose the 60 | [bibtext](https://dblp.uni-trier.de/rec/bibtex/journals/corr/VaswaniSPUJGKP17) link, search `docs/source/refs.bib` 61 | using `CTRL-F` for `DBLP:journals/corr/VaswaniSPUJGKP17`, and if you do not find it then copy-paste the 62 | citation into `refs.bib`. Then, in your docstring, use ``:cite:`DBLP:journals/corr/VaswaniSPUJGKP17` ``. 63 | - However, a link is better than nothing. 64 | - Please document tensor shapes. Prefer the format 65 | ``` ``(a, b, c)`` ```. This style is easy to read, allows using ``x`` for multplication, and is common 66 | (PyTorch uses a few variations on the parentheses format, AllenNLP uses exactly this format, Fairseq uses 67 | the parentheses format with single ticks). 68 | - Again, a different style is better than no shape documentation. 69 | - Please avoid unnecessary space characters, try to capitalize, and try to punctuate. 70 | 71 | For multi-line docstrings, add a blank line after the closing ``"""``. 72 | Don't use a blank line before the closing quotes. 73 | 74 | ``""" not this """`` ``"""This."""`` 75 | 76 | ```python 77 | """ 78 | Not this. 79 | """ 80 | ``` 81 | ```python 82 | """This.""" 83 | ``` 84 | 85 | This note is the least important. Focus on content first, but remember that consistent docs look good. 86 | - Be sensible about the first line. Generally, one stand-alone summary line (per the Google guidelines) is good. 87 | Sometimes, it's better to cut directly to the args or an extended description. It's always acceptable to have a 88 | "trailing" citation. -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017-Present OpenNMT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Encoder-Agnostic Adaptation for Conditional Language Generation 2 | 3 | This repo contains the code used in [Encoder-Agnostic Adaptation for Conditional Language Generation](https://arxiv.org/abs/1908.06938), Zachary M. Ziegler, Luke Melas-Kyriazi, Sebastian Gehrmann and Alexander M. Rush. It extends [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py). 4 | 5 | This code was tested with `pytorch 1.0.1`. See requirements.txt for a complete list of dependencies. 6 | 7 | ## Download GPT2 weights 8 | 9 | `cd gpt2 && python download_model.py 124M` 10 | 11 | ## General notes 12 | 13 | All experiments use gradient accumulation to mimic the large batch sizes these hyperparameter settings were optimized for by e.g. Facebook. If you run into GPU memory issues simply reduce the batch size and increase the `accum_count` to keep the effective batch size the same. 14 | 15 | ## Data 16 | 17 | The BPEized data used in the experiments in the paper can be found [here](https://drive.google.com/file/d/1Z6AdOr2MtWlN7sYRTMibzAcghBjSBzZK/view?usp=sharing). To run any of these models with your own data you should first BPEize it with `python gpt2/encode_text.py `. Before training the raw data is preprocessed into binary data shards with the commands below. 18 | 19 | ## Class-conditional generation 20 | 21 | ### Preprocess 22 | 23 | `python preprocess.py -train_src data/imdb/train.src.bpe -train_tgt data/imdb/train.tgt.bpe -valid_src data/imdb/valid.src.bpe -valid_tgt data/imdb/valid.tgt.bpe -save_data data/imdb/IMDB_BPETGT -tgt_seq_length_trunc 400 -tgt_vocab gpt2/vocab.txt -fixed_vocab -free_src` 24 | 25 | ### Train 26 | **Baseline**: `python train.py -config config/imdb/transformer_imdb_cond.yml -run_name baseline` 27 | 28 | **Simple fusion**: `python train.py -config config/imdb/transformer_imdb_cond.yml -run_name simple_fusion -gpt2_params_path gpt2/models/124M/ -simple_fusion -dropout 0.1 -accum_count 30 -batch_size 1000 -valid_batch_size 16` 29 | 30 | **Repr-transformer**: `python train.py -config config/imdb/transformer_imdb_cond.yml -run_name repr_trans -GPT_representation_loc tgt -GPT_representation_mode elmo -gpt2_params_path gpt2/models/124M/ -position_encoding_learned_dec -word_vec_size 768 -rnn_size 768` 31 | 32 | **Context attention**: `python train.py -config config/imdb/transformer_imdb_ctxattn.yml -run_name ctxattn -gpt2_params_path gpt2/models/124M/ -gpt2_init_embanddec` 33 | 34 | **Pseudo self attention**: `python train.py -config config/imdb/transformer_imdb_psa.yml -run_name psa -gpt2_params_path gpt2/models/124M/ -gpt2_init_embanddec` 35 | 36 | ### Generation 37 | 38 | Generation is performed via random sampling. 39 | 40 | `python translate.py -beam_size 1 -random_sampling_topk -1 -random_sampling_temp 0.7 -model -src data/imdb/test.src.bpe -min_length 1 -max_length 400 -verbose` 41 | 42 | ## Summarization 43 | 44 | ### Preprocess 45 | 46 | `python preprocess.py -train_src data/cnndm/train.txt.src.bpe -train_tgt data/cnndm/train.txt.tgt.bpe -valid_src data/cnndm/val.txt.src.bpe -valid_tgt data/cnndm/val.txt.tgt.bpe -save_data data/cnndm/CNNDM_BPE_COPY -src_seq_length_trunc 400 -tgt_seq_length_trunc 100 -src_vocab gpt2/vocab.txt -tgt_vocab gpt2/vocab.txt -dynamic_dict -fixed_vocab` 47 | 48 | ### Train 49 | The default settings use 4 GPUs (see config files). If using more GPUs or fewer GPUs, modify the `world_size` and `gpu_ranks` values in the config file and adjust `accum_count` so the effective batch size remains the same. 50 | 51 | **Baseline**: `python train.py -config config/cnndm/transformer_cnndm_baseline.yml -run_name baseline` 52 | 53 | **Repr-transformer**: `python train.py -config config/cnndm/transformer_cnndm_baseline.yml -run_name repr_trans -GPT_representation_loc tgt -GPT_representation_mode elmo -gpt2_params_path gpt2/models/124M/ -position_encoding_learned_dec -word_vec_size 768 -rnn_size 768 -train_steps 50000` 54 | 55 | **Context attention**: `python train.py -config config/cnndm/transformer_cnndm_ctxattn.yml -run_name ctxattn -gpt2_params_path gpt2/models/124M/ -gpt2_init_embanddec` 56 | 57 | **Pseudo self attention**: `python train.py -config config/cnndm/transformer_cnndm_psa.yml -run_name psa -gpt2_params_path gpt2/models/124M/ -gpt2_init_embanddec` 58 | 59 | ### Generation 60 | 61 | Generation is performed via beam search. 62 | 63 | `python translate.py -beam_size 5 -model -src data/cnndm/test.txt.src.bpe -min_length 60 -verbose -block_ngram_repeat 3` 64 | 65 | ## Story generation 66 | The default settings use 4 GPUs (see config files). If using more GPUs or fewer GPUs, modify the `world_size` and `gpu_ranks` values in the config file and adjust `accum_count` so the effective batch size remains the same. 67 | 68 | ### Preprocess 69 | 70 | `python preprocess.py -train_src data/stories/train.wp_source.bpe -train_tgt data/stories/train.wp_target.bpe -valid_src data/stories/valid.wp_source.bpe -valid_tgt data/stories/valid.wp_target.bpe -save_data data/stories/STORIES_BPE -src_vocab gpt2/vocab.txt -tgt_vocab gpt2/vocab.txt -fixed_vocab` 71 | 72 | ### Train 73 | **Baseline**: `python train.py -config config/story_gen/transformer_story_baseline.yml -run_name baseline` 74 | 75 | **Repr-transformer**: `python train.py -config config/story_gen/transformer_story_baseline.yml -run_name repr_trans -GPT_representation_loc tgt -GPT_representation_mode elmo -gpt2_params_path gpt2/models/124M/ -position_encoding_learned_dec -word_vec_size 768 -rnn_size 768` 76 | 77 | **Context attention**: `python train.py -config config/story_gen/transformer_story_ctxattn.yml -run_name ctxattn -gpt2_params_path gpt2/models/124M/ -gpt2_init_embanddec` 78 | 79 | **Pseudo self attention**: `python train.py -config config/story_gen/transformer_story_psa.yml -run_name psa -gpt2_params_path gpt2/models/124M/ -gpt2_init_embanddec` 80 | 81 | ### Generation 82 | 83 | Generation is performed via top-k/random sampling. 84 | 85 | `python translate.py -beam_size 1 -random_sampling_topk 100 -random_sampling_temp 0.9 -model -src data/stories/test.wp_source.bpe -max_length 1000 -verbose` 86 | 87 | ## Image captioning 88 | 89 | Coming soon... 90 | -------------------------------------------------------------------------------- /config/cnndm/transformer_cnndm_baseline.yml: -------------------------------------------------------------------------------- 1 | data: data/cnndm/CNNDM_BPE_COPY 2 | save_checkpoint_steps: 2000 3 | keep_checkpoint: 10 4 | seed: 123 5 | warmup_steps: 4000 6 | train_steps: 30000 7 | valid_steps: 300 8 | report_every: 50 9 | 10 | decoder_type: transformer 11 | encoder_type: transformer 12 | word_vec_size: 512 13 | rnn_size: 512 14 | enc_layers: 6 15 | dec_layers: 6 16 | transformer_ff: 2048 17 | heads: 8 18 | 19 | accum_count: 40 20 | optim: adam 21 | adam_beta1: 0.9 22 | adam_beta2: 0.98 23 | decay_method: invsq 24 | learning_rate: 0.0005 25 | warmup_init_factor: 5000 26 | max_grad_norm: 0.0 27 | 28 | batch_size: 2800 29 | batch_type: tokens 30 | normalization: tokens 31 | dropout: 0.3 32 | attn_dropout: 0.2 33 | label_smoothing: 0.1 34 | 35 | max_generator_batches: 0 36 | 37 | param_init: 0.0 38 | param_init_glorot: 'true' 39 | position_encoding: 'true' 40 | position_encoding_ctxsize: 1024 41 | share_decoder_embeddings: 'true' 42 | share_embeddings: 'true' 43 | 44 | copy_attn: 'true' 45 | 46 | world_size: 4 47 | gpu_ranks: 48 | - 0 49 | - 1 50 | - 2 51 | - 3 52 | tensorboard: 'true' 53 | -------------------------------------------------------------------------------- /config/cnndm/transformer_cnndm_ctxattn.yml: -------------------------------------------------------------------------------- 1 | data: data/cnndm/CNNDM_BPE_COPY 2 | save_checkpoint_steps: 2000 3 | keep_checkpoint: 2 4 | seed: 123 5 | warmup_steps: 7000 6 | train_steps: 70000 7 | valid_steps: 500 8 | report_every: 100 9 | 10 | decoder_type: transformer 11 | encoder_type: transformer 12 | word_vec_size: 768 13 | rnn_size: 768 14 | enc_layers: 4 15 | dec_layers: 12 16 | transformer_ff: 3072 17 | heads: 12 18 | use_GPT_version_ctxattn: 'true' 19 | 20 | accum_count: 7 21 | optim: adam 22 | adam_beta1: 0.9 23 | adam_beta2: 0.998 24 | decay_method: stlr 25 | learning_rate: 1e-3 26 | max_grad_norm: 0.0 27 | disc_ft: 1.2 28 | dec_lr_factor: 3 29 | 30 | batch_size: 2800 31 | batch_type: tokens 32 | normalization: tokens 33 | dropout: 0.2 34 | label_smoothing: 0.1 35 | 36 | max_generator_batches: 0 37 | 38 | param_init: 0.0 39 | param_init_glorot: 'true' 40 | position_encoding: 'true' 41 | position_encoding_learned: 'true' 42 | position_encoding_ctxsize: 1024 43 | share_decoder_embeddings: 'true' 44 | share_position_embeddings: 'true' 45 | share_embeddings: 'true' 46 | 47 | copy_attn: 'true' 48 | 49 | world_size: 4 50 | gpu_ranks: 51 | - 0 52 | - 1 53 | - 2 54 | - 3 55 | tensorboard: 'true' 56 | -------------------------------------------------------------------------------- /config/cnndm/transformer_cnndm_psa.yml: -------------------------------------------------------------------------------- 1 | data: data/cnndm/CNNDM_BPE_COPY 2 | save_checkpoint_steps: 2000 3 | keep_checkpoint: 2 4 | seed: 123 5 | warmup_steps: 7000 6 | train_steps: 70000 7 | valid_steps: 500 8 | report_every: 100 9 | 10 | decoder_type: transformer 11 | encoder_type: transformer 12 | word_vec_size: 768 13 | rnn_size: 768 14 | enc_layers: 4 15 | dec_layers: 12 16 | transformer_ff: 3072 17 | heads: 12 18 | use_GPT_version_psa: 'true' 19 | 20 | accum_count: 7 21 | optim: adam 22 | adam_beta1: 0.9 23 | adam_beta2: 0.998 24 | decay_method: stlr 25 | learning_rate: 1e-3 26 | max_grad_norm: 0.0 27 | disc_ft: 1.2 28 | dec_lr_factor: 3 29 | 30 | batch_size: 2800 31 | batch_type: tokens 32 | normalization: tokens 33 | dropout: 0.2 34 | label_smoothing: 0.1 35 | 36 | max_generator_batches: 0 37 | 38 | param_init: 0.0 39 | param_init_glorot: 'true' 40 | position_encoding: 'true' 41 | position_encoding_learned: 'true' 42 | position_encoding_ctxsize: 1024 43 | share_decoder_embeddings: 'true' 44 | share_position_embeddings: 'true' 45 | share_embeddings: 'true' 46 | 47 | copy_attn: 'true' 48 | 49 | world_size: 4 50 | gpu_ranks: 51 | - 0 52 | - 1 53 | - 2 54 | - 3 55 | tensorboard: 'true' 56 | -------------------------------------------------------------------------------- /config/imdb/transformer_imdb_cond.yml: -------------------------------------------------------------------------------- 1 | data: data/imdb/IMDB_BPETGT 2 | save_checkpoint_steps: 1000 3 | keep_checkpoint: 3 4 | seed: 123 5 | warmup_steps: 1000 6 | train_steps: 10000 7 | valid_steps: 40 8 | report_every: 40 9 | 10 | decoder_type: transformer 11 | encoder_type: mean 12 | word_vec_size: 512 13 | rnn_size: 512 14 | enc_layers: 1 15 | dec_layers: 4 16 | transformer_ff: 2048 17 | heads: 8 18 | 19 | accum_count: 16 20 | optim: adam 21 | adam_beta1: 0.9 22 | adam_beta2: 0.998 23 | decay_method: noam 24 | learning_rate: 2.0 25 | max_grad_norm: 0.0 26 | 27 | batch_size: 2200 28 | batch_type: tokens 29 | normalization: tokens 30 | dropout: 0.2 31 | label_smoothing: 0.1 32 | 33 | max_generator_batches: 2 34 | 35 | param_init: 0.0 36 | param_init_glorot: 'true' 37 | position_encoding: 'true' 38 | position_encoding_ctxsize: 1024 39 | share_decoder_embeddings: 'true' 40 | 41 | world_size: 1 42 | gpu_ranks: 43 | - 0 44 | tensorboard: 'true' 45 | -------------------------------------------------------------------------------- /config/imdb/transformer_imdb_ctxattn.yml: -------------------------------------------------------------------------------- 1 | data: data/imdb/IMDB_BPETGT 2 | save_checkpoint_steps: 500 3 | keep_checkpoint: 3 4 | seed: 123 5 | warmup_steps: 500 6 | train_steps: 5000 7 | valid_steps: 40 8 | report_every: 40 9 | 10 | decoder_type: transformer 11 | encoder_type: mean 12 | word_vec_size: 768 13 | rnn_size: 768 14 | enc_layers: 1 15 | dec_layers: 12 16 | transformer_ff: 3072 17 | heads: 12 18 | use_GPT_version_ctxattn: 'true' 19 | 20 | accum_count: 18 21 | optim: adam 22 | adam_beta1: 0.9 23 | adam_beta2: 0.998 24 | decay_method: stlr 25 | learning_rate: 1e-3 26 | max_grad_norm: 0.0 27 | disc_ft: 1.3 28 | 29 | batch_size: 2200 30 | batch_type: tokens 31 | normalization: tokens 32 | dropout: 0.1 33 | label_smoothing: 0.1 34 | 35 | max_generator_batches: 2 36 | 37 | param_init: 0.0 38 | param_init_glorot: 'true' 39 | position_encoding: 'true' 40 | position_encoding_learned_dec: 'true' 41 | position_encoding_ctxsize: 1024 42 | share_decoder_embeddings: 'true' 43 | 44 | world_size: 1 45 | gpu_ranks: 46 | - 0 47 | tensorboard: 'true' 48 | -------------------------------------------------------------------------------- /config/imdb/transformer_imdb_psa.yml: -------------------------------------------------------------------------------- 1 | data: data/imdb/IMDB_BPETGT 2 | save_checkpoint_steps: 500 3 | keep_checkpoint: 3 4 | seed: 123 5 | warmup_steps: 250 6 | train_steps: 2500 7 | valid_steps: 40 8 | report_every: 40 9 | 10 | decoder_type: transformer 11 | encoder_type: mean 12 | word_vec_size: 768 13 | rnn_size: 768 14 | enc_layers: 1 15 | dec_layers: 12 16 | transformer_ff: 3072 17 | heads: 12 18 | use_GPT_version_psa: 'true' 19 | 20 | accum_count: 18 21 | optim: adam 22 | adam_beta1: 0.9 23 | adam_beta2: 0.998 24 | decay_method: stlr 25 | learning_rate: 1e-3 26 | max_grad_norm: 0.0 27 | disc_ft: 1.3 28 | 29 | batch_size: 2200 30 | batch_type: tokens 31 | normalization: tokens 32 | dropout: 0.1 33 | label_smoothing: 0.1 34 | 35 | max_generator_batches: 2 36 | 37 | param_init: 0.0 38 | param_init_glorot: 'true' 39 | position_encoding: 'true' 40 | position_encoding_learned_dec: 'true' 41 | position_encoding_ctxsize: 1024 42 | share_decoder_embeddings: 'true' 43 | 44 | world_size: 1 45 | gpu_ranks: 46 | - 0 47 | tensorboard: 'true' 48 | -------------------------------------------------------------------------------- /config/story_gen/transformer_story_baseline.yml: -------------------------------------------------------------------------------- 1 | data: data/stories/STORIES_BPE 2 | save_checkpoint_steps: 2000 3 | keep_checkpoint: 10 4 | seed: 123 5 | warmup_steps: 4000 6 | train_steps: 60000 7 | valid_steps: 500 8 | report_every: 50 9 | 10 | decoder_type: transformer 11 | encoder_type: transformer 12 | word_vec_size: 512 13 | rnn_size: 512 14 | enc_layers: 6 15 | dec_layers: 6 16 | transformer_ff: 2048 17 | heads: 8 18 | 19 | accum_count: 43 20 | optim: adam 21 | adam_beta1: 0.9 22 | adam_beta2: 0.98 23 | decay_method: invsq 24 | learning_rate: 0.0005 25 | warmup_init_factor: 5000 26 | max_grad_norm: 0.0 27 | 28 | batch_size: 2600 29 | valid_batch_size: 4 30 | batch_type: tokens 31 | normalization: tokens 32 | dropout: 0.3 33 | attn_dropout: 0.2 34 | label_smoothing: 0.1 35 | 36 | max_generator_batches: 0 37 | force_bs1: 'true' 38 | 39 | param_init: 0.0 40 | param_init_glorot: 'true' 41 | position_encoding: 'true' 42 | position_encoding_ctxsize: 1024 43 | share_decoder_embeddings: 'true' 44 | share_embeddings: 'true' # This is not quite the same, but probably should only have positive effect? 45 | 46 | world_size: 4 47 | gpu_ranks: 48 | - 0 49 | - 1 50 | - 2 51 | - 3 52 | tensorboard: 'true' 53 | -------------------------------------------------------------------------------- /config/story_gen/transformer_story_ctxattn.yml: -------------------------------------------------------------------------------- 1 | data: data/stories/STORIES_BPE 2 | save_checkpoint_steps: 2000 3 | keep_checkpoint: 10 4 | seed: 123 5 | warmup_steps: 8000 6 | train_steps: 100000 7 | valid_steps: 800 8 | report_every: 100 9 | 10 | decoder_type: transformer 11 | encoder_type: transformer 12 | word_vec_size: 768 13 | rnn_size: 768 14 | enc_layers: 6 15 | dec_layers: 12 16 | transformer_ff: 3072 17 | enc_heads: 8 18 | dec_heads: 12 19 | use_GPT_version_ctxattn: 'true' 20 | 21 | accum_count: 14 22 | optim: adam 23 | adam_beta1: 0.9 24 | adam_beta2: 0.98 25 | decay_method: invsq 26 | learning_rate: 0.0005 27 | warmup_init_factor: 5000 28 | max_grad_norm: 0.0 29 | disc_ft: 1.2 30 | dec_lr_factor: 3 31 | 32 | batch_size: 2048 33 | valid_batch_size: 4 34 | batch_type: tokens 35 | normalization: tokens 36 | dropout: 0.25 37 | attn_dropout: 0.2 38 | label_smoothing: 0.1 39 | 40 | max_generator_batches: 0 41 | force_bs1: 'true' 42 | 43 | param_init: 0.0 44 | param_init_glorot: 'true' 45 | position_encoding: 'true' 46 | position_encoding_learned: 'true' 47 | position_encoding_ctxsize: 1024 48 | share_decoder_embeddings: 'true' 49 | share_embeddings: 'true' # This is not quite the same, but probably should only have positive effect? 50 | share_position_embeddings: 'true' 51 | 52 | world_size: 4 53 | gpu_ranks: 54 | - 0 55 | - 1 56 | - 2 57 | - 3 58 | tensorboard: 'true' 59 | -------------------------------------------------------------------------------- /config/story_gen/transformer_story_psa.yml: -------------------------------------------------------------------------------- 1 | data: data/stories/STORIES_BPE 2 | save_checkpoint_steps: 2000 3 | keep_checkpoint: 10 4 | seed: 123 5 | warmup_steps: 8000 6 | train_steps: 100000 7 | valid_steps: 800 8 | report_every: 100 9 | 10 | decoder_type: transformer 11 | encoder_type: transformer 12 | word_vec_size: 768 13 | rnn_size: 768 14 | enc_layers: 6 15 | dec_layers: 12 16 | transformer_ff: 3072 17 | enc_heads: 8 18 | dec_heads: 12 19 | use_GPT_version_psa: 'true' 20 | 21 | accum_count: 14 22 | optim: adam 23 | adam_beta1: 0.9 24 | adam_beta2: 0.98 25 | decay_method: invsq 26 | learning_rate: 0.0005 27 | warmup_init_factor: 5000 28 | max_grad_norm: 0.0 29 | dec_lr_factor: 1.5 30 | disc_ft: 1.1 31 | 32 | batch_size: 2048 33 | valid_batch_size: 4 34 | batch_type: tokens 35 | normalization: tokens 36 | dropout: 0.25 37 | attn_dropout: 0.2 38 | label_smoothing: 0.1 39 | 40 | max_generator_batches: 0 41 | force_bs1: 'true' 42 | 43 | param_init: 0.0 44 | param_init_glorot: 'true' 45 | position_encoding: 'true' 46 | position_encoding_learned: 'true' 47 | position_encoding_ctxsize: 1024 48 | share_decoder_embeddings: 'true' 49 | share_embeddings: 'true' # This is not quite the same, but probably should only have positive effect? 50 | share_position_embeddings: 'true' 51 | 52 | world_size: 4 53 | gpu_ranks: 54 | - 0 55 | - 1 56 | - 2 57 | - 3 58 | tensorboard: 'true' 59 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python3 -msphinx 7 | SPHINXPROJ = OpenNMT-py 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinxcontrib.bibtex 3 | sphinxcontrib.mermaid 4 | sphinx-rtd-theme 5 | recommonmark 6 | sphinx-argparse -------------------------------------------------------------------------------- /docs/source/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributors 2 | 3 | OpenNMT-py is a community developed project and we love developer contributions. 4 | 5 | ## Guidelines 6 | Before sending a PR, please do this checklist first: 7 | 8 | - Please run `tools/pull_request_chk.sh` and fix any errors. When adding new functionality, also add tests to this script. Included checks: 9 | 1. flake8 check for coding style; 10 | 2. unittest; 11 | 3. continuous integration tests listed in `.travis.yml`. 12 | - When adding/modifying class constructor, please make the arguments as same naming style as its superclass in PyTorch. 13 | - If your change is based on a paper, please include a clear comment and reference in the code (more on that below). 14 | 15 | ### Docstrings 16 | Above all, try to follow the Google docstring format 17 | ([Napoleon example](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html), 18 | [Google styleguide](http://google.github.io/styleguide/pyguide.html)). 19 | This makes it easy to include your contributions in the Sphinx documentation. And, do feel free 20 | to autodoc your contributions in the API ``.rst`` files in the `docs/source` folder! If you do, check that 21 | your additions look right. 22 | 23 | ```bash 24 | cd docs 25 | # install some dependencies if necessary: 26 | # recommonmark, sphinx_rtd_theme, sphinxcontrib-bibtex 27 | make html 28 | firefox build/html/main.html # or your browser of choice 29 | ``` 30 | 31 | Some particular advice: 32 | - Try to follow Python 3 [``typing`` module](https://docs.python.org/3/library/typing.html) conventions when documenting types. 33 | - Exception: use "or" instead of unions for more readability 34 | - For external types, use the full "import name". Common abbreviations (e.g. ``np``) are acceptable. 35 | For ``torch.Tensor`` types, the ``torch.`` is optional. 36 | - Please don't use tics like `` (`str`) `` or rst directives like `` (:obj:`str`) ``. Napoleon handles types 37 | very well without additional help, so avoid the clutter. 38 | - [Google docstrings don't support multiple returns](https://stackoverflow.com/questions/29221551/can-sphinx-napoleon-document-function-returning-multiple-arguments). 39 | For multiple returns, the following works well with Sphinx and is still very readable. 40 | ```python 41 | def foo(a, b): 42 | """This is my docstring. 43 | 44 | Args: 45 | a (object): Something. 46 | b (class): Another thing. 47 | 48 | Returns: 49 | (object, class): 50 | 51 | * a: Something or rather with a long 52 | description that spills over. 53 | * b: And another thing. 54 | """ 55 | 56 | return a, b 57 | ``` 58 | - When citing a paper, avoid directly linking in the docstring! Add a Bibtex entry to `docs/source/refs.bib`. 59 | E.g., to cite "Attention Is All You Need", visit [arXiv](https://arxiv.org/abs/1706.03762), choose the 60 | [bibtext](https://dblp.uni-trier.de/rec/bibtex/journals/corr/VaswaniSPUJGKP17) link, search `docs/source/refs.bib` 61 | using `CTRL-F` for `DBLP:journals/corr/VaswaniSPUJGKP17`, and if you do not find it then copy-paste the 62 | citation into `refs.bib`. Then, in your docstring, use ``:cite:`DBLP:journals/corr/VaswaniSPUJGKP17` ``. 63 | - However, a link is better than nothing. 64 | - Please document tensor shapes. Prefer the format 65 | ``` ``(a, b, c)`` ```. This style is easy to read, allows using ``x`` for multplication, and is common 66 | (PyTorch uses a few variations on the parentheses format, AllenNLP uses exactly this format, Fairseq uses 67 | the parentheses format with single ticks). 68 | - Again, a different style is better than no shape documentation. 69 | - Please avoid unnecessary space characters, try to capitalize, and try to punctuate. 70 | 71 | For multi-line docstrings, add a blank line after the closing ``"""``. 72 | Don't use a blank line before the closing quotes. 73 | 74 | ``""" not this """`` ``"""This."""`` 75 | 76 | ```python 77 | """ 78 | Not this. 79 | """ 80 | ``` 81 | ```python 82 | """This.""" 83 | ``` 84 | 85 | This note is the least important. Focus on content first, but remember that consistent docs look good. 86 | - Be sensible about the first line. Generally, one stand-alone summary line (per the Google guidelines) is good. 87 | Sometimes, it's better to cut directly to the args or an extended description. It's always acceptable to have a 88 | "trailing" citation. -------------------------------------------------------------------------------- /docs/source/FAQ.md: -------------------------------------------------------------------------------- 1 | # FAQ 2 | 3 | ## How do I use Pretrained embeddings (e.g. GloVe)? 4 | 5 | Using vocabularies from OpenNMT-py preprocessing outputs, `embeddings_to_torch.py` to generate encoder and decoder embeddings initialized with GloVe’s values. 6 | 7 | the script is a slightly modified version of ylhsieh’s one2. 8 | 9 | Usage: 10 | 11 | ``` 12 | embeddings_to_torch.py [-h] -emb_file EMB_FILE -output_file OUTPUT_FILE -dict_file DICT_FILE [-verbose] 13 | 14 | emb_file: GloVe like embedding file i.e. CSV [word] [dim1] ... [dim_d] 15 | 16 | output_file: a filename to save the output as PyTorch serialized tensors2 17 | 18 | dict_file: dict output from OpenNMT-py preprocessing 19 | ``` 20 | 21 | Example 22 | 23 | 24 | 1) get GloVe files: 25 | 26 | ``` 27 | mkdir "glove_dir" 28 | wget http://nlp.stanford.edu/data/glove.6B.zip 29 | unzip glove.6B.zip -d "glove_dir" 30 | ``` 31 | 32 | 2) prepare data: 33 | 34 | ``` 35 | python preprocess.py \ 36 | -train_src data/train.src.txt \ 37 | -train_tgt data/train.tgt.txt \ 38 | -valid_src data/valid.src.txt \ 39 | -valid_tgt data/valid.tgt.txt \ 40 | -save_data data/data 41 | ``` 42 | 43 | 3) prepare embeddings: 44 | 45 | ``` 46 | ./tools/embeddings_to_torch.py -emb_file "glove_dir/glove.6B.100d.txt" \ 47 | -dict_file "data/data.vocab.pt" \ 48 | -output_file "data/embeddings" 49 | ``` 50 | 51 | 4) train using pre-trained embeddings: 52 | 53 | ``` 54 | python train.py -save_model data/model \ 55 | -batch_size 64 \ 56 | -layers 2 \ 57 | -rnn_size 200 \ 58 | -word_vec_size 100 \ 59 | -pre_word_vecs_enc "data/embeddings.enc.pt" \ 60 | -pre_word_vecs_dec "data/embeddings.dec.pt" \ 61 | -data data/data 62 | ``` 63 | 64 | 65 | ## How do I use the Transformer model? Do you support multi-gpu? 66 | 67 | The transformer model is very sensitive to hyperparameters. To run it 68 | effectively you need to set a bunch of different options that mimic the Google 69 | setup. We have confirmed the following command can replicate their WMT results. 70 | 71 | ``` 72 | python train.py -data /tmp/de2/data -save_model /tmp/extra \ 73 | -layers 6 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 \ 74 | -encoder_type transformer -decoder_type transformer -position_encoding \ 75 | -train_steps 200000 -max_generator_batches 2 -dropout 0.1 \ 76 | -batch_size 4096 -batch_type tokens -normalization tokens -accum_count 2 \ 77 | -optim adam -adam_beta2 0.998 -decay_method noam -warmup_steps 8000 -learning_rate 2 \ 78 | -max_grad_norm 0 -param_init 0 -param_init_glorot \ 79 | -label_smoothing 0.1 -valid_steps 10000 -save_checkpoint_steps 10000 \ 80 | -world_size 4 -gpu_ranks 0 1 2 3 81 | ``` 82 | 83 | Here are what each of the parameters mean: 84 | 85 | * `param_init_glorot` `-param_init 0`: correct initialization of parameters 86 | * `position_encoding`: add sinusoidal position encoding to each embedding 87 | * `optim adam`, `decay_method noam`, `warmup_steps 8000`: use special learning rate. 88 | * `batch_type tokens`, `normalization tokens`, `accum_count 4`: batch and normalize based on number of tokens and not sentences. Compute gradients based on four batches. 89 | - `label_smoothing 0.1`: use label smoothing loss. 90 | 91 | Multi GPU settings 92 | First you need to make sure you export CUDA_VISIBLE_DEVICES=0,1,2,3 93 | If you want to use GPU id 1 and 3 of your OS, you will need to export CUDA_VISIBLE_DEVICES=1,3 94 | * `world_size 4 gpu_ranks 0 1 2 3`: This will use 4 GPU on this node only. 95 | 96 | If you want to use 2 nodes with 2 GPU each, you need to set -master_ip and master_port, and 97 | * `world_size 4 gpu_ranks 0 1`: on the first node 98 | * `world_size 4 gpu_ranks 2 3`: on the second node 99 | * `accum_count 2`: This will accumulate over 2 batches before updating parameters. 100 | 101 | if you use a regular network card (1 Gbps) then we suggest to use a higher accum_count to minimize the inter-node communication. 102 | 103 | ## How can I ensemble Models at inference? 104 | 105 | You can specify several models in the translate.py command line: -model model1_seed1 model2_seed2 106 | Bear in mind that your models must share the same traget vocabulary. 107 | 108 | 109 | -------------------------------------------------------------------------------- /docs/source/_static/theme_overrides.css: -------------------------------------------------------------------------------- 1 | /* override table width restrictions */ 2 | @media screen and (min-width: 767px) { 3 | 4 | .wy-table-responsive table td { 5 | /* !important prevents the common CSS stylesheets from overriding 6 | this as on RTD they are loaded after this stylesheet */ 7 | white-space: normal !important; 8 | } 9 | 10 | .wy-table-responsive { 11 | overflow: visible !important; 12 | } 13 | } -------------------------------------------------------------------------------- /docs/source/examples.rst: -------------------------------------------------------------------------------- 1 | == Examples == 2 | 3 | 4 | .. include:: quickstart.md 5 | .. include:: extended.md 6 | -------------------------------------------------------------------------------- /docs/source/extended.md: -------------------------------------------------------------------------------- 1 | 2 | # Translation 3 | 4 | The example below uses the Moses tokenizer (http://www.statmt.org/moses/) to prepare the data and the moses BLEU script for evaluation. This example if for training for the WMT'16 Multimodal Translation task (http://www.statmt.org/wmt16/multimodal-task.html). 5 | 6 | Step 0. Download the data. 7 | 8 | ```bash 9 | mkdir -p data/multi30k 10 | wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz && tar -xf training.tar.gz -C data/multi30k && rm training.tar.gz 11 | wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz && tar -xf validation.tar.gz -C data/multi30k && rm validation.tar.gz 12 | wget http://www.quest.dcs.shef.ac.uk/wmt17_files_mmt/mmt_task1_test2016.tar.gz && tar -xf mmt_task1_test2016.tar.gz -C data/multi30k && rm mmt_task1_test2016.tar.gz 13 | ``` 14 | 15 | Step 1. Preprocess the data. 16 | 17 | ```bash 18 | for l in en de; do for f in data/multi30k/*.$l; do if [[ "$f" != *"test"* ]]; then sed -i "$ d" $f; fi; done; done 19 | for l in en de; do for f in data/multi30k/*.$l; do perl tools/tokenizer.perl -a -no-escape -l $l -q < $f > $f.atok; done; done 20 | python preprocess.py -train_src data/multi30k/train.en.atok -train_tgt data/multi30k/train.de.atok -valid_src data/multi30k/val.en.atok -valid_tgt data/multi30k/val.de.atok -save_data data/multi30k.atok.low -lower 21 | ``` 22 | 23 | Step 2. Train the model. 24 | 25 | ```bash 26 | python train.py -data data/multi30k.atok.low -save_model multi30k_model -gpu_ranks 0 27 | ``` 28 | 29 | Step 3. Translate sentences. 30 | 31 | ```bash 32 | python translate.py -gpu 0 -model multi30k_model_*_e13.pt -src data/multi30k/test2016.en.atok -tgt data/multi30k/test2016.de.atok -replace_unk -verbose -output multi30k.test.pred.atok 33 | ``` 34 | 35 | And evaluate 36 | 37 | ```bash 38 | perl tools/multi-bleu.perl data/multi30k/test2016.de.atok < multi30k.test.pred.atok 39 | ``` 40 | -------------------------------------------------------------------------------- /docs/source/im2text.md: -------------------------------------------------------------------------------- 1 | # Image to Text 2 | 3 | A deep learning-based approach to learning the image-to-text conversion, built on top of the OpenNMT system. It is completely data-driven, hence can be used for a variety of image-to-text problems, such as image captioning, optical character recognition and LaTeX decompilation. 4 | 5 | Take LaTeX decompilation as an example, given a formula image: 6 | 7 |

8 | 9 | The goal is to infer the LaTeX source that can be compiled to such an image: 10 | 11 | ``` 12 | d s _ { 1 1 } ^ { 2 } = d x ^ { + } d x ^ { - } + l _ { p } ^ { 9 } \frac { p _ { - } } { r ^ { 7 } } \delta ( x ^ { - } ) d x ^ { - } d x ^ { - } + d x _ { 1 } ^ { 2 } + \; \cdots \; + d x _ { 9 } ^ { 2 } 13 | ``` 14 | 15 | The paper [[What You Get Is What You See: A Visual Markup Decompiler]](https://arxiv.org/pdf/1609.04938.pdf) provides more technical details of this model. 16 | 17 | ### Dependencies 18 | 19 | * `torchvision`: `conda install torchvision` 20 | * `Pillow`: `pip install Pillow` 21 | 22 | ### Quick Start 23 | 24 | To get started, we provide a toy Math-to-LaTex example. We assume that the working directory is `OpenNMT-py` throughout this document. 25 | 26 | Im2Text consists of four commands: 27 | 28 | 0) Download the data. 29 | 30 | ``` 31 | wget -O data/im2text.tgz http://lstm.seas.harvard.edu/latex/im2text_small.tgz; tar zxf data/im2text.tgz -C data/ 32 | ``` 33 | 34 | 1) Preprocess the data. 35 | 36 | ``` 37 | python preprocess.py -data_type img -src_dir data/im2text/images/ -train_src data/im2text/src-train.txt \ 38 | -train_tgt data/im2text/tgt-train.txt -valid_src data/im2text/src-val.txt \ 39 | -valid_tgt data/im2text/tgt-val.txt -save_data data/im2text/demo \ 40 | -tgt_seq_length 150 -tgt_words_min_frequency 2 -shard_size 500 -image_channel_size 1 41 | ``` 42 | 43 | 2) Train the model. 44 | 45 | ``` 46 | python train.py -model_type img -data data/im2text/demo -save_model demo-model -gpu_ranks 0 -batch_size 20 \ 47 | -max_grad_norm 20 -learning_rate 0.1 -word_vec_size 80 -encoder_type brnn -image_channel_size 1 48 | ``` 49 | 50 | 3) Translate the images. 51 | 52 | ``` 53 | python translate.py -data_type img -model demo-model_acc_x_ppl_x_e13.pt -src_dir data/im2text/images \ 54 | -src data/im2text/src-test.txt -output pred.txt -max_length 150 -beam_size 5 -gpu 0 -verbose 55 | ``` 56 | 57 | The above dataset is sampled from the [im2latex-100k-dataset](http://lstm.seas.harvard.edu/latex/im2text.tgz). We provide a trained model [[link]](http://lstm.seas.harvard.edu/latex/py-model.pt) on this dataset. 58 | 59 | ### Options 60 | 61 | * `-src_dir`: The directory containing the images. 62 | 63 | * `-train_tgt`: The file storing the tokenized labels, one label per line. It shall look like: 64 | ``` 65 | ... 66 | ... 67 | ... 68 | ... 69 | ``` 70 | 71 | * `-train_src`: The file storing the paths of the images (relative to `src_dir`). 72 | ``` 73 | 74 | 75 | 76 | ... 77 | ``` 78 | -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | 2 | .. toctree:: 3 | :maxdepth: 2 4 | 5 | index.md 6 | quickstart.md 7 | extended.md 8 | 9 | 10 | This portal provides a detailled documentation of the OpenNMT toolkit. It describes how to use the PyTorch project and how it works. 11 | 12 | 13 | 14 | ## Installation 15 | 16 | 1\. [Install PyTorch](http://pytorch.org/) 17 | 18 | 2\. Clone the OpenNMT-py repository: 19 | 20 | ```bash 21 | git clone https://github.com/OpenNMT/OpenNMT-py 22 | cd OpenNMT-py 23 | ``` 24 | 25 | 3\. Install required libraries 26 | 27 | ```bash 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | And you are ready to go! Take a look at the [quickstart](quickstart.md) to familiarize yourself with the main training workflow. 32 | 33 | Alternatively you can use Docker to install with `nvidia-docker`. The main Dockerfile is included 34 | in the root directory. 35 | 36 | ## Citation 37 | 38 | When using OpenNMT for research please cite our 39 | [OpenNMT technical report](https://doi.org/10.18653/v1/P17-4012) 40 | 41 | ``` 42 | @inproceedings{opennmt, 43 | author = {Guillaume Klein and 44 | Yoon Kim and 45 | Yuntian Deng and 46 | Jean Senellart and 47 | Alexander M. Rush}, 48 | title = {OpenNMT: Open-Source Toolkit for Neural Machine Translation}, 49 | booktitle = {Proc. ACL}, 50 | year = {2017}, 51 | url = {https://doi.org/10.18653/v1/P17-4012}, 52 | doi = {10.18653/v1/P17-4012} 53 | } 54 | ``` 55 | 56 | ## Additional resources 57 | 58 | You can find additional help or tutorials in the following resources: 59 | 60 | * [Gitter channel](https://gitter.im/OpenNMT/openmt-py) 61 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Contents 2 | -------- 3 | 4 | .. toctree:: 5 | :caption: Getting Started 6 | :maxdepth: 2 7 | 8 | main.md 9 | quickstart.md 10 | FAQ.md 11 | CONTRIBUTING.md 12 | ref.rst 13 | 14 | 15 | .. toctree:: 16 | :caption: Examples 17 | :maxdepth: 2 18 | 19 | Library.md 20 | extended.md 21 | Summarization.md 22 | im2text.md 23 | speech2text.md 24 | 25 | 26 | .. toctree:: 27 | :caption: Scripts 28 | :maxdepth: 2 29 | 30 | options/preprocess.rst 31 | options/train.rst 32 | options/translate.rst 33 | options/server.rst 34 | 35 | 36 | .. toctree:: 37 | :caption: API 38 | :maxdepth: 2 39 | 40 | onmt.rst 41 | onmt.modules.rst 42 | onmt.translation.rst 43 | onmt.translate.translation_server.rst 44 | onmt.inputters.rst -------------------------------------------------------------------------------- /docs/source/main.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | 4 | This portal provides a detailed documentation of the OpenNMT toolkit. It describes how to use the PyTorch project and how it works. 5 | 6 | 7 | 8 | ## Installation 9 | 10 | 1\. [Install PyTorch](http://pytorch.org/) 11 | 12 | 2\. Clone the OpenNMT-py repository: 13 | 14 | ```bash 15 | git clone https://github.com/OpenNMT/OpenNMT-py 16 | cd OpenNMT-py 17 | ``` 18 | 19 | 3\. Install required libraries 20 | 21 | ```bash 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | And you are ready to go! Take a look at the [quickstart](quickstart) to familiarize yourself with the main training workflow. 26 | 27 | Alternatively you can use Docker to install with `nvidia-docker`. The main Dockerfile is included 28 | in the root directory. 29 | 30 | ## Citation 31 | 32 | When using OpenNMT for research please cite our 33 | [OpenNMT technical report](https://doi.org/10.18653/v1/P17-4012) 34 | 35 | ``` 36 | @inproceedings{opennmt, 37 | author = {Guillaume Klein and 38 | Yoon Kim and 39 | Yuntian Deng and 40 | Jean Senellart and 41 | Alexander M. Rush}, 42 | title = {OpenNMT: Open-Source Toolkit for Neural Machine Translation}, 43 | booktitle = {Proc. ACL}, 44 | year = {2017}, 45 | url = {https://doi.org/10.18653/v1/P17-4012}, 46 | doi = {10.18653/v1/P17-4012} 47 | } 48 | ``` 49 | 50 | ## Additional resources 51 | 52 | You can find additional help or tutorials in the following resources: 53 | 54 | * [Gitter channel](https://gitter.im/OpenNMT/openmt-py) 55 | 56 | * [Forum](http://forum.opennmt.net/) -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | onmt 2 | ==== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | onmt 8 | -------------------------------------------------------------------------------- /docs/source/onmt.inputters.rst: -------------------------------------------------------------------------------- 1 | Data Loaders 2 | ================= 3 | 4 | Data Readers 5 | ------------- 6 | 7 | .. autoexception:: onmt.inputters.datareader_base.MissingDependencyException 8 | 9 | .. autoclass:: onmt.inputters.DataReaderBase 10 | :members: 11 | 12 | .. autoclass:: onmt.inputters.TextDataReader 13 | :members: 14 | 15 | .. autoclass:: onmt.inputters.ImageDataReader 16 | :members: 17 | 18 | .. autoclass:: onmt.inputters.AudioDataReader 19 | :members: 20 | 21 | 22 | Dataset 23 | -------- 24 | 25 | .. autoclass:: onmt.inputters.Dataset 26 | :members: 27 | -------------------------------------------------------------------------------- /docs/source/onmt.modules.rst: -------------------------------------------------------------------------------- 1 | Modules 2 | ============= 3 | 4 | Core Modules 5 | ------------ 6 | 7 | .. autoclass:: onmt.modules.Embeddings 8 | :members: 9 | 10 | 11 | Encoders 12 | --------- 13 | 14 | .. autoclass:: onmt.encoders.EncoderBase 15 | :members: 16 | 17 | .. autoclass:: onmt.encoders.MeanEncoder 18 | :members: 19 | 20 | .. autoclass:: onmt.encoders.RNNEncoder 21 | :members: 22 | 23 | 24 | Decoders 25 | --------- 26 | 27 | 28 | .. autoclass:: onmt.decoders.DecoderBase 29 | :members: 30 | 31 | .. autoclass:: onmt.decoders.decoder.RNNDecoderBase 32 | :members: 33 | 34 | .. autoclass:: onmt.decoders.StdRNNDecoder 35 | :members: 36 | 37 | .. autoclass:: onmt.decoders.InputFeedRNNDecoder 38 | :members: 39 | 40 | Attention 41 | ---------- 42 | 43 | .. autoclass:: onmt.modules.AverageAttention 44 | :members: 45 | 46 | .. autoclass:: onmt.modules.GlobalAttention 47 | :members: 48 | 49 | 50 | 51 | Architecture: Transfomer 52 | ---------------------------- 53 | 54 | .. autoclass:: onmt.modules.PositionalEncoding 55 | :members: 56 | 57 | .. autoclass:: onmt.modules.position_ffn.PositionwiseFeedForward 58 | :members: 59 | 60 | .. autoclass:: onmt.encoders.TransformerEncoder 61 | :members: 62 | 63 | .. autoclass:: onmt.decoders.TransformerDecoder 64 | :members: 65 | 66 | .. autoclass:: onmt.modules.MultiHeadedAttention 67 | :members: 68 | :undoc-members: 69 | 70 | 71 | Architecture: Conv2Conv 72 | ---------------------------- 73 | 74 | (These methods are from a user contribution 75 | and have not been thoroughly tested.) 76 | 77 | 78 | .. autoclass:: onmt.encoders.CNNEncoder 79 | :members: 80 | 81 | 82 | .. autoclass:: onmt.decoders.CNNDecoder 83 | :members: 84 | 85 | .. autoclass:: onmt.modules.ConvMultiStepAttention 86 | :members: 87 | 88 | .. autoclass:: onmt.modules.WeightNormConv2d 89 | :members: 90 | 91 | Architecture: SRU 92 | ---------------------------- 93 | 94 | .. autoclass:: onmt.models.sru.SRU 95 | :members: 96 | 97 | 98 | Alternative Encoders 99 | -------------------- 100 | 101 | onmt\.modules\.AudioEncoder 102 | 103 | .. autoclass:: onmt.encoders.AudioEncoder 104 | :members: 105 | 106 | 107 | onmt\.modules\.ImageEncoder 108 | 109 | .. autoclass:: onmt.encoders.ImageEncoder 110 | :members: 111 | 112 | 113 | Copy Attention 114 | -------------- 115 | 116 | .. autoclass:: onmt.modules.CopyGenerator 117 | :members: 118 | 119 | 120 | Structured Attention 121 | ------------------------------------------- 122 | 123 | .. autoclass:: onmt.modules.structured_attention.MatrixTree 124 | :members: 125 | -------------------------------------------------------------------------------- /docs/source/onmt.rst: -------------------------------------------------------------------------------- 1 | Framework 2 | ================= 3 | 4 | Model 5 | ----- 6 | 7 | .. autoclass:: onmt.models.NMTModel 8 | :members: 9 | 10 | Trainer 11 | ------- 12 | 13 | .. autoclass:: onmt.Trainer 14 | :members: 15 | 16 | 17 | .. autoclass:: onmt.utils.Statistics 18 | :members: 19 | 20 | Loss 21 | ---- 22 | 23 | 24 | .. autoclass:: onmt.utils.loss.LossComputeBase 25 | :members: 26 | 27 | 28 | Optimizer 29 | ----- 30 | 31 | .. autoclass:: onmt.utils.Optimizer 32 | :members: 33 | -------------------------------------------------------------------------------- /docs/source/onmt.translate.translation_server.rst: -------------------------------------------------------------------------------- 1 | Server 2 | ====== 3 | 4 | 5 | Models 6 | ------------- 7 | 8 | .. autoclass:: onmt.translate.translation_server.ServerModel 9 | :members: 10 | 11 | 12 | Core Server 13 | ------------ 14 | 15 | .. autoexception:: onmt.translate.translation_server.ServerModelError 16 | 17 | .. autoclass:: onmt.translate.translation_server.Timer 18 | :members: 19 | 20 | .. autoclass:: onmt.translate.translation_server.TranslationServer 21 | :members: 22 | -------------------------------------------------------------------------------- /docs/source/onmt.translation.rst: -------------------------------------------------------------------------------- 1 | Translation 2 | ================== 3 | 4 | Translations 5 | ------------- 6 | 7 | .. autoclass:: onmt.translate.Translation 8 | :members: 9 | 10 | Translator Class 11 | ----------------- 12 | 13 | .. autoclass:: onmt.translate.Translator 14 | :members: 15 | 16 | .. autoclass:: onmt.translate.TranslationBuilder 17 | :members: 18 | 19 | 20 | Decoding Strategies 21 | -------------------- 22 | .. autoclass:: onmt.translate.DecodeStrategy 23 | :members: 24 | 25 | .. autoclass:: onmt.translate.BeamSearch 26 | :members: 27 | 28 | .. autofunction:: onmt.translate.random_sampling.sample_with_temperature 29 | 30 | .. autoclass:: onmt.translate.RandomSampling 31 | :members: 32 | 33 | Scoring 34 | -------- 35 | .. autoclass:: onmt.translate.penalties.PenaltyBuilder 36 | :members: 37 | 38 | .. autoclass:: onmt.translate.GNMTGlobalScorer 39 | :members: 40 | -------------------------------------------------------------------------------- /docs/source/options/preprocess.rst: -------------------------------------------------------------------------------- 1 | Preprocess 2 | ========== 3 | 4 | .. argparse:: 5 | :filename: ../preprocess.py 6 | :func: _get_parser 7 | :prog: preprocess.py -------------------------------------------------------------------------------- /docs/source/options/server.rst: -------------------------------------------------------------------------------- 1 | Server 2 | ========= 3 | 4 | .. argparse:: 5 | :filename: ../server.py 6 | :func: _get_parser 7 | :prog: server.py -------------------------------------------------------------------------------- /docs/source/options/train.rst: -------------------------------------------------------------------------------- 1 | Train 2 | ===== 3 | 4 | .. argparse:: 5 | :filename: ../train.py 6 | :func: _get_parser 7 | :prog: train.py -------------------------------------------------------------------------------- /docs/source/options/translate.rst: -------------------------------------------------------------------------------- 1 | Translate 2 | ========= 3 | 4 | .. argparse:: 5 | :filename: ../translate.py 6 | :func: _get_parser 7 | :prog: translate.py -------------------------------------------------------------------------------- /docs/source/quickstart.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Quickstart 4 | 5 | 6 | ### Step 1: Preprocess the data 7 | 8 | ```bash 9 | python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo 10 | ``` 11 | 12 | We will be working with some example data in `data/` folder. 13 | 14 | The data consists of parallel source (`src`) and target (`tgt`) data containing one sentence per line with tokens separated by a space: 15 | 16 | * `src-train.txt` 17 | * `tgt-train.txt` 18 | * `src-val.txt` 19 | * `tgt-val.txt` 20 | 21 | Validation files are required and used to evaluate the convergence of the training. It usually contains no more than 5000 sentences. 22 | 23 | ```text 24 | $ head -n 3 data/src-train.txt 25 | It is not acceptable that , with the help of the national bureaucracies , Parliament 's legislative prerogative should be made null and void by means of implementing provisions whose content , purpose and extent are not laid down in advance . 26 | Federal Master Trainer and Senior Instructor of the Italian Federation of Aerobic Fitness , Group Fitness , Postural Gym , Stretching and Pilates; from 2004 , he has been collaborating with Antiche Terme as personal Trainer and Instructor of Stretching , Pilates and Postural Gym . 27 | " Two soldiers came up to me and told me that if I refuse to sleep with them , they will kill me . They beat me and ripped my clothes . 28 | ``` 29 | 30 | ### Step 2: Train the model 31 | 32 | ```bash 33 | python train.py -data data/demo -save_model demo-model 34 | ``` 35 | 36 | The main train command is quite simple. Minimally it takes a data file 37 | and a save file. This will run the default model, which consists of a 38 | 2-layer LSTM with 500 hidden units on both the encoder/decoder. 39 | If you want to train on GPU, you need to set, as an example: 40 | CUDA_VISIBLE_DEVICES=1,3 41 | `-world_size 2 -gpu_ranks 0 1` to use (say) GPU 1 and 3 on this node only. 42 | To know more about distributed training on single or multi nodes, read the FAQ section. 43 | 44 | ### Step 3: Translate 45 | 46 | ```bash 47 | python translate.py -model demo-model_XYZ.pt -src data/src-test.txt -output pred.txt -replace_unk -verbose 48 | ``` 49 | 50 | Now you have a model which you can use to predict on new data. We do this by running beam search. This will output predictions into `pred.txt`. 51 | 52 | Note: 53 | 54 | The predictions are going to be quite terrible, as the demo dataset is small. Try running on some larger datasets! For example you can download millions of parallel sentences for [translation](http://www.statmt.org/wmt16/translation-task.html) or [summarization](https://github.com/harvardnlp/sent-summary). 55 | -------------------------------------------------------------------------------- /docs/source/ref.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | References 3 | ========== 4 | 5 | 6 | 7 | References 8 | 9 | .. bibliography:: refs.bib 10 | 11 | -------------------------------------------------------------------------------- /docs/source/speech2text.md: -------------------------------------------------------------------------------- 1 | # Speech to Text 2 | 3 | A deep learning-based approach to learning the speech-to-text conversion, built on top of the OpenNMT system. 4 | 5 | Given raw audio, we first apply short-time Fourier transform (STFT), then apply Convolutional Neural Networks to get the source features. Based on this source representation, we use an LSTM decoder with attention to produce the text character by character. 6 | 7 | ### Dependencies 8 | 9 | * `torchaudio`: `sudo apt-get install -y sox libsox-dev libsox-fmt-all; pip install git+https://github.com/pytorch/audio` 10 | * `librosa`: `pip install librosa` 11 | 12 | ### Quick Start 13 | 14 | To get started, we provide a toy speech-to-text example. We assume that the working directory is `OpenNMT-py` throughout this document. 15 | 16 | 0) Download the data. 17 | 18 | ``` 19 | wget -O data/speech.tgz http://lstm.seas.harvard.edu/latex/speech.tgz; tar zxf data/speech.tgz -C data/ 20 | ``` 21 | 22 | 23 | 1) Preprocess the data. 24 | 25 | ``` 26 | python preprocess.py -data_type audio -src_dir data/speech/an4_dataset -train_src data/speech/src-train.txt -train_tgt data/speech/tgt-train.txt -valid_src data/speech/src-val.txt -valid_tgt data/speech/tgt-val.txt -shard_size 300 -save_data data/speech/demo 27 | ``` 28 | 29 | 2) Train the model. 30 | 31 | ``` 32 | python train.py -model_type audio -enc_rnn_size 512 -dec_rnn_size 512 -audio_enc_pooling 1,1,2,2 -dropout 0 -enc_layers 4 -dec_layers 1 -rnn_type LSTM -data data/speech/demo -save_model demo-model -global_attention mlp -gpu_ranks 0 -batch_size 8 -optim adam -max_grad_norm 100 -learning_rate 0.0003 -learning_rate_decay 0.8 -train_steps 100000 33 | ``` 34 | 35 | 3) Translate the speechs. 36 | 37 | ``` 38 | python translate.py -data_type audio -model demo-model_acc_x_ppl_x_e13.pt -src_dir data/speech/an4_dataset -src data/speech/src-val.txt -output pred.txt -gpu 0 -verbose 39 | ``` 40 | 41 | 42 | ### Options 43 | 44 | * `-src_dir`: The directory containing the audio files. 45 | 46 | * `-train_tgt`: The file storing the tokenized labels, one label per line. It shall look like: 47 | ``` 48 | ... 49 | ... 50 | ... 51 | ... 52 | ``` 53 | 54 | * `-train_src`: The file storing the paths of the audio files (relative to `src_dir`). 55 | ``` 56 | 57 | 58 | 59 | ... 60 | ``` 61 | 62 | * `sample_rate`: Sample rate. Default: 16000. 63 | * `window_size`: Window size for spectrogram in seconds. Default: 0.02. 64 | * `window_stride`: Window stride for spectrogram in seconds. Default: 0.01. 65 | * `window`: Window type for spectrogram generation. Default: hamming. 66 | 67 | ### Acknowledgement 68 | 69 | Our preprocessing and CNN encoder is adapted from [deepspeech.pytorch](https://github.com/SeanNaren/deepspeech.pytorch). 70 | -------------------------------------------------------------------------------- /gpt2/.gitignore: -------------------------------------------------------------------------------- 1 | 117M 2 | 345M 3 | -------------------------------------------------------------------------------- /gpt2/decode_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from pytorch_pretrained_bert import GPT2Tokenizer 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--src', '-src', type=str) 7 | parser.add_argument('--dst', '-dst', type=str) 8 | 9 | args = parser.parse_args() 10 | enc = GPT2Tokenizer.from_pretrained('gpt2') 11 | 12 | if args.dst is None: 13 | if args.src[-4:] == '.bpe': 14 | args.dst = args.src[:-4] 15 | elif args.src[-8:] == '.encoded': 16 | args.dst = args.src[:-8] 17 | else: 18 | raise ValueError('dst needed or src that ends in .bpe or .encoded') 19 | 20 | i = 0 21 | with open(args.dst, 'w') as fw: 22 | with open(args.src, 'r') as f: 23 | for line in f: 24 | i += 1 25 | text = line.strip() 26 | 27 | text = ''.join(text.split(' ')) 28 | 29 | decoded = bytearray([enc.byte_decoder[c] for c in text]).decode('utf-8', errors=enc.errors) 30 | decoded = decoded.replace('\n', '') # We need one example per line 31 | decoded = decoded.replace('\r', '') 32 | decoded += '\n' 33 | fw.write(decoded) 34 | print(i) 35 | -------------------------------------------------------------------------------- /gpt2/download_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import requests 4 | from tqdm import tqdm 5 | 6 | if len(sys.argv) != 2: 7 | print('You must enter the model name as a parameter, e.g.: download_model.py 124M') 8 | sys.exit(1) 9 | 10 | model = sys.argv[1] 11 | 12 | subdir = os.path.join('models', model) 13 | if not os.path.exists(subdir): 14 | os.makedirs(subdir) 15 | subdir = subdir.replace('\\','/') # needed for Windows 16 | 17 | for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']: 18 | 19 | r = requests.get("https://storage.googleapis.com/gpt-2/" + subdir + "/" + filename, stream=True) 20 | 21 | with open(os.path.join(subdir, filename), 'wb') as f: 22 | file_size = int(r.headers["content-length"]) 23 | chunk_size = 1000 24 | with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: 25 | # 1k for chunk_size, since Ethernet packet size is around 1500 bytes 26 | for chunk in r.iter_content(chunk_size=chunk_size): 27 | f.write(chunk) 28 | pbar.update(chunk_size) 29 | -------------------------------------------------------------------------------- /gpt2/encode_text.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pytorch_pretrained_bert import GPT2Tokenizer 3 | import regex as re 4 | 5 | pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 6 | enc = GPT2Tokenizer.from_pretrained('gpt2') 7 | 8 | filename = sys.argv[1] 9 | 10 | with_tldr = False 11 | replace_newline = False 12 | tok_trunc = 1000000 13 | 14 | write_name = file_prefix+filename+'.bpe' 15 | if with_tldr and 'src' in filename: 16 | write_name += '.tldr' 17 | 18 | with open(file_prefix+filename, 'r') as f: 19 | with open(write_name, 'w') as fw: 20 | for line in f: 21 | txt = line.strip() 22 | if with_tldr and 'src' in filename: 23 | txt += '\nTL;DR:' 24 | 25 | if replace_newline: 26 | txt = txt.replace('', '\n') 27 | 28 | bpe_tokens = [] 29 | for token in re.findall(pat, txt): # line.strip() to make sure newline is not encoded 30 | token = ''.join(enc.byte_encoder[b] for b in token.encode('utf-8')) 31 | bpe_tokens.extend(enc.bpe(token).split(' ')) 32 | fw.write(' '.join(bpe_tokens[:tok_trunc]) + '\n') 33 | -------------------------------------------------------------------------------- /onmt/__init__.py: -------------------------------------------------------------------------------- 1 | """ Main entry point of the ONMT library """ 2 | from __future__ import division, print_function 3 | 4 | import onmt.inputters 5 | import onmt.encoders 6 | import onmt.decoders 7 | import onmt.models 8 | import onmt.utils 9 | import onmt.modules 10 | from onmt.trainer import Trainer 11 | import sys 12 | import onmt.utils.optimizers 13 | onmt.utils.optimizers.Optim = onmt.utils.optimizers.Optimizer 14 | sys.modules["onmt.Optim"] = onmt.utils.optimizers 15 | 16 | # For Flake 17 | __all__ = [onmt.inputters, onmt.encoders, onmt.decoders, onmt.models, 18 | onmt.utils, onmt.modules, "Trainer"] 19 | 20 | __version__ = "0.8.2" 21 | -------------------------------------------------------------------------------- /onmt/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining decoders.""" 2 | from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, \ 3 | StdRNNDecoder 4 | from onmt.decoders.transformer import TransformerDecoder 5 | from onmt.decoders.cnn_decoder import CNNDecoder 6 | from onmt.decoders.rnn_uncond import RNNUncondDecoder 7 | 8 | 9 | str2dec = {"rnn": StdRNNDecoder, "ifrnn": InputFeedRNNDecoder, 10 | "cnn": CNNDecoder, "transformer": TransformerDecoder, 11 | "rnn_uncond": RNNUncondDecoder} 12 | 13 | __all__ = ["DecoderBase", "TransformerDecoder", "StdRNNDecoder", "CNNDecoder", 14 | "InputFeedRNNDecoder", "str2dec", 15 | "RNNUncondDecoder"] 16 | -------------------------------------------------------------------------------- /onmt/decoders/cnn_decoder.py: -------------------------------------------------------------------------------- 1 | """Implementation of the CNN Decoder part of 2 | "Convolutional Sequence to Sequence Learning" 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from onmt.modules import ConvMultiStepAttention, GlobalAttention 8 | from onmt.utils.cnn_factory import shape_transform, GatedConv 9 | from onmt.decoders.decoder import DecoderBase 10 | 11 | SCALE_WEIGHT = 0.5 ** 0.5 12 | 13 | 14 | class CNNDecoder(DecoderBase): 15 | """Decoder based on "Convolutional Sequence to Sequence Learning" 16 | :cite:`DBLP:journals/corr/GehringAGYD17`. 17 | 18 | Consists of residual convolutional layers, with ConvMultiStepAttention. 19 | """ 20 | 21 | def __init__(self, num_layers, hidden_size, attn_type, 22 | copy_attn, cnn_kernel_width, dropout, embeddings, 23 | copy_attn_type): 24 | super(CNNDecoder, self).__init__() 25 | 26 | self.cnn_kernel_width = cnn_kernel_width 27 | self.embeddings = embeddings 28 | 29 | # Decoder State 30 | self.state = {} 31 | 32 | input_size = self.embeddings.embedding_size 33 | self.linear = nn.Linear(input_size, hidden_size) 34 | self.conv_layers = nn.ModuleList( 35 | [GatedConv(hidden_size, cnn_kernel_width, dropout, True) 36 | for i in range(num_layers)] 37 | ) 38 | self.attn_layers = nn.ModuleList( 39 | [ConvMultiStepAttention(hidden_size) for i in range(num_layers)] 40 | ) 41 | 42 | # CNNDecoder has its own attention mechanism. 43 | # Set up a separate copy attention layer if needed. 44 | assert not copy_attn, "Copy mechanism not yet tested in conv2conv" 45 | if copy_attn: 46 | self.copy_attn = GlobalAttention( 47 | hidden_size, attn_type=copy_attn_type) 48 | else: 49 | self.copy_attn = None 50 | 51 | @classmethod 52 | def from_opt(cls, opt, embeddings): 53 | """Alternate constructor.""" 54 | return cls( 55 | opt.dec_layers, 56 | opt.dec_rnn_size, 57 | opt.global_attention, 58 | opt.copy_attn, 59 | opt.cnn_kernel_width, 60 | opt.dropout, 61 | embeddings, 62 | opt.copy_attn_type) 63 | 64 | def init_state(self, _, memory_bank, enc_hidden): 65 | """Init decoder state.""" 66 | self.state["src"] = (memory_bank + enc_hidden) * SCALE_WEIGHT 67 | self.state["previous_input"] = None 68 | 69 | def map_state(self, fn): 70 | self.state["src"] = fn(self.state["src"], 1) 71 | if self.state["previous_input"] is not None: 72 | self.state["previous_input"] = fn(self.state["previous_input"], 1) 73 | 74 | def detach_state(self): 75 | self.state["previous_input"] = self.state["previous_input"].detach() 76 | 77 | def forward(self, tgt, memory_bank, step=None, **kwargs): 78 | """ See :obj:`onmt.modules.RNNDecoderBase.forward()`""" 79 | 80 | if self.state["previous_input"] is not None: 81 | tgt = torch.cat([self.state["previous_input"], tgt], 0) 82 | 83 | dec_outs = [] 84 | attns = {"std": []} 85 | if self.copy_attn is not None: 86 | attns["copy"] = [] 87 | 88 | emb = self.embeddings(tgt) 89 | assert emb.dim() == 3 # len x batch x embedding_dim 90 | 91 | tgt_emb = emb.transpose(0, 1).contiguous() 92 | # The output of CNNEncoder. 93 | src_memory_bank_t = memory_bank.transpose(0, 1).contiguous() 94 | # The combination of output of CNNEncoder and source embeddings. 95 | src_memory_bank_c = self.state["src"].transpose(0, 1).contiguous() 96 | 97 | emb_reshape = tgt_emb.contiguous().view( 98 | tgt_emb.size(0) * tgt_emb.size(1), -1) 99 | linear_out = self.linear(emb_reshape) 100 | x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1) 101 | x = shape_transform(x) 102 | 103 | pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1) 104 | 105 | pad = pad.type_as(x) 106 | base_target_emb = x 107 | 108 | for conv, attention in zip(self.conv_layers, self.attn_layers): 109 | new_target_input = torch.cat([pad, x], 2) 110 | out = conv(new_target_input) 111 | c, attn = attention(base_target_emb, out, 112 | src_memory_bank_t, src_memory_bank_c) 113 | x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT 114 | output = x.squeeze(3).transpose(1, 2) 115 | 116 | # Process the result and update the attentions. 117 | dec_outs = output.transpose(0, 1).contiguous() 118 | if self.state["previous_input"] is not None: 119 | dec_outs = dec_outs[self.state["previous_input"].size(0):] 120 | attn = attn[:, self.state["previous_input"].size(0):].squeeze() 121 | attn = torch.stack([attn]) 122 | attns["std"] = attn 123 | if self.copy_attn is not None: 124 | attns["copy"] = attn 125 | 126 | # Update the state. 127 | self.state["previous_input"] = tgt 128 | # TODO change the way attns is returned dict => list or tuple (onnx) 129 | return dec_outs, attns 130 | -------------------------------------------------------------------------------- /onmt/decoders/ensemble.py: -------------------------------------------------------------------------------- 1 | """Ensemble decoding. 2 | 3 | Decodes using multiple models simultaneously, 4 | combining their prediction distributions by averaging. 5 | All models in the ensemble must share a target vocabulary. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from onmt.encoders.encoder import EncoderBase 12 | from onmt.models import NMTModel 13 | import onmt.model_builder 14 | 15 | 16 | class EnsembleDecoderOutput(object): 17 | """Wrapper around multiple decoder final hidden states.""" 18 | def __init__(self, model_dec_outs): 19 | self.model_dec_outs = tuple(model_dec_outs) 20 | 21 | def squeeze(self, dim=None): 22 | """Delegate squeeze to avoid modifying 23 | :func:`onmt.translate.translator.Translator.translate_batch()` 24 | """ 25 | return EnsembleDecoderOutput([ 26 | x.squeeze(dim) for x in self.model_dec_outs]) 27 | 28 | def __getitem__(self, index): 29 | return self.model_dec_outs[index] 30 | 31 | 32 | class EnsembleEncoder(EncoderBase): 33 | """Dummy Encoder that delegates to individual real Encoders.""" 34 | def __init__(self, model_encoders): 35 | super(EnsembleEncoder, self).__init__() 36 | self.model_encoders = nn.ModuleList(model_encoders) 37 | 38 | def forward(self, src, lengths=None): 39 | enc_hidden, memory_bank, _ = zip(*[ 40 | model_encoder(src, lengths) 41 | for model_encoder in self.model_encoders]) 42 | return enc_hidden, memory_bank, lengths 43 | 44 | 45 | class EnsembleDecoder(nn.Module): 46 | """Dummy Decoder that delegates to individual real Decoders.""" 47 | def __init__(self, model_decoders): 48 | super(EnsembleDecoder, self).__init__() 49 | self.model_decoders = nn.ModuleList(model_decoders) 50 | 51 | def forward(self, tgt, memory_bank, memory_lengths=None, step=None, **kwargs): 52 | """See :func:`onmt.decoders.decoder.DecoderBase.forward()`.""" 53 | # Memory_lengths is a single tensor shared between all models. 54 | # This assumption will not hold if Translator is modified 55 | # to calculate memory_lengths as something other than the length 56 | # of the input. 57 | dec_outs, attns = zip(*[ 58 | model_decoder( 59 | tgt, memory_bank[i], 60 | memory_lengths=memory_lengths, step=step) 61 | for i, model_decoder in enumerate(self.model_decoders)]) 62 | mean_attns = self.combine_attns(attns) 63 | return EnsembleDecoderOutput(dec_outs), mean_attns 64 | 65 | def combine_attns(self, attns): 66 | result = {} 67 | for key in attns[0].keys(): 68 | result[key] = torch.stack([attn[key] for attn in attns]).mean(0) 69 | return result 70 | 71 | def init_state(self, src, memory_bank, enc_hidden): 72 | """ See :obj:`RNNDecoderBase.init_state()` """ 73 | for i, model_decoder in enumerate(self.model_decoders): 74 | model_decoder.init_state(src, memory_bank[i], enc_hidden[i]) 75 | 76 | def map_state(self, fn): 77 | for model_decoder in self.model_decoders: 78 | model_decoder.map_state(fn) 79 | 80 | 81 | class EnsembleGenerator(nn.Module): 82 | """ 83 | Dummy Generator that delegates to individual real Generators, 84 | and then averages the resulting target distributions. 85 | """ 86 | def __init__(self, model_generators, raw_probs=False): 87 | super(EnsembleGenerator, self).__init__() 88 | self.model_generators = nn.ModuleList(model_generators) 89 | self._raw_probs = raw_probs 90 | 91 | def forward(self, hidden, attn=None, src_map=None): 92 | """ 93 | Compute a distribution over the target dictionary 94 | by averaging distributions from models in the ensemble. 95 | All models in the ensemble must share a target vocabulary. 96 | """ 97 | distributions = torch.stack( 98 | [mg(h) if attn is None else mg(h, attn, src_map) 99 | for h, mg in zip(hidden, self.model_generators)] 100 | ) 101 | if self._raw_probs: 102 | return torch.log(torch.exp(distributions).mean(0)) 103 | else: 104 | return distributions.mean(0) 105 | 106 | 107 | class EnsembleModel(NMTModel): 108 | """Dummy NMTModel wrapping individual real NMTModels.""" 109 | def __init__(self, models, raw_probs=False): 110 | encoder = EnsembleEncoder(model.encoder for model in models) 111 | decoder = EnsembleDecoder(model.decoder for model in models) 112 | super(EnsembleModel, self).__init__(encoder, decoder) 113 | self.generator = EnsembleGenerator( 114 | [model.generator for model in models], raw_probs) 115 | self.models = nn.ModuleList(models) 116 | 117 | 118 | def load_test_model(opt): 119 | """Read in multiple models for ensemble.""" 120 | shared_fields = None 121 | shared_model_opt = None 122 | models = [] 123 | for model_path in opt.models: 124 | fields, model, model_opt = \ 125 | onmt.model_builder.load_test_model(opt, model_path=model_path) 126 | if shared_fields is None: 127 | shared_fields = fields 128 | else: 129 | for key, field in fields.items(): 130 | try: 131 | f_iter = iter(field) 132 | except TypeError: 133 | f_iter = [(key, field)] 134 | for sn, sf in f_iter: 135 | if sf is not None and 'vocab' in sf.__dict__: 136 | sh_field = shared_fields[key] 137 | try: 138 | sh_f_iter = iter(sh_field) 139 | except TypeError: 140 | sh_f_iter = [(key, sh_field)] 141 | sh_f_dict = dict(sh_f_iter) 142 | assert sf.vocab.stoi == sh_f_dict[sn].vocab.stoi, \ 143 | "Ensemble models must use the same " \ 144 | "preprocessed data" 145 | models.append(model) 146 | if shared_model_opt is None: 147 | shared_model_opt = model_opt 148 | ensemble_model = EnsembleModel(models, opt.avg_raw_probs) 149 | return shared_fields, ensemble_model, shared_model_opt 150 | -------------------------------------------------------------------------------- /onmt/decoders/rnn_uncond.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from onmt.decoders.decoder import DecoderBase 5 | from onmt.utils.rnn_factory import rnn_factory 6 | 7 | class RNNUncondDecoder(DecoderBase): 8 | """Base recurrent attention-based decoder class. 9 | 10 | Specifies the interface used by different decoder types 11 | and required by :class:`~onmt.models.NMTModel`. 12 | 13 | 14 | .. mermaid:: 15 | 16 | graph BT 17 | A[Input] 18 | subgraph RNN 19 | C[Pos 1] 20 | D[Pos 2] 21 | E[Pos N] 22 | end 23 | G[Decoder State] 24 | H[Decoder State] 25 | I[Outputs] 26 | F[memory_bank] 27 | A--emb-->C 28 | A--emb-->D 29 | A--emb-->E 30 | H-->C 31 | C-- attn --- F 32 | D-- attn --- F 33 | E-- attn --- F 34 | C-->I 35 | D-->I 36 | E-->I 37 | E-->G 38 | F---I 39 | 40 | Args: 41 | rnn_type (str): 42 | style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU] 43 | bidirectional_encoder (bool) : use with a bidirectional encoder 44 | num_layers (int) : number of stacked layers 45 | hidden_size (int) : hidden size of each layer 46 | attn_type (str) : see :class:`~onmt.modules.GlobalAttention` 47 | attn_func (str) : see :class:`~onmt.modules.GlobalAttention` 48 | coverage_attn (str): see :class:`~onmt.modules.GlobalAttention` 49 | context_gate (str): see :class:`~onmt.modules.ContextGate` 50 | copy_attn (bool): setup a separate copy attention mechanism 51 | dropout (float) : dropout value for :class:`torch.nn.Dropout` 52 | embeddings (onmt.modules.Embeddings): embedding module to use 53 | reuse_copy_attn (bool): reuse the attention for copying 54 | copy_attn_type (str): The copy attention style. See 55 | :class:`~onmt.modules.GlobalAttention`. 56 | """ 57 | 58 | def __init__(self, rnn_type, num_layers, 59 | hidden_size, dropout=0.0, embeddings=None): 60 | super(RNNUncondDecoder, self).__init__(attentional=False) 61 | if rnn_type == 'GRU': 62 | raise NotImplementedError 63 | 64 | self.num_layers = num_layers 65 | self.hidden_size = hidden_size 66 | self.embeddings = embeddings 67 | self.dropout = nn.Dropout(dropout) 68 | 69 | # Decoder state 70 | self.state = {} 71 | 72 | # Build the RNN. 73 | self.rnn = self._build_rnn(rnn_type, 74 | input_size=self._input_size, 75 | hidden_size=hidden_size, 76 | num_layers=num_layers, 77 | dropout=dropout) 78 | 79 | @classmethod 80 | def from_opt(cls, opt, embeddings): 81 | """Alternate constructor.""" 82 | return cls( 83 | opt.rnn_type, 84 | opt.dec_layers, 85 | opt.dec_rnn_size, 86 | opt.dropout, 87 | embeddings) 88 | 89 | def init_state(self, src, memory_bank, encoder_final): 90 | """Initialize decoder state with last state of the encoder.""" 91 | batch_size = memory_bank.shape[1] 92 | weight = next(self.parameters()) 93 | h = weight.new_zeros(self.num_layers, batch_size, self.hidden_size) 94 | c = weight.new_zeros(self.num_layers, batch_size, self.hidden_size) 95 | self.state["hidden"] = (h, c) 96 | 97 | def map_state(self, fn): 98 | self.state["hidden"] = tuple(fn(h, 1) for h in self.state["hidden"]) 99 | 100 | def detach_state(self): 101 | self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"]) 102 | 103 | def forward(self, tgt, memory_bank, memory_lengths=None, step=None, **kwargs): 104 | """ 105 | Args: 106 | tgt (LongTensor): sequences of padded tokens 107 | ``(tgt_len, batch, nfeats)``. 108 | memory_bank (FloatTensor): vectors from the encoder 109 | ``(src_len, batch, hidden)``. 110 | memory_lengths (LongTensor): the padded source lengths 111 | ``(batch,)``. 112 | 113 | Returns: 114 | (FloatTensor, dict[str, FloatTensor]): 115 | 116 | * dec_outs: output from the decoder (after attn) 117 | ``(tgt_len, batch, hidden)``. 118 | * attns: distribution over src at each tgt 119 | ``(tgt_len, batch, src_len)``. 120 | """ 121 | 122 | emb = self.embeddings(tgt) 123 | dec_outs, dec_state = self.rnn(emb, self.state["hidden"]) 124 | dec_outs = self.dropout(dec_outs) 125 | 126 | self.state["hidden"] = dec_state 127 | 128 | # Concatenates sequence of tensors along a new dimension. 129 | # NOTE: v0.3 to 0.4: dec_outs / attns[*] may not be list 130 | # (in particular in case of SRU) it was not raising error in 0.3 131 | # since stack(Variable) was allowed. 132 | # In 0.4, SRU returns a tensor that shouldn't be stacke 133 | if type(dec_outs) == list: 134 | dec_outs = torch.stack(dec_outs) 135 | 136 | return dec_outs, None 137 | 138 | def _build_rnn(self, rnn_type, **kwargs): 139 | rnn, _ = rnn_factory(rnn_type, **kwargs) 140 | return rnn 141 | 142 | @property 143 | def _input_size(self): 144 | return self.embeddings.embedding_size 145 | 146 | -------------------------------------------------------------------------------- /onmt/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining encoders.""" 2 | from onmt.encoders.encoder import EncoderBase 3 | from onmt.encoders.transformer import TransformerEncoder 4 | from onmt.encoders.rnn_encoder import RNNEncoder 5 | from onmt.encoders.cnn_encoder import CNNEncoder 6 | from onmt.encoders.mean_encoder import MeanEncoder 7 | from onmt.encoders.audio_encoder import AudioEncoder 8 | from onmt.encoders.image_encoder import ImageEncoder 9 | from onmt.encoders.imgvec_encoder import ImgVecEncoder 10 | from onmt.encoders.embonly import EmbOnlyEncoder 11 | 12 | class NoneEncoder: 13 | @classmethod 14 | def from_opt(cls, opt, embeddings): 15 | return None 16 | 17 | str2enc = {"rnn": RNNEncoder, "brnn": RNNEncoder, "cnn": CNNEncoder, 18 | "transformer": TransformerEncoder, "img": ImageEncoder, 19 | "audio": AudioEncoder, "mean": MeanEncoder, 20 | "embonly": EmbOnlyEncoder, 'imgvec': ImgVecEncoder, 21 | 'none': NoneEncoder} 22 | 23 | __all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", 24 | "MeanEncoder", "str2enc", "EmbOnlyEncoder"] 25 | -------------------------------------------------------------------------------- /onmt/encoders/audio_encoder.py: -------------------------------------------------------------------------------- 1 | """Audio encoder""" 2 | import math 3 | 4 | import torch.nn as nn 5 | 6 | from torch.nn.utils.rnn import pack_padded_sequence as pack 7 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 8 | 9 | from onmt.utils.rnn_factory import rnn_factory 10 | from onmt.encoders.encoder import EncoderBase 11 | 12 | 13 | class AudioEncoder(EncoderBase): 14 | """A simple encoder CNN -> RNN for audio input. 15 | 16 | Args: 17 | rnn_type (str): Type of RNN (e.g. GRU, LSTM, etc). 18 | enc_layers (int): Number of encoder layers. 19 | dec_layers (int): Number of decoder layers. 20 | brnn (bool): Bidirectional encoder. 21 | enc_rnn_size (int): Size of hidden states of the rnn. 22 | dec_rnn_size (int): Size of the decoder hidden states. 23 | enc_pooling (str): A comma separated list either of length 1 24 | or of length ``enc_layers`` specifying the pooling amount. 25 | dropout (float): dropout probablity. 26 | sample_rate (float): input spec 27 | window_size (int): input spec 28 | """ 29 | 30 | def __init__(self, rnn_type, enc_layers, dec_layers, brnn, 31 | enc_rnn_size, dec_rnn_size, enc_pooling, dropout, 32 | sample_rate, window_size): 33 | super(AudioEncoder, self).__init__() 34 | self.enc_layers = enc_layers 35 | self.rnn_type = rnn_type 36 | self.dec_layers = dec_layers 37 | num_directions = 2 if brnn else 1 38 | self.num_directions = num_directions 39 | assert enc_rnn_size % num_directions == 0 40 | enc_rnn_size_real = enc_rnn_size // num_directions 41 | assert dec_rnn_size % num_directions == 0 42 | self.dec_rnn_size = dec_rnn_size 43 | dec_rnn_size_real = dec_rnn_size // num_directions 44 | self.dec_rnn_size_real = dec_rnn_size_real 45 | self.dec_rnn_size = dec_rnn_size 46 | input_size = int(math.floor((sample_rate * window_size) / 2) + 1) 47 | enc_pooling = enc_pooling.split(',') 48 | assert len(enc_pooling) == enc_layers or len(enc_pooling) == 1 49 | if len(enc_pooling) == 1: 50 | enc_pooling = enc_pooling * enc_layers 51 | enc_pooling = [int(p) for p in enc_pooling] 52 | self.enc_pooling = enc_pooling 53 | 54 | if dropout > 0: 55 | self.dropout = nn.Dropout(dropout) 56 | else: 57 | self.dropout = None 58 | self.W = nn.Linear(enc_rnn_size, dec_rnn_size, bias=False) 59 | self.batchnorm_0 = nn.BatchNorm1d(enc_rnn_size, affine=True) 60 | self.rnn_0, self.no_pack_padded_seq = \ 61 | rnn_factory(rnn_type, 62 | input_size=input_size, 63 | hidden_size=enc_rnn_size_real, 64 | num_layers=1, 65 | dropout=dropout, 66 | bidirectional=brnn) 67 | self.pool_0 = nn.MaxPool1d(enc_pooling[0]) 68 | for l in range(enc_layers - 1): 69 | batchnorm = nn.BatchNorm1d(enc_rnn_size, affine=True) 70 | rnn, _ = \ 71 | rnn_factory(rnn_type, 72 | input_size=enc_rnn_size, 73 | hidden_size=enc_rnn_size_real, 74 | num_layers=1, 75 | dropout=dropout, 76 | bidirectional=brnn) 77 | setattr(self, 'rnn_%d' % (l + 1), rnn) 78 | setattr(self, 'pool_%d' % (l + 1), 79 | nn.MaxPool1d(enc_pooling[l + 1])) 80 | setattr(self, 'batchnorm_%d' % (l + 1), batchnorm) 81 | 82 | @classmethod 83 | def from_opt(cls, opt, embeddings=None): 84 | """Alternate constructor.""" 85 | if embeddings is not None: 86 | raise ValueError("Cannot use embeddings with AudioEncoder.") 87 | return cls( 88 | opt.rnn_type, 89 | opt.enc_layers, 90 | opt.dec_layers, 91 | opt.brnn, 92 | opt.enc_rnn_size, 93 | opt.dec_rnn_size, 94 | opt.audio_enc_pooling, 95 | opt.dropout, 96 | opt.sample_rate, 97 | opt.window_size) 98 | 99 | def forward(self, src, lengths=None): 100 | """See :func:`onmt.encoders.encoder.EncoderBase.forward()`""" 101 | batch_size, _, nfft, t = src.size() 102 | src = src.transpose(0, 1).transpose(0, 3).contiguous() \ 103 | .view(t, batch_size, nfft) 104 | orig_lengths = lengths 105 | lengths = lengths.view(-1).tolist() 106 | 107 | for l in range(self.enc_layers): 108 | rnn = getattr(self, 'rnn_%d' % l) 109 | pool = getattr(self, 'pool_%d' % l) 110 | batchnorm = getattr(self, 'batchnorm_%d' % l) 111 | stride = self.enc_pooling[l] 112 | packed_emb = pack(src, lengths) 113 | memory_bank, tmp = rnn(packed_emb) 114 | memory_bank = unpack(memory_bank)[0] 115 | t, _, _ = memory_bank.size() 116 | memory_bank = memory_bank.transpose(0, 2) 117 | memory_bank = pool(memory_bank) 118 | lengths = [int(math.floor((length - stride) / stride + 1)) 119 | for length in lengths] 120 | memory_bank = memory_bank.transpose(0, 2) 121 | src = memory_bank 122 | t, _, num_feat = src.size() 123 | src = batchnorm(src.contiguous().view(-1, num_feat)) 124 | src = src.view(t, -1, num_feat) 125 | if self.dropout and l + 1 != self.enc_layers: 126 | src = self.dropout(src) 127 | 128 | memory_bank = memory_bank.contiguous().view(-1, memory_bank.size(2)) 129 | memory_bank = self.W(memory_bank).view(-1, batch_size, 130 | self.dec_rnn_size) 131 | 132 | state = memory_bank.new_full((self.dec_layers * self.num_directions, 133 | batch_size, self.dec_rnn_size_real), 0) 134 | if self.rnn_type == 'LSTM': 135 | # The encoder hidden is (layers*directions) x batch x dim. 136 | encoder_final = (state, state) 137 | else: 138 | encoder_final = state 139 | return encoder_final, memory_bank, orig_lengths.new_tensor(lengths) 140 | -------------------------------------------------------------------------------- /onmt/encoders/cnn_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Convolutional Sequence to Sequence Learning" 3 | """ 4 | import torch.nn as nn 5 | 6 | from onmt.encoders.encoder import EncoderBase 7 | from onmt.utils.cnn_factory import shape_transform, StackedCNN 8 | 9 | SCALE_WEIGHT = 0.5 ** 0.5 10 | 11 | 12 | class CNNEncoder(EncoderBase): 13 | """Encoder based on "Convolutional Sequence to Sequence Learning" 14 | :cite:`DBLP:journals/corr/GehringAGYD17`. 15 | """ 16 | 17 | def __init__(self, num_layers, hidden_size, 18 | cnn_kernel_width, dropout, embeddings): 19 | super(CNNEncoder, self).__init__() 20 | 21 | self.embeddings = embeddings 22 | input_size = embeddings.embedding_size 23 | self.linear = nn.Linear(input_size, hidden_size) 24 | self.cnn = StackedCNN(num_layers, hidden_size, 25 | cnn_kernel_width, dropout) 26 | 27 | @classmethod 28 | def from_opt(cls, opt, embeddings): 29 | """Alternate constructor.""" 30 | return cls( 31 | opt.enc_layers, 32 | opt.enc_rnn_size, 33 | opt.cnn_kernel_width, 34 | opt.dropout, 35 | embeddings) 36 | 37 | def forward(self, input, lengths=None, hidden=None): 38 | """See :class:`onmt.modules.EncoderBase.forward()`""" 39 | self._check_args(input, lengths, hidden) 40 | 41 | emb = self.embeddings(input) 42 | # s_len, batch, emb_dim = emb.size() 43 | 44 | emb = emb.transpose(0, 1).contiguous() 45 | emb_reshape = emb.view(emb.size(0) * emb.size(1), -1) 46 | emb_remap = self.linear(emb_reshape) 47 | emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1) 48 | emb_remap = shape_transform(emb_remap) 49 | out = self.cnn(emb_remap) 50 | 51 | return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \ 52 | out.squeeze(3).transpose(0, 1).contiguous(), lengths 53 | -------------------------------------------------------------------------------- /onmt/encoders/embonly.py: -------------------------------------------------------------------------------- 1 | from onmt.encoders.encoder import EncoderBase 2 | 3 | class EmbOnlyEncoder(EncoderBase): 4 | def __init__(self, embeddings): 5 | super(EmbOnlyEncoder, self).__init__() 6 | self.embeddings = embeddings 7 | 8 | @classmethod 9 | def from_opt(cls, opt, embeddings): 10 | """Alternate constructor.""" 11 | return cls(embeddings) 12 | 13 | def forward(self, src, lengths=None): 14 | """See :func:`EncoderBase.forward()`""" 15 | self._check_args(src, lengths) 16 | emb = self.embeddings(src) 17 | return emb, emb, lengths 18 | -------------------------------------------------------------------------------- /onmt/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | """Base class for encoders and generic multi encoders.""" 2 | 3 | import torch.nn as nn 4 | 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | class EncoderBase(nn.Module): 9 | """ 10 | Base encoder class. Specifies the interface used by different encoder types 11 | and required by :class:`onmt.Models.NMTModel`. 12 | 13 | .. mermaid:: 14 | 15 | graph BT 16 | A[Input] 17 | subgraph RNN 18 | C[Pos 1] 19 | D[Pos 2] 20 | E[Pos N] 21 | end 22 | F[Memory_Bank] 23 | G[Final] 24 | A-->C 25 | A-->D 26 | A-->E 27 | C-->F 28 | D-->F 29 | E-->F 30 | E-->G 31 | """ 32 | 33 | @classmethod 34 | def from_opt(cls, opt, embeddings=None): 35 | raise NotImplementedError 36 | 37 | def _check_args(self, src, lengths=None, hidden=None): 38 | _, n_batch, _ = src.size() 39 | if lengths is not None: 40 | n_batch_, = lengths.size() 41 | aeq(n_batch, n_batch_) 42 | 43 | def forward(self, src, lengths=None): 44 | """ 45 | Args: 46 | src (LongTensor): 47 | padded sequences of sparse indices ``(src_len, batch, nfeat)`` 48 | lengths (LongTensor): length of each sequence ``(batch,)`` 49 | 50 | 51 | Returns: 52 | (FloatTensor, FloatTensor): 53 | 54 | * final encoder state, used to initialize decoder 55 | * memory bank for attention, ``(src_len, batch, hidden)`` 56 | """ 57 | 58 | raise NotImplementedError 59 | -------------------------------------------------------------------------------- /onmt/encoders/image_encoder.py: -------------------------------------------------------------------------------- 1 | """Image Encoder.""" 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | from onmt.encoders.encoder import EncoderBase 7 | 8 | 9 | class ImageEncoder(EncoderBase): 10 | """A simple encoder CNN -> RNN for image src. 11 | 12 | Args: 13 | num_layers (int): number of encoder layers. 14 | bidirectional (bool): bidirectional encoder. 15 | rnn_size (int): size of hidden states of the rnn. 16 | dropout (float): dropout probablity. 17 | """ 18 | 19 | def __init__(self, num_layers, bidirectional, rnn_size, dropout, 20 | image_chanel_size=3): 21 | super(ImageEncoder, self).__init__() 22 | self.num_layers = num_layers 23 | self.num_directions = 2 if bidirectional else 1 24 | self.hidden_size = rnn_size 25 | 26 | self.layer1 = nn.Conv2d(image_chanel_size, 64, kernel_size=(3, 3), 27 | padding=(1, 1), stride=(1, 1)) 28 | self.layer2 = nn.Conv2d(64, 128, kernel_size=(3, 3), 29 | padding=(1, 1), stride=(1, 1)) 30 | self.layer3 = nn.Conv2d(128, 256, kernel_size=(3, 3), 31 | padding=(1, 1), stride=(1, 1)) 32 | self.layer4 = nn.Conv2d(256, 256, kernel_size=(3, 3), 33 | padding=(1, 1), stride=(1, 1)) 34 | self.layer5 = nn.Conv2d(256, 512, kernel_size=(3, 3), 35 | padding=(1, 1), stride=(1, 1)) 36 | self.layer6 = nn.Conv2d(512, 512, kernel_size=(3, 3), 37 | padding=(1, 1), stride=(1, 1)) 38 | 39 | self.batch_norm1 = nn.BatchNorm2d(256) 40 | self.batch_norm2 = nn.BatchNorm2d(512) 41 | self.batch_norm3 = nn.BatchNorm2d(512) 42 | 43 | src_size = 512 44 | self.rnn = nn.LSTM(src_size, int(rnn_size / self.num_directions), 45 | num_layers=num_layers, 46 | dropout=dropout, 47 | bidirectional=bidirectional) 48 | self.pos_lut = nn.Embedding(1000, src_size) 49 | 50 | @classmethod 51 | def from_opt(cls, opt, embeddings=None): 52 | """Alternate constructor.""" 53 | if embeddings is not None: 54 | raise ValueError("Cannot use embeddings with ImageEncoder.") 55 | # why is the model_opt.__dict__ check necessary? 56 | if "image_channel_size" not in opt.__dict__: 57 | image_channel_size = 3 58 | else: 59 | image_channel_size = opt.image_channel_size 60 | return cls( 61 | opt.enc_layers, 62 | opt.brnn, 63 | opt.enc_rnn_size, 64 | opt.dropout, 65 | image_channel_size 66 | ) 67 | 68 | def load_pretrained_vectors(self, opt): 69 | """Pass in needed options only when modify function definition.""" 70 | pass 71 | 72 | def forward(self, src, lengths=None): 73 | """See :func:`onmt.encoders.encoder.EncoderBase.forward()`""" 74 | 75 | batch_size = src.size(0) 76 | # (batch_size, 64, imgH, imgW) 77 | # layer 1 78 | src = F.relu(self.layer1(src[:, :, :, :] - 0.5), True) 79 | 80 | # (batch_size, 64, imgH/2, imgW/2) 81 | src = F.max_pool2d(src, kernel_size=(2, 2), stride=(2, 2)) 82 | 83 | # (batch_size, 128, imgH/2, imgW/2) 84 | # layer 2 85 | src = F.relu(self.layer2(src), True) 86 | 87 | # (batch_size, 128, imgH/2/2, imgW/2/2) 88 | src = F.max_pool2d(src, kernel_size=(2, 2), stride=(2, 2)) 89 | 90 | # (batch_size, 256, imgH/2/2, imgW/2/2) 91 | # layer 3 92 | # batch norm 1 93 | src = F.relu(self.batch_norm1(self.layer3(src)), True) 94 | 95 | # (batch_size, 256, imgH/2/2, imgW/2/2) 96 | # layer4 97 | src = F.relu(self.layer4(src), True) 98 | 99 | # (batch_size, 256, imgH/2/2/2, imgW/2/2) 100 | src = F.max_pool2d(src, kernel_size=(1, 2), stride=(1, 2)) 101 | 102 | # (batch_size, 512, imgH/2/2/2, imgW/2/2) 103 | # layer 5 104 | # batch norm 2 105 | src = F.relu(self.batch_norm2(self.layer5(src)), True) 106 | 107 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2) 108 | src = F.max_pool2d(src, kernel_size=(2, 1), stride=(2, 1)) 109 | 110 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2) 111 | src = F.relu(self.batch_norm3(self.layer6(src)), True) 112 | 113 | # # (batch_size, 512, H, W) 114 | all_outputs = [] 115 | for row in range(src.size(2)): 116 | inp = src[:, :, row, :].transpose(0, 2) \ 117 | .transpose(1, 2) 118 | row_vec = torch.Tensor(batch_size).type_as(inp.data) \ 119 | .long().fill_(row) 120 | pos_emb = self.pos_lut(row_vec) 121 | with_pos = torch.cat( 122 | (pos_emb.view(1, pos_emb.size(0), pos_emb.size(1)), inp), 0) 123 | outputs, hidden_t = self.rnn(with_pos) 124 | all_outputs.append(outputs) 125 | out = torch.cat(all_outputs, 0) 126 | 127 | return hidden_t, out, lengths 128 | -------------------------------------------------------------------------------- /onmt/encoders/imgvec_encoder.py: -------------------------------------------------------------------------------- 1 | """Define a minimal encoder.""" 2 | from onmt.encoders.encoder import EncoderBase 3 | from torch import nn 4 | 5 | 6 | class ImgVecEncoder(EncoderBase): 7 | """A trivial non-recurrent encoder. Simply applies mean pooling. 8 | 9 | Args: 10 | num_layers (int): number of replicated layers 11 | embeddings (onmt.modules.Embeddings): embedding module to use 12 | """ 13 | 14 | def __init__(self, num_layers, emb_dim, outp_dim): 15 | super(ImgVecEncoder, self).__init__() 16 | self.num_layers = num_layers 17 | self.proj = nn.Linear(emb_dim, outp_dim) 18 | 19 | @classmethod 20 | def from_opt(cls, opt, embeddings): 21 | """Alternate constructor.""" 22 | return cls( 23 | opt.enc_layers, 24 | opt.image_channel_size, 25 | opt.word_vec_size) 26 | 27 | def forward(self, emb, lengths=None): 28 | """See :func:`EncoderBase.forward()`""" 29 | self._check_args(emb, lengths) 30 | 31 | emb = self.proj(emb) 32 | _, batch, emb_dim = emb.size() 33 | 34 | mean = emb.mean(0).expand(self.num_layers, batch, emb_dim) 35 | memory_bank = emb 36 | encoder_final = (mean, mean) 37 | return encoder_final, memory_bank, lengths 38 | -------------------------------------------------------------------------------- /onmt/encoders/mean_encoder.py: -------------------------------------------------------------------------------- 1 | """Define a minimal encoder.""" 2 | from onmt.encoders.encoder import EncoderBase 3 | 4 | 5 | class MeanEncoder(EncoderBase): 6 | """A trivial non-recurrent encoder. Simply applies mean pooling. 7 | 8 | Args: 9 | num_layers (int): number of replicated layers 10 | embeddings (onmt.modules.Embeddings): embedding module to use 11 | """ 12 | 13 | def __init__(self, num_layers, embeddings): 14 | super(MeanEncoder, self).__init__() 15 | self.num_layers = num_layers 16 | self.embeddings = embeddings 17 | 18 | @classmethod 19 | def from_opt(cls, opt, embeddings): 20 | """Alternate constructor.""" 21 | return cls( 22 | opt.enc_layers, 23 | embeddings) 24 | 25 | def forward(self, src, lengths=None): 26 | """See :func:`EncoderBase.forward()`""" 27 | self._check_args(src, lengths) 28 | 29 | emb = self.embeddings(src) 30 | _, batch, emb_dim = emb.size() 31 | mean = emb.mean(0).expand(self.num_layers, batch, emb_dim) 32 | memory_bank = emb 33 | encoder_final = (mean, mean) 34 | return encoder_final, memory_bank, lengths 35 | -------------------------------------------------------------------------------- /onmt/encoders/rnn_encoder.py: -------------------------------------------------------------------------------- 1 | """Define RNN-based encoders.""" 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.nn.utils.rnn import pack_padded_sequence as pack 6 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 7 | 8 | from onmt.encoders.encoder import EncoderBase 9 | from onmt.utils.rnn_factory import rnn_factory 10 | 11 | 12 | class RNNEncoder(EncoderBase): 13 | """ A generic recurrent neural network encoder. 14 | 15 | Args: 16 | rnn_type (str): 17 | style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU] 18 | bidirectional (bool) : use a bidirectional RNN 19 | num_layers (int) : number of stacked layers 20 | hidden_size (int) : hidden size of each layer 21 | dropout (float) : dropout value for :class:`torch.nn.Dropout` 22 | embeddings (onmt.modules.Embeddings): embedding module to use 23 | """ 24 | 25 | def __init__(self, rnn_type, bidirectional, num_layers, 26 | hidden_size, dropout=0.0, embeddings=None, 27 | use_bridge=False): 28 | super(RNNEncoder, self).__init__() 29 | assert embeddings is not None 30 | 31 | num_directions = 2 if bidirectional else 1 32 | assert hidden_size % num_directions == 0 33 | hidden_size = hidden_size // num_directions 34 | self.embeddings = embeddings 35 | 36 | self.rnn, self.no_pack_padded_seq = \ 37 | rnn_factory(rnn_type, 38 | input_size=embeddings.embedding_size, 39 | hidden_size=hidden_size, 40 | num_layers=num_layers, 41 | dropout=dropout, 42 | bidirectional=bidirectional) 43 | 44 | # Initialize the bridge layer 45 | self.use_bridge = use_bridge 46 | if self.use_bridge: 47 | self._initialize_bridge(rnn_type, 48 | hidden_size, 49 | num_layers) 50 | 51 | @classmethod 52 | def from_opt(cls, opt, embeddings): 53 | """Alternate constructor.""" 54 | return cls( 55 | opt.rnn_type, 56 | opt.brnn, 57 | opt.enc_layers, 58 | opt.enc_rnn_size, 59 | opt.dropout, 60 | embeddings, 61 | opt.bridge) 62 | 63 | def forward(self, src, lengths=None): 64 | """See :func:`EncoderBase.forward()`""" 65 | self._check_args(src, lengths) 66 | 67 | emb = self.embeddings(src) 68 | # s_len, batch, emb_dim = emb.size() 69 | 70 | packed_emb = emb 71 | if lengths is not None and not self.no_pack_padded_seq: 72 | # Lengths data is wrapped inside a Tensor. 73 | lengths_list = lengths.view(-1).tolist() 74 | packed_emb = pack(emb, lengths_list) 75 | 76 | memory_bank, encoder_final = self.rnn(packed_emb) 77 | 78 | if lengths is not None and not self.no_pack_padded_seq: 79 | memory_bank = unpack(memory_bank)[0] 80 | 81 | if self.use_bridge: 82 | encoder_final = self._bridge(encoder_final) 83 | return encoder_final, memory_bank, lengths 84 | 85 | def _initialize_bridge(self, rnn_type, 86 | hidden_size, 87 | num_layers): 88 | 89 | # LSTM has hidden and cell state, other only one 90 | number_of_states = 2 if rnn_type == "LSTM" else 1 91 | # Total number of states 92 | self.total_hidden_dim = hidden_size * num_layers 93 | 94 | # Build a linear layer for each 95 | self.bridge = nn.ModuleList([nn.Linear(self.total_hidden_dim, 96 | self.total_hidden_dim, 97 | bias=True) 98 | for _ in range(number_of_states)]) 99 | 100 | def _bridge(self, hidden): 101 | """Forward hidden state through bridge.""" 102 | def bottle_hidden(linear, states): 103 | """ 104 | Transform from 3D to 2D, apply linear and return initial size 105 | """ 106 | size = states.size() 107 | result = linear(states.view(-1, self.total_hidden_dim)) 108 | return F.relu(result).view(size) 109 | 110 | if isinstance(hidden, tuple): # LSTM 111 | outs = tuple([bottle_hidden(layer, hidden[ix]) 112 | for ix, layer in enumerate(self.bridge)]) 113 | else: 114 | outs = bottle_hidden(self.bridge[0], hidden) 115 | return outs 116 | -------------------------------------------------------------------------------- /onmt/inputters/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining inputters. 2 | 3 | Inputters implement the logic of transforming raw data to vectorized inputs, 4 | e.g., from a line of text to a sequence of embeddings. 5 | """ 6 | from onmt.inputters.inputter import \ 7 | load_old_vocab, get_fields, OrderedIterator, \ 8 | build_vocab, old_style_vocab, filter_example 9 | from onmt.inputters.dataset_base import Dataset 10 | from onmt.inputters.text_dataset import text_sort_key, TextDataReader 11 | from onmt.inputters.image_dataset import img_sort_key, ImageDataReader 12 | from onmt.inputters.audio_dataset import audio_sort_key, AudioDataReader 13 | from onmt.inputters.image_vec_dataset import img_vec_sort_key, ImageVecDataReader 14 | from onmt.inputters.datareader_base import DataReaderBase 15 | 16 | 17 | str2reader = { 18 | "text": TextDataReader, "img": ImageDataReader, 19 | "audio": AudioDataReader, 'imgvec': ImageVecDataReader} 20 | 21 | def none_sort(ex): 22 | return len(ex.tgt[0]) if hasattr(ex, 'tgt') else 0 23 | str2sortkey = { 24 | 'text': text_sort_key, 'img': img_sort_key, 25 | 'audio': audio_sort_key, 'imgvec': img_vec_sort_key, 26 | 'none': none_sort} 27 | 28 | 29 | __all__ = ['Dataset', 'load_old_vocab', 'get_fields', 'DataReaderBase', 30 | 'filter_example', 'old_style_vocab', 31 | 'build_vocab', 'OrderedIterator', 32 | 'text_sort_key', 'img_sort_key', 'audio_sort_key', 33 | 'TextDataReader', 'ImageDataReader', 'AudioDataReader', 34 | 'img_vec_sort_key', 'ImageVecDataReader'] 35 | -------------------------------------------------------------------------------- /onmt/inputters/datareader_base.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | 4 | # several data readers need optional dependencies. There's no 5 | # appropriate builtin exception 6 | class MissingDependencyException(Exception): 7 | pass 8 | 9 | 10 | class DataReaderBase(object): 11 | """Read data from file system and yield as dicts. 12 | 13 | Raises: 14 | onmt.inputters.datareader_base.MissingDependencyException: A number 15 | of DataReaders need specific additional packages. 16 | If any are missing, this will be raised. 17 | """ 18 | 19 | @classmethod 20 | def from_opt(cls, opt): 21 | """Alternative constructor. 22 | 23 | Args: 24 | opt (argparse.Namespace): The parsed arguments. 25 | """ 26 | 27 | return cls() 28 | 29 | @classmethod 30 | def _read_file(cls, path): 31 | """Line-by-line read a file as bytes.""" 32 | with open(path, "rb") as f: 33 | for line in f: 34 | yield line 35 | 36 | @staticmethod 37 | def _raise_missing_dep(*missing_deps): 38 | """Raise missing dep exception with standard error message.""" 39 | raise MissingDependencyException( 40 | "Could not create reader. Be sure to install " 41 | "the following dependencies: " + ", ".join(missing_deps)) 42 | 43 | def read(self, data, side, src_dir): 44 | """Read data from file system and yield as dicts.""" 45 | raise NotImplementedError() 46 | -------------------------------------------------------------------------------- /onmt/inputters/image_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import torch 6 | from torchtext.data import Field 7 | 8 | from onmt.inputters.datareader_base import DataReaderBase 9 | 10 | # domain specific dependencies 11 | try: 12 | from PIL import Image 13 | from torchvision import transforms 14 | import cv2 15 | except ImportError: 16 | Image, transforms, cv2 = None, None, None 17 | 18 | 19 | class ImageDataReader(DataReaderBase): 20 | """Read image data from disk. 21 | 22 | Args: 23 | truncate (tuple[int] or NoneType): maximum img size. Use 24 | ``(0,0)`` or ``None`` for unlimited. 25 | channel_size (int): Number of channels per image. 26 | 27 | Raises: 28 | onmt.inputters.datareader_base.MissingDependencyException: If 29 | importing any of ``PIL``, ``torchvision``, or ``cv2`` fail. 30 | """ 31 | 32 | def __init__(self, truncate=None, channel_size=3): 33 | self._check_deps() 34 | self.truncate = truncate 35 | self.channel_size = channel_size 36 | 37 | @classmethod 38 | def from_opt(cls, opt): 39 | return cls(channel_size=opt.image_channel_size) 40 | 41 | @classmethod 42 | def _check_deps(cls): 43 | if any([Image is None, transforms is None, cv2 is None]): 44 | cls._raise_missing_dep( 45 | "PIL", "torchvision", "cv2") 46 | 47 | def read(self, images, side, img_dir=None): 48 | """Read data into dicts. 49 | 50 | Args: 51 | images (str or Iterable[str]): Sequence of image paths or 52 | path to file containing audio paths. 53 | In either case, the filenames may be relative to ``src_dir`` 54 | (default behavior) or absolute. 55 | side (str): Prefix used in return dict. Usually 56 | ``"src"`` or ``"tgt"``. 57 | img_dir (str): Location of source image files. See ``images``. 58 | 59 | Yields: 60 | a dictionary containing image data, path and index for each line. 61 | """ 62 | if isinstance(images, str): 63 | images = DataReaderBase._read_file(images) 64 | 65 | for i, filename in enumerate(images): 66 | filename = filename.decode("utf-8").strip() 67 | img_path = os.path.join(img_dir, filename) 68 | if not os.path.exists(img_path): 69 | img_path = filename 70 | 71 | assert os.path.exists(img_path), \ 72 | 'img path %s not found' % filename 73 | 74 | if self.channel_size == 1: 75 | img = transforms.ToTensor()( 76 | Image.fromarray(cv2.imread(img_path, 0))) 77 | else: 78 | img = transforms.ToTensor()(Image.open(img_path)) 79 | if self.truncate and self.truncate != (0, 0): 80 | if not (img.size(1) <= self.truncate[0] 81 | and img.size(2) <= self.truncate[1]): 82 | continue 83 | yield {side: img, side + '_path': filename, 'indices': i} 84 | 85 | 86 | def img_sort_key(ex): 87 | """Sort using the size of the image: (width, height).""" 88 | return ex.src.size(2), ex.src.size(1) 89 | 90 | 91 | def batch_img(data, vocab): 92 | """Pad and batch a sequence of images.""" 93 | c = data[0].size(0) 94 | h = max([t.size(1) for t in data]) 95 | w = max([t.size(2) for t in data]) 96 | imgs = torch.zeros(len(data), c, h, w).fill_(1) 97 | for i, img in enumerate(data): 98 | imgs[i, :, 0:img.size(1), 0:img.size(2)] = img 99 | return imgs 100 | 101 | 102 | def image_fields(**kwargs): 103 | img = Field( 104 | use_vocab=False, dtype=torch.float, 105 | postprocessing=batch_img, sequential=False) 106 | return img 107 | -------------------------------------------------------------------------------- /onmt/inputters/image_vec_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import numpy as np 5 | 6 | import torch 7 | from torchtext.data import Field 8 | 9 | from onmt.inputters.datareader_base import DataReaderBase 10 | 11 | 12 | class ImageVecDataReader(DataReaderBase): 13 | """Read image data vectors from disk. 14 | 15 | Args: 16 | channel_size (int): Number of channels per image. 17 | """ 18 | def read(self, path, side, img_dir=None): 19 | """Read data into dicts. 20 | 21 | Args: 22 | path (str): Path to npy file with saved image vectors 23 | The filenames may be relative to ``src_dir`` 24 | (default behavior) or absolute. 25 | side (str): Prefix used in return dict. Usually 26 | ``"src"`` or ``"tgt"``. 27 | 28 | Yields: 29 | a dictionary containing image data and index for each line. 30 | """ 31 | features = np.load(path, encoding='latin1')['vec_list'] 32 | 33 | for i in range(features.shape[0]): 34 | yield {side: torch.tensor(features[i]), 'indices': i} 35 | 36 | def img_vec_sort_key(ex): 37 | """Sort using the number of image box features.""" 38 | return ex.src.size(0) 39 | 40 | def batch_img_vec(data, vocab): 41 | """Batch a sequence of image vectors.""" 42 | imgs = torch.stack(data, dim=1) # [K, B, dim] 43 | return imgs 44 | 45 | def image_vec_fields(**kwargs): 46 | img = Field( 47 | use_vocab=False, dtype=torch.float, 48 | postprocessing=batch_img_vec, sequential=False) 49 | return img 50 | -------------------------------------------------------------------------------- /onmt/inputters/none_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import numpy as np 5 | 6 | import torch 7 | from torchtext.data import Field 8 | 9 | from onmt.inputters.datareader_base import DataReaderBase 10 | 11 | 12 | class NoneDataReader(DataReaderBase): 13 | """Read image data vectors from disk. 14 | 15 | Args: 16 | channel_size (int): Number of channels per image. 17 | """ 18 | def read(self, path, side, img_dir=None): 19 | """Read data into dicts. 20 | 21 | Args: 22 | path (str): Path to npy file with saved image vectors 23 | The filenames may be relative to ``src_dir`` 24 | (default behavior) or absolute. 25 | side (str): Prefix used in return dict. Usually 26 | ``"src"`` or ``"tgt"``. 27 | 28 | Yields: 29 | a dictionary containing image data and index for each line. 30 | """ 31 | 32 | for i in range(features.shape[0]): 33 | yield {side: torch.tensor(features[i]), 'indices': i} 34 | 35 | def img_vec_sort_key(ex): 36 | """Sort using the number of image box features.""" 37 | return ex.src.size(0) 38 | 39 | def batch_img_vec(data, vocab): 40 | """Batch a sequence of image vectors.""" 41 | imgs = torch.stack(data, dim=1) # [K, B, dim] 42 | return imgs 43 | 44 | def image_vec_fields(**kwargs): 45 | img = Field( 46 | use_vocab=False, dtype=torch.float, 47 | postprocessing=batch_img_vec, sequential=False) 48 | return img 49 | -------------------------------------------------------------------------------- /onmt/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining models.""" 2 | from onmt.models.model_saver import build_model_saver, ModelSaver 3 | from onmt.models.model import NMTModel 4 | from onmt.models.uncond_model import UncondModel 5 | from onmt.models.simple_fusion_model import SimpleFusionModel 6 | 7 | __all__ = ["build_model_saver", "ModelSaver", 8 | "NMTModel", "SimpleFusionModel", 9 | "UncondModel", 10 | "check_sru_requirement"] 11 | -------------------------------------------------------------------------------- /onmt/models/model.py: -------------------------------------------------------------------------------- 1 | """ Onmt NMT Model base class definition """ 2 | import torch.nn as nn 3 | 4 | 5 | class NMTModel(nn.Module): 6 | """ 7 | Core trainable object in OpenNMT. Implements a trainable interface 8 | for a simple, generic encoder + decoder model. 9 | 10 | Args: 11 | encoder (onmt.encoders.EncoderBase): an encoder object 12 | decoder (onmt.decoders.DecoderBase): a decoder object 13 | """ 14 | 15 | def __init__(self, encoder, decoder): 16 | super(NMTModel, self).__init__() 17 | self.encoder = encoder 18 | self.decoder = decoder 19 | 20 | def forward(self, src, tgt, lengths, bptt=False, **kwargs): 21 | """Forward propagate a `src` and `tgt` pair for training. 22 | Possible initialized with a beginning decoder state. 23 | 24 | Args: 25 | src (Tensor): A source sequence passed to encoder. 26 | typically for inputs this will be a padded `LongTensor` 27 | of size ``(len, batch, features)``. However, may be an 28 | image or other generic input depending on encoder. 29 | tgt (LongTensor): A target sequence of size ``(tgt_len, batch)``. 30 | lengths(LongTensor): The src lengths, pre-padding ``(batch,)``. 31 | bptt (Boolean): A flag indicating if truncated bptt is set. 32 | If reset then init_state 33 | 34 | Returns: 35 | (FloatTensor, dict[str, FloatTensor]): 36 | 37 | * decoder output ``(tgt_len, batch, hidden)`` 38 | * dictionary attention dists of ``(tgt_len, batch, src_len)`` 39 | """ 40 | tgt = tgt[:-1] # exclude last target from inputs 41 | 42 | enc_state, memory_bank, lengths = self.encoder(src, lengths) 43 | if bptt is False: 44 | self.decoder.init_state(src, memory_bank, enc_state) 45 | dec_out, attns = self.decoder(tgt, memory_bank, 46 | memory_lengths=lengths, **kwargs) 47 | return dec_out, attns 48 | -------------------------------------------------------------------------------- /onmt/models/model_saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from collections import deque 6 | from onmt.utils.logging import logger 7 | 8 | from copy import deepcopy 9 | 10 | 11 | def build_model_saver(model_opt, opt, model, fields, optim): 12 | model_saver = ModelSaver(opt.save_model, 13 | model, 14 | model_opt, 15 | fields, 16 | optim, 17 | opt.keep_checkpoint) 18 | return model_saver 19 | 20 | 21 | class ModelSaverBase(object): 22 | """Base class for model saving operations 23 | 24 | Inherited classes must implement private methods: 25 | * `_save` 26 | * `_rm_checkpoint 27 | """ 28 | 29 | def __init__(self, base_path, model, model_opt, fields, optim, 30 | keep_checkpoint=-1): 31 | self.base_path = base_path 32 | self.model = model 33 | self.model_opt = model_opt 34 | self.fields = fields 35 | self.optim = optim 36 | self.last_saved_step = None 37 | self.keep_checkpoint = keep_checkpoint 38 | if keep_checkpoint > 0: 39 | self.checkpoint_queue = deque([], maxlen=keep_checkpoint) 40 | 41 | def save(self, step, moving_average=None): 42 | """Main entry point for model saver 43 | 44 | It wraps the `_save` method with checks and apply `keep_checkpoint` 45 | related logic 46 | """ 47 | 48 | if self.keep_checkpoint == 0 or step == self.last_saved_step: 49 | return 50 | 51 | if moving_average: 52 | save_model = deepcopy(self.model) 53 | for avg, param in zip(moving_average, save_model.parameters()): 54 | param.data.copy_(avg.data) 55 | else: 56 | save_model = self.model 57 | 58 | chkpt, chkpt_name = self._save(step, save_model) 59 | self.last_saved_step = step 60 | 61 | if moving_average: 62 | del save_model 63 | 64 | if self.keep_checkpoint > 0: 65 | if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen: 66 | todel = self.checkpoint_queue.popleft() 67 | self._rm_checkpoint(todel) 68 | self.checkpoint_queue.append(chkpt_name) 69 | 70 | def _save(self, step): 71 | """Save a resumable checkpoint. 72 | 73 | Args: 74 | step (int): step number 75 | 76 | Returns: 77 | (object, str): 78 | 79 | * checkpoint: the saved object 80 | * checkpoint_name: name (or path) of the saved checkpoint 81 | """ 82 | 83 | raise NotImplementedError() 84 | 85 | def _rm_checkpoint(self, name): 86 | """Remove a checkpoint 87 | 88 | Args: 89 | name(str): name that indentifies the checkpoint 90 | (it may be a filepath) 91 | """ 92 | 93 | raise NotImplementedError() 94 | 95 | 96 | class ModelSaver(ModelSaverBase): 97 | """Simple model saver to filesystem""" 98 | 99 | def _save(self, step, model): 100 | real_model = (model.module 101 | if isinstance(model, nn.DataParallel) 102 | else model) 103 | real_generator = (real_model.generator.module 104 | if isinstance(real_model.generator, nn.DataParallel) 105 | else real_model.generator) 106 | 107 | model_state_dict = real_model.state_dict() 108 | model_state_dict = {k: v for k, v in model_state_dict.items() 109 | if 'generator' not in k} 110 | generator_state_dict = real_generator.state_dict() 111 | checkpoint = { 112 | 'model': model_state_dict, 113 | 'generator': generator_state_dict, 114 | 'vocab': self.fields, 115 | 'opt': self.model_opt, 116 | 'optim': self.optim.state_dict(), 117 | } 118 | 119 | logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step)) 120 | checkpoint_path = '%s_step_%d.pt' % (self.base_path, step) 121 | torch.save(checkpoint, checkpoint_path) 122 | return checkpoint, checkpoint_path 123 | 124 | def _rm_checkpoint(self, name): 125 | os.remove(name) 126 | -------------------------------------------------------------------------------- /onmt/models/simple_fusion_model.py: -------------------------------------------------------------------------------- 1 | """ Onmt NMT Model base class definition """ 2 | import torch.nn as nn 3 | 4 | 5 | class SimpleFusionModel(nn.Module): 6 | """ 7 | Core trainable object in OpenNMT. Implements a trainable interface 8 | for a simple, generic encoder + decoder model. 9 | 10 | Args: 11 | encoder (onmt.encoders.EncoderBase): an encoder object 12 | decoder (onmt.decoders.DecoderBase): a decoder object 13 | """ 14 | 15 | def __init__(self, encoder, decoder, lm_decoder): 16 | super(SimpleFusionModel, self).__init__() 17 | self.encoder = encoder 18 | self.decoder = decoder 19 | self.lm_decoder = lm_decoder 20 | 21 | def forward(self, src, tgt, lengths, bptt=False, **kwargs): 22 | """Forward propagate a `src` and `tgt` pair for training. 23 | Possible initialized with a beginning decoder state. 24 | 25 | Args: 26 | src (Tensor): A source sequence passed to encoder. 27 | typically for inputs this will be a padded `LongTensor` 28 | of size ``(len, batch, features)``. However, may be an 29 | image or other generic input depending on encoder. 30 | tgt (LongTensor): A target sequence of size ``(tgt_len, batch)``. 31 | lengths(LongTensor): The src lengths, pre-padding ``(batch,)``. 32 | bptt (Boolean): A flag indicating if truncated bptt is set. 33 | If reset then init_state 34 | 35 | Returns: 36 | (FloatTensor, dict[str, FloatTensor]): 37 | 38 | * decoder output ``(tgt_len, batch, hidden)`` 39 | * dictionary attention dists of ``(tgt_len, batch, src_len)`` 40 | """ 41 | tgt = tgt[:-1] # exclude last target from inputs 42 | 43 | enc_state, memory_bank, lengths = self.encoder(src, lengths) 44 | if bptt is False: 45 | self.decoder.init_state(src, memory_bank, enc_state) 46 | dec_out, attns = self.decoder(tgt, memory_bank, 47 | memory_lengths=lengths, **kwargs) 48 | 49 | if bptt is False: 50 | self.lm_decoder.init_state(src, None, None) 51 | lm_dec_out, _ = self.lm_decoder(tgt, memory_bank.new_zeros(1, 1, 1)) 52 | return [dec_out, lm_dec_out], attns 53 | -------------------------------------------------------------------------------- /onmt/models/stacked_rnn.py: -------------------------------------------------------------------------------- 1 | """ Implementation of ONMT RNN for Input Feeding Decoding """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class StackedLSTM(nn.Module): 7 | """ 8 | Our own implementation of stacked LSTM. 9 | Needed for the decoder, because we do input feeding. 10 | """ 11 | 12 | def __init__(self, num_layers, input_size, rnn_size, dropout): 13 | super(StackedLSTM, self).__init__() 14 | self.dropout = nn.Dropout(dropout) 15 | self.num_layers = num_layers 16 | self.layers = nn.ModuleList() 17 | 18 | for _ in range(num_layers): 19 | self.layers.append(nn.LSTMCell(input_size, rnn_size)) 20 | input_size = rnn_size 21 | 22 | def forward(self, input_feed, hidden): 23 | h_0, c_0 = hidden 24 | h_1, c_1 = [], [] 25 | for i, layer in enumerate(self.layers): 26 | h_1_i, c_1_i = layer(input_feed, (h_0[i], c_0[i])) 27 | input_feed = h_1_i 28 | if i + 1 != self.num_layers: 29 | input_feed = self.dropout(input_feed) 30 | h_1 += [h_1_i] 31 | c_1 += [c_1_i] 32 | 33 | h_1 = torch.stack(h_1) 34 | c_1 = torch.stack(c_1) 35 | 36 | return input_feed, (h_1, c_1) 37 | 38 | 39 | class StackedGRU(nn.Module): 40 | """ 41 | Our own implementation of stacked GRU. 42 | Needed for the decoder, because we do input feeding. 43 | """ 44 | 45 | def __init__(self, num_layers, input_size, rnn_size, dropout): 46 | super(StackedGRU, self).__init__() 47 | self.dropout = nn.Dropout(dropout) 48 | self.num_layers = num_layers 49 | self.layers = nn.ModuleList() 50 | 51 | for _ in range(num_layers): 52 | self.layers.append(nn.GRUCell(input_size, rnn_size)) 53 | input_size = rnn_size 54 | 55 | def forward(self, input_feed, hidden): 56 | h_1 = [] 57 | for i, layer in enumerate(self.layers): 58 | h_1_i = layer(input_feed, hidden[0][i]) 59 | input_feed = h_1_i 60 | if i + 1 != self.num_layers: 61 | input_feed = self.dropout(input_feed) 62 | h_1 += [h_1_i] 63 | 64 | h_1 = torch.stack(h_1) 65 | return input_feed, (h_1,) 66 | -------------------------------------------------------------------------------- /onmt/models/uncond_model.py: -------------------------------------------------------------------------------- 1 | """ Onmt NMT Model base class definition """ 2 | import torch.nn as nn 3 | import torch 4 | 5 | class UncondModel(nn.Module): 6 | """ 7 | Core trainable object in OpenNMT. Implements a trainable interface 8 | for a simple, generic encoder + decoder model. 9 | 10 | Args: 11 | encoder (onmt.encoders.EncoderBase): an encoder object 12 | decoder (onmt.decoders.DecoderBase): a decoder object 13 | """ 14 | 15 | def __init__(self, decoder): 16 | super(UncondModel, self).__init__() 17 | self.decoder = decoder 18 | 19 | def forward(self, src, tgt, lengths, bptt=False, **kwargs): 20 | """Forward propagate a `src` and `tgt` pair for training. 21 | Possible initialized with a beginning decoder state. 22 | 23 | Args: 24 | src (Tensor): A source sequence passed to encoder. 25 | typically for inputs this will be a padded `LongTensor` 26 | of size ``(len, batch, features)``. However, may be an 27 | image or other generic input depending on encoder. 28 | tgt (LongTensor): A target sequence of size ``(tgt_len, batch)``. 29 | lengths(LongTensor): The src lengths, pre-padding ``(batch,)``. 30 | bptt (Boolean): A flag indicating if truncated bptt is set. 31 | If reset then init_state 32 | 33 | Returns: 34 | (FloatTensor, dict[str, FloatTensor]): 35 | 36 | * decoder output ``(tgt_len, batch, hidden)`` 37 | * dictionary attention dists of ``(tgt_len, batch, src_len)`` 38 | """ 39 | tgt = tgt[:-1] # exclude last target from inputs 40 | 41 | memory_bank = torch.zeros((1, tgt.shape[1], 1), dtype=torch.float, device=tgt.device) 42 | 43 | if bptt is False: 44 | self.decoder.init_state(src, memory_bank, None) 45 | dec_out, attns = self.decoder(tgt, memory_bank, 46 | memory_lengths=lengths, **kwargs) 47 | return dec_out, attns 48 | -------------------------------------------------------------------------------- /onmt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """ Attention and normalization modules """ 2 | from onmt.modules.util_class import Elementwise 3 | from onmt.modules.gate import context_gate_factory, ContextGate 4 | from onmt.modules.global_attention import GlobalAttention 5 | from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention 6 | from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ 7 | CopyGeneratorLossCompute 8 | from onmt.modules.multi_headed_attn import MultiHeadedAttention, JointMultiHeadedAttention 9 | from onmt.modules.embeddings import Embeddings, PositionalEncoding 10 | from onmt.modules.weight_norm import WeightNormConv2d 11 | from onmt.modules.average_attn import AverageAttention 12 | from onmt.modules.simple_fusion_generator import SimpleFusionGenerator 13 | from onmt.modules.gpt_mlp import MLP 14 | 15 | __all__ = ["Elementwise", "context_gate_factory", "ContextGate", 16 | "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", 17 | "CopyGeneratorLoss", "CopyGeneratorLossCompute", 18 | "MultiHeadedAttention", "Embeddings", "PositionalEncoding", 19 | "WeightNormConv2d", "AverageAttention", "JointMultiHeadedAttention", 20 | "SimpleFusionGenerator", "MLP"] 21 | -------------------------------------------------------------------------------- /onmt/modules/average_attn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Average Attention module.""" 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from onmt.modules.position_ffn import PositionwiseFeedForward 8 | 9 | 10 | class AverageAttention(nn.Module): 11 | """ 12 | Average Attention module from 13 | "Accelerating Neural Transformer via an Average Attention Network" 14 | :cite:`DBLP:journals/corr/abs-1805-00631`. 15 | 16 | Args: 17 | model_dim (int): the dimension of keys/values/queries, 18 | must be divisible by head_count 19 | dropout (float): dropout parameter 20 | """ 21 | 22 | def __init__(self, model_dim, dropout=0.1): 23 | self.model_dim = model_dim 24 | 25 | super(AverageAttention, self).__init__() 26 | 27 | self.average_layer = PositionwiseFeedForward(model_dim, model_dim, 28 | dropout) 29 | self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2) 30 | 31 | def cumulative_average_mask(self, batch_size, inputs_len): 32 | """ 33 | Builds the mask to compute the cumulative average as described in 34 | :cite:`DBLP:journals/corr/abs-1805-00631` -- Figure 3 35 | 36 | Args: 37 | batch_size (int): batch size 38 | inputs_len (int): length of the inputs 39 | 40 | Returns: 41 | (FloatTensor): 42 | 43 | * A Tensor of shape ``(batch_size, input_len, input_len)`` 44 | """ 45 | 46 | triangle = torch.tril(torch.ones(inputs_len, inputs_len)) 47 | weights = torch.ones(1, inputs_len) / torch.arange( 48 | 1, inputs_len + 1, dtype=torch.float) 49 | mask = triangle * weights.transpose(0, 1) 50 | 51 | return mask.unsqueeze(0).expand(batch_size, inputs_len, inputs_len) 52 | 53 | def cumulative_average(self, inputs, mask_or_step, 54 | layer_cache=None, step=None): 55 | """ 56 | Computes the cumulative average as described in 57 | :cite:`DBLP:journals/corr/abs-1805-00631` -- Equations (1) (5) (6) 58 | 59 | Args: 60 | inputs (FloatTensor): sequence to average 61 | ``(batch_size, input_len, dimension)`` 62 | mask_or_step: if cache is set, this is assumed 63 | to be the current step of the 64 | dynamic decoding. Otherwise, it is the mask matrix 65 | used to compute the cumulative average. 66 | layer_cache: a dictionary containing the cumulative average 67 | of the previous step. 68 | 69 | Returns: 70 | a tensor of the same shape and type as ``inputs``. 71 | """ 72 | 73 | if layer_cache is not None: 74 | step = mask_or_step 75 | device = inputs.device 76 | average_attention = (inputs + step * 77 | layer_cache["prev_g"].to(device)) / (step + 1) 78 | layer_cache["prev_g"] = average_attention 79 | return average_attention 80 | else: 81 | mask = mask_or_step 82 | return torch.matmul(mask, inputs) 83 | 84 | def forward(self, inputs, mask=None, layer_cache=None, step=None): 85 | """ 86 | Args: 87 | inputs (FloatTensor): ``(batch_size, input_len, model_dim)`` 88 | 89 | Returns: 90 | (FloatTensor, FloatTensor): 91 | 92 | * gating_outputs ``(batch_size, input_len, model_dim)`` 93 | * average_outputs average attention 94 | ``(batch_size, input_len, model_dim)`` 95 | """ 96 | 97 | batch_size = inputs.size(0) 98 | inputs_len = inputs.size(1) 99 | 100 | device = inputs.device 101 | average_outputs = self.cumulative_average( 102 | inputs, self.cumulative_average_mask(batch_size, 103 | inputs_len).to(device).float() 104 | if layer_cache is None else step, layer_cache=layer_cache) 105 | average_outputs = self.average_layer(average_outputs) 106 | gating_outputs = self.gating_layer(torch.cat((inputs, 107 | average_outputs), -1)) 108 | input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2) 109 | gating_outputs = torch.sigmoid(input_gate) * inputs + \ 110 | torch.sigmoid(forget_gate) * average_outputs 111 | 112 | return gating_outputs, average_outputs 113 | -------------------------------------------------------------------------------- /onmt/modules/conv_multi_step_attention.py: -------------------------------------------------------------------------------- 1 | """ Multi Step Attention for CNN """ 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | SCALE_WEIGHT = 0.5 ** 0.5 9 | 10 | 11 | def seq_linear(linear, x): 12 | """ linear transform for 3-d tensor """ 13 | batch, hidden_size, length, _ = x.size() 14 | h = linear(torch.transpose(x, 1, 2).contiguous().view( 15 | batch * length, hidden_size)) 16 | return torch.transpose(h.view(batch, length, hidden_size, 1), 1, 2) 17 | 18 | 19 | class ConvMultiStepAttention(nn.Module): 20 | """ 21 | Conv attention takes a key matrix, a value matrix and a query vector. 22 | Attention weight is calculated by key matrix with the query vector 23 | and sum on the value matrix. And the same operation is applied 24 | in each decode conv layer. 25 | """ 26 | 27 | def __init__(self, input_size): 28 | super(ConvMultiStepAttention, self).__init__() 29 | self.linear_in = nn.Linear(input_size, input_size) 30 | self.mask = None 31 | 32 | def apply_mask(self, mask): 33 | """ Apply mask """ 34 | self.mask = mask 35 | 36 | def forward(self, base_target_emb, input_from_dec, encoder_out_top, 37 | encoder_out_combine): 38 | """ 39 | Args: 40 | base_target_emb: target emb tensor 41 | input_from_dec: output of decode conv 42 | encoder_out_top: the key matrix for calculation of attetion weight, 43 | which is the top output of encode conv 44 | encoder_out_combine: 45 | the value matrix for the attention-weighted sum, 46 | which is the combination of base emb and top output of encode 47 | """ 48 | 49 | # checks 50 | # batch, channel, height, width = base_target_emb.size() 51 | batch, _, height, _ = base_target_emb.size() 52 | # batch_, channel_, height_, width_ = input_from_dec.size() 53 | batch_, _, height_, _ = input_from_dec.size() 54 | aeq(batch, batch_) 55 | aeq(height, height_) 56 | 57 | # enc_batch, enc_channel, enc_height = encoder_out_top.size() 58 | enc_batch, _, enc_height = encoder_out_top.size() 59 | # enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() 60 | enc_batch_, _, enc_height_ = encoder_out_combine.size() 61 | 62 | aeq(enc_batch, enc_batch_) 63 | aeq(enc_height, enc_height_) 64 | 65 | preatt = seq_linear(self.linear_in, input_from_dec) 66 | target = (base_target_emb + preatt) * SCALE_WEIGHT 67 | target = torch.squeeze(target, 3) 68 | target = torch.transpose(target, 1, 2) 69 | pre_attn = torch.bmm(target, encoder_out_top) 70 | 71 | if self.mask is not None: 72 | pre_attn.data.masked_fill_(self.mask, -float('inf')) 73 | 74 | attn = F.softmax(pre_attn, dim=2) 75 | 76 | context_output = torch.bmm( 77 | attn, torch.transpose(encoder_out_combine, 1, 2)) 78 | context_output = torch.transpose( 79 | torch.unsqueeze(context_output, 3), 1, 2) 80 | return context_output, attn 81 | -------------------------------------------------------------------------------- /onmt/modules/gate.py: -------------------------------------------------------------------------------- 1 | """ ContextGate module """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def context_gate_factory(gate_type, embeddings_size, decoder_size, 7 | attention_size, output_size): 8 | """Returns the correct ContextGate class""" 9 | 10 | gate_types = {'source': SourceContextGate, 11 | 'target': TargetContextGate, 12 | 'both': BothContextGate} 13 | 14 | assert gate_type in gate_types, "Not valid ContextGate type: {0}".format( 15 | gate_type) 16 | return gate_types[gate_type](embeddings_size, decoder_size, attention_size, 17 | output_size) 18 | 19 | 20 | class ContextGate(nn.Module): 21 | """ 22 | Context gate is a decoder module that takes as input the previous word 23 | embedding, the current decoder state and the attention state, and 24 | produces a gate. 25 | The gate can be used to select the input from the target side context 26 | (decoder state), from the source context (attention state) or both. 27 | """ 28 | 29 | def __init__(self, embeddings_size, decoder_size, 30 | attention_size, output_size): 31 | super(ContextGate, self).__init__() 32 | input_size = embeddings_size + decoder_size + attention_size 33 | self.gate = nn.Linear(input_size, output_size, bias=True) 34 | self.sig = nn.Sigmoid() 35 | self.source_proj = nn.Linear(attention_size, output_size) 36 | self.target_proj = nn.Linear(embeddings_size + decoder_size, 37 | output_size) 38 | 39 | def forward(self, prev_emb, dec_state, attn_state): 40 | input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1) 41 | z = self.sig(self.gate(input_tensor)) 42 | proj_source = self.source_proj(attn_state) 43 | proj_target = self.target_proj( 44 | torch.cat((prev_emb, dec_state), dim=1)) 45 | return z, proj_source, proj_target 46 | 47 | 48 | class SourceContextGate(nn.Module): 49 | """Apply the context gate only to the source context""" 50 | 51 | def __init__(self, embeddings_size, decoder_size, 52 | attention_size, output_size): 53 | super(SourceContextGate, self).__init__() 54 | self.context_gate = ContextGate(embeddings_size, decoder_size, 55 | attention_size, output_size) 56 | self.tanh = nn.Tanh() 57 | 58 | def forward(self, prev_emb, dec_state, attn_state): 59 | z, source, target = self.context_gate( 60 | prev_emb, dec_state, attn_state) 61 | return self.tanh(target + z * source) 62 | 63 | 64 | class TargetContextGate(nn.Module): 65 | """Apply the context gate only to the target context""" 66 | 67 | def __init__(self, embeddings_size, decoder_size, 68 | attention_size, output_size): 69 | super(TargetContextGate, self).__init__() 70 | self.context_gate = ContextGate(embeddings_size, decoder_size, 71 | attention_size, output_size) 72 | self.tanh = nn.Tanh() 73 | 74 | def forward(self, prev_emb, dec_state, attn_state): 75 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 76 | return self.tanh(z * target + source) 77 | 78 | 79 | class BothContextGate(nn.Module): 80 | """Apply the context gate to both contexts""" 81 | 82 | def __init__(self, embeddings_size, decoder_size, 83 | attention_size, output_size): 84 | super(BothContextGate, self).__init__() 85 | self.context_gate = ContextGate(embeddings_size, decoder_size, 86 | attention_size, output_size) 87 | self.tanh = nn.Tanh() 88 | 89 | def forward(self, prev_emb, dec_state, attn_state): 90 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 91 | return self.tanh((1. - z) * target + z * source) 92 | -------------------------------------------------------------------------------- /onmt/modules/gpt_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | def gelu(x): 6 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 7 | 8 | class MLP(nn.Module): 9 | def __init__(self, n_embd, n_state, dropout): # in MLP: n_state=3072 (4 * n_embd) 10 | super(MLP, self).__init__() 11 | self.c_fc = nn.Linear(n_embd, n_state) 12 | self.c_proj = nn.Linear(n_state, n_embd) 13 | self.act = gelu 14 | self.dropout_1 = nn.Dropout(dropout) 15 | self.dropout_2 = nn.Dropout(dropout) 16 | 17 | self.reset_parameters() 18 | 19 | def reset_parameters(self): 20 | self.c_fc.weight.data.normal_(std=0.02) 21 | self.c_fc.bias.data.zero_() 22 | self.c_proj.weight.data.normal_(std=0.02) 23 | self.c_proj.bias.data.zero_() 24 | 25 | def forward(self, x): 26 | """ 27 | x is input, [T, B, n_state] 28 | """ 29 | h = self.dropout_1(self.act(self.c_fc(x))) 30 | h2 = self.dropout_2(self.c_proj(h)) 31 | return h2 32 | -------------------------------------------------------------------------------- /onmt/modules/position_ffn.py: -------------------------------------------------------------------------------- 1 | """Position feed-forward network from "Attention is All You Need".""" 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class PositionwiseFeedForward(nn.Module): 7 | """ A two-layer Feed-Forward-Network with residual layer norm. 8 | 9 | Args: 10 | d_model (int): the size of input for the first-layer of the FFN. 11 | d_ff (int): the hidden layer size of the second-layer 12 | of the FNN. 13 | dropout (float): dropout probability in :math:`[0, 1)`. 14 | """ 15 | 16 | def __init__(self, d_model, d_ff, dropout=0.1): 17 | super(PositionwiseFeedForward, self).__init__() 18 | self.w_1 = nn.Linear(d_model, d_ff) 19 | self.w_2 = nn.Linear(d_ff, d_model) 20 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 21 | self.dropout_1 = nn.Dropout(dropout) 22 | self.relu = nn.ReLU() 23 | self.dropout_2 = nn.Dropout(dropout) 24 | 25 | def forward(self, x): 26 | """Layer definition. 27 | 28 | Args: 29 | x: ``(batch_size, input_len, model_dim)`` 30 | 31 | Returns: 32 | (FloatTensor): Output ``(batch_size, input_len, model_dim)``. 33 | """ 34 | 35 | inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) 36 | output = self.dropout_2(self.w_2(inter)) 37 | return output + x 38 | -------------------------------------------------------------------------------- /onmt/modules/simple_fusion_generator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class SimpleFusionGenerator(nn.Module): 4 | def __init__(self, decoder_input_size, lm_input_size, output_size): 5 | super(SimpleFusionGenerator, self).__init__() 6 | self.decoder_linear = nn.Linear(decoder_input_size, output_size) 7 | self.lm_linear = nn.Linear(lm_input_size, output_size, bias=False) 8 | self.gen_func = nn.LogSoftmax(dim=-1) 9 | 10 | def forward(self, decoder_hidden, lm_hidden): 11 | """ 12 | Compute a distribution over the target dictionary 13 | extended by the dynamic dictionary implied by copying 14 | source words. 15 | 16 | Args: 17 | decoder_hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` 18 | lm_hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` 19 | """ 20 | 21 | # Original probabilities. 22 | decoder_logits = self.decoder_linear(decoder_hidden) 23 | lm_logits = self.lm_linear(lm_hidden) 24 | logits = (decoder_logits + lm_logits).float() 25 | log_probs = self.gen_func(logits) 26 | 27 | return log_probs 28 | -------------------------------------------------------------------------------- /onmt/modules/sparse_activations.py: -------------------------------------------------------------------------------- 1 | """ 2 | An implementation of sparsemax (Martins & Astudillo, 2016). See 3 | :cite:`DBLP:journals/corr/MartinsA16` for detailed description. 4 | 5 | By Ben Peters and Vlad Niculae 6 | """ 7 | 8 | import torch 9 | from torch.autograd import Function 10 | import torch.nn as nn 11 | 12 | 13 | def _make_ix_like(input, dim=0): 14 | d = input.size(dim) 15 | rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) 16 | view = [1] * input.dim() 17 | view[0] = -1 18 | return rho.view(view).transpose(0, dim) 19 | 20 | 21 | def _threshold_and_support(input, dim=0): 22 | """Sparsemax building block: compute the threshold 23 | 24 | Args: 25 | input: any dimension 26 | dim: dimension along which to apply the sparsemax 27 | 28 | Returns: 29 | the threshold value 30 | """ 31 | 32 | input_srt, _ = torch.sort(input, descending=True, dim=dim) 33 | input_cumsum = input_srt.cumsum(dim) - 1 34 | rhos = _make_ix_like(input, dim) 35 | support = rhos * input_srt > input_cumsum 36 | 37 | support_size = support.sum(dim=dim).unsqueeze(dim) 38 | tau = input_cumsum.gather(dim, support_size - 1) 39 | tau /= support_size.to(input.dtype) 40 | return tau, support_size 41 | 42 | 43 | class SparsemaxFunction(Function): 44 | 45 | @staticmethod 46 | def forward(ctx, input, dim=0): 47 | """sparsemax: normalizing sparse transform (a la softmax) 48 | 49 | Parameters: 50 | input (Tensor): any shape 51 | dim: dimension along which to apply sparsemax 52 | 53 | Returns: 54 | output (Tensor): same shape as input 55 | """ 56 | ctx.dim = dim 57 | max_val, _ = input.max(dim=dim, keepdim=True) 58 | input -= max_val # same numerical stability trick as for softmax 59 | tau, supp_size = _threshold_and_support(input, dim=dim) 60 | output = torch.clamp(input - tau, min=0) 61 | ctx.save_for_backward(supp_size, output) 62 | return output 63 | 64 | @staticmethod 65 | def backward(ctx, grad_output): 66 | supp_size, output = ctx.saved_tensors 67 | dim = ctx.dim 68 | grad_input = grad_output.clone() 69 | grad_input[output == 0] = 0 70 | 71 | v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() 72 | v_hat = v_hat.unsqueeze(dim) 73 | grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) 74 | return grad_input, None 75 | 76 | 77 | sparsemax = SparsemaxFunction.apply 78 | 79 | 80 | class Sparsemax(nn.Module): 81 | 82 | def __init__(self, dim=0): 83 | self.dim = dim 84 | super(Sparsemax, self).__init__() 85 | 86 | def forward(self, input): 87 | return sparsemax(input, self.dim) 88 | 89 | 90 | class LogSparsemax(nn.Module): 91 | 92 | def __init__(self, dim=0): 93 | self.dim = dim 94 | super(LogSparsemax, self).__init__() 95 | 96 | def forward(self, input): 97 | return torch.log(sparsemax(input, self.dim)) 98 | -------------------------------------------------------------------------------- /onmt/modules/sparse_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from onmt.modules.sparse_activations import _threshold_and_support 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | class SparsemaxLossFunction(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, input, target): 12 | """ 13 | input (FloatTensor): ``(n, num_classes)``. 14 | target (LongTensor): ``(n,)``, the indices of the target classes 15 | """ 16 | input_batch, classes = input.size() 17 | target_batch = target.size(0) 18 | aeq(input_batch, target_batch) 19 | 20 | z_k = input.gather(1, target.unsqueeze(1)).squeeze() 21 | tau_z, support_size = _threshold_and_support(input, dim=1) 22 | support = input > tau_z 23 | x = torch.where( 24 | support, input**2 - tau_z**2, 25 | torch.tensor(0.0, device=input.device) 26 | ).sum(dim=1) 27 | ctx.save_for_backward(input, target, tau_z) 28 | # clamping necessary because of numerical errors: loss should be lower 29 | # bounded by zero, but negative values near zero are possible without 30 | # the clamp 31 | return torch.clamp(x / 2 - z_k + 0.5, min=0.0) 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | input, target, tau_z = ctx.saved_tensors 36 | sparsemax_out = torch.clamp(input - tau_z, min=0) 37 | delta = torch.zeros_like(sparsemax_out) 38 | delta.scatter_(1, target.unsqueeze(1), 1) 39 | return sparsemax_out - delta, None 40 | 41 | 42 | sparsemax_loss = SparsemaxLossFunction.apply 43 | 44 | 45 | class SparsemaxLoss(nn.Module): 46 | """ 47 | An implementation of sparsemax loss, first proposed in 48 | :cite:`DBLP:journals/corr/MartinsA16`. If using 49 | a sparse output layer, it is not possible to use negative log likelihood 50 | because the loss is infinite in the case the target is assigned zero 51 | probability. Inputs to SparsemaxLoss are arbitrary dense real-valued 52 | vectors (like in nn.CrossEntropyLoss), not probability vectors (like in 53 | nn.NLLLoss). 54 | """ 55 | 56 | def __init__(self, weight=None, ignore_index=-100, 57 | reduction='elementwise_mean'): 58 | assert reduction in ['elementwise_mean', 'sum', 'none'] 59 | self.reduction = reduction 60 | self.weight = weight 61 | self.ignore_index = ignore_index 62 | super(SparsemaxLoss, self).__init__() 63 | 64 | def forward(self, input, target): 65 | loss = sparsemax_loss(input, target) 66 | if self.ignore_index >= 0: 67 | ignored_positions = target == self.ignore_index 68 | size = float((target.size(0) - ignored_positions.sum()).item()) 69 | loss.masked_fill_(ignored_positions, 0.0) 70 | else: 71 | size = float(target.size(0)) 72 | if self.reduction == 'sum': 73 | loss = loss.sum() 74 | elif self.reduction == 'elementwise_mean': 75 | loss = loss.sum() / size 76 | return loss 77 | -------------------------------------------------------------------------------- /onmt/modules/structured_attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.cuda 4 | 5 | 6 | class MatrixTree(nn.Module): 7 | """Implementation of the matrix-tree theorem for computing marginals 8 | of non-projective dependency parsing. This attention layer is used 9 | in the paper "Learning Structured Text Representations" 10 | :cite:`DBLP:journals/corr/LiuL17d`. 11 | """ 12 | 13 | def __init__(self, eps=1e-5): 14 | self.eps = eps 15 | super(MatrixTree, self).__init__() 16 | 17 | def forward(self, input): 18 | laplacian = input.exp() + self.eps 19 | output = input.clone() 20 | for b in range(input.size(0)): 21 | lap = laplacian[b].masked_fill( 22 | torch.eye(input.size(1), device=input.device).ne(0), 0) 23 | lap = -lap + torch.diag(lap.sum(0)) 24 | # store roots on diagonal 25 | lap[0] = input[b].diag().exp() 26 | inv_laplacian = lap.inverse() 27 | 28 | factor = inv_laplacian.diag().unsqueeze(1)\ 29 | .expand_as(input[b]).transpose(0, 1) 30 | term1 = input[b].exp().mul(factor).clone() 31 | term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone() 32 | term1[:, 0] = 0 33 | term2[0] = 0 34 | output[b] = term1 - term2 35 | roots_output = input[b].diag().exp().mul( 36 | inv_laplacian.transpose(0, 1)[0]) 37 | output[b] = output[b] + torch.diag(roots_output) 38 | return output 39 | -------------------------------------------------------------------------------- /onmt/modules/util_class.py: -------------------------------------------------------------------------------- 1 | """ Misc classes """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | # At the moment this class is only used by embeddings.Embeddings look-up tables 7 | class Elementwise(nn.ModuleList): 8 | """ 9 | A simple network container. 10 | Parameters are a list of modules. 11 | Inputs are a 3d Tensor whose last dimension is the same length 12 | as the list. 13 | Outputs are the result of applying modules to inputs elementwise. 14 | An optional merge parameter allows the outputs to be reduced to a 15 | single Tensor. 16 | """ 17 | 18 | def __init__(self, merge=None, *args): 19 | assert merge in [None, 'first', 'concat', 'sum', 'mlp'] 20 | self.merge = merge 21 | super(Elementwise, self).__init__(*args) 22 | 23 | def forward(self, inputs): 24 | inputs_ = [feat.squeeze(2) for feat in inputs.split(1, dim=2)] 25 | assert len(self) == len(inputs_) 26 | outputs = [f(x) for f, x in zip(self, inputs_)] 27 | if self.merge == 'first': 28 | return outputs[0] 29 | elif self.merge == 'concat' or self.merge == 'mlp': 30 | return torch.cat(outputs, 2) 31 | elif self.merge == 'sum': 32 | return sum(outputs) 33 | else: 34 | return outputs 35 | 36 | 37 | class Cast(nn.Module): 38 | """ 39 | Basic layer that casts its input to a specific data type. The same tensor 40 | is returned if the data type is already correct. 41 | """ 42 | 43 | def __init__(self, dtype): 44 | super(Cast, self).__init__() 45 | self._dtype = dtype 46 | 47 | def forward(self, x): 48 | return x.to(self._dtype) 49 | -------------------------------------------------------------------------------- /onmt/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harvardnlp/encoder-agnostic-adaptation/5eff09874f25ac256f07daa0d3b9e7c03705086f/onmt/tests/__init__.py -------------------------------------------------------------------------------- /onmt/tests/rebuild_test_models.sh: -------------------------------------------------------------------------------- 1 | # # Retrain the models used for CI. 2 | # # Should be done rarely, indicates a major breaking change. 3 | my_python=python 4 | 5 | ############### TEST regular RNN choose either -rnn_type LSTM / GRU / SRU and set input_feed 0 for SRU 6 | if true; then 7 | rm data/*.pt 8 | $my_python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000 9 | 10 | $my_python train.py -data data/data -save_model tmp -world_size 1 -gpu_ranks 0 -rnn_size 256 -word_vec_size 256 -layers 1 -train_steps 10000 -optim adam -learning_rate 0.001 -rnn_type LSTM -input_feed 0 11 | #-truncated_decoder 5 12 | #-label_smoothing 0.1 13 | 14 | mv tmp*e10.pt onmt/tests/test_model.pt 15 | rm tmp*.pt 16 | fi 17 | # 18 | # 19 | ############### TEST CNN 20 | if false; then 21 | rm data/*.pt 22 | $my_python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000 23 | 24 | $my_python train.py -data data/data -save_model /tmp/tmp -world_size 1 -gpu_ranks 0 -rnn_size 256 -word_vec_size 256 -layers 2 -train_steps 10000 -optim adam -learning_rate 0.001 -encoder_type cnn -decoder_type cnn 25 | 26 | 27 | mv /tmp/tmp*e10.pt onmt/tests/test_model.pt 28 | 29 | rm /tmp/tmp*.pt 30 | fi 31 | # 32 | ################# MORPH DATA 33 | if true; then 34 | rm data/morph/*.pt 35 | $my_python preprocess.py -train_src data/morph/src.train -train_tgt data/morph/tgt.train -valid_src data/morph/src.valid -valid_tgt data/morph/tgt.valid -save_data data/morph/data 36 | 37 | $my_python train.py -data data/morph/data -save_model tmp -world_size 1 -gpu_ranks 0 -rnn_size 400 -word_vec_size 100 -layers 1 -train_steps 8000 -optim adam -learning_rate 0.001 38 | 39 | 40 | mv tmp*e8.pt onmt/tests/test_model2.pt 41 | 42 | rm tmp*.pt 43 | fi 44 | ############### TEST TRANSFORMER 45 | if false; then 46 | rm data/*.pt 47 | $my_python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000 -share_vocab 48 | 49 | 50 | $my_python train.py -data data/data -save_model /tmp/tmp -batch_type tokens -batch_size 1024 -accum_count 4 \ 51 | -layers 4 -rnn_size 256 -word_vec_size 256 -encoder_type transformer -decoder_type transformer -share_embedding \ 52 | -train_steps 10000 -world_size 1 -gpu_ranks 0 -max_generator_batches 4 -dropout 0.1 -normalization tokens \ 53 | -max_grad_norm 0 -optim adam -decay_method noam -learning_rate 2 -label_smoothing 0.1 \ 54 | -position_encoding -param_init 0 -warmup_steps 100 -param_init_glorot -adam_beta2 0.998 55 | # 56 | mv /tmp/tmp*e10.pt onmt/tests/test_model.pt 57 | rm /tmp/tmp*.pt 58 | fi 59 | # 60 | if false; then 61 | $my_python translate.py -gpu 0 -model onmt/tests/test_model.pt \ 62 | -src data/src-val.txt -output onmt/tests/output_hyp.txt -beam 5 -batch_size 16 63 | 64 | fi 65 | 66 | 67 | -------------------------------------------------------------------------------- /onmt/tests/test_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Here come the tests for attention types and their compatibility 3 | """ 4 | import unittest 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | import onmt 9 | 10 | 11 | class TestAttention(unittest.TestCase): 12 | 13 | def test_masked_global_attention(self): 14 | 15 | source_lengths = torch.IntTensor([7, 3, 5, 2]) 16 | # illegal_weights_mask = torch.ByteTensor([ 17 | # [0, 0, 0, 0, 0, 0, 0], 18 | # [0, 0, 0, 1, 1, 1, 1], 19 | # [0, 0, 0, 0, 0, 1, 1], 20 | # [0, 0, 1, 1, 1, 1, 1]]) 21 | 22 | batch_size = source_lengths.size(0) 23 | dim = 20 24 | 25 | memory_bank = Variable(torch.randn(batch_size, 26 | source_lengths.max(), dim)) 27 | hidden = Variable(torch.randn(batch_size, dim)) 28 | 29 | attn = onmt.modules.GlobalAttention(dim) 30 | 31 | _, alignments = attn(hidden, memory_bank, 32 | memory_lengths=source_lengths) 33 | # TODO: fix for pytorch 0.3 34 | # illegal_weights = alignments.masked_select(illegal_weights_mask) 35 | 36 | # self.assertEqual(0.0, illegal_weights.data.sum()) 37 | -------------------------------------------------------------------------------- /onmt/tests/test_copy_generator.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss 3 | 4 | import itertools 5 | from copy import deepcopy 6 | 7 | import torch 8 | from torch.nn.functional import softmax 9 | 10 | from onmt.tests.utils_for_tests import product_dict 11 | 12 | 13 | class TestCopyGenerator(unittest.TestCase): 14 | INIT_CASES = list(product_dict( 15 | input_size=[172], 16 | output_size=[319], 17 | pad_idx=[0, 39], 18 | )) 19 | PARAMS = list(product_dict( 20 | batch_size=[1, 14], 21 | max_seq_len=[23], 22 | tgt_max_len=[50], 23 | n_extra_words=[107] 24 | )) 25 | 26 | @classmethod 27 | def dummy_inputs(cls, params, init_case): 28 | hidden = torch.randn((params["batch_size"] * params["tgt_max_len"], 29 | init_case["input_size"])) 30 | attn = torch.randn((params["batch_size"] * params["tgt_max_len"], 31 | params["max_seq_len"])) 32 | src_map = torch.randn((params["max_seq_len"], params["batch_size"], 33 | params["n_extra_words"])) 34 | return hidden, attn, src_map 35 | 36 | @classmethod 37 | def expected_shape(cls, params, init_case): 38 | return params["tgt_max_len"] * params["batch_size"], \ 39 | init_case["output_size"] + params["n_extra_words"] 40 | 41 | def test_copy_gen_forward_shape(self): 42 | for params, init_case in itertools.product( 43 | self.PARAMS, self.INIT_CASES): 44 | cgen = CopyGenerator(**init_case) 45 | dummy_in = self.dummy_inputs(params, init_case) 46 | res = cgen(*dummy_in) 47 | expected_shape = self.expected_shape(params, init_case) 48 | self.assertEqual(res.shape, expected_shape, init_case.__str__()) 49 | 50 | def test_copy_gen_outp_has_no_prob_of_pad(self): 51 | for params, init_case in itertools.product( 52 | self.PARAMS, self.INIT_CASES): 53 | cgen = CopyGenerator(**init_case) 54 | dummy_in = self.dummy_inputs(params, init_case) 55 | res = cgen(*dummy_in) 56 | self.assertTrue( 57 | res[:, init_case["pad_idx"]].allclose(torch.tensor(0.0))) 58 | 59 | def test_copy_gen_trainable_params_update(self): 60 | for params, init_case in itertools.product( 61 | self.PARAMS, self.INIT_CASES): 62 | cgen = CopyGenerator(**init_case) 63 | trainable_params = {n: p for n, p in cgen.named_parameters() 64 | if p.requires_grad} 65 | assert len(trainable_params) > 0 # sanity check 66 | old_weights = deepcopy(trainable_params) 67 | dummy_in = self.dummy_inputs(params, init_case) 68 | res = cgen(*dummy_in) 69 | pretend_loss = res.sum() 70 | pretend_loss.backward() 71 | dummy_optim = torch.optim.SGD(trainable_params.values(), 1) 72 | dummy_optim.step() 73 | for param_name in old_weights.keys(): 74 | self.assertTrue( 75 | trainable_params[param_name] 76 | .ne(old_weights[param_name]).any(), 77 | param_name + " " + init_case.__str__()) 78 | 79 | 80 | class TestCopyGeneratorLoss(unittest.TestCase): 81 | INIT_CASES = list(product_dict( 82 | vocab_size=[172], 83 | unk_index=[0, 39], 84 | ignore_index=[1, 17], # pad idx 85 | force_copy=[True, False] 86 | )) 87 | PARAMS = list(product_dict( 88 | batch_size=[1, 14], 89 | tgt_max_len=[50], 90 | n_extra_words=[107] 91 | )) 92 | 93 | @classmethod 94 | def dummy_inputs(cls, params, init_case): 95 | n_unique_src_words = 13 96 | scores = torch.randn((params["batch_size"] * params["tgt_max_len"], 97 | init_case["vocab_size"] + n_unique_src_words)) 98 | scores = softmax(scores, dim=1) 99 | align = torch.randint(0, n_unique_src_words, 100 | (params["batch_size"] * params["tgt_max_len"],)) 101 | target = torch.randint(0, init_case["vocab_size"], 102 | (params["batch_size"] * params["tgt_max_len"],)) 103 | target[0] = init_case["unk_index"] 104 | target[1] = init_case["ignore_index"] 105 | return scores, align, target 106 | 107 | @classmethod 108 | def expected_shape(cls, params, init_case): 109 | return (params["batch_size"] * params["tgt_max_len"],) 110 | 111 | def test_copy_loss_forward_shape(self): 112 | for params, init_case in itertools.product( 113 | self.PARAMS, self.INIT_CASES): 114 | loss = CopyGeneratorLoss(**init_case) 115 | dummy_in = self.dummy_inputs(params, init_case) 116 | res = loss(*dummy_in) 117 | expected_shape = self.expected_shape(params, init_case) 118 | self.assertEqual(res.shape, expected_shape, init_case.__str__()) 119 | 120 | def test_copy_loss_ignore_index_is_ignored(self): 121 | for params, init_case in itertools.product( 122 | self.PARAMS, self.INIT_CASES): 123 | loss = CopyGeneratorLoss(**init_case) 124 | scores, align, target = self.dummy_inputs(params, init_case) 125 | res = loss(scores, align, target) 126 | should_be_ignored = (target == init_case["ignore_index"]).nonzero() 127 | assert len(should_be_ignored) > 0 # otherwise not testing anything 128 | self.assertTrue(res[should_be_ignored].allclose(torch.tensor(0.0))) 129 | 130 | def test_copy_loss_output_range_is_positive(self): 131 | for params, init_case in itertools.product( 132 | self.PARAMS, self.INIT_CASES): 133 | loss = CopyGeneratorLoss(**init_case) 134 | dummy_in = self.dummy_inputs(params, init_case) 135 | res = loss(*dummy_in) 136 | self.assertTrue((res >= 0).all()) 137 | -------------------------------------------------------------------------------- /onmt/tests/test_embeddings.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from onmt.modules.embeddings import Embeddings 3 | 4 | import itertools 5 | from copy import deepcopy 6 | 7 | import torch 8 | 9 | from onmt.tests.utils_for_tests import product_dict 10 | 11 | 12 | class TestEmbeddings(unittest.TestCase): 13 | INIT_CASES = list(product_dict( 14 | word_vec_size=[172], 15 | word_vocab_size=[319], 16 | word_padding_idx=[17], 17 | position_encoding=[False, True], 18 | feat_merge=["first", "concat", "sum", "mlp"], 19 | feat_vec_exponent=[-1, 1.1, 0.7], 20 | feat_vec_size=[0, 200], 21 | feat_padding_idx=[[], [29], [0, 1]], 22 | feat_vocab_sizes=[[], [39], [401, 39]], 23 | dropout=[0, 0.5], 24 | fix_word_vecs=[False, True] 25 | )) 26 | PARAMS = list(product_dict( 27 | batch_size=[1, 14], 28 | max_seq_len=[23] 29 | )) 30 | 31 | @classmethod 32 | def case_is_degenerate(cls, case): 33 | no_feats = len(case["feat_vocab_sizes"]) == 0 34 | if case["feat_merge"] != "first" and no_feats: 35 | return True 36 | if case["feat_merge"] == "first" and not no_feats: 37 | return True 38 | if case["feat_merge"] == "concat" and case["feat_vec_exponent"] != -1: 39 | return True 40 | if no_feats and case["feat_vec_exponent"] != -1: 41 | return True 42 | if len(case["feat_vocab_sizes"]) != len(case["feat_padding_idx"]): 43 | return True 44 | if case["feat_vec_size"] == 0 and case["feat_vec_exponent"] <= 0: 45 | return True 46 | if case["feat_merge"] == "sum": 47 | if case["feat_vec_exponent"] != -1: 48 | return True 49 | if case["feat_vec_size"] != 0: 50 | return True 51 | if case["feat_vec_size"] != 0 and case["feat_vec_exponent"] != -1: 52 | return True 53 | return False 54 | 55 | @classmethod 56 | def cases(cls): 57 | for case in cls.INIT_CASES: 58 | if not cls.case_is_degenerate(case): 59 | yield case 60 | 61 | @classmethod 62 | def dummy_inputs(cls, params, init_case): 63 | max_seq_len = params["max_seq_len"] 64 | batch_size = params["batch_size"] 65 | fv_sizes = init_case["feat_vocab_sizes"] 66 | n_words = init_case["word_vocab_size"] 67 | voc_sizes = [n_words] + fv_sizes 68 | pad_idxs = [init_case["word_padding_idx"]] + \ 69 | init_case["feat_padding_idx"] 70 | lengths = torch.randint(0, max_seq_len, (batch_size,)) 71 | lengths[0] = max_seq_len 72 | inps = torch.empty((max_seq_len, batch_size, len(voc_sizes)), 73 | dtype=torch.long) 74 | for f, (voc_size, pad_idx) in enumerate(zip(voc_sizes, pad_idxs)): 75 | for b, len_ in enumerate(lengths): 76 | inps[:len_, b, f] = torch.randint(0, voc_size-1, (len_,)) 77 | inps[len_:, b, f] = pad_idx 78 | return inps 79 | 80 | @classmethod 81 | def expected_shape(cls, params, init_case): 82 | wvs = init_case["word_vec_size"] 83 | fvs = init_case["feat_vec_size"] 84 | nf = len(init_case["feat_vocab_sizes"]) 85 | size = wvs 86 | if init_case["feat_merge"] not in {"sum", "mlp"}: 87 | size += nf * fvs 88 | return params["max_seq_len"], params["batch_size"], size 89 | 90 | def test_embeddings_forward_shape(self): 91 | for params, init_case in itertools.product(self.PARAMS, self.cases()): 92 | emb = Embeddings(**init_case) 93 | dummy_in = self.dummy_inputs(params, init_case) 94 | res = emb(dummy_in) 95 | expected_shape = self.expected_shape(params, init_case) 96 | self.assertEqual(res.shape, expected_shape, init_case.__str__()) 97 | 98 | def test_embeddings_trainable_params(self): 99 | for params, init_case in itertools.product(self.PARAMS, 100 | self.cases()): 101 | emb = Embeddings(**init_case) 102 | trainable_params = {n: p for n, p in emb.named_parameters() 103 | if p.requires_grad} 104 | # first check there's nothing unexpectedly not trainable 105 | for key in emb.state_dict(): 106 | if key not in trainable_params: 107 | if key.endswith("emb_luts.0.weight") and \ 108 | init_case["fix_word_vecs"]: 109 | # ok: word embeddings shouldn't be trainable 110 | # if word vecs are fixed 111 | continue 112 | if key.endswith(".pe.pe"): 113 | # ok: positional encodings shouldn't be trainable 114 | assert init_case["position_encoding"] 115 | continue 116 | else: 117 | self.fail("Param {:s} is unexpectedly not " 118 | "trainable.".format(key)) 119 | # then check nothing unexpectedly trainable 120 | if init_case["fix_word_vecs"]: 121 | self.assertFalse( 122 | any(trainable_param.endswith("emb_luts.0.weight") 123 | for trainable_param in trainable_params), 124 | "Word embedding is trainable but word vecs are fixed.") 125 | if init_case["position_encoding"]: 126 | self.assertFalse( 127 | any(trainable_p.endswith(".pe.pe") 128 | for trainable_p in trainable_params), 129 | "Positional encoding is trainable.") 130 | 131 | def test_embeddings_trainable_params_update(self): 132 | for params, init_case in itertools.product(self.PARAMS, self.cases()): 133 | emb = Embeddings(**init_case) 134 | trainable_params = {n: p for n, p in emb.named_parameters() 135 | if p.requires_grad} 136 | if len(trainable_params) > 0: 137 | old_weights = deepcopy(trainable_params) 138 | dummy_in = self.dummy_inputs(params, init_case) 139 | res = emb(dummy_in) 140 | pretend_loss = res.sum() 141 | pretend_loss.backward() 142 | dummy_optim = torch.optim.SGD(trainable_params.values(), 1) 143 | dummy_optim.step() 144 | for param_name in old_weights.keys(): 145 | self.assertTrue( 146 | trainable_params[param_name] 147 | .ne(old_weights[param_name]).any(), 148 | param_name + " " + init_case.__str__()) 149 | -------------------------------------------------------------------------------- /onmt/tests/test_image_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from onmt.inputters.image_dataset import ImageDataReader 3 | 4 | import os 5 | import shutil 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class TestImageDataReader(unittest.TestCase): 13 | # this test touches the file system, so it could be considered an 14 | # integration test 15 | _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | _IMG_DATA_DIRNAME = "test_image_data" 17 | _IMG_DATA_DIR = os.path.join(_THIS_DIR, _IMG_DATA_DIRNAME) 18 | _IMG_DATA_FMT = "test_img_{:d}.png" 19 | _IMG_DATA_PATH_FMT = os.path.join(_IMG_DATA_DIR, _IMG_DATA_FMT) 20 | 21 | _IMG_LIST_DIR = "test_image_filenames" 22 | # file to hold full paths to image data 23 | _IMG_LIST_PATHS_FNAME = "test_files.txt" 24 | _IMG_LIST_PATHS_PATH = os.path.join( 25 | _IMG_LIST_DIR, _IMG_LIST_PATHS_FNAME) 26 | # file to hold image paths relative to _IMG_DATA_DIR (i.e. file names) 27 | _IMG_LIST_FNAMES_FNAME = "test_fnames.txt" 28 | _IMG_LIST_FNAMES_PATH = os.path.join( 29 | _IMG_LIST_DIR, _IMG_LIST_FNAMES_FNAME) 30 | 31 | # it's ok if non-image files co-exist with image files in the data dir 32 | _JUNK_FILE = os.path.join( 33 | _IMG_DATA_DIR, "this_is_junk.txt") 34 | 35 | _N_EXAMPLES = 20 36 | _N_CHANNELS = 3 37 | 38 | @classmethod 39 | def setUpClass(cls): 40 | if not os.path.exists(cls._IMG_DATA_DIR): 41 | os.makedirs(cls._IMG_DATA_DIR) 42 | if not os.path.exists(cls._IMG_LIST_DIR): 43 | os.makedirs(cls._IMG_LIST_DIR) 44 | 45 | with open(cls._JUNK_FILE, "w") as f: 46 | f.write("this is some garbage\nShould have no impact.") 47 | 48 | with open(cls._IMG_LIST_PATHS_PATH, "w") as f_list_fnames, \ 49 | open(cls._IMG_LIST_FNAMES_PATH, "w") as f_list_paths: 50 | cls.n_rows = torch.randint(30, 314, (cls._N_EXAMPLES,)) 51 | cls.n_cols = torch.randint(30, 314, (cls._N_EXAMPLES,)) 52 | for i in range(cls._N_EXAMPLES): 53 | img = np.random.randint( 54 | 0, 255, (cls.n_rows[i], cls.n_cols[i], cls._N_CHANNELS)) 55 | f_path = cls._IMG_DATA_PATH_FMT.format(i) 56 | cv2.imwrite(f_path, img) 57 | f_name_short = cls._IMG_DATA_FMT.format(i) 58 | f_list_fnames.write(f_name_short + "\n") 59 | f_list_paths.write(f_path + "\n") 60 | 61 | @classmethod 62 | def tearDownClass(cls): 63 | shutil.rmtree(cls._IMG_DATA_DIR) 64 | shutil.rmtree(cls._IMG_LIST_DIR) 65 | 66 | def test_read_from_dir_and_data_file_containing_filenames(self): 67 | rdr = ImageDataReader(channel_size=self._N_CHANNELS) 68 | i = 0 # initialize since there's a sanity check on i 69 | for i, img in enumerate(rdr.read( 70 | self._IMG_LIST_FNAMES_PATH, "src", self._IMG_DATA_DIR)): 71 | self.assertEqual( 72 | img["src"].shape, 73 | (self._N_CHANNELS, self.n_rows[i], self.n_cols[i])) 74 | self.assertEqual(img["src_path"], 75 | self._IMG_DATA_PATH_FMT.format(i)) 76 | self.assertGreater(i, 0, "No image data was read.") 77 | 78 | def test_read_from_dir_and_data_file_containing_paths(self): 79 | rdr = ImageDataReader(channel_size=self._N_CHANNELS) 80 | i = 0 # initialize since there's a sanity check on i 81 | for i, img in enumerate(rdr.read( 82 | self._IMG_LIST_PATHS_PATH, "src", self._IMG_DATA_DIR)): 83 | self.assertEqual( 84 | img["src"].shape, 85 | (self._N_CHANNELS, self.n_rows[i], self.n_cols[i])) 86 | self.assertEqual(img["src_path"], 87 | self._IMG_DATA_FMT.format(i)) 88 | self.assertGreater(i, 0, "No image data was read.") 89 | 90 | 91 | class TestImageDataReader1Channel(TestImageDataReader): 92 | _N_CHANNELS = 1 93 | -------------------------------------------------------------------------------- /onmt/tests/test_simple.py: -------------------------------------------------------------------------------- 1 | import onmt 2 | 3 | 4 | def test_load(): 5 | onmt 6 | pass 7 | -------------------------------------------------------------------------------- /onmt/tests/test_structured_attention.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from onmt.modules.structured_attention import MatrixTree 3 | 4 | import torch 5 | 6 | 7 | class TestStructuredAttention(unittest.TestCase): 8 | def test_matrix_tree_marg_pdfs_sum_to_1(self): 9 | dtree = MatrixTree() 10 | q = torch.rand(1, 5, 5) 11 | marg = dtree.forward(q) 12 | self.assertTrue( 13 | marg.sum(1).allclose(torch.tensor(1.0))) 14 | -------------------------------------------------------------------------------- /onmt/tests/utils_for_tests.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | 4 | def product_dict(**kwargs): 5 | keys = kwargs.keys() 6 | vals = kwargs.values() 7 | for instance in itertools.product(*vals): 8 | yield dict(zip(keys, instance)) 9 | -------------------------------------------------------------------------------- /onmt/translate/__init__.py: -------------------------------------------------------------------------------- 1 | """ Modules for translation """ 2 | from onmt.translate.translator import Translator 3 | from onmt.translate.translation import Translation, TranslationBuilder 4 | from onmt.translate.beam import Beam, GNMTGlobalScorer 5 | from onmt.translate.beam_search import BeamSearch 6 | from onmt.translate.decode_strategy import DecodeStrategy 7 | from onmt.translate.random_sampling import RandomSampling 8 | from onmt.translate.penalties import PenaltyBuilder 9 | from onmt.translate.translation_server import TranslationServer, \ 10 | ServerModelError 11 | 12 | __all__ = ['Translator', 'Translation', 'Beam', 'BeamSearch', 13 | 'GNMTGlobalScorer', 'TranslationBuilder', 14 | 'PenaltyBuilder', 'TranslationServer', 'ServerModelError', 15 | "DecodeStrategy", "RandomSampling"] 16 | -------------------------------------------------------------------------------- /onmt/translate/penalties.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | 4 | 5 | class PenaltyBuilder(object): 6 | """Returns the Length and Coverage Penalty function for Beam Search. 7 | 8 | Args: 9 | length_pen (str): option name of length pen 10 | cov_pen (str): option name of cov pen 11 | 12 | Attributes: 13 | has_cov_pen (bool): Whether coverage penalty is None (applying it 14 | is a no-op). Note that the converse isn't true. Setting beta 15 | to 0 should force coverage length to be a no-op. 16 | has_len_pen (bool): Whether length penalty is None (applying it 17 | is a no-op). Note that the converse isn't true. Setting alpha 18 | to 1 should force length penalty to be a no-op. 19 | coverage_penalty (callable[[FloatTensor, float], FloatTensor]): 20 | Calculates the coverage penalty. 21 | length_penalty (callable[[int, float], float]): Calculates 22 | the length penalty. 23 | """ 24 | 25 | def __init__(self, cov_pen, length_pen): 26 | self.has_cov_pen = not self._pen_is_none(cov_pen) 27 | self.coverage_penalty = self._coverage_penalty(cov_pen) 28 | self.has_len_pen = not self._pen_is_none(length_pen) 29 | self.length_penalty = self._length_penalty(length_pen) 30 | 31 | @staticmethod 32 | def _pen_is_none(pen): 33 | return pen == "none" or pen is None 34 | 35 | def _coverage_penalty(self, cov_pen): 36 | if cov_pen == "wu": 37 | return self.coverage_wu 38 | elif cov_pen == "summary": 39 | return self.coverage_summary 40 | elif self._pen_is_none(cov_pen): 41 | return self.coverage_none 42 | else: 43 | raise NotImplementedError("No '{:s}' coverage penalty.".format( 44 | cov_pen)) 45 | 46 | def _length_penalty(self, length_pen): 47 | if length_pen == "wu": 48 | return self.length_wu 49 | elif length_pen == "avg": 50 | return self.length_average 51 | elif self._pen_is_none(length_pen): 52 | return self.length_none 53 | else: 54 | raise NotImplementedError("No '{:s}' length penalty.".format( 55 | length_pen)) 56 | 57 | # Below are all the different penalty terms implemented so far. 58 | # Subtract coverage penalty from topk log probs. 59 | # Divide topk log probs by length penalty. 60 | 61 | def coverage_wu(self, cov, beta=0.): 62 | """GNMT coverage re-ranking score. 63 | 64 | See "Google's Neural Machine Translation System" :cite:`wu2016google`. 65 | ``cov`` is expected to be sized ``(*, seq_len)``, where ``*`` is 66 | probably ``batch_size x beam_size`` but could be several 67 | dimensions like ``(batch_size, beam_size)``. If ``cov`` is attention, 68 | then the ``seq_len`` axis probably sums to (almost) 1. 69 | """ 70 | 71 | penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(-1) 72 | return beta * penalty 73 | 74 | def coverage_summary(self, cov, beta=0.): 75 | """Our summary penalty.""" 76 | penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(-1) 77 | penalty -= cov.size(-1) 78 | return beta * penalty 79 | 80 | def coverage_none(self, cov, beta=0.): 81 | """Returns zero as penalty""" 82 | none = torch.zeros((1,), device=cov.device, 83 | dtype=torch.float) 84 | if cov.dim() == 3: 85 | none = none.unsqueeze(0) 86 | return none 87 | 88 | def length_wu(self, cur_len, alpha=0.): 89 | """GNMT length re-ranking score. 90 | 91 | See "Google's Neural Machine Translation System" :cite:`wu2016google`. 92 | """ 93 | 94 | return ((5 + cur_len) / 6.0) ** alpha 95 | 96 | def length_average(self, cur_len, alpha=0.): 97 | """Returns the current sequence length.""" 98 | return cur_len 99 | 100 | def length_none(self, cur_len, alpha=0.): 101 | """Returns unmodified scores.""" 102 | return 1.0 103 | -------------------------------------------------------------------------------- /onmt/translate/translation.py: -------------------------------------------------------------------------------- 1 | """ Translation main class """ 2 | from __future__ import unicode_literals, print_function 3 | 4 | import torch 5 | from onmt.inputters.text_dataset import TextMultiField 6 | 7 | 8 | class TranslationBuilder(object): 9 | """ 10 | Build a word-based translation from the batch output 11 | of translator and the underlying dictionaries. 12 | 13 | Replacement based on "Addressing the Rare Word 14 | Problem in Neural Machine Translation" :cite:`Luong2015b` 15 | 16 | Args: 17 | data (onmt.inputters.Dataset): Data. 18 | fields (List[Tuple[str, torchtext.data.Field]]): data fields 19 | n_best (int): number of translations produced 20 | replace_unk (bool): replace unknown words using attention 21 | has_tgt (bool): will the batch have gold targets 22 | """ 23 | 24 | def __init__(self, data, fields, n_best=1, replace_unk=False, 25 | has_tgt=False): 26 | self.data = data 27 | self.fields = fields 28 | self._has_text_src = 'src' in self.fields and isinstance( 29 | dict(self.fields)["src"], TextMultiField) 30 | self.n_best = n_best 31 | self.replace_unk = replace_unk 32 | self.has_tgt = has_tgt 33 | 34 | def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn): 35 | tgt_field = dict(self.fields)["tgt"].base_field 36 | vocab = tgt_field.vocab 37 | tokens = [] 38 | for tok in pred: 39 | if tok < len(vocab): 40 | tokens.append(vocab.itos[tok]) 41 | else: 42 | tokens.append(src_vocab.itos[tok - len(vocab)]) 43 | if tokens[-1] == tgt_field.eos_token: 44 | tokens = tokens[:-1] 45 | break 46 | if self.replace_unk and attn is not None and src is not None: 47 | for i in range(len(tokens)): 48 | if tokens[i] == tgt_field.unk_token: 49 | _, max_index = attn[i].max(0) 50 | tokens[i] = src_raw[max_index.item()] 51 | return tokens 52 | 53 | def from_batch(self, translation_batch): 54 | batch = translation_batch["batch"] 55 | assert(len(translation_batch["gold_score"]) == 56 | len(translation_batch["predictions"])) 57 | batch_size = batch.batch_size 58 | 59 | preds, pred_score, attn, gold_score, indices = list(zip( 60 | *sorted(zip(translation_batch["predictions"], 61 | translation_batch["scores"], 62 | translation_batch["attention"], 63 | translation_batch["gold_score"], 64 | batch.indices.data), 65 | key=lambda x: x[-1]))) 66 | 67 | # Sorting 68 | inds, perm = torch.sort(batch.indices) 69 | if self._has_text_src: 70 | src = batch.src[0][:, :, 0].index_select(1, perm) 71 | else: 72 | src = None 73 | 74 | if self.has_tgt: 75 | tgt, _ = batch.tgt if isinstance(batch.tgt, tuple) else (batch.tgt, None) 76 | tgt = tgt[:, :, 0].index_select(1, perm) 77 | else: 78 | tgt = None 79 | 80 | translations = [] 81 | for b in range(batch_size): 82 | if self._has_text_src: 83 | src_vocab = self.data.src_vocabs[inds[b]] \ 84 | if self.data.src_vocabs else None 85 | src_raw = self.data.examples[inds[b]].src[0] 86 | else: 87 | src_vocab = None 88 | src_raw = None 89 | pred_sents = [self._build_target_tokens( 90 | src[:, b] if src is not None else None, 91 | src_vocab, src_raw, 92 | preds[b][n], attn[b][n]) 93 | for n in range(self.n_best)] 94 | gold_sent = None 95 | if tgt is not None: 96 | gold_sent = self._build_target_tokens( 97 | src[:, b] if src is not None else None, 98 | src_vocab, src_raw, 99 | tgt[1:, b] if tgt is not None else None, None) 100 | 101 | translation = Translation( 102 | src[:, b] if src is not None else None, 103 | src_raw, pred_sents, attn[b], pred_score[b], 104 | gold_sent, gold_score[b] 105 | ) 106 | translations.append(translation) 107 | 108 | return translations 109 | 110 | 111 | class Translation(object): 112 | """Container for a translated sentence. 113 | 114 | Attributes: 115 | src (LongTensor): Source word IDs. 116 | src_raw (List[str]): Raw source words. 117 | pred_sents (List[List[str]]): Words from the n-best translations. 118 | pred_scores (List[List[float]]): Log-probs of n-best translations. 119 | attns (List[FloatTensor]) : Attention distribution for each 120 | translation. 121 | gold_sent (List[str]): Words from gold translation. 122 | gold_score (List[float]): Log-prob of gold translation. 123 | """ 124 | 125 | __slots__ = ["src", "src_raw", "pred_sents", "attns", "pred_scores", 126 | "gold_sent", "gold_score"] 127 | 128 | def __init__(self, src, src_raw, pred_sents, 129 | attn, pred_scores, tgt_sent, gold_score): 130 | self.src = src 131 | self.src_raw = src_raw 132 | self.pred_sents = pred_sents 133 | self.attns = attn 134 | self.pred_scores = pred_scores 135 | self.gold_sent = tgt_sent 136 | self.gold_score = gold_score 137 | 138 | def log(self, sent_number): 139 | """ 140 | Log translation. 141 | """ 142 | 143 | msg = ['\nSENT {}: {}\n'.format(sent_number, self.src_raw)] 144 | 145 | best_pred = self.pred_sents[0] 146 | best_score = self.pred_scores[0] 147 | pred_sent = ' '.join(best_pred) 148 | msg.append('PRED {}: {}\n'.format(sent_number, pred_sent)) 149 | msg.append("PRED SCORE: {:.4f}\n".format(best_score)) 150 | 151 | if self.gold_sent is not None: 152 | tgt_sent = ' '.join(self.gold_sent) 153 | msg.append('GOLD {}: {}\n'.format(sent_number, tgt_sent)) 154 | msg.append(("GOLD SCORE: {:.4f}\n".format(self.gold_score))) 155 | if len(self.pred_sents) > 1: 156 | msg.append('\nBEST HYP:\n') 157 | for score, sent in zip(self.pred_scores, self.pred_sents): 158 | msg.append("[{:.4f}] {}\n".format(score, sent)) 159 | 160 | return "".join(msg) 161 | -------------------------------------------------------------------------------- /onmt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining various utilities.""" 2 | from onmt.utils.misc import split_corpus, aeq, use_gpu, set_random_seed 3 | from onmt.utils.report_manager import ReportMgr, build_report_manager 4 | from onmt.utils.statistics import Statistics 5 | from onmt.utils.optimizers import MultipleOptimizer, \ 6 | Optimizer, AdaFactor 7 | 8 | __all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr", 9 | "build_report_manager", "Statistics", 10 | "MultipleOptimizer", "Optimizer", "AdaFactor"] 11 | -------------------------------------------------------------------------------- /onmt/utils/cnn_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Convolutional Sequence to Sequence Learning" 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | 8 | import onmt.modules 9 | 10 | SCALE_WEIGHT = 0.5 ** 0.5 11 | 12 | 13 | def shape_transform(x): 14 | """ Tranform the size of the tensors to fit for conv input. """ 15 | return torch.unsqueeze(torch.transpose(x, 1, 2), 3) 16 | 17 | 18 | class GatedConv(nn.Module): 19 | """ Gated convolution for CNN class """ 20 | 21 | def __init__(self, input_size, width=3, dropout=0.2, nopad=False): 22 | super(GatedConv, self).__init__() 23 | self.conv = onmt.modules.WeightNormConv2d( 24 | input_size, 2 * input_size, kernel_size=(width, 1), stride=(1, 1), 25 | padding=(width // 2 * (1 - nopad), 0)) 26 | init.xavier_uniform_(self.conv.weight, gain=(4 * (1 - dropout))**0.5) 27 | self.dropout = nn.Dropout(dropout) 28 | 29 | def forward(self, x_var): 30 | x_var = self.dropout(x_var) 31 | x_var = self.conv(x_var) 32 | out, gate = x_var.split(int(x_var.size(1) / 2), 1) 33 | out = out * torch.sigmoid(gate) 34 | return out 35 | 36 | 37 | class StackedCNN(nn.Module): 38 | """ Stacked CNN class """ 39 | 40 | def __init__(self, num_layers, input_size, cnn_kernel_width=3, 41 | dropout=0.2): 42 | super(StackedCNN, self).__init__() 43 | self.dropout = dropout 44 | self.num_layers = num_layers 45 | self.layers = nn.ModuleList() 46 | for _ in range(num_layers): 47 | self.layers.append( 48 | GatedConv(input_size, cnn_kernel_width, dropout)) 49 | 50 | def forward(self, x): 51 | for conv in self.layers: 52 | x = x + conv(x) 53 | x *= SCALE_WEIGHT 54 | return x 55 | -------------------------------------------------------------------------------- /onmt/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ Pytorch Distributed utils 2 | This piece of code was heavily inspired by the equivalent of Fairseq-py 3 | https://github.com/pytorch/fairseq 4 | """ 5 | 6 | 7 | from __future__ import print_function 8 | 9 | import math 10 | import pickle 11 | import torch.distributed 12 | 13 | from onmt.utils.logging import logger 14 | 15 | 16 | def is_master(opt, device_id): 17 | return opt.gpu_ranks[device_id] == 0 18 | 19 | 20 | def multi_init(opt, device_id): 21 | dist_init_method = 'tcp://{master_ip}:{master_port}'.format( 22 | master_ip=opt.master_ip, 23 | master_port=opt.master_port) 24 | dist_world_size = opt.world_size 25 | torch.distributed.init_process_group( 26 | backend=opt.gpu_backend, init_method=dist_init_method, 27 | world_size=dist_world_size, rank=opt.gpu_ranks[device_id]) 28 | gpu_rank = torch.distributed.get_rank() 29 | if not is_master(opt, device_id): 30 | logger.disabled = True 31 | 32 | return gpu_rank 33 | 34 | 35 | def all_reduce_and_rescale_tensors(tensors, rescale_denom, 36 | buffer_size=10485760): 37 | """All-reduce and rescale tensors in chunks of the specified size. 38 | 39 | Args: 40 | tensors: list of Tensors to all-reduce 41 | rescale_denom: denominator for rescaling summed Tensors 42 | buffer_size: all-reduce chunk size in bytes 43 | """ 44 | # buffer size in bytes, determine equiv. # of elements based on data type 45 | buffer_t = tensors[0].new( 46 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 47 | buffer = [] 48 | 49 | def all_reduce_buffer(): 50 | # copy tensors into buffer_t 51 | offset = 0 52 | for t in buffer: 53 | numel = t.numel() 54 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 55 | offset += numel 56 | 57 | # all-reduce and rescale 58 | torch.distributed.all_reduce(buffer_t[:offset]) 59 | buffer_t.div_(rescale_denom) 60 | 61 | # copy all-reduced buffer back into tensors 62 | offset = 0 63 | for t in buffer: 64 | numel = t.numel() 65 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 66 | offset += numel 67 | 68 | filled = 0 69 | for t in tensors: 70 | sz = t.numel() * t.element_size() 71 | if sz > buffer_size: 72 | # tensor is bigger than buffer, all-reduce and rescale directly 73 | torch.distributed.all_reduce(t) 74 | t.div_(rescale_denom) 75 | elif filled + sz > buffer_size: 76 | # buffer is full, all-reduce and replace buffer with grad 77 | all_reduce_buffer() 78 | buffer = [t] 79 | filled = sz 80 | else: 81 | # add tensor to buffer 82 | buffer.append(t) 83 | filled += sz 84 | 85 | if len(buffer) > 0: 86 | all_reduce_buffer() 87 | 88 | 89 | def all_gather_list(data, max_size=4096): 90 | """Gathers arbitrary data from all nodes into a list.""" 91 | world_size = torch.distributed.get_world_size() 92 | if not hasattr(all_gather_list, '_in_buffer') or \ 93 | max_size != all_gather_list._in_buffer.size(): 94 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 95 | all_gather_list._out_buffers = [ 96 | torch.cuda.ByteTensor(max_size) 97 | for i in range(world_size) 98 | ] 99 | in_buffer = all_gather_list._in_buffer 100 | out_buffers = all_gather_list._out_buffers 101 | 102 | enc = pickle.dumps(data) 103 | enc_size = len(enc) 104 | if enc_size + 2 > max_size: 105 | raise ValueError( 106 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 107 | assert max_size < 255*256 108 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 109 | in_buffer[1] = enc_size % 255 110 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 111 | 112 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 113 | 114 | results = [] 115 | for i in range(world_size): 116 | out_buffer = out_buffers[i] 117 | size = (255 * out_buffer[0].item()) + out_buffer[1].item() 118 | 119 | bytes_list = bytes(out_buffer[2:size+2].tolist()) 120 | result = pickle.loads(bytes_list) 121 | results.append(result) 122 | return results 123 | -------------------------------------------------------------------------------- /onmt/utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | console_handler = logging.StreamHandler() 15 | console_handler.setFormatter(log_format) 16 | logger.handlers = [console_handler] 17 | 18 | if log_file and log_file != '': 19 | file_handler = logging.FileHandler(log_file) 20 | file_handler.setLevel(log_file_level) 21 | file_handler.setFormatter(log_format) 22 | logger.addHandler(file_handler) 23 | 24 | return logger 25 | -------------------------------------------------------------------------------- /onmt/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import random 5 | import inspect 6 | from itertools import islice 7 | 8 | 9 | def split_corpus(path, shard_size, iter_func=None, binary=True): 10 | priv_str = "r" 11 | if binary: 12 | priv_str += "b" 13 | with open(path, priv_str) as f: 14 | if shard_size <= 0: 15 | if iter_func is not None: 16 | yield iter_func(f.readlines()) 17 | else: 18 | yield f.readlines() 19 | else: 20 | while True: 21 | shard = list(islice(f, shard_size)) 22 | if not shard: 23 | break 24 | if iter_func is not None: 25 | yield iter_func(shard) 26 | else: 27 | yield shard 28 | 29 | 30 | def aeq(*args): 31 | """ 32 | Assert all arguments have the same value 33 | """ 34 | arguments = (arg for arg in args) 35 | first = next(arguments) 36 | assert all(arg == first for arg in arguments), \ 37 | "Not all arguments have the same value: " + str(args) 38 | 39 | 40 | def sequence_mask(lengths, max_len=None): 41 | """ 42 | Creates a boolean mask from sequence lengths. 43 | """ 44 | batch_size = lengths.numel() 45 | max_len = max_len or lengths.max() 46 | return (torch.arange(0, max_len) 47 | .type_as(lengths) 48 | .repeat(batch_size, 1) 49 | .lt(lengths.unsqueeze(1))) 50 | 51 | 52 | def tile(x, count, dim=0): 53 | """ 54 | Tiles x on dimension dim count times. 55 | """ 56 | perm = list(range(len(x.size()))) 57 | if dim != 0: 58 | perm[0], perm[dim] = perm[dim], perm[0] 59 | x = x.permute(perm).contiguous() 60 | out_size = list(x.size()) 61 | out_size[0] *= count 62 | batch = x.size(0) 63 | ''' 64 | x = x.view(batch, -1) \ 65 | .transpose(0, 1) \ 66 | .repeat(count, 1) \ 67 | .transpose(0, 1) \ 68 | .contiguous() \ 69 | .view(*out_size) 70 | ''' 71 | x = x.contiguous().view(batch, -1) 72 | x = x.transpose(0, 1) 73 | x = x.repeat(count, 1) 74 | x = x.transpose(0, 1) 75 | x = x.contiguous() 76 | x = x.view(*out_size) 77 | if dim != 0: 78 | x = x.permute(perm).contiguous() 79 | return x 80 | 81 | 82 | def use_gpu(opt): 83 | """ 84 | Creates a boolean if gpu used 85 | """ 86 | return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \ 87 | (hasattr(opt, 'gpu') and opt.gpu > -1) 88 | 89 | 90 | def set_random_seed(seed, is_cuda): 91 | """Sets the random seed.""" 92 | if seed > 0: 93 | torch.manual_seed(seed) 94 | # this one is needed for torchtext random call (shuffled iterator) 95 | # in multi gpu it ensures datasets are read in the same order 96 | random.seed(seed) 97 | # some cudnn methods can be random even after fixing the seed 98 | # unless you tell it to be deterministic 99 | torch.backends.cudnn.deterministic = True 100 | 101 | if is_cuda and seed > 0: 102 | # These ensure same initialization in multi gpu mode 103 | torch.cuda.manual_seed(seed) 104 | 105 | 106 | def generate_relative_positions_matrix(length, max_relative_positions, 107 | cache=False): 108 | """Generate the clipped relative positions matrix 109 | for a given length and maximum relative positions""" 110 | if cache: 111 | distance_mat = torch.arange(-length+1, 1, 1).unsqueeze(0) 112 | else: 113 | range_vec = torch.arange(length) 114 | range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1) 115 | distance_mat = range_mat - range_mat.transpose(0, 1) 116 | distance_mat_clipped = torch.clamp(distance_mat, 117 | min=-max_relative_positions, 118 | max=max_relative_positions) 119 | # Shift values to be >= 0 120 | final_mat = distance_mat_clipped + max_relative_positions 121 | return final_mat 122 | 123 | 124 | def relative_matmul(x, z, transpose): 125 | """Helper function for relative positions attention.""" 126 | batch_size = x.shape[0] 127 | heads = x.shape[1] 128 | length = x.shape[2] 129 | x_t = x.permute(2, 0, 1, 3) 130 | x_t_r = x_t.reshape(length, heads * batch_size, -1) 131 | if transpose: 132 | z_t = z.transpose(1, 2) 133 | x_tz_matmul = torch.matmul(x_t_r, z_t) 134 | else: 135 | x_tz_matmul = torch.matmul(x_t_r, z) 136 | x_tz_matmul_r = x_tz_matmul.reshape(length, batch_size, heads, -1) 137 | x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3) 138 | return x_tz_matmul_r_t 139 | 140 | 141 | def fn_args(fun): 142 | """Returns the list of function arguments name.""" 143 | return inspect.getfullargspec(fun).args 144 | -------------------------------------------------------------------------------- /onmt/utils/report_manager.py: -------------------------------------------------------------------------------- 1 | """ Report manager utility """ 2 | from __future__ import print_function 3 | import time 4 | from datetime import datetime 5 | 6 | import onmt 7 | 8 | from onmt.utils.logging import logger 9 | 10 | 11 | def build_report_manager(opt): 12 | if opt.tensorboard: 13 | from tensorboardX import SummaryWriter 14 | tensorboard_log_dir = opt.tensorboard_log_dir 15 | 16 | #if not opt.train_from: 17 | # tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S") 18 | #tensorboard_log_dir += opt.config+'_'+opt.run_name 19 | 20 | writer = SummaryWriter(tensorboard_log_dir, 21 | comment="Unmt") 22 | else: 23 | writer = None 24 | 25 | report_mgr = ReportMgr(opt.report_every, start_time=-1, 26 | tensorboard_writer=writer) 27 | return report_mgr 28 | 29 | 30 | class ReportMgrBase(object): 31 | """ 32 | Report Manager Base class 33 | Inherited classes should override: 34 | * `_report_training` 35 | * `_report_step` 36 | """ 37 | 38 | def __init__(self, report_every, start_time=-1.): 39 | """ 40 | Args: 41 | report_every(int): Report status every this many sentences 42 | start_time(float): manually set report start time. Negative values 43 | means that you will need to set it later or use `start()` 44 | """ 45 | self.report_every = report_every 46 | self.progress_step = 0 47 | self.start_time = start_time 48 | 49 | def start(self): 50 | self.start_time = time.time() 51 | 52 | def log(self, *args, **kwargs): 53 | logger.info(*args, **kwargs) 54 | 55 | def report_training(self, step, num_steps, learning_rate, 56 | report_stats, multigpu=False): 57 | """ 58 | This is the user-defined batch-level traing progress 59 | report function. 60 | 61 | Args: 62 | step(int): current step count. 63 | num_steps(int): total number of batches. 64 | learning_rate(float): current learning rate. 65 | report_stats(Statistics): old Statistics instance. 66 | Returns: 67 | report_stats(Statistics): updated Statistics instance. 68 | """ 69 | if self.start_time < 0: 70 | raise ValueError("""ReportMgr needs to be started 71 | (set 'start_time' or use 'start()'""") 72 | 73 | if step % self.report_every == 0: 74 | if multigpu: 75 | report_stats = \ 76 | onmt.utils.Statistics.all_gather_stats(report_stats) 77 | self._report_training( 78 | step, num_steps, learning_rate, report_stats) 79 | self.progress_step += 1 80 | return onmt.utils.Statistics() 81 | else: 82 | return report_stats 83 | 84 | def _report_training(self, *args, **kwargs): 85 | """ To be overridden """ 86 | raise NotImplementedError() 87 | 88 | def report_step(self, lr, step, train_stats=None, valid_stats=None): 89 | """ 90 | Report stats of a step 91 | 92 | Args: 93 | train_stats(Statistics): training stats 94 | valid_stats(Statistics): validation stats 95 | lr(float): current learning rate 96 | """ 97 | self._report_step( 98 | lr, step, train_stats=train_stats, valid_stats=valid_stats) 99 | 100 | def _report_step(self, *args, **kwargs): 101 | raise NotImplementedError() 102 | 103 | 104 | class ReportMgr(ReportMgrBase): 105 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None): 106 | """ 107 | A report manager that writes statistics on standard output as well as 108 | (optionally) TensorBoard 109 | 110 | Args: 111 | report_every(int): Report status every this many sentences 112 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`): 113 | The TensorBoard Summary writer to use or None 114 | """ 115 | super(ReportMgr, self).__init__(report_every, start_time) 116 | self.tensorboard_writer = tensorboard_writer 117 | 118 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step): 119 | if self.tensorboard_writer is not None: 120 | stats.log_tensorboard( 121 | prefix, self.tensorboard_writer, learning_rate, step) 122 | 123 | def _report_training(self, step, num_steps, learning_rate, 124 | report_stats): 125 | """ 126 | See base class method `ReportMgrBase.report_training`. 127 | """ 128 | report_stats.output(step, num_steps, 129 | learning_rate, self.start_time) 130 | 131 | # Log the progress using the number of batches on the x-axis. 132 | self.maybe_log_tensorboard(report_stats, 133 | "progress", 134 | learning_rate, 135 | self.progress_step) 136 | report_stats = onmt.utils.Statistics() 137 | 138 | return report_stats 139 | 140 | def _report_step(self, lr, step, train_stats=None, valid_stats=None): 141 | """ 142 | See base class method `ReportMgrBase.report_step`. 143 | """ 144 | if train_stats is not None: 145 | self.log('Train perplexity: %g' % train_stats.ppl()) 146 | self.log('Train accuracy: %g' % train_stats.accuracy()) 147 | 148 | self.maybe_log_tensorboard(train_stats, 149 | "train", 150 | lr, 151 | step) 152 | 153 | if valid_stats is not None: 154 | self.log('Validation perplexity: %g' % valid_stats.ppl()) 155 | self.log('Validation xent: %g' % valid_stats.xent()) 156 | self.log('Validation accuracy: %g' % valid_stats.accuracy()) 157 | 158 | self.maybe_log_tensorboard(valid_stats, 159 | "valid", 160 | lr, 161 | step) 162 | -------------------------------------------------------------------------------- /onmt/utils/rnn_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | RNN tools 3 | """ 4 | import torch.nn as nn 5 | import onmt.models 6 | 7 | 8 | def rnn_factory(rnn_type, **kwargs): 9 | """ rnn factory, Use pytorch version when available. """ 10 | no_pack_padded_seq = False 11 | if rnn_type == "SRU": 12 | # SRU doesn't support PackedSequence. 13 | no_pack_padded_seq = True 14 | rnn = onmt.models.sru.SRU(**kwargs) 15 | else: 16 | rnn = getattr(nn, rnn_type)(**kwargs) 17 | return rnn, no_pack_padded_seq 18 | -------------------------------------------------------------------------------- /onmt/utils/statistics.py: -------------------------------------------------------------------------------- 1 | """ Statistics calculation utility """ 2 | from __future__ import division 3 | import time 4 | import math 5 | import sys 6 | 7 | from onmt.utils.logging import logger 8 | 9 | 10 | class Statistics(object): 11 | """ 12 | Accumulator for loss statistics. 13 | Currently calculates: 14 | 15 | * accuracy 16 | * perplexity 17 | * elapsed time 18 | """ 19 | 20 | def __init__(self, loss=0, n_words=0, n_correct=0): 21 | self.loss = loss 22 | self.n_words = n_words 23 | self.n_correct = n_correct 24 | self.n_src_words = 0 25 | self.start_time = time.time() 26 | 27 | @staticmethod 28 | def all_gather_stats(stat, max_size=4096): 29 | """ 30 | Gather a `Statistics` object accross multiple process/nodes 31 | 32 | Args: 33 | stat(:obj:Statistics): the statistics object to gather 34 | accross all processes/nodes 35 | max_size(int): max buffer size to use 36 | 37 | Returns: 38 | `Statistics`, the update stats object 39 | """ 40 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size) 41 | return stats[0] 42 | 43 | @staticmethod 44 | def all_gather_stats_list(stat_list, max_size=4096): 45 | """ 46 | Gather a `Statistics` list accross all processes/nodes 47 | 48 | Args: 49 | stat_list(list([`Statistics`])): list of statistics objects to 50 | gather accross all processes/nodes 51 | max_size(int): max buffer size to use 52 | 53 | Returns: 54 | our_stats(list([`Statistics`])): list of updated stats 55 | """ 56 | from torch.distributed import get_rank 57 | from onmt.utils.distributed import all_gather_list 58 | 59 | # Get a list of world_size lists with len(stat_list) Statistics objects 60 | all_stats = all_gather_list(stat_list, max_size=max_size) 61 | 62 | our_rank = get_rank() 63 | our_stats = all_stats[our_rank] 64 | for other_rank, stats in enumerate(all_stats): 65 | if other_rank == our_rank: 66 | continue 67 | for i, stat in enumerate(stats): 68 | our_stats[i].update(stat, update_n_src_words=True) 69 | return our_stats 70 | 71 | def update(self, stat, update_n_src_words=False): 72 | """ 73 | Update statistics by suming values with another `Statistics` object 74 | 75 | Args: 76 | stat: another statistic object 77 | update_n_src_words(bool): whether to update (sum) `n_src_words` 78 | or not 79 | 80 | """ 81 | self.loss += stat.loss 82 | self.n_words += stat.n_words 83 | self.n_correct += stat.n_correct 84 | 85 | if update_n_src_words: 86 | self.n_src_words += stat.n_src_words 87 | 88 | def accuracy(self): 89 | """ compute accuracy """ 90 | return 100 * (self.n_correct / self.n_words) 91 | 92 | def xent(self): 93 | """ compute cross entropy """ 94 | return self.loss / self.n_words 95 | 96 | def ppl(self): 97 | """ compute perplexity """ 98 | return math.exp(min(self.loss / self.n_words, 100)) 99 | 100 | def elapsed_time(self): 101 | """ compute elapsed time """ 102 | return time.time() - self.start_time 103 | 104 | def output(self, step, num_steps, learning_rate, start): 105 | """Write out statistics to stdout. 106 | 107 | Args: 108 | step (int): current step 109 | n_batch (int): total batches 110 | start (int): start time of step. 111 | """ 112 | t = self.elapsed_time() 113 | step_fmt = "%2d" % step 114 | if num_steps > 0: 115 | step_fmt = "%s/%5d" % (step_fmt, num_steps) 116 | logger.info( 117 | ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.3f; " + 118 | "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") 119 | % (step_fmt, 120 | self.accuracy(), 121 | self.ppl(), 122 | self.xent(), 123 | learning_rate, 124 | self.n_src_words / (t + 1e-5), 125 | self.n_words / (t + 1e-5), 126 | time.time() - start)) 127 | sys.stdout.flush() 128 | 129 | def log_tensorboard(self, prefix, writer, learning_rate, step): 130 | """ display statistics to tensorboard """ 131 | t = self.elapsed_time() 132 | writer.add_scalar(prefix + "/xent", self.xent(), step) 133 | writer.add_scalar(prefix + "/ppl", self.ppl(), step) 134 | writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) 135 | writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) 136 | writer.add_scalar(prefix + "/lr", learning_rate, step) 137 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | six 2 | tqdm==4.30.* 3 | torch==1.0.1 4 | git+https://github.com/pytorch/text.git@master#wheel=torchtext 5 | future 6 | configargparse 7 | PyYAML 8 | tensorflow 9 | tensorboardX 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | setup(name='OpenNMT-py', 6 | description='A python implementation of OpenNMT', 7 | version='0.8.2', 8 | 9 | packages=['onmt', 'onmt.encoders', 'onmt.modules', 'onmt.tests', 10 | 'onmt.translate', 'onmt.decoders', 'onmt.inputters', 11 | 'onmt.models', 'onmt.utils']) 12 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Train models.""" 3 | import os 4 | import glob 5 | import numpy as np 6 | import signal 7 | import torch 8 | 9 | import onmt.opts as opts 10 | import onmt.utils.distributed 11 | 12 | from onmt.utils.logging import logger 13 | from onmt.train_single import main as single_main 14 | from onmt.utils.parse import ArgumentParser 15 | 16 | 17 | def main(opt): 18 | ArgumentParser.validate_train_opts(opt) 19 | ArgumentParser.update_model_opts(opt) 20 | ArgumentParser.validate_model_opts(opt) 21 | 22 | nb_gpu = len(opt.gpu_ranks) 23 | 24 | if opt.world_size > 1: 25 | mp = torch.multiprocessing.get_context('spawn') 26 | # Create a thread to listen for errors in the child processes. 27 | error_queue = mp.SimpleQueue() 28 | error_handler = ErrorHandler(error_queue) 29 | # Train with multiprocessing. 30 | procs = [] 31 | for device_id in range(nb_gpu): 32 | procs.append(mp.Process(target=run, args=( 33 | opt, device_id, error_queue, ), daemon=True)) 34 | procs[device_id].start() 35 | logger.info(" Starting process pid: %d " % procs[device_id].pid) 36 | error_handler.add_child(procs[device_id].pid) 37 | for p in procs: 38 | p.join() 39 | 40 | elif nb_gpu == 1: # case 1 GPU only 41 | single_main(opt, 0) 42 | else: # case only CPU 43 | single_main(opt, -1) 44 | 45 | 46 | def run(opt, device_id, error_queue): 47 | """ run process """ 48 | try: 49 | gpu_rank = onmt.utils.distributed.multi_init(opt, device_id) 50 | if gpu_rank != opt.gpu_ranks[device_id]: 51 | raise AssertionError("An error occurred in \ 52 | Distributed initialization") 53 | single_main(opt, device_id) 54 | except KeyboardInterrupt: 55 | pass # killed by parent, do nothing 56 | except Exception: 57 | # propagate exception to parent process, keeping original traceback 58 | import traceback 59 | error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc())) 60 | 61 | 62 | class ErrorHandler(object): 63 | """A class that listens for exceptions in children processes and propagates 64 | the tracebacks to the parent process.""" 65 | 66 | def __init__(self, error_queue): 67 | """ init error handler """ 68 | import signal 69 | import threading 70 | self.error_queue = error_queue 71 | self.children_pids = [] 72 | self.error_thread = threading.Thread( 73 | target=self.error_listener, daemon=True) 74 | self.error_thread.start() 75 | signal.signal(signal.SIGUSR1, self.signal_handler) 76 | 77 | def add_child(self, pid): 78 | """ error handler """ 79 | self.children_pids.append(pid) 80 | 81 | def error_listener(self): 82 | """ error listener """ 83 | (rank, original_trace) = self.error_queue.get() 84 | self.error_queue.put((rank, original_trace)) 85 | os.kill(os.getpid(), signal.SIGUSR1) 86 | 87 | def signal_handler(self, signalnum, stackframe): 88 | """ signal handler """ 89 | for pid in self.children_pids: 90 | os.kill(pid, signal.SIGINT) # kill children processes 91 | (rank, original_trace) = self.error_queue.get() 92 | msg = """\n\n-- Tracebacks above this line can probably 93 | be ignored --\n\n""" 94 | msg += original_trace 95 | raise Exception(msg) 96 | 97 | 98 | def _get_parser(): 99 | parser = ArgumentParser(description='train.py') 100 | 101 | opts.config_opts(parser) 102 | opts.model_opts(parser) 103 | opts.train_opts(parser) 104 | return parser 105 | 106 | 107 | if __name__ == "__main__": 108 | parser = _get_parser() 109 | 110 | opt = parser.parse_args() 111 | 112 | if opt.config is None or opt.run_name is None: 113 | raise ValueError('base config and run_name must be set during training') 114 | 115 | config_name = opt.config.split('/')[-1] 116 | config_name = ''.join(config_name.split('.')[:-1]) 117 | 118 | dataset_name = opt.data.split('/')[-2]+'_'+opt.data.split('/')[-1] 119 | output_dir = 'output/'+dataset_name+'/'+config_name+'_'+opt.run_name+'/' 120 | os.makedirs(output_dir, exist_ok=True) 121 | 122 | setattr(opt, 'save_model', output_dir+'checkpoints/model') 123 | setattr(opt, 'save_config', output_dir+'config.yml') 124 | setattr(opt, 'tensorboard_log_dir', 'output/'+dataset_name+'/tblogs/'+config_name+'_'+opt.run_name) 125 | parser.write_config_file(opt, [output_dir+'config.yml']) 126 | 127 | if opt.autorestart: 128 | filenames = [] 129 | step_nums = [] 130 | for filename in glob.glob(output_dir+'checkpoints/*.pt'): 131 | filenames.append(filename) 132 | step_num = os.path.basename(filename).split('_')[-1][:-3] 133 | step_nums.append(int(step_num)) 134 | 135 | if len(filenames) > 0: 136 | indices = np.argsort(step_nums) 137 | filenames = np.array(filenames)[indices] 138 | 139 | opt.train_from = filenames[-1] 140 | opt.gpt2_init_embanddec = False 141 | opt.encoder_from = None 142 | opt.gpt2_params_path = None 143 | 144 | main(opt) 145 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import unicode_literals 5 | from itertools import repeat 6 | import os 7 | import numpy as np 8 | import json 9 | 10 | from onmt.utils.logging import init_logger 11 | from onmt.utils.misc import split_corpus 12 | from onmt.translate.translator import build_translator 13 | 14 | import onmt.opts as opts 15 | from onmt.utils.parse import ArgumentParser 16 | 17 | def constraint_iter_func(f_iter): 18 | all_tags = [] 19 | for json_line in f_iter: 20 | data = json.loads(json_line) 21 | words = data['words'] 22 | probs = [p[1] for p in data['class_probabilities'][:len(words)]] 23 | tags = [1 if p > opt.bu_threshold else 0 for p in probs] 24 | all_tags.append(tags) 25 | #print(len(words), len(data['class_probabilities'])) 26 | #all_tags.append(words) 27 | return all_tags 28 | 29 | 30 | def main(opt): 31 | ArgumentParser.validate_translate_opts(opt) 32 | logger = init_logger(opt.log_file) 33 | 34 | if opt.constraint_file: 35 | tag_shards = split_corpus(opt.constraint_file, opt.shard_size, iter_func=constraint_iter_func, binary=False) 36 | 37 | translator = build_translator(opt, report_score=True) 38 | 39 | if opt.data_type == 'imgvec': 40 | assert opt.shard_size <= 0 41 | src_shards = [opt.src] 42 | else: 43 | if opt.data_type == 'none': 44 | src_shards = [None]*99999 45 | else: 46 | src_shards = split_corpus(opt.src, opt.shard_size) 47 | tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ 48 | if opt.tgt is not None else repeat(None) 49 | shard_pairs = zip(src_shards, tgt_shards) 50 | 51 | for i, (src_shard, tgt_shard) in enumerate(shard_pairs): 52 | logger.info("Translating shard %d." % i) 53 | 54 | tag_shard = None 55 | if opt.constraint_file: 56 | tag_shard = next(tag_shards) 57 | 58 | translator.translate( 59 | src=src_shard, 60 | tgt=tgt_shard, 61 | src_dir=opt.src_dir, 62 | batch_size=opt.batch_size, 63 | attn_debug=opt.attn_debug, 64 | tag_shard=tag_shard 65 | ) 66 | 67 | 68 | def _get_parser(): 69 | parser = ArgumentParser(description='translate.py') 70 | 71 | opts.config_opts(parser) 72 | opts.translate_opts(parser) 73 | return parser 74 | 75 | 76 | if __name__ == "__main__": 77 | parser = _get_parser() 78 | 79 | opt = parser.parse_args() 80 | 81 | model_path = opt.models[0] 82 | step = os.path.basename(model_path)[:-3].split('step_')[-1] 83 | temp = opt.random_sampling_temp 84 | 85 | if opt.extra_output_str: 86 | opt.extra_output_str = '_'+opt.extra_output_str 87 | 88 | if opt.output is None: 89 | output_path = '/'.join(model_path.split('/')[:-2])+'/output_%s_%s%s.encoded' % (step, temp, opt.extra_output_str) 90 | opt.output = output_path 91 | print(opt.output) 92 | 93 | main(opt) 94 | --------------------------------------------------------------------------------