├── .gitignore ├── .jenkins ├── README.md └── build.sh ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docker └── jenkins │ ├── Dockerfile │ ├── README.md │ ├── add_jenkins_user.sh │ ├── build.sh │ └── install_prereqs.sh ├── pytorch_translate ├── __init__.py ├── attention │ ├── __init__.py │ ├── attention_utils.py │ ├── base_attention.py │ ├── dot_attention.py │ ├── mlp_attention.py │ ├── multihead_attention.py │ ├── no_attention.py │ └── pooling_attention.py ├── average_attention.py ├── beam_decode.py ├── beam_search_and_decode_v2.py ├── benchmark.py ├── bleu_significance.py ├── char_aware_hybrid.py ├── char_encoder.py ├── char_source_hybrid.py ├── char_source_model.py ├── char_source_transformer_model.py ├── checkpoint.py ├── common_layers.py ├── constants.py ├── data │ ├── __init__.py │ ├── char_data.py │ ├── data.py │ ├── dictionary.py │ ├── iterators.py │ ├── language_pair_upsampling_dataset.py │ ├── masked_lm_dictionary.py │ ├── utils.py │ └── weighted_data.py ├── dual_learning │ ├── __init__.py │ ├── dual_learning_criterion.py │ ├── dual_learning_models.py │ └── dual_learning_task.py ├── ensemble_export.py ├── evals.py ├── examples │ ├── generate_iwslt14.sh │ ├── train_iwslt14.sh │ ├── train_lm.sh │ └── translate_iwslt14.sh ├── file_io.py ├── generate.py ├── hybrid_transformer_rnn.py ├── model_constants.py ├── models │ ├── __init__.py │ └── transformer_from_pretrained_xlm.py ├── multi_model.py ├── multilingual.py ├── multilingual_model.py ├── multilingual_utils.py ├── ngram.py ├── options.py ├── preprocess.py ├── rescoring │ ├── model_scorers.py │ ├── rescorer.py │ └── weights_search.py ├── research │ ├── __init__.py │ ├── attention │ │ └── multihead_attention.py │ ├── beam_search │ │ ├── __init__.py │ │ └── competing_completed.py │ ├── deliberation_networks │ │ └── deliberation_networks.py │ ├── knowledge_distillation │ │ ├── __init__.py │ │ ├── collect_top_k_probs.py │ │ ├── dual_decoder_kd_loss.py │ │ ├── dual_decoder_kd_model.py │ │ ├── hybrid_dual_decoder_kd_model.py │ │ ├── knowledge_distillation_loss.py │ │ └── teacher_score_data.py │ ├── lexical_choice │ │ ├── __init__.py │ │ └── lexical_translation.py │ ├── multisource │ │ ├── __init__.py │ │ ├── multisource_data.py │ │ └── multisource_decode.py │ ├── rescore │ │ ├── __init__.py │ │ ├── cloze_transformer_model.py │ │ └── rescoring_criterion.py │ └── tune_ensemble_weights │ │ ├── tune_model_weights.py │ │ └── tune_model_weights_with_ax.py ├── rnn.py ├── rnn_cell.py ├── semi_supervised.py ├── sequence_criterions.py ├── tasks │ ├── __init__.py │ ├── cross_lingual_lm.py │ ├── denoising_autoencoder_task.py │ ├── knowledge_distillation_task.py │ ├── multilingual_task.py │ ├── pytorch_translate_multi_task.py │ ├── pytorch_translate_task.py │ ├── semi_supervised_task.py │ ├── translation_from_pretrained_xlm.py │ ├── translation_lev_task.py │ └── utils.py ├── torchscript_export.py ├── train.py ├── transformer.py ├── transformer_aan.py ├── utils.py ├── vocab_constants.py ├── vocab_reduction.py ├── weighted_criterions.py └── word_prediction │ ├── __init__.py │ ├── word_prediction_criterion.py │ ├── word_prediction_model.py │ └── word_predictor.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Setuptools distribution folder. 5 | build/ 6 | /dist/ 7 | 8 | # Python egg metadata, regenerated from source files by setuptools. 9 | /*.egg-info 10 | /*.egg 11 | 12 | # Dot files 13 | .python3 14 | 15 | # Produced by install script 16 | onnx/ 17 | nccl_2.1.15-1+cuda8.0_x86_64/ 18 | pytorch/ 19 | 20 | # Produced by example scripts 21 | checkpoints/ 22 | data/ 23 | decoder.pb 24 | encoder.pb 25 | model/ 26 | 27 | # CMake compilation 28 | CMakeCache.txt 29 | CMakeFiles 30 | Makefile 31 | cmake_install.cmake 32 | translation_decoder 33 | -------------------------------------------------------------------------------- /.jenkins/README.md: -------------------------------------------------------------------------------- 1 | # Jenkins continuous integration test 2 | 3 | ## Usage 4 | 5 | Jenkins will automatically trigger a new build/test under 6 | https://ci.pytorch.org/jenkins/job/translate-builds/job/translate-xenial-cuda9-cudnn7-py3-build-test/ 7 | whenever a pull request is opened or updated, and display the results on the 8 | pull request. If you need to re-run the test due to infrastructure issues 9 | or non-code-related changes, you can manually trigger a re-test by commenting 10 | on the pull request "@pytorchbot retest this please" (this should usually not 11 | be necessary). 12 | 13 | ## Implementation details 14 | 15 | After Jenkins has patched the appropriate version of the PyTorch Translate 16 | repo on a Docker image, it calls `build.sh` via 17 | https://github.com/pytorch/ossci-job-dsl/blob/master/src/jobs/translate.groovy 18 | to build/install PyTorch Translate and run tests. `build.sh` should be 19 | self-sufficient given the repo code, and not require any internet access 20 | to build and run the tests. (Note that fairseq's 21 | [setup.py](https://github.com/pytorch/fairseq/blob/main/setup.py#L42) 22 | may try to fetch some 23 | [required packages](https://github.com/pytorch/fairseq/blob/main/requirements.txt). 24 | They're currently already installed in our Docker image 25 | [requirements](https://github.com/pytorch/translate/blob/master/docker/jenkins/install_prereqs.sh) - 26 | though that may need to be updated in the future.) 27 | 28 | If you need to download other packages or dependencies, consider adding them 29 | to the Docker image instead, in the `../docker/jenkins/Dockerfile` used 30 | by `../docker/jenkins/build.sh`. 31 | -------------------------------------------------------------------------------- /.jenkins/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Builds PyTorch Translate and runs basic tests. 3 | 4 | pip uninstall -y pytorch-translate 5 | python3 setup.py build develop 6 | python3 setup.py test 7 | 8 | # TODO(weiho): Re-enable testing these end-to-end scripts after refactoring 9 | # out the wget to be part of the Dockerfile. Possibly wait for v2 of our 10 | # OSS CI. 11 | # . pytorch_translate/examples/train_iwslt14.sh 12 | # . pytorch_translate/examples/generate_iwslt14.sh 13 | # . pytorch_translate/examples/export_iwslt14.sh 14 | # echo "hallo welt ." | . pytorch_translate/examples/translate_iwslt14.sh 15 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | 78 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Translate 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | Development will be done by pull requests. In some rare cases, internal changes will be submitted. 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | If you are not familiar with creating a Pull Request, here are some guides: 11 | - http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request 12 | - https://help.github.com/articles/creating-a-pull-request/ 13 | 14 | ## Issues 15 | We use GitHub issues to track public bugs. Please ensure your description is 16 | clear and has sufficient instructions to be able to reproduce the issue. 17 | 18 | ## Coding Style 19 | 20 | * Python: please follow [the PEP style](https://www.python.org/dev/peps/pep-0008/) 21 | * C++: please use `clang-format -style google` to format your code 22 | 23 | ## License 24 | By contributing to Translate, you agree that your contributions will be licensed 25 | under the LICENSE file in the root directory of this source tree. 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018-, Facebook 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /docker/jenkins/Dockerfile: -------------------------------------------------------------------------------- 1 | # Pre-req installations: 2 | # https://docs.docker.com/install/linux/docker-ce/ubuntu/ 3 | # https://github.com/NVIDIA/nvidia-docker 4 | 5 | # Usage: 6 | # sudo docker build -t pytorch_translate_initial_release . 2>&1 | tee stdout 7 | # or 8 | # sudo nvidia-docker build -t pytorch_translate_initial_release . 2>&1 | tee stdout 9 | # sudo nvidia-docker run -i -t --rm pytorch_translate_initial_release /bin/bash 10 | 11 | # Remove all stopped Docker containers: sudo docker rm $(sudo docker ps -a -q) 12 | # Remove all untagged images: sudo docker rmi $(sudo docker images -q --filter "dangling=true") 13 | 14 | # Available versions: https://hub.docker.com/r/nvidia/cuda/ 15 | FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 16 | 17 | SHELL ["/bin/bash", "-c"] 18 | 19 | RUN apt-get update && apt-get install -y --no-install-recommends \ 20 | build-essential \ 21 | ca-certificates \ 22 | cmake \ 23 | git \ 24 | libgflags-dev \ 25 | libgoogle-glog-dev \ 26 | libgtest-dev \ 27 | libiomp-dev \ 28 | libleveldb-dev \ 29 | liblmdb-dev \ 30 | libopencv-dev \ 31 | libopenmpi-dev \ 32 | libprotobuf-dev \ 33 | libsnappy-dev \ 34 | locales \ 35 | openmpi-bin \ 36 | openmpi-doc \ 37 | protobuf-compiler \ 38 | sudo \ 39 | wget 40 | 41 | # Sometimes needed to avoid SSL CA issues. 42 | RUN update-ca-certificates 43 | 44 | ENV HOME /home 45 | WORKDIR ${HOME}/ 46 | 47 | RUN wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh && \ 48 | chmod +x miniconda.sh && \ 49 | ./miniconda.sh -b -p ${HOME}/miniconda && \ 50 | rm miniconda.sh 51 | 52 | # Setting these env var outside of the install script to ensure 53 | # they persist in image 54 | # (See https://stackoverflow.com/questions/33379393/docker-env-vs-run-export) 55 | ENV PATH ${HOME}/miniconda/bin:$PATH 56 | ENV CONDA_PATH ${HOME}/miniconda 57 | 58 | # Needed to prevent UnicodeDecodeError: 'ascii' codec can't decode byte 59 | # when installing fairseq. 60 | RUN locale-gen en_US.UTF-8 61 | ENV LANG en_US.UTF-8 62 | ENV LANGUAGE en_US:en 63 | ENV LC_ALL en_US.UTF-8 64 | 65 | # Reminder: this should be updated when switching between CUDA 8 or 9. Should 66 | # be kept in sync with TMP_CUDA_VERSION in install_prereqs.sh 67 | ENV NCCL_ROOT_DIR ${HOME}/translate/nccl_2.1.15-1+cuda9.0_x86_64 68 | ENV LD_LIBRARY_PATH ${CONDA_PATH}/lib:${NCCL_ROOT_DIR}/lib:${LD_LIBRARY_PATH} 69 | 70 | ADD ./install_prereqs.sh install_prereqs.sh 71 | RUN ./install_prereqs.sh 72 | RUN rm install_prereqs.sh 73 | 74 | # Add Jenkins user 75 | ARG JENKINS 76 | ARG JENKINS_UID 77 | ARG JENKINS_GID 78 | ADD ./add_jenkins_user.sh add_jenkins_user.sh 79 | RUN if [ -n "${JENKINS}" ]; then bash ./add_jenkins_user.sh ${JENKINS_UID} ${JENKINS_GID}; fi 80 | RUN rm add_jenkins_user.sh 81 | -------------------------------------------------------------------------------- /docker/jenkins/README.md: -------------------------------------------------------------------------------- 1 | # Docker image for Jenkins continuous integration test 2 | 3 | `build.sh` is called by 4 | https://github.com/pytorch/ossci-job-dsl/blob/master/src/jobs/translate_docker.groovy 5 | in order to build a Docker image with PyTorch Translate dependencies, for use 6 | by Jenkins continuous integration tests. 7 | 8 | `Dockerfile` contains the Docker image specifications, and calls 9 | `install_prereqs.sh` and `add_jenkins_user.sh` as a part of installing the 10 | dependencies. Note that the Docker image does NOT contain a copy of the 11 | PyTorch Translate repo, as that is done by 12 | https://github.com/pytorch/ossci-job-dsl/blob/master/src/jobs/translate.groovy 13 | (which in turn calls `../../.jenkins/build.sh` to actually build/install 14 | PyTorch Translate and run tests). 15 | 16 | ## Building a new Docker image 17 | 18 | Most changes to PyTorch Translate code will not require any change to the 19 | Docker image used by Jenkins. However, you may need to build a new Docker image 20 | if you: 21 | 1. Add a dependency to a new external package 22 | 2. Need to pull in an updated version of an external package (including 23 | dependencies such as PyTorch, Caffe2, ONNX) 24 | 25 | To do so, you need to push your changes to a branch of pytorch/translate. 26 | If your pull request is from a local fork of that repo, you may need to 27 | manually copy those changes to a new branch of pytorch/translate. 28 | The following links may be useful: 29 | * https://gist.github.com/IanVaughan/2887949 30 | * https://help.github.com/articles/adding-a-remote/ 31 | * https://stackoverflow.com/questions/4878249/how-to-change-the-remote-a-branch-is-tracking 32 | 33 | Afterwards, go to https://ci.pytorch.org/jenkins/job/translate-docker-trigger/build 34 | and build a new Docker image off that branch. Note the build # of your new 35 | Docker image. You can check whether a particular Jenkins build/test is using 36 | your Docker image by checking `DOCKER_IMAGE_TAG` under "Parameters" on the 37 | side bar. You can re-run a Jenkins build/test using an updated Docker image by 38 | clicking "Rebuild" on the side bar and updating `DOCKER_IMAGE_TAG` to the 39 | build # of your new Docker image. 40 | 41 | Assuming that tests at head all pass under the new Docker image, 42 | translate-docker-trigger should have automatically updated the version number in 43 | https://github.com/pytorch/ossci-job-dsl/blob/master/src/main/groovy/ossci/translate/DockerVersion.groovy. 44 | However, if the Docker image change is tied to another pull request, and they're 45 | both backward-incompatible, you may need to add some env-checking to the 46 | pull request to make it compatible with both the old and new Docker images; 47 | or temporarily break tests at head; or manually edit DockerVersion.groovy. 48 | 49 | It may be possible to automate this if someone is interested in trying to 50 | "fix the job dsl to automatically rebuild upon docker changes 🙂 51 | (but it is not that easy; you need to only rebuild docker if there is 52 | a change in the docker files; you don't want to keep rebuilding it for 53 | non-Docker changes)" per @ezyang. 54 | -------------------------------------------------------------------------------- /docker/jenkins/add_jenkins_user.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Taken from https://github.com/pytorch/pytorch/blob/master/docker/caffe2/jenkins/common/add_jenkins_user.sh 3 | 4 | set -ex 5 | 6 | # Mirror jenkins user in container 7 | echo "jenkins:x:$JENKINS_UID:$JENKINS_GID::/var/lib/jenkins:" >> /etc/passwd 8 | echo "jenkins:x:$JENKINS_GID:" >> /etc/group 9 | 10 | # Create $HOME 11 | mkdir -p /var/lib/jenkins 12 | chown jenkins:jenkins /var/lib/jenkins 13 | mkdir -p /var/lib/jenkins/.ccache 14 | chown jenkins:jenkins /var/lib/jenkins/.ccache 15 | 16 | # Allow writing to /usr/local (for make install) 17 | chown jenkins:jenkins /usr/local 18 | 19 | # Allow sudo 20 | echo 'jenkins ALL=(ALL) NOPASSWD:ALL' > /etc/sudoers.d/jenkins 21 | -------------------------------------------------------------------------------- /docker/jenkins/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Builds Docker image. 3 | 4 | # Set Jenkins UID and GID if running Jenkins 5 | if [ -n "${JENKINS:-}" ]; then 6 | JENKINS_UID=$(id -u jenkins) 7 | JENKINS_GID=$(id -g jenkins) 8 | fi 9 | 10 | # ${@:2} skips the first argument since we don't have multiple 11 | # Docker images, as opposed to like 12 | # https://github.com/pietern/pytorch-dockerfiles/blob/master/build.sh 13 | docker build \ 14 | --no-cache \ 15 | --build-arg "JENKINS=${JENKINS:-}" \ 16 | --build-arg "JENKINS_UID=${JENKINS_UID:-}" \ 17 | --build-arg "JENKINS_GID=${JENKINS_GID:-}" \ 18 | "${@:2}" \ 19 | . 20 | -------------------------------------------------------------------------------- /docker/jenkins/install_prereqs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | . ${HOME}/miniconda/bin/activate 6 | export LD_LIBRARY_PATH="${CONDA_PATH}/lib:${LD_LIBRARY_PATH}" 7 | 8 | # Toggles between CUDA 8 or 9. Needs to be kept in sync with Dockerfile 9 | TMP_CUDA_VERSION="9" 10 | 11 | # Uninstall previous versions of PyTorch. Doing this twice is intentional. 12 | # Error messages about torch not being installed are benign. 13 | pip uninstall -y torch || true 14 | pip uninstall -y torch || true 15 | 16 | # Install basic PyTorch dependencies. 17 | # Numpy should be > 1.12 to prevent torch tensor from treating single-element value as shape. 18 | conda install -y cffi cmake mkl mkl-include numpy=1.14 pyyaml setuptools typing tqdm 19 | # Add LAPACK support for the GPU. 20 | conda install -y -c pytorch "magma-cuda${TMP_CUDA_VERSION}0" 21 | 22 | # Caffe2 relies on the past module. 23 | yes | pip install future 24 | 25 | # statistical significance requires pandas 26 | yes | pip install pandas 27 | 28 | # Install NCCL2. 29 | wget "https://s3.amazonaws.com/pytorch/nccl_2.1.15-1%2Bcuda${TMP_CUDA_VERSION}.0_x86_64.txz" 30 | TMP_NCCL_VERSION="nccl_2.1.15-1+cuda${TMP_CUDA_VERSION}.0_x86_64" 31 | tar -xvf "${TMP_NCCL_VERSION}.txz" 32 | export NCCL_ROOT_DIR="$(pwd)/${TMP_NCCL_VERSION}" 33 | export LD_LIBRARY_PATH="${NCCL_ROOT_DIR}/lib:${LD_LIBRARY_PATH}" 34 | rm "${TMP_NCCL_VERSION}.txz" 35 | 36 | 37 | # Install the combined PyTorch nightly conda package. 38 | conda install pytorch-nightly cudatoolkit=${TMP_CUDA_VERSION}.0 -c pytorch 39 | 40 | echo "Starting to install ONNX" 41 | git clone --recursive https://github.com/onnx/onnx.git 42 | yes | pip install ./onnx 2>&1 | tee ONNX_OUT 43 | 44 | # train with tensorboard 45 | yes | pip install tensorboard_logger 46 | -------------------------------------------------------------------------------- /pytorch_translate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/translate/b89dc35abeb7fe516e3b95ccacdedfc1a92e5626/pytorch_translate/__init__.py -------------------------------------------------------------------------------- /pytorch_translate/attention/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import importlib 4 | import os 5 | 6 | from pytorch_translate.attention.base_attention import BaseAttention 7 | 8 | 9 | ATTENTION_REGISTRY = {} 10 | 11 | 12 | def build_attention(attention_type, decoder_hidden_state_dim, context_dim, **kwargs): 13 | return ATTENTION_REGISTRY[attention_type]( 14 | decoder_hidden_state_dim, context_dim, **kwargs 15 | ) 16 | 17 | 18 | def register_attention(name): 19 | """Decorator to register a new attention type.""" 20 | 21 | def register_attention_cls(cls): 22 | if name in ATTENTION_REGISTRY: 23 | raise ValueError("Cannot register duplicate attention ({})".format(name)) 24 | if not issubclass(cls, BaseAttention): 25 | raise ValueError( 26 | "Attention ({} : {}) must extend BaseAttention".format( 27 | name, cls.__name__ 28 | ) 29 | ) 30 | ATTENTION_REGISTRY[name] = cls 31 | return cls 32 | 33 | return register_attention_cls 34 | 35 | 36 | # automatically import any Python files in the attention/ directory 37 | for file in sorted(os.listdir(os.path.dirname(__file__))): 38 | if file.endswith(".py") and not file.startswith("_"): 39 | module = file[: file.find(".py")] 40 | importlib.import_module("pytorch_translate.attention.{}".format(module)) 41 | -------------------------------------------------------------------------------- /pytorch_translate/attention/attention_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Dict, Optional 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | 10 | 11 | def create_src_lengths_mask( 12 | batch_size: int, src_lengths: Tensor, max_src_len: Optional[int] = None 13 | ): 14 | """ 15 | Generate boolean mask to prevent attention beyond the end of source 16 | 17 | Inputs: 18 | batch_size : int 19 | src_lengths : [batch_size] of sentence lengths 20 | max_src_len: Optionally override max_src_len for the mask 21 | 22 | Outputs: 23 | [batch_size, max_src_len] 24 | """ 25 | if max_src_len is None: 26 | max_src_len = int(src_lengths.max()) 27 | src_indices = torch.arange(0, max_src_len).unsqueeze(0).type_as(src_lengths) 28 | src_indices = src_indices.expand(batch_size, max_src_len) 29 | src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_src_len) 30 | # returns [batch_size, max_seq_len] 31 | return (src_indices < src_lengths).int().detach() 32 | 33 | 34 | def masked_softmax(scores, src_lengths, src_length_masking=True): 35 | """Apply source length masking then softmax. 36 | Input and output have shape bsz x src_len""" 37 | if src_length_masking: 38 | bsz, max_src_len = scores.size() 39 | # compute masks 40 | src_mask = create_src_lengths_mask(bsz, src_lengths) 41 | # Fill pad positions with -inf 42 | scores = scores.masked_fill(src_mask == 0, -np.inf) 43 | 44 | # Cast to float and then back again to prevent loss explosion under fp16. 45 | return F.softmax(scores.float(), dim=-1).type_as(scores) 46 | -------------------------------------------------------------------------------- /pytorch_translate/attention/base_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class BaseAttention(nn.Module): 7 | def __init__(self, decoder_hidden_state_dim, context_dim): 8 | super().__init__() 9 | self.decoder_hidden_state_dim = decoder_hidden_state_dim 10 | self.context_dim = context_dim 11 | 12 | def forward(self, decoder_state, source_hids, src_lengths): 13 | """ 14 | Input 15 | decoder_state: bsz x decoder_hidden_state_dim 16 | source_hids: srclen x bsz x context_dim 17 | src_lengths: bsz x 1, actual sequence lengths 18 | Output 19 | output: bsz x context_dim 20 | attn_scores: max_src_len x bsz 21 | """ 22 | raise NotImplementedError 23 | -------------------------------------------------------------------------------- /pytorch_translate/attention/dot_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from pytorch_translate.attention import ( 7 | attention_utils, 8 | BaseAttention, 9 | register_attention, 10 | ) 11 | from pytorch_translate.common_layers import Linear 12 | 13 | 14 | @register_attention("dot") 15 | class DotAttention(BaseAttention): 16 | def __init__(self, decoder_hidden_state_dim, context_dim, **kwargs): 17 | super().__init__(decoder_hidden_state_dim, context_dim) 18 | 19 | self.input_proj = None 20 | force_projection = kwargs.get("force_projection", False) 21 | if force_projection or decoder_hidden_state_dim != context_dim: 22 | self.input_proj = Linear(decoder_hidden_state_dim, context_dim, bias=True) 23 | self.src_length_masking = kwargs.get("src_length_masking", True) 24 | 25 | def prepare_for_onnx_export_(self, **kwargs): 26 | self.src_length_masking = False 27 | 28 | def forward(self, decoder_state, source_hids, src_lengths): 29 | # Reshape to bsz x src_len x context_dim 30 | source_hids = source_hids.transpose(0, 1) 31 | # decoder_state: bsz x context_dim 32 | if self.input_proj is not None: 33 | decoder_state = self.input_proj(decoder_state) 34 | # compute attention (bsz x src_len x context_dim) * (bsz x context_dim x 1) 35 | attn_scores = torch.bmm(source_hids, decoder_state.unsqueeze(2)).squeeze(2) 36 | 37 | # Mask + softmax (bsz x src_len) 38 | normalized_masked_attn_scores = attention_utils.masked_softmax( 39 | attn_scores, src_lengths, self.src_length_masking 40 | ) 41 | 42 | # Sum weighted sources 43 | attn_weighted_context = ( 44 | (source_hids * normalized_masked_attn_scores.unsqueeze(2)) 45 | .contiguous() 46 | .sum(1) 47 | ) 48 | 49 | return attn_weighted_context, normalized_masked_attn_scores.t() 50 | -------------------------------------------------------------------------------- /pytorch_translate/attention/mlp_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from pytorch_translate.attention import ( 7 | attention_utils, 8 | BaseAttention, 9 | register_attention, 10 | ) 11 | from pytorch_translate.common_layers import Linear 12 | 13 | 14 | @register_attention("mlp") 15 | class MLPAttention(BaseAttention): 16 | """The original attention from Badhanau et al. (2014) 17 | https://arxiv.org/abs/1409.0473 based on a Multi-Layer Perceptron. 18 | 19 | The attention score between position i in the encoder and position j in the 20 | decoder is: 21 | alpha_ij = V_a * tanh(W_ae * enc_i + W_ad * dec_j + b_a) 22 | """ 23 | 24 | def __init__(self, decoder_hidden_state_dim, context_dim, **kwargs): 25 | super().__init__(decoder_hidden_state_dim, context_dim) 26 | 27 | self.context_dim = context_dim 28 | self.attention_dim = kwargs.get("attention_dim", context_dim) 29 | # W_ae and b_a 30 | self.encoder_proj = Linear(context_dim, self.attention_dim, bias=True) 31 | # W_ad 32 | self.decoder_proj = Linear( 33 | decoder_hidden_state_dim, self.attention_dim, bias=False 34 | ) 35 | # V_a 36 | self.to_scores = Linear(self.attention_dim, 1, bias=False) 37 | self.src_length_masking = kwargs.get("src_length_masking", True) 38 | 39 | def prepare_for_onnx_export_(self, **kwargs): 40 | self.src_length_masking = False 41 | 42 | def forward(self, decoder_state, source_hids, src_lengths): 43 | """The expected input dimensions are: 44 | 45 | decoder_state: bsz x decoder_hidden_state_dim 46 | source_hids: src_len x bsz x context_dim 47 | src_lengths: bsz 48 | """ 49 | src_len, bsz, _ = source_hids.size() 50 | # (src_len*bsz) x context_dim (to feed through linear) 51 | flat_source_hids = source_hids.view(-1, self.context_dim) 52 | # (src_len*bsz) x attention_dim 53 | encoder_component = self.encoder_proj(flat_source_hids) 54 | # src_len x bsz x attention_dim 55 | encoder_component = encoder_component.view(src_len, bsz, self.attention_dim) 56 | # 1 x bsz x attention_dim 57 | decoder_component = self.decoder_proj(decoder_state).unsqueeze(0) 58 | # Sum with broadcasting and apply the non linearity 59 | # src_len x bsz x attention_dim 60 | hidden_att = F.tanh( 61 | (decoder_component + encoder_component).view(-1, self.attention_dim) 62 | ) 63 | # Project onto the reals to get attentions scores (bsz x src_len) 64 | attn_scores = self.to_scores(hidden_att).view(src_len, bsz).t() 65 | 66 | # Mask + softmax (src_len x bsz) 67 | normalized_masked_attn_scores = attention_utils.masked_softmax( 68 | attn_scores, src_lengths, self.src_length_masking 69 | ).t() 70 | 71 | # Sum weighted sources (bsz x context_dim) 72 | attn_weighted_context = ( 73 | source_hids * normalized_masked_attn_scores.unsqueeze(2) 74 | ).sum(0) 75 | 76 | return attn_weighted_context, normalized_masked_attn_scores 77 | -------------------------------------------------------------------------------- /pytorch_translate/attention/multihead_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Optional 4 | 5 | from fairseq.modules import multihead_attention as fair_multihead 6 | from pytorch_translate.attention import ( 7 | attention_utils, 8 | BaseAttention, 9 | register_attention, 10 | ) 11 | from torch import Tensor 12 | 13 | 14 | @register_attention("multihead") 15 | class MultiheadAttention(BaseAttention): 16 | """ 17 | Multiheaded Scaled Dot Product Attention 18 | 19 | Implements equation: 20 | MultiHead(Q, K, V) = Concat(head_1,...,head_h)W^O 21 | where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V) 22 | 23 | Similarly to the above, d_k = d_v = d_model / h 24 | In this implementation, keys and values are both set to encoder output 25 | 26 | Inputs 27 | init: 28 | decoder_hidden_state_dim : dimensionality of decoder hidden state 29 | context_dim : dimensionality of encoder output 30 | kwargs : 31 | nheads : integer # of attention heads 32 | unseen_mask: if True, only attend to previous sequence positions 33 | src_lengths_mask: if True, mask padding based on src_lengths 34 | 35 | forward: 36 | decoder_state : [batch size, d_model] 37 | source_hids : [sequence length, batch size, d_model] 38 | src_lengths : [batch size] 39 | 40 | forward: 41 | query : [sequence length, batch size, d_model] 42 | key: [sequence length, batch size, d_model] 43 | value: [sequence length, batch size, d_model] 44 | 45 | Output 46 | result : [batch_size, d_model] 47 | """ 48 | 49 | def __init__( 50 | self, 51 | decoder_hidden_state_dim, 52 | context_dim, 53 | *, 54 | nheads=1, 55 | unseen_mask=False, 56 | src_length_mask=True, 57 | ): 58 | super().__init__(decoder_hidden_state_dim, context_dim) 59 | assert decoder_hidden_state_dim == context_dim 60 | d_model = decoder_hidden_state_dim # for brevity 61 | assert d_model % nheads == 0 62 | 63 | if unseen_mask: 64 | raise NotImplementedError( 65 | "Unseen mask not supported with sequential decoding" 66 | ) 67 | self._fair_attn = fair_multihead.MultiheadAttention(d_model, nheads) 68 | self.use_src_length_mask = src_length_mask 69 | 70 | def forward( 71 | self, 72 | decoder_state, 73 | source_hids, 74 | src_lengths, 75 | squeeze: bool = True, 76 | max_src_len: Optional[int] = None, 77 | ): 78 | """ 79 | Computes MultiheadAttention with respect to either a vector 80 | or a tensor 81 | 82 | Inputs: 83 | decoder_state: (bsz x decoder_hidden_state_dim) or 84 | (bsz x T x decoder_hidden_state_dim) 85 | source_hids: srclen x bsz x context_dim 86 | src_lengths: bsz x 1, actual sequence lengths 87 | squeeze: Whether or not to squeeze on the time dimension. 88 | Even if decoder_state.dim() is 2 dimensional an 89 | explicit time step dimension will be unsqueezed. 90 | max_src_len: Optionally override the max_src_len otherwise 91 | inferred from src_lengths. Useful during beam search when we 92 | might have already finalized the longest src_sequence 93 | Outputs: 94 | [batch_size, max_src_len] if decoder_state.dim() == 2 & squeeze 95 | or 96 | [batch_size, 1, max_src_len] if decoder_state.dim() == 2 & !squeeze 97 | or 98 | [batch_size, T, max_src_len] if decoder_state.dim() == 3 & !squeeze 99 | or 100 | [batch_size, T, max_src_len] if decoder_state.dim() == 3 & squeeze & T != 1 101 | or 102 | [batch_size, max_src_len] if decoder_state.dim() == 3 & squeeze & T == 1 103 | """ 104 | batch_size = decoder_state.shape[0] 105 | if decoder_state.dim() == 3: 106 | query = decoder_state 107 | elif decoder_state.dim() == 2: 108 | query = decoder_state.unsqueeze(1) 109 | else: 110 | raise ValueError("decoder state must be either 2 or 3 dimensional") 111 | query = query.transpose(0, 1) 112 | value = key = source_hids 113 | 114 | src_len_mask: Optional[Tensor] = None 115 | if src_lengths is not None and self.use_src_length_mask: 116 | # [batch_size, 1, seq_len] 117 | src_len_mask_int = attention_utils.create_src_lengths_mask( 118 | batch_size=batch_size, src_lengths=src_lengths, max_src_len=max_src_len 119 | ) 120 | src_len_mask = src_len_mask_int != 1 121 | attn, attn_weights = self._fair_attn.forward( 122 | query, key, value, key_padding_mask=src_len_mask, need_weights=True 123 | ) 124 | # attn.shape = T X bsz X embed_dim 125 | # attn_weights.shape = bsz X T X src_len 126 | if attn_weights is not None: 127 | attn_weights = attn_weights.transpose(0, 2) 128 | # attn_weights.shape = src_len X T X bsz 129 | 130 | if squeeze: 131 | attn = attn.squeeze(0) 132 | # attn.shape = squeeze(T) X bsz X embed_dim 133 | if attn_weights is not None: 134 | attn_weights = attn_weights.squeeze(1) 135 | # attn_weights.shape = src_len X squeeze(T) X bsz 136 | return attn, attn_weights 137 | return attn, attn_weights 138 | -------------------------------------------------------------------------------- /pytorch_translate/attention/no_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | from pytorch_translate.attention import BaseAttention, register_attention 5 | from pytorch_translate.utils import maybe_cuda 6 | 7 | 8 | @register_attention("no") 9 | class NoAttention(BaseAttention): 10 | def __init__(self, decoder_hidden_state_dim, context_dim, **kwargs): 11 | super().__init__(decoder_hidden_state_dim, 0) 12 | 13 | def forward(self, decoder_state, source_hids, src_lengths): 14 | return None, maybe_cuda(torch.zeros(1, src_lengths.shape[0])) 15 | -------------------------------------------------------------------------------- /pytorch_translate/attention/pooling_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | from pytorch_translate.attention import ( 5 | attention_utils, 6 | BaseAttention, 7 | register_attention, 8 | ) 9 | from torch.autograd import Variable 10 | 11 | 12 | @register_attention("pooling") 13 | class PoolingAttention(BaseAttention): 14 | def __init__(self, decoder_hidden_state_dim, context_dim, **kwargs): 15 | super().__init__(decoder_hidden_state_dim, context_dim) 16 | 17 | self.pool_type = kwargs.get("pool_type", "mean") 18 | 19 | def forward(self, decoder_state, source_hids, src_lengths): 20 | assert self.decoder_hidden_state_dim == self.context_dim 21 | max_src_len = source_hids.size()[0] 22 | assert max_src_len == src_lengths.data.max() 23 | batch_size = source_hids.size()[1] 24 | 25 | src_mask = ( 26 | attention_utils.create_src_lengths_mask(batch_size, src_lengths) 27 | .type_as(source_hids) 28 | .t() 29 | .unsqueeze(2) 30 | ) 31 | 32 | if self.pool_type == "mean": 33 | # need to make src_lengths a 3-D tensor to normalize masked_hiddens 34 | denom = src_lengths.view(1, batch_size, 1).type_as(source_hids) 35 | masked_hiddens = source_hids * src_mask 36 | context = (masked_hiddens / denom).sum(dim=0) 37 | elif self.pool_type == "max": 38 | masked_hiddens = source_hids - 10e6 * (1 - src_mask) 39 | context = masked_hiddens.max(dim=0)[0] 40 | else: 41 | raise ValueError(f"Pooling type {self.pool_type} is not supported.") 42 | attn_scores = Variable( 43 | torch.ones(src_mask.shape[1], src_mask.shape[0]).type_as(source_hids.data), 44 | requires_grad=False, 45 | ).t() 46 | 47 | return context, attn_scores 48 | 49 | 50 | @register_attention("max") 51 | class MaxPoolingAttention(PoolingAttention): 52 | def __init__(self, decoder_hidden_state_dim, context_dim, **kwargs): 53 | super().__init__(decoder_hidden_state_dim, context_dim, pool_type="max") 54 | 55 | 56 | @register_attention("mean") 57 | class MeanPoolingAttention(PoolingAttention): 58 | def __init__(self, decoder_hidden_state_dim, context_dim, **kwargs): 59 | super().__init__(decoder_hidden_state_dim, context_dim, pool_type="mean") 60 | -------------------------------------------------------------------------------- /pytorch_translate/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import random 5 | import tempfile 6 | 7 | from fairseq import options, tasks 8 | from pytorch_translate import ( # noqa; noqa 9 | generate as pytorch_translate_generate, 10 | options as pytorch_translate_options, 11 | rnn, 12 | utils as pytorch_translate_utils, 13 | ) 14 | from pytorch_translate.constants import CHECKPOINT_PATHS_DELIMITER 15 | 16 | 17 | def get_parser_with_args(): 18 | parser = options.get_parser("Generation", default_task="pytorch_translate") 19 | pytorch_translate_options.add_verbosity_args(parser) 20 | pytorch_translate_options.add_dataset_args(parser, gen=True) 21 | generation_group = options.add_generation_args(parser) 22 | pytorch_translate_options.expand_generation_args(generation_group) 23 | 24 | generation_group.add_argument( 25 | "--source-vocab-file", 26 | default="", 27 | metavar="FILE", 28 | help="Path to text file representing the Dictionary to use.", 29 | ) 30 | generation_group.add_argument( 31 | "--char-source-vocab-file", 32 | default="", 33 | metavar="FILE", 34 | help=( 35 | "Same as --source-vocab-file except using characters. " 36 | "(For use with char_source and char_aware models only.)" 37 | ), 38 | ) 39 | generation_group.add_argument( 40 | "--target-vocab-file", 41 | default="", 42 | metavar="FILE", 43 | help="Path to text file representing the Dictionary to use.", 44 | ) 45 | generation_group.add_argument( 46 | "--char-target-vocab-file", 47 | default="", 48 | metavar="FILE", 49 | help=( 50 | "Same as --source-target-file except using characters. " 51 | "(For use with char_aware models only.)" 52 | ), 53 | ) 54 | generation_group.add_argument( 55 | "--multiling-source-lang", 56 | action="append", 57 | metavar="SRC", 58 | help=( 59 | "Must be set for decoding with multilingual models. " 60 | "Must match an entry from --multiling-encoder-lang from training." 61 | ), 62 | ) 63 | generation_group.add_argument( 64 | "--multiling-target-lang", 65 | action="append", 66 | metavar="TARGET", 67 | help=( 68 | "Must be set for decoding with multilingual models. " 69 | "Must match an entry from --multiling-decoder-lang from training." 70 | ), 71 | ) 72 | 73 | # Add args related to benchmarking. 74 | group = parser.add_argument_group("Benchmarking") 75 | group.add_argument( 76 | "--runs-per-length", 77 | default=10, 78 | type=int, 79 | help="Number of times to run generation on each length.", 80 | ) 81 | group.add_argument( 82 | "--examples-per-length", 83 | default=1, 84 | type=int, 85 | help="Sentences of each length to include in each eval (batched if >1).", 86 | ) 87 | 88 | return parser 89 | 90 | 91 | def main(): 92 | parser = get_parser_with_args() 93 | # args = parser.parse_args() 94 | args = options.parse_args_and_arch(parser) 95 | # Disable printout of all source and target sentences 96 | args.quiet = True 97 | benchmark(args) 98 | 99 | 100 | def generate_synthetic_text(dialect, dialect_symbols, length, examples): 101 | temp_file = tempfile.NamedTemporaryFile(mode="w", delete=False, dir="/tmp") 102 | temp_file_name = temp_file.name 103 | temp_file.close() 104 | with open(temp_file_name, "w") as temp_file: 105 | for _ in range(examples): 106 | temp_file.write(" ".join(random.sample(dialect_symbols, length)) + "\n") 107 | return temp_file_name 108 | 109 | 110 | def benchmark(args): 111 | assert args.source_vocab_file and os.path.isfile( 112 | args.source_vocab_file 113 | ), "Please specify a valid file for --source-vocab-file" 114 | assert args.target_vocab_file and os.path.isfile( 115 | args.target_vocab_file 116 | ), "Please specify a valid file for --target-vocab_file" 117 | assert args.path is not None, "--path required for generation!" 118 | 119 | print(args) 120 | 121 | # Benchmarking should be language-agnostic 122 | args.source_lang = "src" 123 | args.target_lang = "tgt" 124 | 125 | ( 126 | models, 127 | model_args, 128 | task, 129 | ) = pytorch_translate_utils.load_diverse_ensemble_for_inference( 130 | args.path.split(CHECKPOINT_PATHS_DELIMITER) 131 | ) 132 | 133 | append_eos_to_source = model_args[0].append_eos_to_source 134 | reverse_source = model_args[0].reverse_source 135 | assert all( 136 | a.append_eos_to_source == append_eos_to_source 137 | and a.reverse_source == reverse_source 138 | for a in model_args 139 | ) 140 | 141 | def benchmark_length(n): 142 | # Generate synthetic raw text files 143 | source_text_file = generate_synthetic_text( 144 | dialect=args.source_lang, 145 | dialect_symbols=task.source_dictionary.symbols, 146 | length=n, 147 | examples=args.examples_per_length, 148 | ) 149 | target_text_file = generate_synthetic_text( 150 | dialect=args.target_lang, 151 | dialect_symbols=task.target_dictionary.symbols, 152 | length=n, 153 | examples=args.examples_per_length, 154 | ) 155 | 156 | task.load_dataset_from_text( 157 | args.gen_subset, 158 | source_text_file=source_text_file, 159 | target_text_file=target_text_file, 160 | append_eos=append_eos_to_source, 161 | reverse_source=reverse_source, 162 | ) 163 | 164 | # Remove temporary text files 165 | os.remove(source_text_file) 166 | os.remove(target_text_file) 167 | 168 | # priming 169 | scorer, num_sentences, gen_timer, _ = pytorch_translate_generate.generate_score( 170 | models=models, args=args, task=task, dataset=task.dataset(args.gen_subset) 171 | ) 172 | 173 | total_time = 0.0 174 | for _ in range(args.runs_per_length): 175 | ( 176 | scorer, 177 | num_sentences, 178 | gen_timer, 179 | _, 180 | ) = pytorch_translate_generate.generate_score( 181 | models=models, 182 | args=args, 183 | task=task, 184 | dataset=task.dataset(args.gen_subset), 185 | ) 186 | total_time += gen_timer.sum 187 | gen_timer.reset() 188 | 189 | sentences_per_run = args.examples_per_length 190 | runs = args.runs_per_length 191 | total_sentences = sentences_per_run * runs 192 | total_tokens = total_sentences * n 193 | 194 | print(f"--- {n} tokens ---") 195 | print(f"Generated {total_tokens} tokens ({runs} runs of {sentences_per_run})") 196 | print(f"Total time: {total_time:.3f} seconds") 197 | time_per_sentence = total_time / total_sentences 198 | print(f"Time per sentence: {time_per_sentence:.3f} seconds\n") 199 | 200 | benchmark_length(6) 201 | benchmark_length(10) 202 | benchmark_length(20) 203 | 204 | 205 | if __name__ == "__main__": 206 | main() 207 | -------------------------------------------------------------------------------- /pytorch_translate/bleu_significance.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | from typing import List, NamedTuple, Optional 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import sacrebleu 9 | 10 | 11 | def get_sufficient_stats( 12 | translations: List[str], references: List[str] 13 | ) -> pd.DataFrame: 14 | assert len(translations) == len(references), ( 15 | f"There are {len(translations)} translated sentences " 16 | f"but {len(references)} reference sentences" 17 | ) 18 | assert sacrebleu.metrics.bleu.BLEU.NGRAM_ORDER == 4, ( 19 | f"Expected SacreBLEU to be using n-gram order 4 " 20 | f"instead of {sacrebleu.metrics.bleu.BLEU.NGRAM_ORDER}." 21 | ) 22 | 23 | sufficient_stats: List[List[int]] = [] 24 | for sentence, ref in zip(translations, references): 25 | sentence_bleu = sacrebleu.corpus_bleu( 26 | sys_stream=sentence, 27 | ref_streams=ref, 28 | lowercase=False, 29 | tokenize="none", 30 | use_effective_order=False, 31 | ) 32 | sufficient_stats.append( 33 | [ 34 | # Number of correct 1-grams, .., 4-grams 35 | sentence_bleu.counts[0], 36 | sentence_bleu.counts[1], 37 | sentence_bleu.counts[2], 38 | sentence_bleu.counts[3], 39 | # Total number of 1-grams, .., 4-grams 40 | sentence_bleu.totals[0], 41 | sentence_bleu.totals[1], 42 | sentence_bleu.totals[2], 43 | sentence_bleu.totals[3], 44 | # Length of translated sentence. 45 | sentence_bleu.sys_len, 46 | # Length of reference sentence. 47 | sentence_bleu.ref_len, 48 | ] 49 | ) 50 | return pd.DataFrame( 51 | sufficient_stats, 52 | columns=[ 53 | "correct_1_grams", 54 | "correct_2_grams", 55 | "correct_3_grams", 56 | "correct_4_grams", 57 | "total_1_grams", 58 | "total_2_grams", 59 | "total_3_grams", 60 | "total_4_grams", 61 | "translation_length", 62 | "reference_length", 63 | ], 64 | ) 65 | 66 | 67 | def calc_bleu_from_stats(sentence_stats: pd.DataFrame) -> sacrebleu.BLEU: 68 | corpus_stats = sentence_stats.sum(axis=0) 69 | corpus_bleu = sacrebleu.compute_bleu( 70 | correct=[ 71 | corpus_stats.correct_1_grams, 72 | corpus_stats.correct_2_grams, 73 | corpus_stats.correct_3_grams, 74 | corpus_stats.correct_4_grams, 75 | ], 76 | total=[ 77 | corpus_stats.total_1_grams, 78 | corpus_stats.total_2_grams, 79 | corpus_stats.total_3_grams, 80 | corpus_stats.total_4_grams, 81 | ], 82 | sys_len=corpus_stats.translation_length, 83 | ref_len=corpus_stats.reference_length, 84 | ) 85 | return corpus_bleu 86 | 87 | 88 | class PairedBootstrapOutput(NamedTuple): 89 | baseline_bleu: sacrebleu.BLEU 90 | new_bleu: sacrebleu.BLEU 91 | num_samples: int 92 | # Number of samples where the baseline was better than the new. 93 | baseline_better: int 94 | # Number of samples where the baseline and new had identical BLEU score. 95 | num_equal: int 96 | # Number of samples where the new was better than baseline. 97 | new_better: int 98 | 99 | 100 | def paired_bootstrap_resample( 101 | baseline_stats: pd.DataFrame, 102 | new_stats: pd.DataFrame, 103 | num_samples: int = 1000, 104 | sample_size: Optional[int] = None, 105 | ) -> PairedBootstrapOutput: 106 | """ 107 | From http://aclweb.org/anthology/W04-3250 108 | Statistical significance tests for machine translation evaluation (Koehn, 2004) 109 | """ 110 | assert len(baseline_stats) == len(new_stats), ( 111 | f"Length mismatch - baseline has {len(baseline_stats)} lines " 112 | f"while new has {len(new_stats)} lines." 113 | ) 114 | num_sentences = len(baseline_stats) 115 | if not sample_size: 116 | # Defaults to sampling new corpora of the same size as the original. 117 | # This is not identical to the original corpus since we are sampling 118 | # with replacement. 119 | sample_size = num_sentences 120 | indices = np.random.randint( 121 | low=0, high=num_sentences, size=(num_samples, sample_size) 122 | ) 123 | 124 | baseline_better: int = 0 125 | new_better: int = 0 126 | num_equal: int = 0 127 | for index in indices: 128 | baseline_bleu = calc_bleu_from_stats(baseline_stats.iloc[index]).score 129 | new_bleu = calc_bleu_from_stats(new_stats.iloc[index]).score 130 | if new_bleu > baseline_bleu: 131 | new_better += 1 132 | elif baseline_bleu > new_bleu: 133 | baseline_better += 1 134 | else: 135 | # If the baseline corpus and new corpus are identical, this 136 | # degenerate case may occur. 137 | num_equal += 1 138 | 139 | return PairedBootstrapOutput( 140 | baseline_bleu=calc_bleu_from_stats(baseline_stats), 141 | new_bleu=calc_bleu_from_stats(new_stats), 142 | num_samples=num_samples, 143 | baseline_better=baseline_better, 144 | num_equal=num_equal, 145 | new_better=new_better, 146 | ) 147 | 148 | 149 | def paired_bootstrap_resample_from_files( 150 | reference_file: str, 151 | baseline_file: str, 152 | new_file: str, 153 | num_samples: int = 1000, 154 | sample_size: Optional[int] = None, 155 | ) -> PairedBootstrapOutput: 156 | with open(reference_file, "r") as f: 157 | references: List[str] = [line for line in f] 158 | 159 | with open(baseline_file, "r") as f: 160 | baseline_translations: List[str] = [line for line in f] 161 | baseline_stats: pd.DataFrame = get_sufficient_stats( 162 | translations=baseline_translations, references=references 163 | ) 164 | 165 | with open(new_file, "r") as f: 166 | new_translations: List[str] = [line for line in f] 167 | new_stats: pd.DataFrame = get_sufficient_stats( 168 | translations=new_translations, references=references 169 | ) 170 | 171 | return paired_bootstrap_resample( 172 | baseline_stats=baseline_stats, 173 | new_stats=new_stats, 174 | num_samples=num_samples, 175 | sample_size=sample_size, 176 | ) 177 | 178 | 179 | def main(): 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument( 182 | "--reference-file", 183 | type=str, 184 | required=True, 185 | help="Text file containing reference tokenized (with whitespace separator) sentences.", 186 | ) 187 | parser.add_argument( 188 | "--baseline-file", 189 | type=str, 190 | required=True, 191 | help="Text file containing tokenized sentences translated by baseline system.", 192 | ) 193 | parser.add_argument( 194 | "--new-file", 195 | type=str, 196 | required=True, 197 | help="Text file containing tokenized sentences translated by new system.", 198 | ) 199 | args = parser.parse_args() 200 | 201 | output = paired_bootstrap_resample_from_files( 202 | reference_file=args.reference_file, 203 | baseline_file=args.baseline_file, 204 | new_file=args.new_file, 205 | ) 206 | 207 | print(f"Baseline BLEU: {output.baseline_bleu.score:.2f}") 208 | print(f"New BLEU: {output.new_bleu.score:.2f}") 209 | print(f"BLEU delta: {output.new_bleu.score - output.baseline_bleu.score:.2f} ") 210 | print( 211 | f"Baseline better confidence: {output.baseline_better / output.num_samples:.2%}" 212 | ) 213 | print(f"New better confidence: {output.new_better / output.num_samples:.2%}") 214 | 215 | 216 | if __name__ == "__main__": 217 | main() 218 | -------------------------------------------------------------------------------- /pytorch_translate/constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | AVERAGED_CHECKPOINT_BEST_FILENAME = "averaged_checkpoint_best.pt" 4 | LAST_CHECKPOINT_FILENAME = "checkpoint_last.pt" 5 | 6 | MONOLINGUAL_DATA_IDENTIFIER = "mono" 7 | 8 | SEMI_SUPERVISED_TASK = "pytorch_translate_semi_supervised" 9 | KNOWLEDGE_DISTILLATION_TASK = "pytorch_translate_knowledge_distillation" 10 | DENOISING_AUTOENCODER_TASK = "pytorch_translate_denoising_autoencoder" 11 | MULTILINGUAL_TRANSLATION_TASK = "pytorch_translate_multilingual_task" 12 | LATENT_VARIABLE_TASK = "translation_vae" 13 | 14 | ARCHS_FOR_CHAR_SOURCE = { 15 | "char_source", 16 | "char_source_hybrid", 17 | "char_source_transformer", 18 | "char_aware_hybrid", 19 | } 20 | ARCHS_FOR_CHAR_TARGET = {"char_aware_hybrid"} 21 | CHECKPOINT_PATHS_DELIMITER = "|" 22 | -------------------------------------------------------------------------------- /pytorch_translate/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/translate/b89dc35abeb7fe516e3b95ccacdedfc1a92e5626/pytorch_translate/data/__init__.py -------------------------------------------------------------------------------- /pytorch_translate/data/dictionary.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import re 5 | from typing import Dict, List, Optional, Set 6 | 7 | from fairseq.data import dictionary 8 | 9 | # TODO(T55884145): Replace with 10 | # from fvcore.common.file_io import PathManager 11 | from fairseq.file_io import PathManager 12 | from pytorch_translate import vocab_constants 13 | 14 | 15 | TAGS = [ 16 | "@DIGITS", 17 | "@EMOTICON", 18 | "@FBENTITY", 19 | "@MULTIPUNCT", 20 | "@NOTRANSLATE", 21 | "@PERSON", 22 | "@PLAIN", 23 | "@URL", 24 | "@USERNAME", 25 | ] 26 | 27 | SPACE_NORMALIZER = re.compile(r"\s+") 28 | 29 | 30 | def default_dictionary_path(save_dir: str, dialect: str) -> str: 31 | return os.path.join(save_dir, f"dict.{dialect}.txt") 32 | 33 | 34 | def default_char_dictionary_path(save_dir: str, dialect: str) -> str: 35 | return os.path.join(save_dir, f"char-dict.{dialect}.txt") 36 | 37 | 38 | def tokenize_line(line, embed_bytes=False): 39 | line = SPACE_NORMALIZER.sub(" ", line) 40 | line = line.strip() 41 | return line.split() 42 | 43 | 44 | def char_tokenize_line(line): 45 | words = tokenize_line(line) 46 | chars = [] 47 | for word in words: 48 | if word in TAGS: 49 | chars.append(word) 50 | else: 51 | chars.extend(c for c in word) 52 | return chars 53 | 54 | 55 | def add_file_to_dictionary(filename, dict, tokenize): 56 | with PathManager.open(filename, "r", encoding="utf-8") as f: 57 | for line in f: 58 | for word in tokenize(line): 59 | dict.add_symbol(word) 60 | dict.add_symbol(dict.eos_word) 61 | 62 | 63 | class Dictionary(dictionary.Dictionary): 64 | """A mapping from symbols to consecutive integers""" 65 | 66 | def __init__( 67 | self, 68 | pad: str = "", 69 | eos: str = "", 70 | unk: str = "", 71 | bos: str = "", 72 | max_special_tokens: int = vocab_constants.MAX_SPECIAL_TOKENS, 73 | ) -> None: 74 | self.unk_word, self.pad_word, self.eos_word = unk, pad, eos 75 | self.symbols: List[str] = [] 76 | self.count: List[int] = [] 77 | self.indices: Dict[str, int] = {} 78 | self.lexicon_indices: Set[int] = set() 79 | 80 | self.pad_index = self.add_symbol(pad) 81 | assert self.pad_index == vocab_constants.PAD_ID 82 | 83 | # Adds a junk symbol for vocab_constants' GO_ID 84 | self.add_symbol("") 85 | 86 | self.eos_index = self.add_symbol(eos) 87 | assert self.eos_index == vocab_constants.EOS_ID 88 | 89 | self.unk_index = self.add_symbol(unk) 90 | assert self.unk_index == vocab_constants.UNK_ID 91 | self.bos_index = self.add_symbol(bos) 92 | 93 | # Adds junk symbols to pad up to the number of special tokens. 94 | num_reserved = max_special_tokens - len(self.symbols) 95 | for i in range(num_reserved): 96 | self.add_symbol(f"") 97 | 98 | self.nspecial = len(self.symbols) 99 | assert self.nspecial == max_special_tokens 100 | 101 | def lexicon_indices_list(self) -> List[int]: 102 | return list(self.lexicon_indices) 103 | 104 | @classmethod 105 | def build_vocab_file( 106 | cls, 107 | corpus_files: List[str], 108 | vocab_file: str, 109 | max_vocab_size: int, 110 | tokens_with_penalty: Optional[str] = None, 111 | is_char_vocab: bool = False, 112 | embed_bytes: bool = False, 113 | padding_factor: int = 8, 114 | ) -> "Dictionary": # https://www.python.org/dev/peps/pep-0484/#forward-references 115 | d = cls() 116 | 117 | tokenize = char_tokenize_line if is_char_vocab else tokenize_line 118 | embed_bytes = embed_bytes and is_char_vocab 119 | 120 | # if we are embedding byte ids then no need to add these to the dict 121 | # the ids an be obtained directly from the character 122 | if not embed_bytes: 123 | for corpus_file in corpus_files: 124 | add_file_to_dictionary(filename=corpus_file, dict=d, tokenize=tokenize) 125 | 126 | # Set indices to receive penalty 127 | if tokens_with_penalty: 128 | # Assume input tokens are unique 129 | lexicon = [] 130 | with PathManager.open(tokens_with_penalty, "r", encoding="utf-8") as f: 131 | for line in f: 132 | tokens = line.strip().split() 133 | if len(tokens) == 1: 134 | lexicon.append(tokens[0]) 135 | 136 | for token, token_index in d.indices.items(): 137 | if token in lexicon: 138 | d.lexicon_indices.add(token_index) 139 | 140 | nwords = -1 if max_vocab_size <= 0 else max_vocab_size + d.nspecial 141 | d.finalize(nwords=nwords, padding_factor=padding_factor) 142 | d.save(vocab_file) 143 | print(f"Generated new vocab file saved at {vocab_file}.") 144 | if max_vocab_size < 0: 145 | print("No maximum vocab sized enforced.") 146 | else: 147 | print(f"Maximum vocab size {max_vocab_size}") 148 | return d 149 | 150 | @classmethod 151 | def build_vocab_file_if_nonexistent( 152 | cls, 153 | corpus_files: List[str], 154 | vocab_file: str, 155 | max_vocab_size: int, 156 | tokens_with_penalty: Optional[str] = None, 157 | is_char_vocab: bool = False, 158 | embed_bytes: bool = False, 159 | padding_factor: int = 8, 160 | ) -> "Dictionary": # https://www.python.org/dev/peps/pep-0484/#forward-references 161 | if PathManager.isfile(vocab_file): 162 | d = cls.load(vocab_file) 163 | print( 164 | f"Re-using existing vocab file {vocab_file}. Specified " 165 | f"max vocab size of {max_vocab_size} may not be enforced." 166 | ) 167 | return d 168 | 169 | print( 170 | f"Vocab file {vocab_file} does not exist. " 171 | "Creating new vocab file at that path." 172 | ) 173 | return cls.build_vocab_file( 174 | corpus_files=corpus_files, 175 | vocab_file=vocab_file, 176 | max_vocab_size=max_vocab_size, 177 | tokens_with_penalty=tokens_with_penalty, 178 | is_char_vocab=is_char_vocab, 179 | embed_bytes=embed_bytes, 180 | padding_factor=padding_factor, 181 | ) 182 | 183 | 184 | class CharDictionary(Dictionary): 185 | """Character vocab with its additonal special tokens.""" 186 | 187 | def __init__(self, word_delim="", **kwargs): 188 | super().__init__(**kwargs) 189 | self.word_delim = word_delim 190 | self.bow_index = self.add_symbol("") 191 | self.eow_index = self.add_symbol("") 192 | self.word_delim_index = self.add_symbol(word_delim) 193 | self.nspecial += 3 194 | 195 | 196 | class MaxVocabDictionary(Dictionary): 197 | """This dictionary takes the form of the largest dictionary supplied via push().""" 198 | 199 | def push(self, d: Dictionary): 200 | if len(d) > len(self): 201 | self.copy_from(d) 202 | 203 | def copy_from(self, d: dictionary.Dictionary): 204 | """Makes self a shallow copy of d.""" 205 | self.unk_word = d.unk_word 206 | self.pad_word = d.pad_word 207 | self.eos_word = d.eos_word 208 | self.symbols = d.symbols 209 | self.count = d.count 210 | self.indices = d.indices 211 | self.pad_index = d.pad_index 212 | self.eos_index = d.eos_index 213 | self.unk_index = d.unk_index 214 | self.nspecial = d.nspecial 215 | self.lexicon_indices = d.lexicon_indices 216 | -------------------------------------------------------------------------------- /pytorch_translate/data/iterators.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from fairseq.data import iterators, RoundRobinZipDatasets 4 | from pytorch_translate.data import weighted_data 5 | 6 | 7 | class WeightedEpochBatchIterator(iterators.EpochBatchIterator): 8 | def __init__( 9 | self, 10 | dataset, 11 | collate_fn, 12 | batch_sampler, 13 | seed=1, 14 | num_shards=1, 15 | shard_id=0, 16 | num_workers=0, 17 | weights=None, 18 | ): 19 | """ 20 | Extension of fairseq.iterators.EpochBatchIterator to use an additional 21 | weights structure. This weighs datasets as a function of epoch value. 22 | 23 | Args: 24 | dataset (~torch.utils.data.Dataset): dataset from which to load the data 25 | collate_fn (callable): merges a list of samples to form a mini-batch 26 | batch_sampler (~torch.utils.data.Sampler): an iterator over batches of 27 | indices 28 | seed (int, optional): seed for random number generator for 29 | reproducibility (default: 1). 30 | num_shards (int, optional): shard the data iterator into N 31 | shards (default: 1). 32 | shard_id (int, optional): which shard of the data iterator to 33 | return (default: 0). 34 | num_workers (int, optional): how many subprocesses to use for data 35 | loading. 0 means the data will be loaded in the main process 36 | (default: 0). 37 | weights: is of the format [(epoch, {dataset: weight})] 38 | """ 39 | super().__init__( 40 | dataset=dataset, 41 | collate_fn=collate_fn, 42 | batch_sampler=batch_sampler, 43 | seed=seed, 44 | num_shards=num_shards, 45 | shard_id=shard_id, 46 | num_workers=num_workers, 47 | ) 48 | self.weights = weights 49 | 50 | def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): 51 | """Return a new iterator over the dataset. 52 | 53 | Args: 54 | shuffle (bool, optional): shuffle batches before returning the 55 | iterator. Default: ``True`` 56 | fix_batches_to_gpus: ensure that batches are always 57 | allocated to the same shards across epochs. Requires 58 | that :attr:`dataset` supports prefetching. Default: 59 | ``False`` 60 | """ 61 | if self.weights and isinstance(self.dataset, RoundRobinZipDatasets): 62 | """ 63 | Set dataset weight based on schedule and current epoch 64 | """ 65 | prev_scheduled_epochs = 0 66 | dataset_weights_map = None 67 | for schedule in self.weights: 68 | # schedule looks like (num_epochs, {dataset: weight}) 69 | if self.epoch <= schedule[0] + prev_scheduled_epochs: 70 | dataset_weights_map = schedule[1] 71 | break 72 | prev_scheduled_epochs += schedule[0] 73 | # Use last weights map if weights map is not specified for the current epoch 74 | if dataset_weights_map is None: 75 | dataset_weights_map = self.weights[-1][1] 76 | for dataset_name in self.dataset.datasets: 77 | if dataset_name in dataset_weights_map: 78 | assert isinstance( 79 | self.dataset.datasets[dataset_name], 80 | weighted_data.WeightedLanguagePairDataset, 81 | ) or isinstance( 82 | self.dataset.datasets[dataset_name], 83 | weighted_data.WeightedBacktranslationDataset, 84 | ) 85 | self.dataset.datasets[dataset_name].weights = [ 86 | dataset_weights_map[dataset_name] 87 | ] 88 | if self._next_epoch_itr is not None: 89 | self._cur_epoch_itr = self._next_epoch_itr 90 | self._next_epoch_itr = None 91 | else: 92 | self.epoch += 1 93 | self._cur_epoch_itr = self._get_iterator_for_epoch( 94 | self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus 95 | ) 96 | return self._cur_epoch_itr 97 | -------------------------------------------------------------------------------- /pytorch_translate/data/language_pair_upsampling_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | from fairseq.data.concat_dataset import ConcatDataset 5 | 6 | 7 | class LanguagePairUpsamplingDataset(ConcatDataset): 8 | def __init__(self, datasets, sample_ratios=1): 9 | super(LanguagePairUpsamplingDataset, self).__init__(datasets, sample_ratios) 10 | if isinstance(sample_ratios, float): 11 | self.memoized_sizes = [self.size(idx) for idx in range(len(self))] 12 | else: 13 | self.memoized_sizes = np.concatenate( 14 | [ 15 | np.tile(ds.src_sizes, sr) 16 | for ds, sr in zip(self.datasets, self.sample_ratios) 17 | ] 18 | ) 19 | 20 | @property 21 | def sizes(self): 22 | return self.memoized_sizes 23 | -------------------------------------------------------------------------------- /pytorch_translate/data/masked_lm_dictionary.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | 10 | from typing import Dict, List, Set 11 | 12 | from pytorch_translate import vocab_constants 13 | from pytorch_translate.data.dictionary import Dictionary 14 | 15 | 16 | class MaskedLMDictionary(Dictionary): 17 | """ 18 | Dictionary for Masked Language Modelling tasks. This extends Dictionary by 19 | adding the mask symbol. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | pad="", 25 | eos="", 26 | unk="", 27 | mask="", 28 | max_special_tokens: int = vocab_constants.MAX_SPECIAL_TOKENS, 29 | ): 30 | self.symbols: List[str] = [] 31 | self.count: List[int] = [] 32 | self.indices: Dict[str, int] = {} 33 | self.lexicon_indices: Set[int] = set() 34 | 35 | self.pad_word = pad 36 | self.pad_index = self.add_symbol(pad) 37 | assert self.pad_index == vocab_constants.PAD_ID 38 | 39 | # Adds a junk symbol for vocab_constants' GO_ID 40 | self.add_symbol("") 41 | 42 | self.eos_word = eos 43 | self.eos_index = self.add_symbol(eos) 44 | assert self.eos_index == vocab_constants.EOS_ID 45 | 46 | self.unk_word = unk 47 | self.unk_index = self.add_symbol(unk) 48 | assert self.unk_index == vocab_constants.UNK_ID 49 | 50 | self.mask_word = mask 51 | self.mask_index = self.add_symbol(mask) 52 | assert self.mask_index == vocab_constants.MASK_ID 53 | 54 | # Adds junk symbols to pad up to the number of special tokens. 55 | num_reserved = max_special_tokens - len(self.symbols) 56 | for i in range(num_reserved): 57 | self.add_symbol(f"") 58 | 59 | self.nspecial = len(self.symbols) 60 | assert self.nspecial == max_special_tokens 61 | 62 | def mask(self): 63 | """Helper to get index of mask symbol""" 64 | return self.mask_index 65 | -------------------------------------------------------------------------------- /pytorch_translate/data/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Optional 4 | 5 | from fvcore.common.file_io import PathManager 6 | from pytorch_translate.data import ( 7 | char_data, 8 | data as pytorch_translate_data, 9 | weighted_data, 10 | ) 11 | 12 | 13 | def load_parallel_dataset( 14 | source_lang, 15 | target_lang, 16 | src_bin_path, 17 | tgt_bin_path, 18 | source_dictionary, 19 | target_dictionary, 20 | split, 21 | remove_eos_from_source, 22 | append_eos_to_target=True, 23 | char_source_dict=None, 24 | log_verbose=True, 25 | ): 26 | corpus = pytorch_translate_data.ParallelCorpusConfig( 27 | source=pytorch_translate_data.CorpusConfig( 28 | dialect=source_lang, data_file=src_bin_path 29 | ), 30 | target=pytorch_translate_data.CorpusConfig( 31 | dialect=target_lang, data_file=tgt_bin_path 32 | ), 33 | weights_file=None, 34 | ) 35 | 36 | if log_verbose: 37 | print("Starting to load binarized data files.", flush=True) 38 | validate_corpus_exists(corpus=corpus, split=split) 39 | 40 | tgt_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file( 41 | corpus.target.data_file 42 | ) 43 | if char_source_dict is not None: 44 | src_dataset = char_data.InMemoryNumpyWordCharDataset.create_from_file( 45 | corpus.source.data_file 46 | ) 47 | else: 48 | src_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file( 49 | corpus.source.data_file 50 | ) 51 | parallel_dataset = weighted_data.WeightedLanguagePairDataset( 52 | src=src_dataset, 53 | src_sizes=src_dataset.sizes, 54 | src_dict=source_dictionary, 55 | tgt=tgt_dataset, 56 | tgt_sizes=tgt_dataset.sizes, 57 | tgt_dict=target_dictionary, 58 | remove_eos_from_source=remove_eos_from_source, 59 | append_eos_to_target=append_eos_to_target, 60 | ) 61 | return parallel_dataset, src_dataset, tgt_dataset 62 | 63 | 64 | def load_monolingual_dataset( 65 | bin_path, 66 | is_source=False, 67 | char_source_dict=None, 68 | log_verbose=True, 69 | num_examples_limit: Optional[int] = None, 70 | ): 71 | if log_verbose: 72 | print("Starting to load binarized monolingual data file.", flush=True) 73 | 74 | if not PathManager.exists(bin_path): 75 | raise ValueError(f"Monolingual binary path {bin_path} not found!") 76 | 77 | if char_source_dict is not None and is_source: 78 | dataset = char_data.InMemoryNumpyWordCharDataset.create_from_file(path=bin_path) 79 | 80 | else: 81 | dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file( 82 | path=bin_path, num_examples_limit=num_examples_limit 83 | ) 84 | 85 | if log_verbose: 86 | print(f"Finished loading dataset {bin_path}", flush=True) 87 | 88 | print( 89 | f"""| Loaded {len(dataset)} monolingual examples for """ 90 | f"""{"source" if is_source else "target"}""" 91 | ) 92 | return dataset 93 | 94 | 95 | def validate_fairseq_dataset_exists(prefix): 96 | if not PathManager.exists(f"{prefix}.idx"): 97 | raise ValueError(f"{prefix}.idx not found!") 98 | if not PathManager.exists(f"{prefix}.bin"): 99 | raise ValueError(f"{prefix}.bin not found!") 100 | 101 | 102 | def validate_corpus_exists( 103 | corpus: pytorch_translate_data.ParallelCorpusConfig, split: str, is_npz: bool = True 104 | ): 105 | """ 106 | Makes sure that the files in the `corpus` are valid files. `split` is used 107 | for logging. 108 | """ 109 | if is_npz: 110 | if not PathManager.exists(corpus.source.data_file): 111 | raise ValueError(f"{corpus.source.data_file} for {split} not found!") 112 | if not PathManager.exists(corpus.target.data_file): 113 | raise ValueError(f"{corpus.target.data_file} for {split} not found!") 114 | else: 115 | validate_fairseq_dataset_exists(corpus.source.data_file) 116 | validate_fairseq_dataset_exists(corpus.target.data_file) 117 | -------------------------------------------------------------------------------- /pytorch_translate/data/weighted_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | from fairseq import data 5 | 6 | 7 | class IndexedWeightsDataset(data.indexed_dataset.IndexedDataset): 8 | def __init__(self, path): 9 | self.values = [] 10 | self.read_data(path) 11 | 12 | def read_data(self, path): 13 | with open(path, "r") as f: 14 | for line in f: 15 | self.values.append(float(line.strip("\n"))) 16 | self._len = len(self.values) 17 | 18 | def __getitem__(self, i): 19 | self.check_index(i) 20 | return self.values[i] 21 | 22 | def __del__(self): 23 | pass 24 | 25 | def __len__(self): 26 | return self._len 27 | 28 | 29 | class WeightedLanguagePairDataset(data.language_pair_dataset.LanguagePairDataset): 30 | """ 31 | Extension of fairseq.data.LanguagePairDataset where each example 32 | has a weight in [0.0, 1.0], which will be used to weigh the loss. 33 | 34 | TODO: Refactor this class to look like WeightedBacktranslationDataset. 35 | We could wrap an existing dataset object and provide additional weights 36 | feature. This way, it will be more composable and can be used with arbitrary 37 | datasets. See D13143051. 38 | 39 | Args: 40 | weights (list): list of per example weight values; each example 41 | has a weight in [0.0, 1.0]. Alternatively, when weights consists of a 42 | single value, that value is broadcast as weight to all examples. [0.0] 43 | gives 0 weight to all examples. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | src, 49 | src_sizes, 50 | src_dict, 51 | tgt=None, 52 | tgt_sizes=None, 53 | tgt_dict=None, 54 | weights=None, 55 | **kwargs, 56 | ): 57 | super().__init__(src, src_sizes, src_dict, tgt, tgt_sizes, tgt_dict, **kwargs) 58 | self.weights = weights 59 | self.src_dict = src_dict 60 | 61 | def __getitem__(self, i): 62 | example = super().__getitem__(i) 63 | if self.weights: 64 | """ 65 | If weight for example is missing, use last seen weight. Sometimes we just 66 | want to assign a weight to the entire dataset with a single value but also 67 | maintain the list convention of weights. This way, even if we don't care/know 68 | about dataset size, we can assign same weight to all examples. 69 | """ 70 | if len(self.weights) <= i: 71 | example["weight"] = self.weights[-1] 72 | else: 73 | example["weight"] = self.weights[i] 74 | else: 75 | example["weight"] = 1.0 76 | 77 | return example 78 | 79 | def __len__(self): 80 | return super().__len__() 81 | 82 | def collater(self, samples): 83 | return WeightedLanguagePairDataset.collate( 84 | samples, self.src_dict.pad(), self.src_dict.eos() 85 | ) 86 | 87 | @staticmethod 88 | def collate(samples, pad_idx, eos_idx, left_pad_source=False): 89 | if len(samples) == 0: 90 | return {} 91 | unweighted_data = data.language_pair_dataset.collate( 92 | samples, pad_idx, eos_idx, left_pad_source 93 | ) 94 | original_weights = torch.FloatTensor([s.get("weight", 1.0) for s in samples]) 95 | # sort by descending source length 96 | src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) 97 | src_lengths, sort_order = src_lengths.sort(descending=True) 98 | weights = original_weights.index_select(0, sort_order) 99 | unweighted_data["weights"] = weights 100 | return unweighted_data 101 | 102 | 103 | class WeightedBacktranslationDataset( 104 | data.backtranslation_dataset.BacktranslationDataset 105 | ): 106 | """ 107 | Extension of fairseq.data.BacktranslationDataset where each example 108 | has a weight in [0.0, 1.0], which will be used to weigh the loss. 109 | 110 | Args: 111 | weights (list): list of per example weight values; each example 112 | has a weight in [0.0, 1.0]. Alternatively, when weights consists of a 113 | single value, that value is broadcast as weight to all examples. [0.0] 114 | gives 0 weight to all examples. 115 | """ 116 | 117 | def __init__(self, dataset, weights=None, **kwargs): 118 | self.weights = weights 119 | self.dataset = dataset 120 | 121 | def __getattr__(self, attr): 122 | if attr in self.__dict__: 123 | return getattr(self, attr) 124 | return getattr(self.dataset, attr) 125 | 126 | def __getitem__(self, i): 127 | example = self.dataset.__getitem__(i) 128 | if self.weights: 129 | """ 130 | If weight for example is missing, use last seen weight. Sometimes we just 131 | want to assign a weight to the entire dataset with a single value but also 132 | maintain the list convention of weights. This way, even if we don't care or 133 | don't know about dataset size, we can assign same weight to all examples. 134 | """ 135 | if len(self.weights) <= i: 136 | example["weight"] = self.weights[-1] 137 | else: 138 | example["weight"] = self.weights[i] 139 | else: 140 | example["weight"] = 1.0 141 | 142 | return example 143 | 144 | def collater(self, samples): 145 | if len(samples) == 0: 146 | return {} 147 | unweighted_data = self.dataset.collater(samples) 148 | original_weights = torch.FloatTensor([s.get("weight", 1.0) for s in samples]) 149 | # sort by descending source length 150 | src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) 151 | src_lengths, sort_order = src_lengths.sort(descending=True) 152 | weights = original_weights.index_select(0, sort_order) 153 | unweighted_data["weights"] = weights 154 | return unweighted_data 155 | -------------------------------------------------------------------------------- /pytorch_translate/dual_learning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/translate/b89dc35abeb7fe516e3b95ccacdedfc1a92e5626/pytorch_translate/dual_learning/__init__.py -------------------------------------------------------------------------------- /pytorch_translate/examples/generate_iwslt14.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NCCL_ROOT_DIR="$(pwd)/nccl_2.1.15-1+cuda8.0_x86_64" 4 | export NCCL_ROOT_DIR 5 | LD_LIBRARY_PATH="${NCCL_ROOT_DIR}/lib:${LD_LIBRARY_PATH}" 6 | export LD_LIBRARY_PATH 7 | wget https://download.pytorch.org/models/translate/iwslt14/model.tar.gz https://download.pytorch.org/models/translate/iwslt14/data.tar.gz 8 | tar -xvzf model.tar.gz 9 | tar -xvzf data.tar.gz 10 | rm -f data.tar.gz model.tar.gz 11 | 12 | python3 pytorch_translate/generate.py \ 13 | "" \ 14 | --path model/averaged_checkpoint_best_0.pt \ 15 | --source-vocab-file model/dictionary-de.txt \ 16 | --target-vocab-file model/dictionary-en.txt \ 17 | --source-text-file data/test.tok.bpe.de \ 18 | --target-text-file data/test.tok.bpe.en \ 19 | --unk-reward -0.5 \ 20 | --length-penalty 0 \ 21 | --word-reward 0.25 \ 22 | --beam 6 \ 23 | --remove-bpe \ 24 | --quiet 25 | 26 | # output should look like: 27 | # | Translated 6750 sentences (152251 tokens) in 37.9s (4018.00 tokens/s) 28 | # | Generate test with beam=6: BLEU4 = 31.31, 65.7/39.2/25.2/16.6 (BP=0.971, ratio=0.972, syslen=127453, reflen=131152) 29 | 30 | python3 pytorch_translate/generate.py \ 31 | "" \ 32 | --path model/averaged_checkpoint_best_0.pt:model/averaged_checkpoint_best_1.pt \ 33 | --source-vocab-file model/dictionary-de.txt \ 34 | --target-vocab-file model/dictionary-en.txt \ 35 | --source-text-file data/test.tok.bpe.de \ 36 | --target-text-file data/test.tok.bpe.en \ 37 | --unk-reward -0.5 \ 38 | --length-penalty 0 \ 39 | --word-reward 0.25 \ 40 | --beam 6 \ 41 | --remove-bpe \ 42 | --quiet 43 | 44 | # output should look like: 45 | # | Translated 6750 sentences (152251 tokens) in 60.2s (2530.11 tokens/s) 46 | # | Generate test with beam=6: BLEU4 = 32.88, 67.4/41.2/27.1/18.2 (BP=0.962, ratio=0.962, syslen=126199, reflen=131152) 47 | # Notice how the performance improves when using two checkpoints instead of one: ensembling models is a very useful 48 | # technique to improve translation quality. 49 | -------------------------------------------------------------------------------- /pytorch_translate/examples/train_iwslt14.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NCCL_ROOT_DIR="$(pwd)/nccl_2.1.15-1+cuda8.0_x86_64" 4 | export NCCL_ROOT_DIR 5 | LD_LIBRARY_PATH="${NCCL_ROOT_DIR}/lib:${LD_LIBRARY_PATH}" 6 | export LD_LIBRARY_PATH 7 | wget https://download.pytorch.org/models/translate/iwslt14/data.tar.gz 8 | tar -xvzf data.tar.gz 9 | rm -rf checkpoints data.tar.gz && mkdir -p checkpoints 10 | CUDA_VISIBLE_DEVICES=0 python3 pytorch_translate/train.py \ 11 | "" \ 12 | --arch rnn \ 13 | --log-verbose \ 14 | --lr-scheduler fixed \ 15 | --force-anneal 200 \ 16 | --cell-type lstm \ 17 | --sequence-lstm \ 18 | --reverse-source \ 19 | --encoder-bidirectional \ 20 | --max-epoch 100 \ 21 | --stop-time-hr 72 \ 22 | --stop-no-best-bleu-eval 5 \ 23 | --optimizer sgd \ 24 | --lr 0.5 \ 25 | --lr-shrink 0.95 \ 26 | --clip-norm 5.0 \ 27 | --encoder-dropout-in 0.1 \ 28 | --encoder-dropout-out 0.1 \ 29 | --decoder-dropout-in 0.2 \ 30 | --decoder-dropout-out 0.2 \ 31 | --criterion label_smoothed_cross_entropy \ 32 | --label-smoothing 0.1 \ 33 | --batch-size 256 \ 34 | --length-penalty 0 \ 35 | --unk-reward -0.5 \ 36 | --word-reward 0.25 \ 37 | --max-tokens 9999999 \ 38 | --encoder-layers 2 \ 39 | --encoder-embed-dim 256 \ 40 | --encoder-hidden-dim 512 \ 41 | --decoder-layers 2 \ 42 | --decoder-embed-dim 256 \ 43 | --decoder-hidden-dim 512 \ 44 | --decoder-out-embed-dim 256 \ 45 | --save-dir checkpoints \ 46 | --attention-type dot \ 47 | --sentence-avg \ 48 | --momentum 0 \ 49 | --num-avg-checkpoints 10 \ 50 | --beam 6 \ 51 | --no-beamable-mm \ 52 | --source-lang de \ 53 | --target-lang en \ 54 | --train-source-text-file data/train.tok.bpe.de \ 55 | --train-target-text-file data/train.tok.bpe.en \ 56 | --eval-source-text-file data/valid.tok.bpe.de \ 57 | --eval-target-text-file data/valid.tok.bpe.en \ 58 | --source-max-vocab-size 14000 \ 59 | --target-max-vocab-size 14000 \ 60 | --log-interval 10 \ 61 | --seed "${RANDOM}" \ 62 | 2>&1 | tee -a checkpoints/log 63 | -------------------------------------------------------------------------------- /pytorch_translate/examples/train_lm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NCCL_ROOT_DIR="$(pwd)/nccl_2.1.15-1+cuda8.0_x86_64" 4 | export NCCL_ROOT_DIR 5 | LD_LIBRARY_PATH="${NCCL_ROOT_DIR}/lib:${LD_LIBRARY_PATH}" 6 | export LD_LIBRARY_PATH 7 | wget http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz 8 | tar -xvzf 1-billion-word-language-modeling-benchmark-r13output.tar.gz 9 | rm -rf checkpoints 1-billion-word-language-modeling-benchmark-r13output.tar.gz && mkdir -p checkpoints 10 | cat 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/news.en-00*-of-00100 > 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled.news.en 11 | CUDA_VISIBLE_DEVICES=0 python3 pytorch_translate/train.py \ 12 | "" \ 13 | --log-verbose \ 14 | --arch rnn \ 15 | --cell-type lstm \ 16 | --sequence-lstm \ 17 | --max-tokens 999999 \ 18 | --max-epoch 2 \ 19 | --optimizer sgd \ 20 | --lr 0.5 \ 21 | --lr-shrink 0.95 \ 22 | --clip-norm 5.0 \ 23 | --encoder-dropout-in 0.1 \ 24 | --encoder-dropout-out 0.1 \ 25 | --decoder-dropout-in 0.2 \ 26 | --decoder-dropout-out 0.2 \ 27 | --criterion "label_smoothed_cross_entropy" \ 28 | --label-smoothing 0.1 \ 29 | --batch-size 64 \ 30 | --encoder-bidirectional \ 31 | --encoder-layers 2 \ 32 | --encoder-embed-dim 256 \ 33 | --encoder-hidden-dim 0 \ 34 | --decoder-layers 2 \ 35 | --decoder-embed-dim 256 \ 36 | --decoder-hidden-dim 512 \ 37 | --decoder-out-embed-dim 256 \ 38 | --save-dir checkpoints \ 39 | --attention-type no \ 40 | --sentence-avg \ 41 | --momentum 0 \ 42 | --generate-bleu-eval-avg-checkpoint 10 \ 43 | --beam 6 --no-beamable-mm --length-penalty 1.0 \ 44 | --max-sentences 64 --max-sentences-valid 64 \ 45 | --source-lang en \ 46 | --target-lang en \ 47 | --train-source-text-file 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled.news.en \ 48 | --train-target-text-file 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled.news.en \ 49 | --eval-source-text-file 1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/news.en-00000-of-00100 \ 50 | --eval-target-text-file 1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/news.en-00000-of-00100 \ 51 | --source-max-vocab-size 50000 \ 52 | --target-max-vocab-size 50000 \ 53 | --log-interval 500 \ 54 | --seed "${RANDOM}" \ 55 | 2>&1 | tee -a checkpoints/log 56 | -------------------------------------------------------------------------------- /pytorch_translate/examples/translate_iwslt14.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Sample script to run the decoder to load an exported model and use it for 4 | # inference. Assumes that `install.sh` has been run, so that `cmake` and 5 | # `make` have already been run in the `pytorch_translate/cpp` directory, 6 | # producing the `translation_decoder` binary. 7 | # 8 | # Sample usage: 9 | # echo "hallo welt" | bash pytorch_translate/examples/translate_iwslt14.sh 10 | 11 | CONDA_PATH="$(dirname "$(which conda)")/../" 12 | export CONDA_PATH 13 | NCCL_ROOT_DIR="$(pwd)/nccl_2.1.15-1+cuda8.0_x86_64" 14 | export NCCL_ROOT_DIR 15 | LD_LIBRARY_PATH="${CONDA_PATH}/lib:${NCCL_ROOT_DIR}/lib:${LD_LIBRARY_PATH}" 16 | export LD_LIBRARY_PATH 17 | 18 | cat | pytorch_translate/cpp/build/translation_decoder \ 19 | --encoder_model "encoder.pb" \ 20 | --decoder_step_model "decoder.pb" \ 21 | --source_vocab_path "model/dictionary-de.txt" \ 22 | --target_vocab_path "model/dictionary-en.txt" \ 23 | `# Tuneable parameters` \ 24 | --beam_size 6 \ 25 | --max_out_seq_len_mult 1.1 \ 26 | --max_out_seq_len_bias 5 \ 27 | `# Must match your training settings` \ 28 | --reverse_source True \ 29 | --append_eos_to_source False \ 30 | `# Unset for more logging/debug messages` \ 31 | --caffe2_log_level 3 32 | -------------------------------------------------------------------------------- /pytorch_translate/file_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | TODO(T55884145): Deprecate this in favor of using 5 | fvcore.common.file_io.PathManager directly. 6 | """ 7 | from fairseq.file_io import PathManager # noqa 8 | -------------------------------------------------------------------------------- /pytorch_translate/model_constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Pretrained model params 4 | PRETRAINED_CHAR_EMBED_DIM = 16 5 | PRETRAINED_CHAR_CNN_PARAMS = [ 6 | (32, 1), 7 | (32, 2), 8 | (64, 3), 9 | (128, 4), 10 | (256, 5), 11 | (512, 6), 12 | (1024, 7), 13 | ] 14 | PRETRAINED_NUM_HIGHWAY_LAYERS = 2 15 | PRETRAINED_CHAR_CNN_NONLINEAR_FN = "relu" 16 | PRETRAINED_CHAR_CNN_OUTPUT_DIM = 512 17 | -------------------------------------------------------------------------------- /pytorch_translate/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import importlib 4 | import os 5 | 6 | 7 | # automatically import any Python files in the models/ directory 8 | for file in sorted(os.listdir(os.path.dirname(__file__))): 9 | if file.endswith(".py") and not file.startswith("_"): 10 | model_name = file[: file.find(".py")] 11 | importlib.import_module("pytorch_translate.models." + model_name) 12 | -------------------------------------------------------------------------------- /pytorch_translate/models/transformer_from_pretrained_xlm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | from fairseq.models import register_model, register_model_architecture 10 | from fairseq.models.transformer import ( 11 | base_architecture as transformer_base_architecture, 12 | ) 13 | from fairseq.models.transformer_from_pretrained_xlm import ( 14 | TransformerFromPretrainedXLMModel, 15 | ) 16 | from pytorch_translate.data.masked_lm_dictionary import MaskedLMDictionary 17 | 18 | 19 | @register_model("pytorch_translate_transformer_from_pretrained_xlm") 20 | class PytorchTranslateTransformerFromPretrainedXLMModel( 21 | TransformerFromPretrainedXLMModel 22 | ): 23 | @classmethod 24 | def build_model(cls, args, task): 25 | return super().build_model(args, task, cls_dictionary=MaskedLMDictionary) 26 | 27 | 28 | @register_model_architecture( 29 | "pytorch_translate_transformer_from_pretrained_xlm", 30 | "pytorch_translate_transformer_from_pretrained_xlm", 31 | ) 32 | def base_architecture(args): 33 | transformer_base_architecture(args) 34 | -------------------------------------------------------------------------------- /pytorch_translate/multilingual_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from collections import OrderedDict 4 | 5 | import torch.nn as nn 6 | from fairseq.models import FairseqMultiModel, register_model 7 | from pytorch_translate import common_layers, utils 8 | 9 | 10 | @register_model("multilingual") 11 | class MultilingualModel(FairseqMultiModel): 12 | """ 13 | To use, you must extend this class and define single_model_cls as a class 14 | variable. Example: 15 | 16 | @register_model("multilingual_transformer") 17 | class MultilingualTransformerModel(MultilingualModel): 18 | single_model_cls = TransformerModel 19 | 20 | @staticmethod 21 | def add_args(parser): 22 | TransformerModel.add_args(parser) 23 | MultilingualModel.add_args(parser) 24 | """ 25 | 26 | def __init__(self, task, encoders, decoders): 27 | super().__init__(encoders, decoders) 28 | self.task = task 29 | self.models = nn.ModuleDict( 30 | { 31 | key: self.__class__.single_model_cls(task, encoders[key], decoders[key]) 32 | for key in self.keys 33 | } 34 | ) 35 | 36 | @staticmethod 37 | def add_args(parser): 38 | """Add model-specific arguments to the parser.""" 39 | parser.add_argument( 40 | "--share-encoder-embeddings", 41 | action="store_true", 42 | help="share encoder embeddings across languages", 43 | ) 44 | parser.add_argument( 45 | "--share-decoder-embeddings", 46 | action="store_true", 47 | help="share decoder embeddings across languages", 48 | ) 49 | parser.add_argument( 50 | "--share-encoders", 51 | action="store_true", 52 | help="share encoders across languages", 53 | ) 54 | parser.add_argument( 55 | "--share-decoders", 56 | action="store_true", 57 | help="share decoders across languages", 58 | ) 59 | 60 | @staticmethod 61 | def set_multilingual_arch_args(args): 62 | args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False) 63 | args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False) 64 | args.share_encoders = getattr(args, "share_encoders", False) 65 | args.share_decoders = getattr(args, "share_decoders", False) 66 | 67 | @classmethod 68 | def build_model(cls, args, task): 69 | """Build a new model instance.""" 70 | if not hasattr(args, "max_source_positions"): 71 | args.max_source_positions = 1024 72 | if not hasattr(args, "max_target_positions"): 73 | args.max_target_positions = 1024 74 | 75 | src_langs = [lang_pair.split("-")[0] for lang_pair in task.lang_pairs] 76 | tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.lang_pairs] 77 | 78 | if args.share_encoders: 79 | args.share_encoder_embeddings = True 80 | if args.share_decoders: 81 | args.share_decoder_embeddings = True 82 | 83 | # encoders/decoders for each language 84 | lang_encoders, lang_decoders = {}, {} 85 | 86 | def get_encoder(lang, shared_encoder_embed_tokens=None): 87 | if lang not in lang_encoders: 88 | src_dict = task.dicts[lang] 89 | if shared_encoder_embed_tokens is None: 90 | encoder_embed_tokens = common_layers.Embedding( 91 | num_embeddings=len(src_dict), 92 | embedding_dim=args.encoder_embed_dim, 93 | padding_idx=src_dict.pad(), 94 | freeze_embed=args.encoder_freeze_embed, 95 | normalize_embed=getattr(args, "encoder_normalize_embed", False), 96 | ) 97 | utils.load_embedding( 98 | embedding=encoder_embed_tokens, 99 | dictionary=src_dict, 100 | pretrained_embed=args.encoder_pretrained_embed, 101 | ) 102 | else: 103 | encoder_embed_tokens = shared_encoder_embed_tokens 104 | lang_encoders[lang] = cls.single_model_cls.build_encoder( 105 | args, src_dict, embed_tokens=encoder_embed_tokens 106 | ) 107 | return lang_encoders[lang] 108 | 109 | def get_decoder(lang, shared_decoder_embed_tokens=None): 110 | """ 111 | Fetch decoder for the input `lang`, which denotes the target 112 | language of the model 113 | """ 114 | if lang not in lang_decoders: 115 | tgt_dict = task.dicts[lang] 116 | if shared_decoder_embed_tokens is None: 117 | decoder_embed_tokens = common_layers.Embedding( 118 | num_embeddings=len(tgt_dict), 119 | embedding_dim=args.decoder_embed_dim, 120 | padding_idx=tgt_dict.pad(), 121 | freeze_embed=args.decoder_freeze_embed, 122 | ) 123 | utils.load_embedding( 124 | embedding=decoder_embed_tokens, 125 | dictionary=tgt_dict, 126 | pretrained_embed=args.decoder_pretrained_embed, 127 | ) 128 | else: 129 | decoder_embed_tokens = shared_decoder_embed_tokens 130 | lang_decoders[lang] = cls.single_model_cls.build_decoder( 131 | args, task.dicts[lang], tgt_dict, embed_tokens=decoder_embed_tokens 132 | ) 133 | return lang_decoders[lang] 134 | 135 | # shared encoders/decoders (if applicable) 136 | shared_encoder, shared_decoder = None, None 137 | if args.share_encoders: 138 | shared_encoder = get_encoder(src_langs[0]) 139 | if args.share_decoders: 140 | shared_decoder = get_decoder(tgt_langs[0]) 141 | 142 | shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None 143 | if args.share_encoder_embeddings: 144 | shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( 145 | dicts=task.dicts, 146 | langs=src_langs, 147 | embed_dim=args.encoder_embed_dim, 148 | build_embedding=common_layers.build_embedding, 149 | pretrained_embed_path=None, 150 | ) 151 | if args.share_decoder_embeddings: 152 | shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( 153 | dicts=task.dicts, 154 | langs=tgt_langs, 155 | embed_dim=args.decoder_embed_dim, 156 | build_embedding=common_layers.build_embedding, 157 | pretrained_embed_path=None, 158 | ) 159 | encoders, decoders = OrderedDict(), OrderedDict() 160 | for lang_pair, src_lang, tgt_lang in zip(task.lang_pairs, src_langs, tgt_langs): 161 | encoders[lang_pair] = ( 162 | shared_encoder 163 | if shared_encoder is not None 164 | else get_encoder( 165 | src_lang, shared_encoder_embed_tokens=shared_encoder_embed_tokens 166 | ) 167 | ) 168 | decoders[lang_pair] = ( 169 | shared_decoder 170 | if shared_decoder is not None 171 | else get_decoder( 172 | tgt_lang, shared_decoder_embed_tokens=shared_decoder_embed_tokens 173 | ) 174 | ) 175 | 176 | return cls(task, encoders, decoders) 177 | -------------------------------------------------------------------------------- /pytorch_translate/multilingual_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | from collections import defaultdict 5 | from typing import Dict, List, Optional, Tuple 6 | 7 | from pytorch_translate.data import dictionary as pytorch_translate_dictionary 8 | 9 | 10 | def get_source_langs(lang_pairs: List[str]) -> List[str]: 11 | """ 12 | Return list of source languages from args.lang_pairs 13 | lang_pairs: List[str] where each element is a str with comma separated list 14 | of language pairs 15 | """ 16 | return [lang_pair.split("-")[0] for lang_pair in lang_pairs] 17 | 18 | 19 | def get_target_langs(lang_pairs: List[str]) -> List[str]: 20 | """ 21 | Return list of target languages from args.lang_pairs 22 | lang_pairs: List[str] where each element is a str with comma separated list 23 | of language pairs 24 | """ 25 | return [lang_pair.split("-")[1] for lang_pair in lang_pairs] 26 | 27 | 28 | def default_binary_path(save_dir: str, lang_pair: str, lang: str, split: str) -> str: 29 | return os.path.join(save_dir, f"{split}-binary-{lang_pair}.{lang}") 30 | 31 | 32 | def get_dict_paths( 33 | vocabulary_args: Optional[List[str]], langs: List[str], save_dir: str 34 | ) -> Dict[str, str]: 35 | """ 36 | Extract dictionary files based on --vocabulary argument, for the given 37 | languages `langs`. 38 | vocabulary_arg: Optional[List[str]] where each element is a str with the format 39 | "lang:vocab_file" 40 | """ 41 | dicts = {} 42 | if vocabulary_args is not None: 43 | for vocab_config in vocabulary_args: 44 | # vocab_config is in the format "lang:vocab_file" 45 | lang, vocab = vocab_config.split(":") 46 | if lang in langs: 47 | dicts[lang] = vocab 48 | for lang in langs: 49 | if lang not in dicts: 50 | dicts[lang] = pytorch_translate_dictionary.default_dictionary_path( 51 | save_dir=save_dir, dialect=lang 52 | ) 53 | return dicts 54 | 55 | 56 | def get_corpora_for_lang(parallel_corpora: List[str], lang: str) -> List[str]: 57 | """ 58 | Fetches list of corpora that belong to given lang 59 | parallel_corpora: List[str] where each element is a str with the format 60 | "src_lang-tgt_lang:src_corpus,tgt_corpus" 61 | 62 | Returns [] if corpora for lang is not found 63 | """ 64 | corpora = [] 65 | for parallel_corpus_config in parallel_corpora: 66 | lang_pair, parallel_corpus = parallel_corpus_config.split(":") 67 | src_lang, tgt_lang = lang_pair.split("-") 68 | if src_lang == lang: 69 | corpora.append(parallel_corpus.split(",")[0]) 70 | if tgt_lang == lang: 71 | corpora.append(parallel_corpus.split(",")[1]) 72 | return corpora 73 | 74 | 75 | def get_parallel_corpus_for_lang_pair( 76 | parallel_corpora: List[str], lang_pair: str 77 | ) -> Tuple[str, str]: 78 | """ 79 | Fetches parallel corpus that belong to given lang_pair 80 | parallel_corpora: List[str] where each element is a str with the format 81 | "src_lang-tgt_lang:src_corpus,tgt_corpus" 82 | 83 | Returns None if parallel corpora for lang_pair is not found 84 | """ 85 | for parallel_corpus_config in parallel_corpora: 86 | corpus_lang_pair, parallel_corpus = parallel_corpus_config.split(":") 87 | if corpus_lang_pair == lang_pair: 88 | return tuple(parallel_corpus.split(",")) 89 | return None 90 | 91 | 92 | def prepare_dicts( 93 | args: argparse.Namespace, langs: List[str] 94 | ) -> Tuple[Dict[str, str], Dict[str, pytorch_translate_dictionary.Dictionary]]: 95 | """ 96 | Uses multilingual train corpora specified in args.multilingual_train_text_file 97 | to build dictionaries for languages specified in `langs`. 98 | Vocab size is defined by args.target_max_vocab_size if lang is in the set 99 | of target languages, otherwise it is decided by args.target_max_vocab_size 100 | """ 101 | tgt_langs = get_target_langs(args.lang_pairs.split(",")) 102 | dict_paths = get_dict_paths(args.vocabulary, langs, args.save_dir) 103 | lang2corpus = defaultdict(list) 104 | for lang in langs: 105 | lang2corpus[lang] = get_corpora_for_lang( 106 | args.multilingual_train_text_file, lang 107 | ) 108 | dict_objects = { 109 | lang: pytorch_translate_dictionary.Dictionary.build_vocab_file_if_nonexistent( 110 | corpus_files=lang2corpus[lang], 111 | vocab_file=dict_paths[lang], 112 | max_vocab_size=( 113 | args.target_max_vocab_size 114 | if lang in tgt_langs 115 | else args.source_max_vocab_size 116 | ), 117 | tokens_with_penalty=None, 118 | ) 119 | for lang in langs 120 | } 121 | return dict_paths, dict_objects 122 | -------------------------------------------------------------------------------- /pytorch_translate/ngram.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from fairseq import utils 7 | from pytorch_translate import attention, utils as pytorch_translate_utils 8 | from pytorch_translate.common_layers import ( 9 | DecoderWithOutputProjection, 10 | Embedding, 11 | Linear, 12 | NonlinearLayer, 13 | ) 14 | from pytorch_translate.utils import maybe_cat 15 | 16 | 17 | class NGramDecoder(DecoderWithOutputProjection): 18 | """n-gram decoder. 19 | 20 | This decoder implementation does not condition on the full target-side 21 | history. Instead, predictions only depend on the target n-gram history and 22 | the full source sentence via attention over encoder outputs. The decoder 23 | network is a feedforward network with source context as additional input. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | src_dict, 29 | dst_dict, 30 | vocab_reduction_params=None, 31 | n=4, 32 | encoder_hidden_dim=512, 33 | embed_dim=512, 34 | freeze_embed=False, 35 | hidden_dim=512, 36 | out_embed_dim=512, 37 | num_layers=1, 38 | dropout_in=0.1, 39 | dropout_out=0.1, 40 | attention_type="dot", 41 | residual_level=None, 42 | activation_fn=nn.ReLU, 43 | project_output=True, 44 | pretrained_embed=None, 45 | projection_pretrained_embed=None, 46 | ): 47 | super().__init__( 48 | src_dict, 49 | dst_dict, 50 | vocab_reduction_params, 51 | out_embed_dim, 52 | project_output=project_output, 53 | pretrained_embed=projection_pretrained_embed, 54 | ) 55 | self.history_len = n - 1 56 | self.encoder_hidden_dim = encoder_hidden_dim 57 | self.embed_dim = embed_dim 58 | self.hidden_dim = hidden_dim 59 | self.out_embed_dim = out_embed_dim 60 | self.dropout_in = dropout_in 61 | self.dropout_out = dropout_out 62 | self.attention_type = attention_type 63 | self.residual_level = residual_level 64 | self.dst_dict = dst_dict 65 | self.activation_fn = activation_fn 66 | 67 | num_embeddings = len(dst_dict) 68 | padding_idx = dst_dict.pad() 69 | self.embed_tokens = Embedding( 70 | num_embeddings=num_embeddings, 71 | embedding_dim=embed_dim, 72 | padding_idx=padding_idx, 73 | freeze_embed=freeze_embed, 74 | ) 75 | pytorch_translate_utils.load_embedding( 76 | embedding=self.embed_tokens, 77 | dictionary=dst_dict, 78 | pretrained_embed=pretrained_embed, 79 | ) 80 | 81 | self.history_conv = nn.Sequential( 82 | torch.nn.Conv1d(embed_dim, hidden_dim, self.history_len), activation_fn() 83 | ) 84 | 85 | self.hidden_dim = hidden_dim 86 | self.layers = nn.ModuleList( 87 | [ 88 | NonlinearLayer(hidden_dim, hidden_dim, activation_fn=activation_fn) 89 | for _ in range(num_layers) 90 | ] 91 | ) 92 | 93 | self.attention = attention.build_attention( 94 | attention_type=attention_type, 95 | decoder_hidden_state_dim=hidden_dim, 96 | context_dim=encoder_hidden_dim, 97 | force_projection=True, 98 | ) 99 | self.combined_output_and_context_dim = self.attention.context_dim + hidden_dim 100 | if self.combined_output_and_context_dim != out_embed_dim: 101 | self.additional_fc = Linear( 102 | self.combined_output_and_context_dim, out_embed_dim 103 | ) 104 | 105 | def forward_unprojected(self, input_tokens, encoder_out, incremental_state=None): 106 | padded_tokens = F.pad( 107 | input_tokens, 108 | (self.history_len - 1, 0, 0, 0), 109 | "constant", 110 | self.dst_dict.eos(), 111 | ) 112 | # We use incremental_state only to check whether we are decoding or not 113 | # self.training is false even for the forward pass through validation 114 | if incremental_state is not None: 115 | padded_tokens = padded_tokens[:, -self.history_len :] 116 | utils.set_incremental_state(self, incremental_state, "incremental_marker", True) 117 | 118 | bsz, seqlen = padded_tokens.size() 119 | seqlen -= self.history_len - 1 120 | 121 | # get outputs from encoder 122 | (encoder_outs, final_hidden, _, src_lengths, _) = encoder_out 123 | 124 | # padded_tokens has shape [batch_size, seq_len+history_len] 125 | x = self.embed_tokens(padded_tokens) 126 | x = F.dropout(x, p=self.dropout_in, training=self.training) 127 | 128 | # Convolution needs shape [batch_size, channels, seq_len] 129 | x = self.history_conv(x.transpose(1, 2)).transpose(1, 2) 130 | x = F.dropout(x, p=self.dropout_out, training=self.training) 131 | 132 | # x has shape [batch_size, seq_len, channels] 133 | for i, layer in enumerate(self.layers): 134 | prev_x = x 135 | x = layer(x) 136 | x = F.dropout(x, p=self.dropout_out, training=self.training) 137 | if self.residual_level is not None and i >= self.residual_level: 138 | x = x + prev_x 139 | 140 | # Attention 141 | attn_out, attn_scores = self.attention( 142 | x.transpose(0, 1).contiguous().view(-1, self.hidden_dim), 143 | encoder_outs.repeat(1, seqlen, 1), 144 | src_lengths.repeat(seqlen), 145 | ) 146 | if attn_out is not None: 147 | attn_out = attn_out.view(seqlen, bsz, -1).transpose(1, 0) 148 | attn_scores = attn_scores.view(-1, seqlen, bsz).transpose(0, 2) 149 | x = maybe_cat((x, attn_out), dim=2) 150 | 151 | # bottleneck layer 152 | if hasattr(self, "additional_fc"): 153 | x = self.additional_fc(x) 154 | x = F.dropout(x, p=self.dropout_out, training=self.training) 155 | return x, attn_scores 156 | 157 | def max_positions(self): 158 | """Maximum output length supported by the decoder.""" 159 | return int(1e5) # an arbitrary large number 160 | -------------------------------------------------------------------------------- /pytorch_translate/rescoring/weights_search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import pickle 5 | 6 | import numpy as np 7 | import torch 8 | from fairseq.scoring import bleu 9 | from pytorch_translate import vocab_constants 10 | from pytorch_translate.data.dictionary import Dictionary 11 | from pytorch_translate.generate import smoothed_sentence_bleu 12 | 13 | 14 | def get_arg_parser(): 15 | parser = argparse.ArgumentParser( 16 | description=("Rescore generated hypotheses with extra models") 17 | ) 18 | parser.add_argument( 19 | "--scores-info-export-path", type=str, help="Model scores for weights search" 20 | ) 21 | parser.add_argument( 22 | "--num-trials", 23 | type=int, 24 | default=1000, 25 | help="Number of iterations of random search", 26 | ) 27 | parser.add_argument("--report-oracle-bleu", default=False, action="store_true") 28 | return parser 29 | 30 | 31 | class DummyTask: 32 | """ 33 | Default values for pad, eos, unk 34 | """ 35 | 36 | def __init__(self): 37 | self.target_dictionary = Dictionary() 38 | 39 | 40 | def evaluate_weights(scores_info, feature_weights, length_penalty): 41 | scorer = bleu.Scorer( 42 | bleu.BleuConfig( 43 | pad=vocab_constants.PAD_ID, 44 | eos=vocab_constants.EOS_ID, 45 | unk=vocab_constants.UNK_ID, 46 | ) 47 | ) 48 | 49 | for example in scores_info: 50 | weighted_scores = (example["scores"] * feature_weights).sum(axis=1) 51 | weighted_scores /= (example["tgt_len"] ** length_penalty) + 1e-12 52 | top_hypo_ind = np.argmax(weighted_scores) 53 | top_hypo = example["hypos"][top_hypo_ind] 54 | ref = example["target_tokens"] 55 | scorer.add(torch.IntTensor(ref), torch.IntTensor(top_hypo)) 56 | 57 | return scorer.score() 58 | 59 | 60 | def identify_nonzero_features(scores_info): 61 | nonzero_features = np.any(scores_info[0]["scores"] != 0, axis=0) 62 | for example in scores_info[1:]: 63 | nonzero_features |= np.any(example["scores"] != 0, axis=0) 64 | 65 | return np.where(nonzero_features)[0] 66 | 67 | 68 | def random_search(scores_info_export_path, num_trials, report_oracle_bleu=False): 69 | with open(scores_info_export_path, "rb") as f: 70 | scores_info = pickle.load(f) 71 | 72 | dummy_task = DummyTask() 73 | 74 | if report_oracle_bleu: 75 | oracle_scorer = bleu.Scorer( 76 | bleu.BleuConfig( 77 | pad=vocab_constants.PAD_ID, 78 | eos=vocab_constants.EOS_ID, 79 | unk=vocab_constants.UNK_ID, 80 | ) 81 | ) 82 | 83 | for example in scores_info: 84 | smoothed_bleu = [] 85 | for hypo in example["hypos"]: 86 | eval_score = smoothed_sentence_bleu( 87 | dummy_task, 88 | torch.IntTensor(example["target_tokens"]), 89 | torch.IntTensor(hypo), 90 | ) 91 | smoothed_bleu.append(eval_score) 92 | best_hypo_ind = np.argmax(smoothed_bleu) 93 | example["best_hypo_ind"] = best_hypo_ind 94 | 95 | oracle_scorer.add( 96 | torch.IntTensor(example["target_tokens"]), 97 | torch.IntTensor(example["hypos"][best_hypo_ind]), 98 | ) 99 | 100 | print("oracle BLEU: ", oracle_scorer.score()) 101 | 102 | num_features = scores_info[0]["scores"].shape[1] 103 | assert all( 104 | example["scores"].shape[1] == num_features for example in scores_info 105 | ), "All examples must have the same number of scores!" 106 | feature_weights = np.zeros(num_features) 107 | feature_weights[0] = 1 108 | score = evaluate_weights(scores_info, feature_weights, length_penalty=1) 109 | print("base BLEU: ", score) 110 | best_score = score 111 | best_weights = feature_weights 112 | best_length_penalty = 0 113 | 114 | nonzero_features = identify_nonzero_features(scores_info) 115 | 116 | for i in range(num_trials): 117 | feature_weights = np.zeros(num_features) 118 | random_weights = np.random.dirichlet(np.ones(nonzero_features.size)) 119 | feature_weights[nonzero_features] = random_weights 120 | length_penalty = 1.5 * np.random.random() 121 | 122 | score = evaluate_weights(scores_info, feature_weights, length_penalty) 123 | if score > best_score: 124 | best_score = score 125 | best_weights = feature_weights 126 | best_length_penalty = length_penalty 127 | 128 | print(f"\r[{i}] best: {best_score}", end="", flush=True) 129 | 130 | print() 131 | print("best weights: ", best_weights) 132 | print("best length penalty: ", length_penalty) 133 | 134 | return best_weights, best_length_penalty, best_score 135 | 136 | 137 | def main(): 138 | args = get_arg_parser().parse_args() 139 | 140 | assert ( 141 | args.scores_info_export_path is not None 142 | ), "--scores-info-export-path is required for weights search" 143 | 144 | random_search( 145 | args.scores_info_export_path, args.num_trials, args.report_oracle_bleu 146 | ) 147 | 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /pytorch_translate/research/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/translate/b89dc35abeb7fe516e3b95ccacdedfc1a92e5626/pytorch_translate/research/__init__.py -------------------------------------------------------------------------------- /pytorch_translate/research/attention/multihead_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | 11 | def create_src_lengths_mask(batch_size, src_lengths): 12 | max_srclen = src_lengths.max() 13 | src_indices = torch.arange(0, max_srclen).unsqueeze(0).type_as(src_lengths) 14 | src_indices = src_indices.expand(batch_size, max_srclen) 15 | src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_srclen) 16 | # returns [batch_size, max_seq_len] 17 | return (src_indices < src_lengths).int().detach() 18 | 19 | 20 | def apply_masks(scores, batch_size, unseen_mask, src_lengths): 21 | seq_len = scores.shape[-1] 22 | 23 | # [1, seq_len, seq_len] 24 | sequence_mask = torch.ones(seq_len, seq_len).unsqueeze(0).int() 25 | if unseen_mask: 26 | # [1, seq_len, seq_len] 27 | sequence_mask = ( 28 | torch.tril(torch.ones(seq_len, seq_len), diagonal=0).unsqueeze(0).int() 29 | ) 30 | 31 | if src_lengths is not None: 32 | # [batch_size, 1, seq_len] 33 | src_lengths_mask = create_src_lengths_mask( 34 | batch_size=batch_size, src_lengths=src_lengths 35 | ).unsqueeze(-2) 36 | 37 | # [batch_size, seq_len, seq_len] 38 | sequence_mask = sequence_mask & src_lengths_mask 39 | 40 | # [batch_size, 1, seq_len, seq_len] 41 | sequence_mask = sequence_mask.unsqueeze(1) 42 | 43 | scores = scores.masked_fill(sequence_mask == 0, -np.inf) 44 | return scores 45 | 46 | 47 | def scaled_dot_prod_attn(query, key, value, unseen_mask=False, src_lengths=None): 48 | """ 49 | Scaled Dot Product Attention 50 | 51 | Implements equation: 52 | Attention(Q, K, V) = softmax(QK^T/\sqrt{d_k})V 53 | 54 | Inputs: 55 | query : [batch size, nheads, sequence length, d_k] 56 | key : [batch size, nheads, sequence length, d_k] 57 | value : [batch size, nheads, sequence length, d_v] 58 | unseen_mask: if True, only attend to previous sequence positions 59 | src_lengths_mask: if True, mask padding based on src_lengths 60 | 61 | Outputs: 62 | attn: [batch size, sequence length, d_v] 63 | 64 | Note that in this implementation d_q = d_k = d_v = dim 65 | """ 66 | d_k = query.shape[-1] 67 | scores = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(d_k) 68 | if unseen_mask or src_lengths is not None: 69 | scores = apply_masks( 70 | scores=scores, 71 | batch_size=query.shape[0], 72 | unseen_mask=unseen_mask, 73 | src_lengths=src_lengths, 74 | ) 75 | p_attn = F.softmax(scores, dim=-1) 76 | return torch.matmul(p_attn, value), p_attn 77 | 78 | 79 | def split_heads(X, nheads): 80 | """ 81 | Split heads: 82 | 1) Split (reshape) last dimension (size d_model) into nheads, d_head 83 | 2) Transpose X from (batch size, sequence length, nheads, d_head) to 84 | (batch size, nheads, sequence length, d_head) 85 | 86 | Inputs: 87 | X : [batch size, sequence length, nheads * d_head] 88 | nheads : integer 89 | Outputs: 90 | [batch size, nheads, sequence length, d_head] 91 | 92 | """ 93 | last_dim = X.shape[-1] 94 | assert last_dim % nheads == 0 95 | X_last_dim_split = X.view(list(X.shape[:-1]) + [nheads, last_dim // nheads]) 96 | return X_last_dim_split.transpose(1, 2) 97 | 98 | 99 | def combine_heads(X): 100 | """ 101 | Combine heads (the inverse of split heads): 102 | 1) Transpose X from (batch size, nheads, sequence length, d_head) to 103 | (batch size, sequence length, nheads, d_head) 104 | 2) Combine (reshape) last 2 dimensions (nheads, d_head) into 1 (d_model) 105 | 106 | Inputs: 107 | X : [batch size * nheads, sequence length, d_head] 108 | nheads : integer 109 | d_head : integer 110 | 111 | Outputs: 112 | [batch_size, seq_len, d_model] 113 | 114 | """ 115 | X = X.transpose(1, 2) 116 | nheads, d_head = X.shape[-2:] 117 | return X.contiguous().view(list(X.shape[:-2]) + [nheads * d_head]) 118 | 119 | 120 | class MultiheadAttention(nn.Module): 121 | """ 122 | Multiheaded Scaled Dot Product Attention 123 | 124 | Implements equation: 125 | MultiHead(Q, K, V) = Concat(head_1,...,head_h)W^O 126 | where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V) 127 | 128 | Similarly to the above, d_k = d_v = d_model / h 129 | 130 | Inputs 131 | init: 132 | nheads : integer # of attention heads 133 | d_model : model dimensionality 134 | d_head : dimensionality of a single head 135 | 136 | forward: 137 | query : [batch size, sequence length, d_model] 138 | key: [batch size, sequence length, d_model] 139 | value: [batch size, sequence length, d_model] 140 | unseen_mask: if True, only attend to previous sequence positions 141 | src_lengths_mask: if True, mask padding based on src_lengths 142 | 143 | Output 144 | result : [batch_size, sequence length, d_model] 145 | """ 146 | 147 | def __init__(self, nheads, d_model): 148 | "Take in model size and number of heads." 149 | super(MultiheadAttention, self).__init__() 150 | assert d_model % nheads == 0 151 | self.d_head = d_model // nheads 152 | self.nheads = nheads 153 | self.Q_fc = nn.Linear(d_model, d_model, bias=False) 154 | self.K_fc = nn.Linear(d_model, d_model, bias=False) 155 | self.V_fc = nn.Linear(d_model, d_model, bias=False) 156 | self.output_fc = nn.Linear(d_model, d_model, bias=False) 157 | self.attn = None 158 | 159 | def forward(self, query, key, value, unseen_mask=False, src_lengths=None): 160 | # 1. Fully-connected layer on q, k, v then 161 | # 2. Split heads on q, k, v 162 | # (batch_size, seq_len, d_model) --> 163 | # (batch_size, nheads, seq_len, d_head) 164 | query = split_heads(self.Q_fc(query), self.nheads) 165 | key = split_heads(self.K_fc(key), self.nheads) 166 | value = split_heads(self.V_fc(value), self.nheads) 167 | 168 | # 4. Scaled dot product attention 169 | # (batch_size, nheads, seq_len, d_head) 170 | x, self.attn = scaled_dot_prod_attn( 171 | query=query, 172 | key=key, 173 | value=value, 174 | unseen_mask=unseen_mask, 175 | src_lengths=src_lengths, 176 | ) 177 | 178 | # 5. Combine heads 179 | x = combine_heads(x) 180 | 181 | # 6. Fully-connected layer for output 182 | return self.output_fc(x) 183 | -------------------------------------------------------------------------------- /pytorch_translate/research/beam_search/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/translate/b89dc35abeb7fe516e3b95ccacdedfc1a92e5626/pytorch_translate/research/beam_search/__init__.py -------------------------------------------------------------------------------- /pytorch_translate/research/knowledge_distillation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/translate/b89dc35abeb7fe516e3b95ccacdedfc1a92e5626/pytorch_translate/research/knowledge_distillation/__init__.py -------------------------------------------------------------------------------- /pytorch_translate/research/knowledge_distillation/collect_top_k_probs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from fairseq import options, progress_bar, utils 7 | from pytorch_translate import ( # noqa # noqa # noqa 8 | hybrid_transformer_rnn, 9 | options as pytorch_translate_options, 10 | rnn, 11 | transformer, 12 | utils as pytorch_translate_utils, 13 | ) 14 | from pytorch_translate.constants import CHECKPOINT_PATHS_DELIMITER 15 | from pytorch_translate.tasks.pytorch_translate_multi_task import ( # noqa 16 | PyTorchTranslateMultiTask, 17 | ) 18 | 19 | 20 | def compute_top_k( 21 | task, 22 | models, 23 | dataset, 24 | k, 25 | use_cuda, 26 | max_tokens=None, 27 | max_sentences=None, 28 | progress_bar_args=None, 29 | ): 30 | """ 31 | This function runs forward computation on an ensemble of trained models 32 | using binarized parallel training data and returns the top-k probabilities 33 | and their corresponding token indices for each output step. 34 | 35 | Returns: (top_k_scores, top_k_indices) 36 | Each a NumPy array of size (total_target_tokens, k) 37 | """ 38 | top_k_scores_list = [None for _ in range(len(dataset))] 39 | top_k_indices_list = [None for _ in range(len(dataset))] 40 | 41 | itr = task.get_batch_iterator( 42 | dataset=dataset, max_tokens=max_tokens, max_sentences=max_sentences 43 | ).next_epoch_itr(shuffle=False) 44 | if progress_bar_args is not None: 45 | itr = progress_bar.build_progress_bar( 46 | args=progress_bar_args, 47 | iterator=itr, 48 | prefix=f"top-k probs eval", 49 | no_progress_bar="simple", 50 | ) 51 | 52 | for sample in itr: 53 | sentence_ids = sample["id"] 54 | target_lengths = ( 55 | (sample["net_input"]["prev_output_tokens"] != dataset.tgt_dict.pad()) 56 | .sum(axis=1) 57 | .numpy() 58 | ) 59 | if use_cuda: 60 | sample = utils.move_to_cuda(sample) 61 | avg_probs = None 62 | for model in models: 63 | with torch.no_grad(): 64 | net_output = model(**sample["net_input"]) 65 | probs = model.get_normalized_probs(net_output, log_probs=False) 66 | if avg_probs is None: 67 | avg_probs = probs 68 | else: 69 | avg_probs.add_(probs) 70 | avg_probs.div_(len(models)) 71 | 72 | top_k_avg_probs, indices = torch.topk(avg_probs, k=k) 73 | 74 | top_k_probs_normalized = F.normalize(top_k_avg_probs, p=1, dim=2).cpu() 75 | indices = indices.cpu() 76 | 77 | for i, sentence_id in enumerate(sentence_ids): 78 | length = target_lengths[i] 79 | top_k_scores_list[sentence_id] = top_k_probs_normalized[i][:length].numpy() 80 | top_k_indices_list[sentence_id] = indices[i][:length].numpy() 81 | 82 | assert all( 83 | top_k_scores is not None for top_k_scores in top_k_scores_list 84 | ), "scores not calculated for all examples!" 85 | assert all( 86 | top_k_indices is not None for top_k_indices in top_k_indices_list 87 | ), "indices not calculated for all examples!" 88 | 89 | top_k_scores = np.concatenate(top_k_scores_list, axis=0) 90 | top_k_indices = np.concatenate(top_k_indices_list, axis=0) 91 | 92 | return top_k_scores, top_k_indices 93 | 94 | 95 | def save_top_k(args): 96 | """ 97 | This function runs forward computation on an ensemble of trained models 98 | using binarized parallel training data and saves the top-k probabilities 99 | and their corresponding token indices for each output step. 100 | 101 | Note that the Python binary accepts all generation params, but ignores 102 | inapplicable ones (such as those related to output length). --max-tokens 103 | is of particular importance to prevent memory errors. 104 | """ 105 | pytorch_translate_options.print_args(args) 106 | use_cuda = torch.cuda.is_available() and not getattr(args, "cpu", False) 107 | 108 | ( 109 | models, 110 | model_args, 111 | task, 112 | ) = pytorch_translate_utils.load_diverse_ensemble_for_inference( 113 | args.path.split(CHECKPOINT_PATHS_DELIMITER) 114 | ) 115 | for model in models: 116 | model.eval() 117 | if use_cuda: 118 | model.cuda() 119 | 120 | append_eos_to_source = model_args[0].append_eos_to_source 121 | reverse_source = model_args[0].reverse_source 122 | assert all( 123 | a.append_eos_to_source == append_eos_to_source 124 | and a.reverse_source == reverse_source 125 | for a in model_args 126 | ) 127 | assert ( 128 | args.source_binary_file != "" and args.target_binary_file != "" 129 | ), "collect_top_k_probs requires binarized data." 130 | task.load_dataset(args.gen_subset, args.source_binary_file, args.target_binary_file) 131 | 132 | assert ( 133 | args.top_k_probs_binary_file != "" 134 | ), "must specify output file (--top-k-probs-binary-file)!" 135 | output_path = args.top_k_probs_binary_file 136 | 137 | dataset = task.dataset(args.gen_subset) 138 | 139 | top_k_scores, top_k_indices = compute_top_k( 140 | task=task, 141 | models=models, 142 | dataset=dataset, 143 | k=args.k_probs_to_collect, 144 | use_cuda=use_cuda, 145 | max_tokens=args.teacher_max_tokens, 146 | max_sentences=args.max_sentences, 147 | progress_bar_args=args, 148 | ) 149 | 150 | np.savez(output_path, top_k_scores=top_k_scores, top_k_indices=top_k_indices) 151 | print( 152 | f"Saved top {top_k_scores.shape[1]} probs for a total of " 153 | f"{top_k_scores.shape[0]} tokens to file {output_path}" 154 | ) 155 | 156 | 157 | def get_parser_with_args(): 158 | parser = options.get_parser("Collect Top-K Probs", default_task="pytorch_translate") 159 | pytorch_translate_options.add_verbosity_args(parser) 160 | pytorch_translate_options.add_dataset_args(parser, gen=True) 161 | generation_group = options.add_generation_args(parser) 162 | 163 | generation_group.add_argument( 164 | "--source-binary-file", 165 | default="", 166 | help="Path for the binary file containing source eval examples. " 167 | "(Overrides --source-text-file. Must be used in conjunction with " 168 | "--target-binary-file).", 169 | ) 170 | generation_group.add_argument( 171 | "--target-binary-file", 172 | default="", 173 | help="Path for the binary file containing target eval examples. " 174 | "(Overrides --target-text-file. Must be used in conjunction with " 175 | "--source-binary-file).", 176 | ) 177 | generation_group.add_argument( 178 | "--k-probs-to-collect", 179 | type=int, 180 | default=8, 181 | help="Number of probabilities to collect for each output step.", 182 | ) 183 | generation_group.add_argument( 184 | "--top-k-probs-binary-file", 185 | type=str, 186 | default="", 187 | help="File into which to save top-K probabilities for each token.", 188 | ) 189 | generation_group.add_argument( 190 | "--teacher-max-tokens", 191 | type=int, 192 | default=1000, 193 | help="Maximum number of words in minibatch for teacher to score.", 194 | ) 195 | return parser 196 | 197 | 198 | def main(): 199 | parser = get_parser_with_args() 200 | args = options.parse_args_and_arch(parser) 201 | save_top_k(args) 202 | 203 | 204 | if __name__ == "__main__": 205 | main() 206 | -------------------------------------------------------------------------------- /pytorch_translate/research/knowledge_distillation/dual_decoder_kd_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import math 4 | 5 | from fairseq import utils 6 | from fairseq.criterions import LegacyFairseqCriterion, register_criterion 7 | 8 | 9 | @register_criterion("dual_decoder_kd_loss") 10 | class DualDecoderCriterion(LegacyFairseqCriterion): 11 | def __init__(self, args, task): 12 | super().__init__(args, task) 13 | self.eps = args.label_smoothing 14 | self.kd_weight = args.kd_weight 15 | 16 | @staticmethod 17 | def add_args(parser): 18 | """Add criterion-specific arguments to the parser.""" 19 | parser.add_argument( 20 | "--label-smoothing", 21 | default=0.0, 22 | type=float, 23 | metavar="D", 24 | help="[teacher decoder only] epsilon for label smoothing, 0 means " 25 | "no label smoothing.", 26 | ) 27 | parser.add_argument( 28 | "--kd-weight", 29 | type=float, 30 | default=0.5, 31 | help=( 32 | "[student decoder only] mixture weight between the knowledge " 33 | "distillation and negative log likelihood losses. Must be in " 34 | "[0.0, 1.0].", 35 | ), 36 | ) 37 | 38 | def forward(self, model, sample, reduce=True): 39 | """Compute the loss for the given sample. 40 | 41 | Returns a tuple with three elements: 42 | 1) the loss 43 | 2) the sample size, which is used as the denominator for the gradient 44 | 3) logging outputs to display while training 45 | """ 46 | 47 | src_tokens = sample["net_input"]["src_tokens"] 48 | src_lengths = sample["net_input"]["src_lengths"] 49 | prev_output_tokens = sample["net_input"]["prev_output_tokens"] 50 | 51 | encoder_out = model.encoder(src_tokens, src_lengths) 52 | student_output = model.student_decoder(prev_output_tokens, encoder_out) 53 | teacher_output = model.teacher_decoder(prev_output_tokens, encoder_out) 54 | 55 | teacher_loss, teacher_nll_loss, teacher_probs = self.compute_teacher_loss( 56 | model, teacher_output, sample, reduce=reduce 57 | ) 58 | 59 | # do not propagate gradient from student loss to teacher output 60 | teacher_probs = teacher_probs.detach() 61 | student_loss, student_nll_loss = self.compute_student_loss( 62 | model, student_output, sample, teacher_probs, reduce=reduce 63 | ) 64 | 65 | total_loss = student_loss + teacher_loss 66 | 67 | sample_size = ( 68 | sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"] 69 | ) 70 | logging_output = { 71 | "teacher_loss": utils.item(teacher_loss.data) 72 | if reduce 73 | else teacher_loss.data, 74 | "teacher_nll_loss": utils.item(teacher_nll_loss.data) 75 | if reduce 76 | else teacher_nll_loss.data, 77 | "student_loss": utils.item(student_loss.data) 78 | if reduce 79 | else student_loss.data, 80 | "student_nll_loss": utils.item(student_nll_loss.data) 81 | if reduce 82 | else student_nll_loss.data, 83 | "loss": utils.item(total_loss.data) if reduce else total_loss.data, 84 | "ntokens": sample["ntokens"], 85 | "nsentences": sample["target"].size(0), 86 | "sample_size": sample_size, 87 | } 88 | return total_loss, sample_size, logging_output 89 | 90 | def compute_teacher_loss(self, model, net_output, sample, reduce=True): 91 | probs = model.get_normalized_probs(net_output, log_probs=False) 92 | probs = probs.view(-1, probs.size(-1)) 93 | lprobs = probs.log() 94 | target = model.get_targets(sample, net_output).view(-1, 1) 95 | non_pad_mask = target.ne(self.padding_idx) 96 | nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] 97 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] 98 | if reduce: 99 | nll_loss = nll_loss.sum() 100 | smooth_loss = smooth_loss.sum() 101 | eps_i = self.eps / lprobs.size(-1) 102 | loss = (1.0 - self.eps) * nll_loss + eps_i * smooth_loss 103 | return loss, nll_loss, probs 104 | 105 | def compute_student_loss( 106 | self, model, net_output, sample, teacher_probs, reduce=True 107 | ): 108 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 109 | lprobs = lprobs.view(-1, lprobs.size(-1)) 110 | target = model.get_targets(sample, net_output).view(-1, 1) 111 | non_pad_mask = target.ne(self.padding_idx) 112 | nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] 113 | kd_loss = (teacher_probs * -lprobs).sum(dim=-1, keepdim=True)[non_pad_mask] 114 | if reduce: 115 | nll_loss = nll_loss.sum() 116 | kd_loss = kd_loss.sum() 117 | loss = (1.0 - self.kd_weight) * nll_loss + self.kd_weight * kd_loss 118 | return loss, nll_loss 119 | 120 | @staticmethod 121 | def aggregate_logging_outputs(logging_outputs): 122 | """Aggregate logging outputs from data parallel training.""" 123 | ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) 124 | nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) 125 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 126 | return { 127 | "student_loss": sum(log.get("student_loss", 0) for log in logging_outputs) 128 | / sample_size 129 | / math.log(2), 130 | "student_nll_loss": sum( 131 | log.get("student_nll_loss", 0) for log in logging_outputs 132 | ) 133 | / ntokens 134 | / math.log(2), 135 | "teacher_loss": sum(log.get("teacher_loss", 0) for log in logging_outputs) 136 | / sample_size 137 | / math.log(2), 138 | "teacher_nll_loss": sum( 139 | log.get("teacher_nll_loss", 0) for log in logging_outputs 140 | ) 141 | / ntokens 142 | / math.log(2), 143 | "loss": sum(log.get("loss", 0) for log in logging_outputs) 144 | / sample_size 145 | / math.log(2), 146 | "ntokens": ntokens, 147 | "nsentences": nsentences, 148 | "sample_size": sample_size, 149 | } 150 | -------------------------------------------------------------------------------- /pytorch_translate/research/knowledge_distillation/dual_decoder_kd_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from fairseq.models import ( 4 | FairseqEncoderDecoderModel, 5 | register_model, 6 | register_model_architecture, 7 | ) 8 | from pytorch_translate import ( 9 | hybrid_transformer_rnn, 10 | transformer as pytorch_translate_transformer, 11 | ) 12 | from pytorch_translate.utils import torch_find 13 | 14 | 15 | @register_model("dual_decoder_kd") 16 | class DualDecoderKDModel(FairseqEncoderDecoderModel): 17 | def __init__(self, task, encoder, teacher_decoder, student_decoder): 18 | super().__init__(encoder, student_decoder) 19 | self.teacher_decoder = teacher_decoder 20 | self.student_decoder = student_decoder 21 | self.using_teacher = True 22 | self.task = task 23 | 24 | def get_teacher_model(self): 25 | return pytorch_translate_transformer.TransformerModel( 26 | self.task, self.encoder, self.teacher_decoder 27 | ) 28 | 29 | def get_student_model(self): 30 | return hybrid_transformer_rnn.HybridTransformerRNNModel( 31 | self.task, self.encoder, self.student_decoder 32 | ) 33 | 34 | @staticmethod 35 | def add_args(parser): 36 | """Add model-specific arguments to the parser.""" 37 | 38 | # command-line args for transformer model are used to build 39 | # encoder and teacher decoder 40 | pytorch_translate_transformer.TransformerModel.add_args(parser) 41 | 42 | # distinct args for student decoder 43 | parser.add_argument( 44 | "--student-decoder-embed-dim", 45 | type=int, 46 | metavar="N", 47 | help="[student RNN] decoder embedding dimension", 48 | ) 49 | parser.add_argument( 50 | "--student-decoder-layers", 51 | type=int, 52 | metavar="N", 53 | help="[student RNN] num decoder layers", 54 | ) 55 | parser.add_argument( 56 | "--student-decoder-attention-heads", 57 | type=int, 58 | metavar="N", 59 | help="[student RNN] num decoder attention heads", 60 | ) 61 | parser.add_argument( 62 | "--student-decoder-lstm-units", 63 | type=int, 64 | metavar="N", 65 | help="[student RNN] num LSTM units for each decoder layer", 66 | ) 67 | parser.add_argument( 68 | "--student-decoder-out-embed-dim", 69 | type=int, 70 | metavar="N", 71 | help="[student RNN] decoder output embedding dimension", 72 | ) 73 | parser.add_argument( 74 | "--student-decoder-reduced-attention-dim", 75 | type=int, 76 | default=None, 77 | metavar="N", 78 | help="if specified, computes attention with this dimensionality " 79 | "in the student decoder (instead of using encoder output dims)", 80 | ) 81 | 82 | @classmethod 83 | def build_model(cls, args, task): 84 | """Build a new model instance.""" 85 | # make sure that all args are properly defaulted 86 | # (in case there are any new ones) 87 | base_architecture(args) 88 | 89 | src_dict, tgt_dict = task.source_dictionary, task.target_dictionary 90 | 91 | encoder_embed_tokens = pytorch_translate_transformer.build_embedding( 92 | dictionary=src_dict, 93 | embed_dim=args.encoder_embed_dim, 94 | path=args.encoder_pretrained_embed, 95 | freeze=args.encoder_freeze_embed, 96 | ) 97 | 98 | teacher_decoder_embed_tokens = pytorch_translate_transformer.build_embedding( 99 | dictionary=tgt_dict, 100 | embed_dim=args.decoder_embed_dim, 101 | path=args.decoder_pretrained_embed, 102 | freeze=args.decoder_freeze_embed, 103 | ) 104 | 105 | student_decoder_embed_tokens = pytorch_translate_transformer.build_embedding( 106 | dictionary=tgt_dict, embed_dim=args.student_decoder_embed_dim 107 | ) 108 | 109 | encoder = pytorch_translate_transformer.TransformerEncoder( 110 | args, src_dict, encoder_embed_tokens, proj_to_decoder=True 111 | ) 112 | 113 | teacher_decoder = pytorch_translate_transformer.TransformerModel.build_decoder( 114 | args, src_dict, tgt_dict, embed_tokens=teacher_decoder_embed_tokens 115 | ) 116 | 117 | student_decoder = StudentHybridRNNDecoder( 118 | args, src_dict, tgt_dict, student_decoder_embed_tokens 119 | ) 120 | 121 | return DualDecoderKDModel( 122 | task=task, 123 | encoder=encoder, 124 | teacher_decoder=teacher_decoder, 125 | student_decoder=student_decoder, 126 | ) 127 | 128 | def get_targets(self, sample, net_output): 129 | targets = sample["target"].view(-1) 130 | possible_translation_tokens = net_output[-1] 131 | if possible_translation_tokens is not None: 132 | targets = torch_find( 133 | possible_translation_tokens, targets, len(self.task.target_dictionary) 134 | ) 135 | return targets 136 | 137 | 138 | class StudentHybridRNNDecoder(hybrid_transformer_rnn.HybridRNNDecoder): 139 | """ 140 | Subclass which constructs RNN decoder from student arguments. 141 | (dropout, attention_dropout, and vocab reduction params shared with teacher.) 142 | """ 143 | 144 | def _init_dims(self, args, src_dict, dst_dict, embed_tokens): 145 | self.dropout = args.dropout 146 | 147 | embed_dim = embed_tokens.embedding_dim 148 | self.embed_tokens = embed_tokens 149 | 150 | self.lstm_units = args.student_decoder_lstm_units 151 | self.num_layers = args.student_decoder_layers 152 | self.initial_input_dim = embed_dim 153 | 154 | # for compatibility with transformer dimensions in encoder 155 | # and teacher decoder are different 156 | self.encoder_output_dim = args.decoder_embed_dim 157 | if args.student_decoder_reduced_attention_dim is None: 158 | self.attention_dim = self.encoder_output_dim 159 | else: 160 | self.attention_dim = args.student_decoder_reduced_attention_dim 161 | self.input_dim = self.lstm_units + self.attention_dim 162 | 163 | self.num_attention_heads = args.student_decoder_attention_heads 164 | self.bottleneck_dim = args.student_decoder_out_embed_dim 165 | 166 | 167 | @register_model_architecture("dual_decoder_kd", "dual_decoder_kd") 168 | def base_architecture(args): 169 | pytorch_translate_transformer.base_architecture(args) 170 | args.student_decoder_embed_dim = getattr(args, "student_decoder_embed_dim", 128) 171 | args.student_decoder_layers = getattr(args, "student_decoder_layers", 3) 172 | args.student_decoder_attention_heads = getattr( 173 | args, "student_decoder_attention_heads", 8 174 | ) 175 | args.student_decoder_lstm_units = getattr(args, "student_decoder_lstm_units", 128) 176 | args.student_decoder_out_embed_dim = getattr( 177 | args, "student_decoder_out_embed_dim", 128 178 | ) 179 | args.student_decoder_reduced_attention_dim = getattr( 180 | args, "student_decoder_reduced_attention_dim", None 181 | ) 182 | -------------------------------------------------------------------------------- /pytorch_translate/research/knowledge_distillation/hybrid_dual_decoder_kd_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from fairseq.models import ( 4 | FairseqEncoderDecoderModel, 5 | register_model, 6 | register_model_architecture, 7 | ) 8 | from pytorch_translate import ( 9 | hybrid_transformer_rnn, 10 | transformer as pytorch_translate_transformer, 11 | ) 12 | from pytorch_translate.utils import torch_find 13 | 14 | 15 | @register_model("hybrid_dual_decoder_kd") 16 | class HybridDualDecoderKDModel(FairseqEncoderDecoderModel): 17 | def __init__(self, task, encoder, teacher_decoder, student_decoder): 18 | super().__init__(encoder, student_decoder) 19 | self.teacher_decoder = teacher_decoder 20 | self.student_decoder = student_decoder 21 | self.using_teacher = True 22 | self.task = task 23 | 24 | def get_teacher_model(self): 25 | return hybrid_transformer_rnn.HybridTransformerRNNModel( 26 | self.task, self.encoder, self.teacher_decoder 27 | ) 28 | 29 | def get_student_model(self): 30 | return hybrid_transformer_rnn.HybridTransformerRNNModel( 31 | self.task, self.encoder, self.student_decoder 32 | ) 33 | 34 | @staticmethod 35 | def add_args(parser): 36 | """Add model-specific arguments to the parser.""" 37 | 38 | # command-line args for hybrid_transformer_rnn model are used to build 39 | # encoder and teacher decoder 40 | hybrid_transformer_rnn.HybridTransformerRNNModel.add_args(parser) 41 | 42 | # distinct args for student decoder 43 | parser.add_argument( 44 | "--student-decoder-embed-dim", 45 | type=int, 46 | metavar="N", 47 | help="[student RNN] decoder embedding dimension", 48 | ) 49 | parser.add_argument( 50 | "--student-decoder-layers", 51 | type=int, 52 | metavar="N", 53 | help="[student RNN] num decoder layers", 54 | ) 55 | parser.add_argument( 56 | "--student-decoder-attention-heads", 57 | type=int, 58 | metavar="N", 59 | help="[student RNN] num decoder attention heads", 60 | ) 61 | parser.add_argument( 62 | "--student-decoder-lstm-units", 63 | type=int, 64 | metavar="N", 65 | help="[student RNN] num LSTM units for each decoder layer", 66 | ) 67 | parser.add_argument( 68 | "--student-decoder-out-embed-dim", 69 | type=int, 70 | metavar="N", 71 | help="[student RNN] decoder output embedding dimension", 72 | ) 73 | parser.add_argument( 74 | "--student-decoder-reduced-attention-dim", 75 | type=int, 76 | default=None, 77 | metavar="N", 78 | help="if specified, computes attention with this dimensionality " 79 | "in the student decoder (instead of using encoder output dims)", 80 | ) 81 | 82 | @classmethod 83 | def build_model(cls, args, task): 84 | """Build a new model instance.""" 85 | # make sure that all args are properly defaulted 86 | # (in case there are any new ones) 87 | base_architecture(args) 88 | 89 | src_dict, tgt_dict = task.source_dictionary, task.target_dictionary 90 | 91 | encoder_embed_tokens = pytorch_translate_transformer.build_embedding( 92 | dictionary=src_dict, 93 | embed_dim=args.encoder_embed_dim, 94 | path=args.encoder_pretrained_embed, 95 | freeze=args.encoder_freeze_embed, 96 | ) 97 | 98 | teacher_decoder_embed_tokens = pytorch_translate_transformer.build_embedding( 99 | dictionary=tgt_dict, embed_dim=args.decoder_embed_dim 100 | ) 101 | 102 | student_decoder_embed_tokens = pytorch_translate_transformer.build_embedding( 103 | dictionary=tgt_dict, embed_dim=args.student_decoder_embed_dim 104 | ) 105 | 106 | encoder = pytorch_translate_transformer.TransformerEncoder( 107 | args, src_dict, encoder_embed_tokens, proj_to_decoder=False 108 | ) 109 | 110 | teacher_decoder = hybrid_transformer_rnn.HybridRNNDecoder( 111 | args, src_dict, tgt_dict, teacher_decoder_embed_tokens 112 | ) 113 | 114 | student_decoder = StudentHybridRNNDecoder( 115 | args, src_dict, tgt_dict, student_decoder_embed_tokens 116 | ) 117 | 118 | return HybridDualDecoderKDModel( 119 | task=task, 120 | encoder=encoder, 121 | teacher_decoder=teacher_decoder, 122 | student_decoder=student_decoder, 123 | ) 124 | 125 | def get_targets(self, sample, net_output): 126 | targets = sample["target"].view(-1) 127 | possible_translation_tokens = net_output[-1] 128 | if possible_translation_tokens is not None: 129 | targets = torch_find( 130 | possible_translation_tokens, targets, len(self.task.target_dictionary) 131 | ) 132 | return targets 133 | 134 | 135 | class StudentHybridRNNDecoder(hybrid_transformer_rnn.HybridRNNDecoder): 136 | """ 137 | Subclass which constructs RNN decoder from student arguments. 138 | (dropout, attention_dropout, and vocab reduction params shared with teacher.) 139 | """ 140 | 141 | def _init_dims(self, args, src_dict, dst_dict, embed_tokens): 142 | self.dropout = args.dropout 143 | 144 | embed_dim = embed_tokens.embedding_dim 145 | self.embed_tokens = embed_tokens 146 | 147 | self.lstm_units = args.student_decoder_lstm_units 148 | self.num_layers = args.student_decoder_layers 149 | self.initial_input_dim = embed_dim 150 | 151 | self.encoder_output_dim = args.encoder_embed_dim 152 | if args.student_decoder_reduced_attention_dim is None: 153 | self.attention_dim = self.encoder_output_dim 154 | else: 155 | self.attention_dim = args.student_decoder_reduced_attention_dim 156 | self.input_dim = self.lstm_units + self.attention_dim 157 | 158 | self.num_attention_heads = args.student_decoder_attention_heads 159 | self.bottleneck_dim = args.student_decoder_out_embed_dim 160 | 161 | 162 | @register_model_architecture("hybrid_dual_decoder_kd", "hybrid_dual_decoder_kd") 163 | def base_architecture(args): 164 | hybrid_transformer_rnn.base_architecture(args) 165 | args.student_decoder_embed_dim = getattr(args, "student_decoder_embed_dim", 128) 166 | args.student_decoder_layers = getattr(args, "student_decoder_layers", 3) 167 | args.student_decoder_attention_heads = getattr( 168 | args, "student_decoder_attention_heads", 8 169 | ) 170 | args.student_decoder_lstm_units = getattr(args, "student_decoder_lstm_units", 128) 171 | args.student_decoder_out_embed_dim = getattr( 172 | args, "student_decoder_out_embed_dim", 128 173 | ) 174 | -------------------------------------------------------------------------------- /pytorch_translate/research/knowledge_distillation/knowledge_distillation_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from fairseq import utils 8 | from fairseq.criterions import LegacyFairseqCriterion, register_criterion 9 | from pytorch_translate import utils as pytorch_translate_utils 10 | 11 | 12 | @register_criterion("word_knowledge_distillation") 13 | class KnowledgeDistillationCriterion(LegacyFairseqCriterion): 14 | def __init__(self, args, task): 15 | """ 16 | This code is for word-level knowledge distillation. Most of the algorithm 17 | is inspired from the Kim and Rush (2016) paper: 18 | http://www.aclweb.org/anthology/D16-1139 19 | """ 20 | super().__init__(args, task) 21 | self.kd_weight = getattr(args, "kd_weight", 0) 22 | if self.kd_weight < 0 or self.kd_weight > 1: 23 | raise ValueError(f"--kd-weight ({self.kd_weight}) must be in [0, 1]") 24 | 25 | @staticmethod 26 | def add_args(parser): 27 | """Add criterion-specific arguments to the parser.""" 28 | parser.add_argument( 29 | "--kd-weight", 30 | type=float, 31 | default=0.0, 32 | help=( 33 | "mixture weight between the knowledge distillation and", 34 | "negative log likelihood losses. Must be in [0.0, 1.0]", 35 | ), 36 | ) 37 | 38 | def get_kd_loss(self, sample, student_lprobs, lprobs): 39 | """ 40 | The second return argument is used for unit testing. 41 | 42 | Args: 43 | * sample: batched sample that has teacher score keys (top_k_scores and 44 | top_k_indices) 45 | * student_lprobs: tensor of student log probabilities 46 | * lprobs: flat version of student_lprobs 47 | """ 48 | top_k_teacher_probs_normalized = sample["top_k_scores"] 49 | indices = sample["top_k_indices"] 50 | 51 | assert indices.shape[0:1] == student_lprobs.shape[0:1] 52 | 53 | kd_loss = -( 54 | torch.sum( 55 | torch.gather(student_lprobs, 2, indices) 56 | * top_k_teacher_probs_normalized.float() 57 | ) 58 | ) 59 | return kd_loss 60 | 61 | def forward(self, model, sample, reduce=True): 62 | """Compute the loss for the given sample. 63 | 64 | Returns a tuple with three elements: 65 | 1) the loss, as a Variable 66 | 2) the sample size, which is used as the denominator for the gradient 67 | 3) logging outputs to display while training 68 | """ 69 | 70 | # 1. Generate translation using student model 71 | net_output = model(**sample["net_input"]) 72 | student_lprobs = model.get_normalized_probs(net_output, log_probs=True) 73 | # [bsz, seqlen, vocab] -> [bsz*seqlen, vocab] 74 | lprobs = student_lprobs.view(-1, student_lprobs.size(-1)) 75 | 76 | # 2. Get translation from teacher models and calulate KD loss. 77 | kd_loss = None 78 | if "top_k_scores" in sample: 79 | # top_k_scores is not present in the validation data. 80 | kd_loss = self.get_kd_loss(sample, student_lprobs, lprobs) 81 | 82 | # 3. Compute NLL loss with respect to the ground truth 83 | target = model.get_targets(sample, net_output).view(-1) 84 | nll_loss = F.nll_loss( 85 | lprobs, 86 | target, 87 | size_average=False, 88 | ignore_index=self.padding_idx, 89 | reduce=reduce, 90 | ) 91 | 92 | # 4. Linearly interpolate between NLL and KD loss 93 | if kd_loss is not None: 94 | loss = kd_loss * self.kd_weight + nll_loss * (1 - self.kd_weight) 95 | else: 96 | loss = nll_loss 97 | 98 | if self.args.sentence_avg: 99 | sample_size = sample["target"].size(0) 100 | else: 101 | sample_size = sample["ntokens"] 102 | if self.args.sentence_avg: 103 | sample_size = sample["target"].size(0) 104 | else: 105 | sample_size = sample["ntokens"] 106 | logging_output = { 107 | "loss": utils.item(loss.data) if reduce else loss.data, 108 | "ntokens": sample["ntokens"], 109 | "nsamples": sample["target"].size(0), 110 | "sample_size": sample_size, 111 | } 112 | return loss, sample_size, logging_output 113 | 114 | @staticmethod 115 | def aggregate_logging_outputs(logging_outputs): 116 | """Aggregate logging outputs from data parallel training.""" 117 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 118 | ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) 119 | nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) 120 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 121 | agg_output = { 122 | "loss": loss_sum / sample_size / math.log(2), 123 | "ntokens": ntokens, 124 | "nsentences": nsentences, 125 | "sample_size": sample_size, 126 | } 127 | if sample_size != ntokens: 128 | agg_output["nll_loss"] = loss_sum / ntokens / math.log(2) 129 | return agg_output 130 | -------------------------------------------------------------------------------- /pytorch_translate/research/lexical_choice/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/translate/b89dc35abeb7fe516e3b95ccacdedfc1a92e5626/pytorch_translate/research/lexical_choice/__init__.py -------------------------------------------------------------------------------- /pytorch_translate/research/lexical_choice/lexical_translation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | 5 | 6 | def attention_weighted_src_embedding( 7 | src_embedding, attn_scores, activation_fn=torch.tanh 8 | ): 9 | """ 10 | use the attention weights to form a weighted average of embeddings 11 | parameters: 12 | src_embedding: srclen x bsz x embeddim 13 | attn_scores: bsz x tgtlen x srclen 14 | return: 15 | lex: bsz x tgtlen x embeddim 16 | """ 17 | # lexical choice varying lengths: T x B x C -> B x T x C 18 | src_embedding = src_embedding.transpose(0, 1) 19 | 20 | lex = torch.bmm(attn_scores, src_embedding) 21 | lex = activation_fn(lex) 22 | return lex 23 | 24 | 25 | def lex_logits(lex_h, output_projection_w_lex, output_projection_b_lex, logits_shape): 26 | """ 27 | calculate the logits of lexical layer output 28 | """ 29 | projection_lex_flat = torch.matmul(output_projection_w_lex, lex_h.t()).t() 30 | 31 | logits = ( 32 | torch.onnx.operators.reshape_from_tensor_shape( 33 | projection_lex_flat, logits_shape 34 | ) 35 | + output_projection_b_lex 36 | ) 37 | return logits 38 | -------------------------------------------------------------------------------- /pytorch_translate/research/multisource/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/translate/b89dc35abeb7fe516e3b95ccacdedfc1a92e5626/pytorch_translate/research/multisource/__init__.py -------------------------------------------------------------------------------- /pytorch_translate/research/multisource/multisource_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import torch 5 | from fairseq import data 6 | 7 | 8 | class MultisourceLanguagePairDataset(data.LanguagePairDataset): 9 | """A language pair dataset with multiple source sentences for 10 | each target sentence.""" 11 | 12 | def __getitem__(self, i): 13 | source = [src_sent.long() for src_sent in self.src[i]] 14 | res = {"id": i, "source": source} 15 | if self.tgt: 16 | res["target"] = self.tgt[i].long() 17 | 18 | return res 19 | 20 | def collater(self, samples): 21 | return MultisourceLanguagePairDataset.collate( 22 | samples, 23 | self.src_dict.pad(), 24 | self.src_dict.eos(), 25 | self.tgt is not None, 26 | self.left_pad_source, 27 | self.left_pad_target, 28 | ) 29 | 30 | @staticmethod 31 | def collate( 32 | samples, 33 | pad_idx, 34 | eos_idx, 35 | has_target=True, 36 | left_pad_source=True, 37 | left_pad_target=False, 38 | ): 39 | if len(samples) == 0: 40 | return {} 41 | 42 | n_sources = len(samples[0]["source"]) 43 | assert all( 44 | len(sample["source"]) == n_sources for sample in samples 45 | ), "All samples in a batch must have the same number of source sentences." 46 | 47 | def merge(key, left_pad, source=False, move_eos_to_beginning=False): 48 | if source: 49 | # Collate source sentences all source sentences together. Each 50 | return data.data_utils.collate_tokens( 51 | [s[key][src_id] for s in samples for src_id in range(n_sources)], 52 | pad_idx, 53 | eos_idx, 54 | left_pad, 55 | move_eos_to_beginning, 56 | ) 57 | else: 58 | return data.data_utils.collate_tokens( 59 | [s[key] for s in samples], 60 | pad_idx, 61 | eos_idx, 62 | left_pad, 63 | move_eos_to_beginning, 64 | ) 65 | 66 | id = torch.LongTensor([s["id"] for s in samples]) 67 | src_tokens = merge("source", left_pad=left_pad_source, source=True) 68 | # We sort all source sentences from each batch element by length 69 | src_lengths = torch.LongTensor( 70 | [ 71 | s["source"][src_id].numel() 72 | for s in samples 73 | for src_id in range(n_sources) 74 | ] 75 | ) 76 | src_lengths, sort_order = src_lengths.sort(descending=True) 77 | # id = id.index_select(0, sort_order) 78 | src_tokens = src_tokens.index_select(0, sort_order) 79 | # Record which sentence corresponds to which source and sample 80 | _, rev_order = sort_order.sort() 81 | # srcs_ids[k] contains the indices of kth source sentences of each 82 | # sample in src_tokens 83 | srcs_ids = [rev_order[k::n_sources] for k in range(n_sources)] 84 | 85 | prev_output_tokens = None 86 | target = None 87 | ntokens = None 88 | if has_target: 89 | target = merge("target", left_pad=left_pad_target) 90 | # we create a shifted version of targets for feeding the 91 | # previous output token(s) into the next decoder step 92 | prev_output_tokens = merge( 93 | "target", left_pad=left_pad_target, move_eos_to_beginning=True 94 | ) 95 | ntokens = sum(len(s["target"]) for s in samples) 96 | 97 | return { 98 | "id": id, 99 | "ntokens": ntokens, 100 | "net_input": { 101 | "src_tokens": src_tokens, 102 | "src_lengths": src_lengths, 103 | "src_ids": srcs_ids, 104 | "prev_output_tokens": prev_output_tokens, 105 | }, 106 | "target": target, 107 | } 108 | 109 | 110 | class IndexedRawTextMultisentDataset(data.IndexedRawTextDataset): 111 | """Takes a list of text file as input and binarizes them in memory at 112 | instantiation. Original lines are also kept in memory""" 113 | 114 | def read_data(self, paths, dictionary): 115 | for path in paths: 116 | with open(path, "r") as f: 117 | file_lines = [] 118 | file_tokens_list = [] 119 | file_sizes = [] 120 | for line in f: 121 | file_lines.append(line.strip("\n")) 122 | tokens = dictionary.encode_line( 123 | line, 124 | add_if_not_exist=False, 125 | append_eos=self.append_eos, 126 | reverse_order=self.reverse_order, 127 | ) 128 | file_tokens_list.append(tokens) 129 | file_sizes.append(len(tokens)) 130 | self.lines.append(file_lines) 131 | self.tokens_list.append(file_tokens_list) 132 | self.sizes.append(file_sizes) 133 | # Zip all sentences for each sample together 134 | self.lines = list(zip(*self.lines)) 135 | self.tokens_list = list(zip(*self.tokens_list)) 136 | # Sum sentence sizes for each sample 137 | self.sizes = np.asarray(self.sizes).sum(axis=0) 138 | -------------------------------------------------------------------------------- /pytorch_translate/research/rescore/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/translate/b89dc35abeb7fe516e3b95ccacdedfc1a92e5626/pytorch_translate/research/rescore/__init__.py -------------------------------------------------------------------------------- /pytorch_translate/research/rescore/cloze_transformer_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | from fairseq import utils 5 | from fairseq.models import register_model, register_model_architecture 6 | from pytorch_translate.transformer import ( 7 | base_architecture, 8 | build_embedding, 9 | TransformerDecoder, 10 | TransformerModel, 11 | ) 12 | 13 | 14 | @register_model("cloze_transformer") 15 | class ClozeTransformerModel(TransformerModel): 16 | @classmethod 17 | def build_model(cls, args, task): 18 | """Build a new model instance.""" 19 | # make sure that all args are properly defaulted 20 | # (in case there are any new ones) 21 | cloze_transformer_architecture(args) 22 | 23 | src_dict, tgt_dict = task.source_dictionary, task.target_dictionary 24 | 25 | if args.share_all_embeddings: 26 | if src_dict != tgt_dict: 27 | raise RuntimeError( 28 | "--share-all-embeddings requires a joined dictionary" 29 | ) 30 | if args.encoder_embed_dim != args.decoder_embed_dim: 31 | raise RuntimeError( 32 | "--share-all-embeddings requires --encoder-embed-dim " 33 | "to match --decoder-embed-dim" 34 | ) 35 | if args.decoder_pretrained_embed and ( 36 | args.decoder_pretrained_embed != args.encoder_pretrained_embed 37 | ): 38 | raise RuntimeError( 39 | "--share-all-embeddings not compatible with " 40 | "--decoder-pretrained-embed" 41 | ) 42 | encoder_embed_tokens = build_embedding( 43 | dictionary=src_dict, 44 | embed_dim=args.encoder_embed_dim, 45 | path=args.encoder_pretrained_embed, 46 | freeze=args.encoder_freeze_embed, 47 | ) 48 | decoder_embed_tokens = encoder_embed_tokens 49 | args.share_decoder_input_output_embed = True 50 | else: 51 | encoder_embed_tokens = build_embedding( 52 | dictionary=src_dict, 53 | embed_dim=args.encoder_embed_dim, 54 | path=args.encoder_pretrained_embed, 55 | freeze=args.encoder_freeze_embed, 56 | ) 57 | decoder_embed_tokens = build_embedding( 58 | dictionary=tgt_dict, 59 | embed_dim=args.decoder_embed_dim, 60 | path=args.decoder_pretrained_embed, 61 | freeze=args.decoder_freeze_embed, 62 | ) 63 | 64 | encoder = ClozeTransformerModel.build_encoder( 65 | args, src_dict, embed_tokens=encoder_embed_tokens 66 | ) 67 | decoder = ClozeTransformerModel.build_decoder( 68 | args, src_dict, tgt_dict, embed_tokens=decoder_embed_tokens 69 | ) 70 | return ClozeTransformerModel(task, encoder, decoder) 71 | 72 | @classmethod 73 | def build_decoder(cls, args, src_dict, dst_dict, embed_tokens): 74 | return ClozeTransformerDecoder( 75 | args, src_dict, dst_dict, embed_tokens=embed_tokens 76 | ) 77 | 78 | 79 | class ClozeTransformerDecoder(TransformerDecoder): 80 | """Cloze-Transformer decoder.""" 81 | 82 | def __init__(self, args, src_dict, dst_dict, embed_tokens, left_pad=False): 83 | super().__init__(args, src_dict, dst_dict, embed_tokens) 84 | assert args.decoder_layers == 1 85 | 86 | def buffered_future_mask(self, tensor): 87 | """attend all surounding words except itself 88 | [[0, -inf, 0] 89 | [0, 0, -inf] 90 | [0, 0, 0]] 91 | The attention map is not ture diagonal since we predict y_{t+1} at time-step t 92 | """ 93 | dim = tensor.size(0) 94 | if ( 95 | not hasattr(self, "_future_mask") 96 | or self._future_mask is None 97 | or self._future_mask.device != tensor.device 98 | ): 99 | self._future_mask = torch.triu( 100 | utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 101 | ) 102 | self._future_mask = torch.tril(self._future_mask, 1) 103 | if self._future_mask.size(0) < dim: 104 | self._future_mask = torch.triu( 105 | utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 106 | ) 107 | self._future_mask = torch.tril(self._future_mask, 1) 108 | return self._future_mask[:dim, :dim] 109 | 110 | 111 | @register_model_architecture("cloze_transformer", "cloze_transformer") 112 | def cloze_transformer_architecture(args): 113 | base_architecture(args) 114 | -------------------------------------------------------------------------------- /pytorch_translate/research/tune_ensemble_weights/tune_model_weights.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import itertools 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from fairseq import options 8 | from pytorch_translate import generate 9 | from pytorch_translate.constants import CHECKPOINT_PATHS_DELIMITER 10 | 11 | 12 | def add_tune_args(parser): 13 | group = parser.add_argument_group("Tune parameter parser.") 14 | group.add_argument( 15 | "--n-grid", 16 | default=6, 17 | type=int, 18 | metavar="N", 19 | help="how many grid added to tune for each weight.", 20 | ) 21 | group.add_argument( 22 | "--weight-lower-bound", 23 | default=0.0, 24 | type=float, 25 | help="lower bound for each weight.", 26 | ) 27 | group.add_argument( 28 | "--weight-upper-bound", 29 | default=1.0, 30 | type=float, 31 | help="upper bound for each weight.", 32 | ) 33 | group.add_argument( 34 | "--output-file-name", 35 | default="output.csv", 36 | type=str, 37 | help="name of output file.", 38 | ) 39 | return parser 40 | 41 | 42 | def tune_model_weights(): 43 | parser = generate.get_parser_with_args() 44 | parser = add_tune_args(parser) 45 | args = options.parse_args_and_arch(parser) 46 | print(args.model_weights) 47 | n_models = len(args.path.split(CHECKPOINT_PATHS_DELIMITER)) 48 | print(n_models) 49 | 50 | weight_grid = np.linspace( 51 | args.weight_lower_bound, args.weight_upper_bound, args.n_grid + 1 52 | ) 53 | weight_vec_aux = list(itertools.product(weight_grid, weight_grid)) 54 | weight_vec = [] 55 | for w1, w2 in weight_vec_aux: 56 | weight_sum = w1 + w2 57 | if weight_sum <= 1: 58 | w3 = 1 - weight_sum 59 | weight_vec.append(str(w1) + "," + str(w2) + "," + str(w3)) 60 | 61 | print(len(weight_vec)) 62 | output = pd.DataFrame() 63 | for weight in weight_vec: 64 | args.model_weights = weight 65 | print(args.model_weights) 66 | generate.validate_args(args) 67 | score = generate.generate(args) 68 | print(score) 69 | output = output.append( 70 | {"weight": args.model_weights, "bleu_score": score}, ignore_index=True 71 | ) 72 | output.to_csv(args.output_file_name) 73 | return output 74 | 75 | 76 | if __name__ == "__main__": 77 | tune_model_weights() 78 | -------------------------------------------------------------------------------- /pytorch_translate/research/tune_ensemble_weights/tune_model_weights_with_ax.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import json 3 | 4 | from ax.service.managed_loop import optimize 5 | from fairseq import options 6 | from pytorch_translate import generate 7 | from pytorch_translate.constants import CHECKPOINT_PATHS_DELIMITER 8 | 9 | 10 | def add_tune_args(parser): 11 | group = parser.add_argument_group("Tune parameter parser.") 12 | group.add_argument( 13 | "--n-grid", 14 | default=6, 15 | type=int, 16 | metavar="N", 17 | help="how many grid added to tune for each weight.", 18 | ) 19 | group.add_argument( 20 | "--weight-lower-bound", 21 | default=0.0, 22 | type=float, 23 | help="lower bound for each weight.", 24 | ) 25 | group.add_argument( 26 | "--weight-upper-bound", 27 | default=1.0, 28 | type=float, 29 | help="upper bound for each weight.", 30 | ) 31 | group.add_argument( 32 | "--num-trails-ax-opt", 33 | default=5, 34 | type=int, 35 | help="number of trials in AX optimization.", 36 | ) 37 | group.add_argument( 38 | "--output-json-best-parameters", 39 | default="best_parameters.json", 40 | type=str, 41 | help="name of output file for the best parameters.", 42 | ) 43 | group.add_argument( 44 | "--output-json-best-value", 45 | default="best_value.json", 46 | type=str, 47 | help="name of output file for the best value of the evaluation function.", 48 | ) 49 | return parser 50 | 51 | 52 | def tune_model_weights(): 53 | parser = generate.get_parser_with_args() 54 | parser = add_tune_args(parser) 55 | args = options.parse_args_and_arch(parser) 56 | n_models = len(args.path.split(CHECKPOINT_PATHS_DELIMITER)) 57 | print(n_models) 58 | print(args.weight_lower_bound) 59 | print(args.weight_upper_bound) 60 | print(args.output_json_best_parameters) 61 | print(args.output_json_best_value) 62 | print(args.num_trails_ax_opt) 63 | 64 | def evaluation_function(parameterization): 65 | w1 = parameterization.get("w1") 66 | w2 = parameterization.get("w2") 67 | w3 = parameterization.get("w3") 68 | weight = str(w1) + "," + str(w2) + "," + str(w3) 69 | args.model_weights = weight 70 | generate.validate_args(args) 71 | score = generate.generate(args) 72 | return {"bleu_score": (score, 0.0)} 73 | 74 | lower_bound = args.weight_lower_bound 75 | upper_bound = args.weight_upper_bound 76 | best_parameters, values, experiment, model = optimize( 77 | parameters=[ 78 | { 79 | "name": "w1", 80 | "type": "range", 81 | "bounds": [lower_bound, upper_bound], 82 | "value_type": "float", 83 | }, 84 | {"name": "w2", "type": "range", "bounds": [lower_bound, upper_bound]}, 85 | {"name": "w3", "type": "range", "bounds": [lower_bound, upper_bound]}, 86 | ], 87 | experiment_name="tune_model_weights", 88 | objective_name="bleu_score", 89 | evaluation_function=evaluation_function, 90 | minimize=True, # Optional, defaults to False. 91 | parameter_constraints=[ 92 | "w1 + w2 + w3 <= 1", 93 | "w1 + w2 + w3 >= 0.99", 94 | ], # Optional. 95 | total_trials=args.num_trails_ax_opt, # Optional. 96 | ) 97 | 98 | json_file = json.dumps(best_parameters) 99 | with open(args.output_json_best_parameters, "w") as f: 100 | f.write(json_file) 101 | f.close() 102 | 103 | json_file = json.dumps(values) 104 | with open(args.output_json_best_value, "w") as f: 105 | f.write(json_file) 106 | f.close() 107 | return best_parameters, values 108 | 109 | 110 | if __name__ == "__main__": 111 | tune_model_weights() 112 | -------------------------------------------------------------------------------- /pytorch_translate/rnn_cell.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def LSTMCell(input_dim, hidden_dim, **kwargs): 11 | m = nn.LSTMCell(input_dim, hidden_dim, **kwargs) 12 | for name, param in m.named_parameters(): 13 | if "weight" in name or "bias" in name: 14 | param.data.uniform_(-0.1, 0.1) 15 | return m 16 | 17 | 18 | class MILSTMCellBackend(nn.RNNCell): 19 | def __init__(self, input_size, hidden_size, bias=True): 20 | super(MILSTMCellBackend, self).__init__(input_size, hidden_size, bias=False) 21 | self.input_size = input_size 22 | self.hidden_size = hidden_size 23 | self.bias = bias 24 | self.weight_ih = nn.Parameter(torch.Tensor(4 * hidden_size, input_size)) 25 | self.weight_hh = nn.Parameter(torch.Tensor(4 * hidden_size, hidden_size)) 26 | if bias: 27 | self.bias = nn.Parameter(torch.Tensor(4 * hidden_size)) 28 | else: 29 | self.register_parameter("bias", None) 30 | self.alpha = nn.Parameter(torch.Tensor(4 * hidden_size)) 31 | self.beta_h = nn.Parameter(torch.Tensor(4 * hidden_size)) 32 | self.beta_i = nn.Parameter(torch.Tensor(4 * hidden_size)) 33 | self.reset_parameters() 34 | 35 | def reset_parameters(self): 36 | stdv = 1.0 / math.sqrt(self.hidden_size) 37 | for weight in self.parameters(): 38 | weight.data.uniform_(-stdv, stdv) 39 | 40 | def forward(self, x, hidden): 41 | # get prev_t, cell_t from states 42 | hx, cx = hidden 43 | Wx = F.linear(x, self.weight_ih) 44 | Uz = F.linear(hx, self.weight_hh) 45 | 46 | # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf 47 | gates = self.alpha * Wx * Uz + self.beta_i * Wx + self.beta_h * Uz + self.bias 48 | 49 | # Same as LSTMCell after this point 50 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 51 | 52 | ingate = F.sigmoid(ingate) 53 | forgetgate = F.sigmoid(forgetgate) 54 | cellgate = F.tanh(cellgate) 55 | outgate = F.sigmoid(outgate) 56 | 57 | cy = (forgetgate * cx) + (ingate * cellgate) 58 | hy = outgate * F.tanh(cy) 59 | 60 | return hy, cy 61 | 62 | 63 | def MILSTMCell(input_dim, hidden_dim, **kwargs): 64 | m = MILSTMCellBackend(input_dim, hidden_dim, **kwargs) 65 | for name, param in m.named_parameters(): 66 | if "weight" in name or "bias" in name: 67 | param.data.uniform_(-0.1, 0.1) 68 | return m 69 | 70 | 71 | class LayerNormLSTMCellBackend(nn.LSTMCell): 72 | def __init__(self, input_dim, hidden_dim, bias=True, epsilon=0.00001): 73 | super(LayerNormLSTMCellBackend, self).__init__(input_dim, hidden_dim, bias) 74 | self.epsilon = epsilon 75 | 76 | def _layerNormalization(self, x): 77 | mean = x.mean(1, keepdim=True).expand_as(x) 78 | std = x.std(1, keepdim=True).expand_as(x) 79 | return (x - mean) / (std + self.epsilon) 80 | 81 | def forward(self, x, hidden): 82 | hx, cx = hidden 83 | gates = F.linear(x, self.weight_ih, self.bias_ih) + F.linear( 84 | hx, self.weight_hh, self.bias_hh 85 | ) 86 | 87 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 88 | 89 | ingate = F.sigmoid(self._layerNormalization(ingate)) 90 | forgetgate = F.sigmoid(self._layerNormalization(forgetgate)) 91 | cellgate = F.tanh(self._layerNormalization(cellgate)) 92 | outgate = F.sigmoid(self._layerNormalization(outgate)) 93 | 94 | cy = (forgetgate * cx) + (ingate * cellgate) 95 | 96 | hy = outgate * F.tanh(cy) 97 | 98 | return hy, cy 99 | 100 | 101 | def LayerNormLSTMCell(input_dim, hidden_dim, **kwargs): 102 | m = LayerNormLSTMCellBackend(input_dim, hidden_dim, **kwargs) 103 | for name, param in m.named_parameters(): 104 | if "weight" in name or "bias" in name: 105 | param.data.uniform_(-0.1, 0.1) 106 | return m 107 | -------------------------------------------------------------------------------- /pytorch_translate/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import importlib 4 | import os 5 | 6 | 7 | # automatically import any Python files in the tasks/ directory 8 | for file in sorted(os.listdir(os.path.dirname(__file__))): 9 | if file.endswith(".py") and not file.startswith("_"): 10 | task_name = file[: file.find(".py")] 11 | importlib.import_module("pytorch_translate.tasks." + task_name) 12 | -------------------------------------------------------------------------------- /pytorch_translate/tasks/cross_lingual_lm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | import os 10 | 11 | from fairseq import tokenizer 12 | from fairseq.tasks import register_task 13 | from fairseq.tasks.cross_lingual_lm import CrossLingualLMTask 14 | from pytorch_translate.data.masked_lm_dictionary import MaskedLMDictionary 15 | 16 | 17 | @register_task("pytorch_translate_cross_lingual_lm") 18 | class PytorchTranslateCrossLingualLMTask(CrossLingualLMTask): 19 | """ 20 | Task for training cross-lingual language models. 21 | For more details look at: https://arxiv.org/pdf/1901.07291.pdf 22 | Args: 23 | dictionary (MaskedLMDictionary): the dictionary for the input of the task 24 | """ 25 | 26 | @staticmethod 27 | def add_args(parser): 28 | CrossLingualLMTask.add_args(parser) 29 | """Add task-specific arguments to the parser.""" 30 | parser.add_argument( 31 | "-s", "--source-lang", default=None, metavar="SRC", help="source language" 32 | ) 33 | parser.add_argument( 34 | "-t", 35 | "--target-lang", 36 | default=None, 37 | metavar="TARGET", 38 | help="target language", 39 | ) 40 | parser.add_argument( 41 | "--save-only", action="store_true", help="skip eval and only do save" 42 | ) 43 | 44 | @classmethod 45 | def load_dictionary(cls, filename): 46 | return MaskedLMDictionary.load(filename) 47 | 48 | @classmethod 49 | def build_dictionary( 50 | cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 51 | ): 52 | d = MaskedLMDictionary() 53 | for filename in filenames: 54 | MaskedLMDictionary.add_file_to_dictionary( 55 | filename, d, tokenizer.tokenize_line, workers 56 | ) 57 | d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) 58 | return d 59 | 60 | @classmethod 61 | def setup_task(cls, args, **kwargs): 62 | """Setup the task.""" 63 | if getattr(args, "raw_text", False): 64 | args.dataset_impl = "raw" 65 | elif getattr(args, "lazy_load", False): 66 | args.dataset_impl = "lazy" 67 | 68 | dictionary = MaskedLMDictionary.load( 69 | os.path.join(args.data, args.source_vocab_file) 70 | ) 71 | 72 | print("| dictionary: {} types".format(len(dictionary))) 73 | 74 | return cls(args, dictionary) 75 | -------------------------------------------------------------------------------- /pytorch_translate/tasks/knowledge_distillation_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Dict 4 | 5 | import numpy as np 6 | import torch 7 | from fairseq.tasks import register_task 8 | from pytorch_translate import constants, utils as pytorch_translate_utils 9 | from pytorch_translate.data import ( 10 | data as pytorch_translate_data, 11 | utils as data_utils, 12 | weighted_data, 13 | ) 14 | from pytorch_translate.research.knowledge_distillation.teacher_score_data import ( 15 | TeacherDataset, 16 | ) 17 | from pytorch_translate.tasks.pytorch_translate_task import PytorchTranslateTask 18 | 19 | 20 | @register_task(constants.KNOWLEDGE_DISTILLATION_TASK) 21 | class PytorchKnowledgeDistillationTask(PytorchTranslateTask): 22 | def __init__( 23 | self, args, src_dict, tgt_dict, char_source_dict=None, char_target_dict=None 24 | ): 25 | super().__init__( 26 | args, 27 | src_dict=src_dict, 28 | tgt_dict=tgt_dict, 29 | char_source_dict=char_source_dict, 30 | char_target_dict=char_target_dict, 31 | ) 32 | self.top_k_probs_binary_file = args.top_k_probs_binary_file 33 | self.top_k_teacher_tokens = args.top_k_teacher_tokens 34 | 35 | if self.top_k_probs_binary_file is None: 36 | # Load model ensemble from checkpoints 37 | ( 38 | self.teacher_models, 39 | _, 40 | _, 41 | ) = pytorch_translate_utils.load_diverse_ensemble_for_inference( 42 | args.teacher_path.split(":") 43 | ) 44 | if torch.cuda.is_available(): 45 | for teacher_model in self.teacher_models: 46 | teacher_model = pytorch_translate_utils.maybe_cuda(teacher_model) 47 | else: 48 | self.teacher_models = None 49 | 50 | # Memoized scores for teacher models. By having this and gradually memoizing 51 | # the values, we prevent the teacher model from keeping recalculating the 52 | # teacher scores. 53 | self.top_k_teacher_scores: Dict[int, np.ndarray] = {} 54 | self.top_k_teacher_indices: Dict[int, np.ndarray] = {} 55 | 56 | @staticmethod 57 | def add_args(parser): 58 | PytorchTranslateTask.add_args(parser) 59 | 60 | """Add knowledge-distillation arguments to the parser.""" 61 | parser.add_argument( 62 | "--top-k-probs-binary-file", 63 | metavar="PROBSFILE", 64 | type=str, 65 | default=None, 66 | help="path to .npz file containing KD target probabilities for " 67 | "each output token in training data.", 68 | ) 69 | parser.add_argument( 70 | "--teacher-path", 71 | metavar="FILE", 72 | type=str, 73 | default=None, 74 | help="path(s) to teacher model file(s) colon separated", 75 | ) 76 | parser.add_argument( 77 | "--top-k-teacher-tokens", 78 | type=int, 79 | default=8, 80 | help=( 81 | "Incorporating only the top k words from the teacher model.", 82 | "We zero out all other possibilities and normalize the probabilities", 83 | "based on the K top element.", 84 | "If top-k-teacher-tokens=0, it backs up to the original way of", 85 | "enumerating all.", 86 | ), 87 | ) 88 | 89 | def load_dataset( 90 | self, split, src_bin_path, tgt_bin_path, weights_file=None, is_train=False 91 | ): 92 | """ 93 | Currently this method does not support character models. 94 | """ 95 | corpus = pytorch_translate_data.ParallelCorpusConfig( 96 | source=pytorch_translate_data.CorpusConfig( 97 | dialect=self.args.source_lang, data_file=src_bin_path 98 | ), 99 | target=pytorch_translate_data.CorpusConfig( 100 | dialect=self.args.target_lang, data_file=tgt_bin_path 101 | ), 102 | weights_file=weights_file, 103 | ) 104 | 105 | if self.args.log_verbose: 106 | print("Starting to load binarized data files.", flush=True) 107 | data_utils.validate_corpus_exists(corpus=corpus, split=split) 108 | 109 | dst_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file( 110 | corpus.target.data_file 111 | ) 112 | src_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file( 113 | corpus.source.data_file 114 | ) 115 | if is_train: 116 | self.datasets[split] = TeacherDataset( 117 | src=src_dataset, 118 | src_sizes=src_dataset.sizes, 119 | src_dict=self.src_dict, 120 | tgt=dst_dataset, 121 | tgt_sizes=dst_dataset.sizes, 122 | tgt_dict=self.tgt_dict, 123 | top_k_probs_binary_file=self.top_k_probs_binary_file, 124 | teacher_models=self.teacher_models, 125 | top_k_teacher_tokens=self.top_k_teacher_tokens, 126 | top_k_teacher_scores=self.top_k_teacher_scores, 127 | top_k_teacher_indices=self.top_k_teacher_indices, 128 | left_pad_source=False, 129 | ) 130 | else: 131 | self.datasets[split] = weighted_data.WeightedLanguagePairDataset( 132 | src=src_dataset, 133 | src_sizes=src_dataset.sizes, 134 | src_dict=self.src_dict, 135 | tgt=dst_dataset, 136 | tgt_sizes=dst_dataset.sizes, 137 | tgt_dict=self.tgt_dict, 138 | weights=None, 139 | left_pad_source=False, 140 | ) 141 | 142 | if self.args.log_verbose: 143 | print("Finished loading dataset", flush=True) 144 | 145 | print(f"| {split} {len(self.datasets[split])} examples") 146 | -------------------------------------------------------------------------------- /pytorch_translate/tasks/multilingual_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from collections import OrderedDict 4 | 5 | from fairseq import options 6 | from fairseq.data import RoundRobinZipDatasets 7 | from fairseq.tasks import register_task 8 | from pytorch_translate import constants 9 | from pytorch_translate.data import ( 10 | data as pytorch_translate_data, 11 | utils as data_utils, 12 | weighted_data, 13 | ) 14 | from pytorch_translate.tasks import utils as tasks_utils 15 | from pytorch_translate.tasks.pytorch_translate_multi_task import ( 16 | PyTorchTranslateMultiTask, 17 | ) 18 | 19 | 20 | @register_task(constants.MULTILINGUAL_TRANSLATION_TASK) 21 | class PyTorchTranslateMultilingualTranslationTask(PyTorchTranslateMultiTask): 22 | """ 23 | PyTorchTranslateMultilingualTranslationTask is eventually subclasses 24 | fairseq.tasks.MultilingualTranslationTask. The major differences are- 25 | - There is no --data folder containing data binaries and vocabularies, 26 | instead we use paths from --vocabulary, --multilingual-*-text-file and 27 | --multilingual-*-binary-path 28 | - loss_weights is used to weigh losses from different datasets differently. 29 | This is achieved by using pytorch_translate's WeightedLanguagePairDataset 30 | - The dictionaries are instances of pytorch_translate's Dictionary class 31 | """ 32 | 33 | @staticmethod 34 | def add_args(parser): 35 | PyTorchTranslateMultiTask.add_args(parser) 36 | """Add task-specific arguments to the parser.""" 37 | parser.add_argument( 38 | "--vocabulary", 39 | type=str, 40 | metavar="EXPR", 41 | action="append", 42 | help=( 43 | "Per-language vocabulary configuration." 44 | "Path to vocabulary file must be in the format lang:path" 45 | ), 46 | default=[], 47 | ) 48 | parser.add_argument( 49 | "--multilingual-train-text-file", 50 | type=str, 51 | metavar="EXPR", 52 | action="append", 53 | help=( 54 | "Path to train text file in the format " 55 | "src_lang-tgt_lang:source-path,target-path" 56 | ), 57 | ) 58 | parser.add_argument( 59 | "--multilingual-eval-text-file", 60 | type=str, 61 | metavar="EXPR", 62 | action="append", 63 | help=( 64 | "Path to eval text file in the format " 65 | "src_lang-tgt_lang:source-path,target-path" 66 | ), 67 | ) 68 | parser.add_argument( 69 | "--multilingual-train-binary-path", 70 | type=str, 71 | metavar="EXPR", 72 | action="append", 73 | help=( 74 | "Path to train binary file in the format " 75 | "src_lang-tgt_lang:source-path,target-path" 76 | ), 77 | ) 78 | parser.add_argument( 79 | "--multilingual-eval-binary-path", 80 | type=str, 81 | metavar="EXPR", 82 | action="append", 83 | help=( 84 | "Path to eval binary file in the format " 85 | "src_lang-tgt_lang:source-path,target-path" 86 | ), 87 | ) 88 | 89 | def __init__(self, args, dicts, training): 90 | super().__init__(args, dicts, training) 91 | self.loss_weights = [] 92 | 93 | @classmethod 94 | def setup_task(cls, args, **kwargs): 95 | args.left_pad_source = options.eval_bool(args.left_pad_source) 96 | args.left_pad_target = options.eval_bool(args.left_pad_target) 97 | 98 | if args.source_lang is not None or args.target_lang is not None: 99 | if args.lang_pairs is not None: 100 | raise ValueError( 101 | "--source-lang/--target-lang implies generation, which is " 102 | "incompatible with --lang-pairs" 103 | ) 104 | training = False 105 | args.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)] 106 | else: 107 | training = True 108 | args.lang_pairs = args.lang_pairs.split(",") 109 | args.source_lang, args.target_lang = args.lang_pairs[0].split("-") 110 | 111 | dicts = tasks_utils.load_multilingual_vocabulary(args) 112 | 113 | return cls(args, dicts, training) 114 | 115 | def load_dataset(self, split, **kwargs): 116 | """Load a dataset split.""" 117 | 118 | lang_pair_to_datasets = {} 119 | 120 | binary_path_arg = ( 121 | "--multilingual-train-binary-path" 122 | if split == "train" 123 | else "--multilingual-eval-binary-path" 124 | ) 125 | binary_path_value = ( 126 | self.args.multilingual_train_binary_path 127 | if split == "train" 128 | else self.args.multilingual_eval_binary_path 129 | ) 130 | 131 | format_warning = ( 132 | f"{binary_path_arg} has to be in the format " 133 | " src_lang-tgt_lang:src_dataset_path,tgt_dataset_path" 134 | ) 135 | 136 | for path_config in binary_path_value: 137 | # path_config: str 138 | # in the format "src_lang-tgt_lang:src_dataset_path,tgt_dataset_path" 139 | assert ":" in path_config, format_warning 140 | lang_pair, dataset_paths = path_config.split(":") 141 | 142 | assert "-" in lang_pair, format_warning 143 | 144 | assert "," in dataset_paths, format_warning 145 | src_dataset_path, tgt_dataset_path = dataset_paths.split(",") 146 | 147 | lang_pair_to_datasets[lang_pair] = (src_dataset_path, tgt_dataset_path) 148 | 149 | for lang_pair in self.args.lang_pairs: 150 | assert ( 151 | lang_pair in lang_pair_to_datasets 152 | ), "Not all language pairs have dataset binary paths specified!" 153 | 154 | datasets = {} 155 | for lang_pair in self.args.lang_pairs: 156 | src, tgt = lang_pair.split("-") 157 | src_bin_path, tgt_bin_path = lang_pair_to_datasets[lang_pair] 158 | corpus = pytorch_translate_data.ParallelCorpusConfig( 159 | source=pytorch_translate_data.CorpusConfig( 160 | dialect=src, data_file=src_bin_path 161 | ), 162 | target=pytorch_translate_data.CorpusConfig( 163 | dialect=tgt, data_file=tgt_bin_path 164 | ), 165 | ) 166 | if self.args.log_verbose: 167 | print("Starting to load binarized data files.", flush=True) 168 | 169 | data_utils.validate_corpus_exists(corpus=corpus, split=split) 170 | 171 | tgt_dataset = ( 172 | pytorch_translate_data.InMemoryIndexedDataset.create_from_file( 173 | corpus.target.data_file 174 | ) 175 | ) 176 | src_dataset = ( 177 | pytorch_translate_data.InMemoryIndexedDataset.create_from_file( 178 | corpus.source.data_file 179 | ) 180 | ) 181 | datasets[lang_pair] = weighted_data.WeightedLanguagePairDataset( 182 | src=src_dataset, 183 | src_sizes=src_dataset.sizes, 184 | src_dict=self.dicts[src], 185 | tgt=tgt_dataset, 186 | tgt_sizes=tgt_dataset.sizes, 187 | tgt_dict=self.dicts[tgt], 188 | weights=None, 189 | left_pad_source=False, 190 | ) 191 | self.datasets[split] = RoundRobinZipDatasets( 192 | OrderedDict( 193 | [(lang_pair, datasets[lang_pair]) for lang_pair in self.args.lang_pairs] 194 | ), 195 | eval_key=None 196 | if self.training 197 | else f"{self.args.source_lang}-{self.args.target_lang}", 198 | ) 199 | 200 | if self.args.log_verbose: 201 | print("Finished loading dataset", flush=True) 202 | 203 | print(f"| {split} {len(self.datasets[split])} examples") 204 | -------------------------------------------------------------------------------- /pytorch_translate/tasks/pytorch_translate_multi_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from fairseq import models 4 | from fairseq.data import data_utils, FairseqDataset 5 | from fairseq.models import FairseqMultiModel 6 | from fairseq.tasks.multilingual_translation import MultilingualTranslationTask 7 | from pytorch_translate.data import iterators as ptt_iterators 8 | 9 | 10 | class PyTorchTranslateMultiTask(MultilingualTranslationTask): 11 | def build_model(self, args): 12 | model = models.build_model(args, self) 13 | if not isinstance(model, FairseqMultiModel): 14 | raise ValueError( 15 | "PyTorchTranslateMultiTask requires a FairseqMultiModel architecture" 16 | ) 17 | return model 18 | 19 | def get_batch_iterator( 20 | self, 21 | dataset, 22 | max_tokens=None, 23 | max_sentences=None, 24 | max_positions=None, 25 | ignore_invalid_inputs=False, 26 | required_batch_size_multiple=1, 27 | seed=1, 28 | num_shards=1, 29 | shard_id=0, 30 | num_workers=0, 31 | epoch=1, 32 | data_buffer_size=0, 33 | disable_iterator_cache=False, 34 | ): 35 | assert isinstance(dataset, FairseqDataset) 36 | 37 | # get indices ordered by example size 38 | with data_utils.numpy_seed(seed): 39 | indices = dataset.ordered_indices() 40 | 41 | # filter examples that are too large 42 | indices = data_utils.filter_by_size( 43 | indices, dataset, max_positions, raise_exception=(not ignore_invalid_inputs) 44 | ) 45 | 46 | # create mini-batches with given size constraints 47 | batch_sampler = data_utils.batch_by_size( 48 | indices, 49 | num_tokens_fn=dataset.num_tokens, 50 | max_tokens=max_tokens, 51 | max_sentences=max_sentences, 52 | required_batch_size_multiple=required_batch_size_multiple, 53 | ) 54 | 55 | # return a reusable, sharded iterator 56 | return ptt_iterators.WeightedEpochBatchIterator( 57 | dataset=dataset, 58 | collate_fn=dataset.collater, 59 | batch_sampler=batch_sampler, 60 | seed=seed, 61 | num_shards=num_shards, 62 | shard_id=shard_id, 63 | num_workers=num_workers, 64 | weights=self.loss_weights, 65 | ) 66 | 67 | def max_positions(self): 68 | """Return None to allow model to dictate max sentence length allowed""" 69 | return None 70 | -------------------------------------------------------------------------------- /pytorch_translate/tasks/translation_from_pretrained_xlm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | from fairseq import options, tokenizer 10 | from fairseq.tasks import register_task 11 | from pytorch_translate import constants 12 | from pytorch_translate.data.masked_lm_dictionary import MaskedLMDictionary 13 | from pytorch_translate.tasks.pytorch_translate_task import PytorchTranslateTask 14 | 15 | 16 | @register_task("pytorch_translate_translation_from_pretrained_xlm") 17 | class PytorchTranslateTranslationFromPretrainedXLMTask(PytorchTranslateTask): 18 | """ 19 | Same as TranslationTask except use the MaskedLMDictionary class so that 20 | we can load data that was binarized with the MaskedLMDictionary class. 21 | 22 | This task should be used for the entire training pipeline when we want to 23 | train an NMT model from a pretrained XLM checkpoint: binarizing NMT data, 24 | training NMT with the pretrained XLM checkpoint, and subsequent evaluation 25 | of that trained model. 26 | """ 27 | 28 | @staticmethod 29 | def add_args(parser): 30 | PytorchTranslateTask.add_args(parser) 31 | """Add task-specific arguments to the parser.""" 32 | parser.add_argument( 33 | "--save-only", action="store_true", help="skip eval and only do save" 34 | ) 35 | 36 | @classmethod 37 | def load_dictionary(cls, filename): 38 | """Load the masked LM dictionary from the filename 39 | 40 | Args: 41 | filename (str): the filename 42 | """ 43 | return MaskedLMDictionary.load(filename) 44 | 45 | @classmethod 46 | def build_dictionary( 47 | cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 48 | ): 49 | """Build the dictionary 50 | 51 | Args: 52 | filenames (list): list of filenames 53 | workers (int): number of concurrent workers 54 | threshold (int): defines the minimum word count 55 | nwords (int): defines the total number of words in the final dictionary, 56 | including special symbols 57 | padding_factor (int): can be used to pad the dictionary size to be a 58 | multiple of 8, which is important on some hardware (e.g., Nvidia 59 | Tensor Cores). 60 | """ 61 | d = MaskedLMDictionary() 62 | for filename in filenames: 63 | MaskedLMDictionary.add_file_to_dictionary( 64 | filename, d, tokenizer.tokenize_line, workers 65 | ) 66 | d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) 67 | return d 68 | 69 | @classmethod 70 | def setup_task(cls, args, **kwargs): 71 | args.left_pad_source = options.eval_bool(args.left_pad_source) 72 | 73 | # Load dictionaries 74 | source_dict = MaskedLMDictionary.load(args.source_vocab_file) 75 | target_dict = MaskedLMDictionary.load(args.target_vocab_file) 76 | 77 | source_lang = args.source_lang or "src" 78 | target_lang = args.target_lang or "tgt" 79 | 80 | print(f"| [{source_lang}] dictionary: {len(source_dict)} types") 81 | print(f"| [{target_lang}] dictionary: {len(target_dict)} types") 82 | 83 | use_char_source = (args.char_source_vocab_file != "") or ( 84 | getattr(args, "arch", "") in constants.ARCHS_FOR_CHAR_SOURCE 85 | ) 86 | if use_char_source: 87 | char_source_dict = MaskedLMDictionary.load(args.char_source_vocab_file) 88 | # this attribute is used for CharSourceModel construction 89 | args.char_source_dict_size = len(char_source_dict) 90 | else: 91 | char_source_dict = None 92 | 93 | return cls(args, source_dict, target_dict, char_source_dict) 94 | -------------------------------------------------------------------------------- /pytorch_translate/tasks/translation_lev_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | from fairseq import options 11 | from fairseq.tasks import register_task 12 | from fairseq.tasks.translation_lev import TranslationLevenshteinTask 13 | from pytorch_translate.data import dictionary as pytorch_translate_dictionary 14 | from pytorch_translate.tasks.pytorch_translate_task import PytorchTranslateTask 15 | 16 | 17 | @register_task("ptt_translation_lev") 18 | class PytorchTranslationLevenshteinTask(PytorchTranslateTask): 19 | """ 20 | Translation (Sequence Generation) task for Levenshtein Transformer 21 | See `"Levenshtein Transformer" `_. 22 | """ 23 | 24 | def __init__( 25 | self, args, src_dict, tgt_dict, char_source_dict=None, char_target_dict=None 26 | ): 27 | super().__init__(args, src_dict, tgt_dict, char_source_dict, char_target_dict) 28 | self.src_dict = src_dict 29 | self.tgt_dict = tgt_dict 30 | self.char_source_dict = char_source_dict 31 | self.char_target_dict = char_target_dict 32 | self.trans_lev_task = TranslationLevenshteinTask(args, src_dict, tgt_dict) 33 | 34 | @staticmethod 35 | def add_args(parser): 36 | TranslationLevenshteinTask.add_args(parser) 37 | 38 | def inject_noise(self, target_tokens): 39 | return self.trans_lev_task.inject_noise(target_tokens) 40 | 41 | def build_generator(self, models, args): 42 | self.trans_lev_task.build_generator(models, args) 43 | 44 | @classmethod 45 | def setup_task(cls, args, **kwargs): 46 | args.left_pad_source = options.eval_bool(args.left_pad_source) 47 | source_dict = pytorch_translate_dictionary.Dictionary.load( 48 | args.source_vocab_file 49 | ) 50 | target_dict = pytorch_translate_dictionary.Dictionary.load( 51 | args.target_vocab_file 52 | ) 53 | source_lang = args.source_lang or "src" 54 | target_lang = args.target_lang or "tgt" 55 | args.append_bos = True 56 | 57 | print(f"| [{source_lang}] dictionary: {len(source_dict)} types") 58 | print(f"| [{target_lang}] dictionary: {len(target_dict)} types") 59 | 60 | return cls(args, source_dict, target_dict) 61 | 62 | def train_step( 63 | self, sample, model, criterion, optimizer, update_num, ignore_grad=False 64 | ): 65 | return self.trans_lev_task.train_step( 66 | sample, model, criterion, optimizer, update_num, ignore_grad 67 | ) 68 | 69 | def valid_step(self, sample, model, criterion): 70 | return self.trans_lev_task.valid_step(sample, model, criterion) 71 | -------------------------------------------------------------------------------- /pytorch_translate/tasks/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from collections import OrderedDict 4 | 5 | from pytorch_translate.data import dictionary as pytorch_translate_dictionary 6 | 7 | 8 | def load_multilingual_vocabulary(args): 9 | dicts = OrderedDict() 10 | vocabulary_list = getattr(args, "vocabulary", []) 11 | comparison_lang = None 12 | for vocabulary in vocabulary_list: 13 | assert ( 14 | ":" in vocabulary 15 | ), "--vocabulary must be specified in the format lang:path" 16 | lang, path = vocabulary.split(":") 17 | dicts[lang] = pytorch_translate_dictionary.Dictionary.load(path) 18 | if len(dicts) > 1: 19 | assert dicts[lang].pad() == dicts[comparison_lang].pad() 20 | assert dicts[lang].eos() == dicts[comparison_lang].eos() 21 | assert dicts[lang].unk() == dicts[comparison_lang].unk() 22 | else: 23 | comparison_lang = lang 24 | print(f"| [{lang}] dictionary: {len(dicts[lang])} types") 25 | 26 | return dicts 27 | -------------------------------------------------------------------------------- /pytorch_translate/torchscript_export.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | 5 | from pytorch_translate import rnn # noqa 6 | from pytorch_translate.constants import CHECKPOINT_PATHS_DELIMITER 7 | from pytorch_translate.ensemble_export import BeamSearch 8 | 9 | 10 | def get_parser_with_args(): 11 | parser = argparse.ArgumentParser( 12 | description=("Export PyTorch-trained FBTranslate models") 13 | ) 14 | parser.add_argument( 15 | "--path", 16 | "--checkpoint", 17 | metavar="FILE", 18 | help="path(s) to model file(s), colon separated", 19 | ) 20 | parser.add_argument( 21 | "--output-file", 22 | default="", 23 | help="File name to which to save beam search network", 24 | ) 25 | parser.add_argument( 26 | "--output-graph-file", 27 | default="", 28 | help="File name to which to save the beam search graph for debugging", 29 | ) 30 | parser.add_argument( 31 | "--source-vocab-file", 32 | required=True, 33 | help="File encoding PyTorch dictionary for source language", 34 | ) 35 | parser.add_argument( 36 | "--target-vocab-file", 37 | required=True, 38 | help="File encoding PyTorch dictionary for source language", 39 | ) 40 | parser.add_argument( 41 | "--beam-size", 42 | type=int, 43 | default=6, 44 | help="Number of top candidates returned by each decoder step", 45 | ) 46 | parser.add_argument( 47 | "--word-reward", 48 | type=float, 49 | default=0.0, 50 | help="Value to add for each word (besides EOS)", 51 | ) 52 | parser.add_argument( 53 | "--unk-reward", 54 | type=float, 55 | default=0.0, 56 | help="Value to add for each word UNK token", 57 | ) 58 | 59 | return parser 60 | 61 | 62 | def main(): 63 | parser = get_parser_with_args() 64 | args = parser.parse_args() 65 | 66 | if args.output_file == "": 67 | print("No action taken. Need output_file to be specified.") 68 | parser.print_help() 69 | return 70 | 71 | checkpoint_filenames = args.path.split(CHECKPOINT_PATHS_DELIMITER) 72 | 73 | beam_search = BeamSearch.build_from_checkpoints( 74 | checkpoint_filenames=checkpoint_filenames, 75 | src_dict_filename=args.src_dict, 76 | dst_dict_filename=args.dst_dict, 77 | beam_size=args.beam_size, 78 | word_reward=args.word_reward, 79 | unk_reward=args.unk_reward, 80 | ) 81 | beam_search.save_to_pytorch(output_path=args.output_file) 82 | if args.output_graph_file: 83 | with open(args.output_graph_file.path, "w") as f: 84 | f.write(str(beam_search.graph)) 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /pytorch_translate/vocab_constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | MAX_SPECIAL_TOKENS = 100 4 | 5 | # Number of Byte indices is always fixed at 256 (0-255). The additional 5 indices 6 | # correpsond to the special tokens for byte numberization including 7 | # padding, start and end of word, start and end of sentence. These are 8 | # separate from the special tokens in the dict and match up with the indices 9 | # used by pre-trained ELMo. 10 | NUM_BYTE_INDICES = 261 11 | 12 | PAD_ID = 0 13 | GO_ID = 1 14 | EOS_ID = 2 15 | UNK_ID = 3 16 | MASK_ID = 4 17 | -------------------------------------------------------------------------------- /pytorch_translate/weighted_criterions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from fairseq import utils 4 | from fairseq.criterions import LegacyFairseqCriterion, register_criterion 5 | from fairseq.criterions.label_smoothed_cross_entropy import ( 6 | LabelSmoothedCrossEntropyCriterion, 7 | ) 8 | 9 | 10 | @register_criterion("weighted_label_smoothed_cross_entropy") 11 | class WeightedLabelSmoothedCrossEntropyCriterion(LegacyFairseqCriterion): 12 | def __init__(self, args, task): 13 | super().__init__(args, task) 14 | self.eps = args.label_smoothing 15 | 16 | @classmethod 17 | def add_args(cls, parser): 18 | """Add criterion-specific arguments to the parser.""" 19 | parser.add_argument( 20 | "--label-smoothing", 21 | default=0.0, 22 | type=float, 23 | metavar="D", 24 | help="epsilon for label smoothing, 0 means no label smoothing", 25 | ) 26 | 27 | def forward(self, model, sample, reduce=True): 28 | net_output = model(**sample["net_input"]) 29 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 30 | assert "weights" in sample, "Need to specify weights for examples." 31 | weights = sample["weights"].unsqueeze(1).unsqueeze(2) 32 | lprobs = lprobs * weights 33 | 34 | lprobs = lprobs.view(-1, lprobs.size(-1)) 35 | target = model.get_targets(sample, net_output).view(-1, 1) 36 | non_pad_mask = target.ne(self.padding_idx) 37 | nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] 38 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] 39 | if reduce: 40 | nll_loss = nll_loss.sum() 41 | smooth_loss = smooth_loss.sum() 42 | eps_i = self.eps / lprobs.size(-1) 43 | loss = (1.0 - self.eps) * nll_loss + eps_i * smooth_loss 44 | 45 | sample_size = ( 46 | sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"] 47 | ) 48 | logging_output = { 49 | "loss": utils.item(loss.data) if reduce else loss.data, 50 | "nll_loss": utils.item(nll_loss.data) if reduce else loss.data, 51 | "ntokens": sample["ntokens"], 52 | "nsentences": sample["target"].size(0), 53 | "sample_size": sample_size, 54 | } 55 | return loss, sample_size, logging_output 56 | 57 | @classmethod 58 | def aggregate_logging_outputs(cls, logging_outputs): 59 | """Aggregate logging outputs from data parallel training.""" 60 | return LabelSmoothedCrossEntropyCriterion.aggregate_logging_outputs( 61 | logging_outputs 62 | ) 63 | -------------------------------------------------------------------------------- /pytorch_translate/word_prediction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/translate/b89dc35abeb7fe516e3b95ccacdedfc1a92e5626/pytorch_translate/word_prediction/__init__.py -------------------------------------------------------------------------------- /pytorch_translate/word_prediction/word_prediction_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from fairseq.models import ( 4 | FairseqEncoderDecoderModel, 5 | register_model, 6 | register_model_architecture, 7 | ) 8 | from pytorch_translate import rnn 9 | from pytorch_translate.rnn import LSTMSequenceEncoder, RNNDecoder, RNNEncoder, RNNModel 10 | from pytorch_translate.utils import torch_find 11 | from pytorch_translate.word_prediction import word_predictor 12 | 13 | 14 | class WordPredictionModel(FairseqEncoderDecoderModel): 15 | """ 16 | An architecuture which jointly learns translation and target words 17 | prediction, as described in http://aclweb.org/anthology/D17-1013. 18 | """ 19 | 20 | def __init__(self, task, encoder, decoder, predictor): 21 | super().__init__(encoder, decoder) 22 | self.predictor = predictor 23 | self.task = task 24 | 25 | def forward(self, src_tokens, src_lengths, prev_output_tokens): 26 | encoder_output = self.encoder(src_tokens, src_lengths) 27 | pred_output = self.predictor(encoder_output) 28 | decoder_output = self.decoder(prev_output_tokens, encoder_output) 29 | return pred_output, decoder_output 30 | 31 | def get_predictor_normalized_probs(self, pred_output, log_probs): 32 | return self.predictor.get_normalized_probs(pred_output, log_probs) 33 | 34 | def get_target_words(self, sample): 35 | return sample["target"] 36 | 37 | 38 | @register_model("rnn_word_pred") 39 | class RNNWordPredictionModel(WordPredictionModel): 40 | """ 41 | A subclass which adds words prediction to RNN arch. 42 | """ 43 | 44 | @staticmethod 45 | def add_args(parser): 46 | rnn.RNNModel.add_args(parser) 47 | parser.add_argument( 48 | "--predictor-hidden-dim", 49 | type=int, 50 | metavar="N", 51 | help="word predictor num units", 52 | ) 53 | 54 | parser.add_argument( 55 | "--topk-labels-per-source-token", 56 | type=int, 57 | metavar="N", 58 | help="Top k predicted words from the word predictor module for use" 59 | "as translation candidates in vocab reduction module, as a multiple" 60 | "of source tokens.", 61 | ) 62 | 63 | @classmethod 64 | def build_model(cls, args, task): 65 | """Build a new model instance.""" 66 | src_dict, dst_dict = task.source_dictionary, task.target_dictionary 67 | base_architecture_wp(args) 68 | 69 | encoder_embed_tokens, decoder_embed_tokens = RNNModel.build_embed_tokens( 70 | args, src_dict, dst_dict 71 | ) 72 | 73 | if args.sequence_lstm: 74 | encoder_class = LSTMSequenceEncoder 75 | else: 76 | encoder_class = RNNEncoder 77 | decoder_class = RNNDecoder 78 | 79 | encoder = encoder_class( 80 | src_dict, 81 | embed_tokens=encoder_embed_tokens, 82 | embed_dim=args.encoder_embed_dim, 83 | cell_type=args.cell_type, 84 | num_layers=args.encoder_layers, 85 | hidden_dim=args.encoder_hidden_dim, 86 | dropout_in=args.encoder_dropout_in, 87 | dropout_out=args.encoder_dropout_out, 88 | residual_level=args.residual_level, 89 | bidirectional=bool(args.encoder_bidirectional), 90 | ) 91 | predictor = word_predictor.WordPredictor( 92 | encoder_output_dim=args.encoder_hidden_dim, 93 | hidden_dim=args.predictor_hidden_dim, 94 | output_dim=len(dst_dict), 95 | topk_labels_per_source_token=args.topk_labels_per_source_token, 96 | ) 97 | decoder = decoder_class( 98 | src_dict=src_dict, 99 | dst_dict=dst_dict, 100 | embed_tokens=decoder_embed_tokens, 101 | vocab_reduction_params=args.vocab_reduction_params, 102 | encoder_hidden_dim=args.encoder_hidden_dim, 103 | embed_dim=args.decoder_embed_dim, 104 | out_embed_dim=args.decoder_out_embed_dim, 105 | cell_type=args.cell_type, 106 | num_layers=args.decoder_layers, 107 | hidden_dim=args.decoder_hidden_dim, 108 | attention_type=args.attention_type, 109 | dropout_in=args.decoder_dropout_in, 110 | dropout_out=args.decoder_dropout_out, 111 | residual_level=args.residual_level, 112 | averaging_encoder=args.averaging_encoder, 113 | predictor=None if args.topk_labels_per_source_token is None else predictor, 114 | ) 115 | 116 | return cls(task, encoder, decoder, predictor) 117 | 118 | def get_targets(self, sample, net_output): 119 | targets = sample["target"].view(-1) 120 | possible_translation_tokens = net_output[-1] 121 | if possible_translation_tokens is not None: 122 | targets = torch_find( 123 | possible_translation_tokens, targets, len(self.task.target_dictionary) 124 | ) 125 | return targets 126 | 127 | 128 | @register_model_architecture("rnn_word_pred", "rnn_word_pred") 129 | def base_architecture_wp(args): 130 | # default architecture 131 | rnn.base_architecture(args) 132 | args.predictor_hidden_dim = getattr(args, "predictor_hidden_dim", 512) 133 | args.topk_labels_per_source_token = getattr( 134 | args, "topk_labels_per_source_token", None 135 | ) 136 | -------------------------------------------------------------------------------- /pytorch_translate/word_prediction/word_predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class WordPredictor(nn.Module): 9 | def __init__( 10 | self, 11 | encoder_output_dim, 12 | hidden_dim, 13 | output_dim, 14 | topk_labels_per_source_token=None, 15 | use_self_attention=False, 16 | ): 17 | super().__init__() 18 | self.encoder_output_dim = encoder_output_dim 19 | self.hidden_dim = hidden_dim 20 | self.output_dim = output_dim 21 | self.topk_labels_per_source_token = topk_labels_per_source_token 22 | self.use_self_attention = use_self_attention 23 | 24 | if self.use_self_attention: 25 | self.init_layer = nn.Linear(encoder_output_dim, encoder_output_dim) 26 | self.attn_layer = nn.Linear(2 * encoder_output_dim, 1) 27 | self.hidden_layer = nn.Linear(2 * encoder_output_dim, hidden_dim) 28 | self.output_layer = nn.Linear(hidden_dim, output_dim) 29 | else: 30 | self.hidden_layer = nn.Linear(encoder_output_dim, hidden_dim) 31 | self.output_layer = nn.Linear(hidden_dim, output_dim) 32 | 33 | def forward(self, encoder_output): 34 | # [source_length, batch_size, encoder_output_dim] 35 | encoder_hiddens, *_ = encoder_output 36 | assert encoder_hiddens.dim() 37 | 38 | if self.use_self_attention: 39 | # [batch_size, hidden_dim] 40 | init_state = self._get_init_state(encoder_hiddens) 41 | # [source_length, batch_size, 1] 42 | attn_scores = self._attention(encoder_hiddens, init_state) 43 | # [batch_size, hidden_dim] 44 | attned_state = (encoder_hiddens * attn_scores).sum(0) 45 | 46 | pred_input = torch.cat([init_state, attned_state], 1) 47 | pred_hidden = F.relu(self.hidden_layer(pred_input)) 48 | # [batch_size, vocab_size] 49 | logits = self.output_layer(pred_hidden) 50 | else: 51 | # [source_length, batch_size, hidden_dim] 52 | hidden = F.relu(self.hidden_layer(encoder_hiddens)) 53 | # [batch_size, hidden_dim] 54 | mean_hidden = torch.mean(hidden, 0) 55 | max_hidden = torch.max(hidden, 0)[0] 56 | # [batch_size, vocab_size] 57 | logits = self.output_layer(mean_hidden + max_hidden) 58 | 59 | return logits 60 | 61 | def _get_init_state(self, encoder_hiddens): 62 | x = torch.mean(encoder_hiddens, 0) 63 | x = F.relu(self.init_layer(x)) 64 | return x 65 | 66 | def _attention(self, encoder_hiddens, init_state): 67 | init_state = init_state.unsqueeze(0).expand_as(encoder_hiddens) 68 | attn_input = torch.cat([init_state, encoder_hiddens], 2) 69 | attn_scores = F.relu(self.attn_layer(attn_input)) 70 | attn_scores = F.softmax(attn_scores, 0) 71 | return attn_scores 72 | 73 | def get_normalized_probs(self, net_output, log_probs): 74 | """Get normalized probabilities (or log probs) from a net's output.""" 75 | logits = net_output # [batch, vocab] 76 | if log_probs: 77 | return F.log_softmax(logits, dim=1) 78 | else: 79 | return F.softmax(logits, dim=1) 80 | 81 | def get_topk_predicted_tokens(self, net_output, src_tokens, log_probs: bool): 82 | """ 83 | Get self.topk_labels_per_source_token top predicted words for vocab 84 | reduction (per source token). 85 | """ 86 | assert ( 87 | isinstance(self.topk_labels_per_source_token, int) 88 | and self.topk_labels_per_source_token > 0 89 | ), "topk_labels_per_source_token must be a positive int, or None" 90 | 91 | # number of labels to predict for each example in batch 92 | k = src_tokens.size(1) * self.topk_labels_per_source_token 93 | # [batch_size, vocab_size] 94 | probs = self.get_normalized_probs(net_output, log_probs) 95 | _, topk_indices = torch.topk(probs, k, dim=1) 96 | 97 | return topk_indices 98 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from setuptools import find_packages, setup 4 | 5 | 6 | def readme(): 7 | with open("README.md") as f: 8 | return f.read() 9 | 10 | 11 | def requirements(): 12 | with open("requirements.txt") as f: 13 | return f.read() 14 | 15 | 16 | setup( 17 | name="pytorch-translate", 18 | version="0.1", 19 | author="Facebook AI", 20 | description=("Facebook Translation System"), 21 | long_description=readme(), 22 | url="https://github.com/pytorch/translate", 23 | license="BSD", 24 | packages=find_packages(), 25 | install_requires=[ 26 | "fairseq>=0.5.0", 27 | ], 28 | dependency_links=[ 29 | "git+https://github.com/pytorch/fairseq.git#egg=fairseq-0.5.0", 30 | ], 31 | test_suite="pytorch_translate", 32 | ) 33 | --------------------------------------------------------------------------------