├── .gitignore ├── LICENSE ├── README.md ├── docs ├── benchmarks.md ├── index.md └── walkthrough.md ├── setup.py └── thumt ├── __init__.py ├── bin ├── __init__.py ├── scorer.py ├── trainer.py └── translator.py ├── data ├── __init__.py ├── dataset.py ├── iterator.py ├── pipeline.py └── vocab.py ├── models ├── __init__.py └── transformer.py ├── modules ├── __init__.py ├── affine.py ├── attention.py ├── embedding.py ├── feed_forward.py ├── layer_norm.py ├── losses.py ├── module.py └── recurrent.py ├── optimizers ├── __init__.py ├── clipping.py ├── optimizers.py └── schedules.py ├── scripts ├── average_checkpoints.py ├── build_vocab.py ├── convert_checkpoint.py └── shuffle_corpus.py ├── tokenizers ├── __init__.py ├── tokenizer.py └── unicode_tokenizer.py └── utils ├── __init__.py ├── bleu.py ├── bpe.py ├── checkpoint.py ├── convert_params.py ├── evaluation.py ├── hparams.py ├── inference.py ├── misc.py ├── nest.py ├── scope.py └── summary.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *~ 3 | __pycache__ 4 | .* 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, Natural Language Processing Lab at Tsinghua University 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, this 11 | list of conditions and the following disclaimer in the documentation and/or 12 | other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from this 16 | software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # THUMT: An Open Source Toolkit for Neural Machine Translation 2 | 3 | ## Contents 4 | 5 | * [Introduction](#introduction) 6 | * [Online Demo](#online-demo) 7 | * [Implementations](#implementations) 8 | * [Notable Features](#notable-features) 9 | * [Documentation](#documentation) 10 | * [License](#license) 11 | * [Citation](#citation) 12 | * [Development Team](#development-team) 13 | * [Contact](#contact) 14 | * [Derivative Repositories](#derivative-repositories) 15 | 16 | ## Introduction 17 | 18 | Machine translation is a natural language processing task that aims to translate natural languages using computers automatically. Recent several years have witnessed the rapid development of end-to-end neural machine translation, which has become the new mainstream method in practical MT systems. 19 | 20 | THUMT is an open-source toolkit for neural machine translation developed by [the Natural Language Processing Group at Tsinghua University](http://nlp.csai.tsinghua.edu.cn/site2/index.php?lang=en). The website of THUMT is: [http://thumt.thunlp.org/](http://thumt.thunlp.org/). 21 | 22 | ## Online Demo 23 | 24 | The online demo of THUMT is available at [http://translate.thumt.cn/](http://101.6.5.207:3892/). The languages involved include Ancient Chinese, Arabic, Chinese, English, French, German, Indonesian, Japanese, Portuguese, Russian, and Spanish. 25 | 26 | ## Implementations 27 | 28 | THUMT has currently three main implementations: 29 | 30 | * [THUMT-PyTorch](https://github.com/thumt/THUMT): a new implementation developed with [PyTorch](https://github.com/pytorch/pytorch). It implements the Transformer model (**Transformer**) ([Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)). 31 | 32 | * [THUMT-TensorFlow](https://github.com/thumt/THUMT/tree/tensorflow): an implementation developed with [TensorFlow](https://github.com/tensorflow/tensorflow). It implements the sequence-to-sequence model (**Seq2Seq**) ([Sutskever et al., 2014](https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf)), the standard attention-based model (**RNNsearch**) ([Bahdanau et al., 2014](https://arxiv.org/pdf/1409.0473.pdf)), and the Transformer model (**Transformer**) ([Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)). 33 | 34 | * [THUMT-Theano](https://github.com/thumt/THUMT/tree/theano): the original project developed with [Theano](https://github.com/Theano/Theano), which is no longer updated because MLA put an end to [Theano](https://github.com/Theano/Theano). It implements the standard attention-based model (**RNNsearch**) ([Bahdanau et al., 2014](https://arxiv.org/pdf/1409.0473.pdf)), minimum risk training (**MRT**) ([Shen et al., 2016](http://nlp.csai.tsinghua.edu.cn/~ly/papers/acl2016_mrt.pdf)) for optimizing model parameters with respect to evaluation metrics, semi-supervised training (**SST**) ([Cheng et al., 2016](http://nlp.csai.tsinghua.edu.cn/~ly/papers/acl2016_semi.pdf)) for exploiting monolingual corpora to learn bi-directional translation models, and layer-wise relevance propagation (**LRP**) ([Ding et al., 2017](http://nlp.csai.tsinghua.edu.cn/~ly/papers/acl2017_dyz.pdf)) for visualizing and anlayzing RNNsearch. 35 | 36 | The following table summarizes the features of three implementations: 37 | 38 | | Implementation | Model | Criterion | Optimizer | LRP | 39 | | :------------: | :---: | :--------------: | :--------------: | :----------------: | 40 | | Theano | RNNsearch | MLE, MRT, SST | SGD, AdaDelta, Adam | RNNsearch | 41 | | TensorFlow | Seq2Seq, RNNsearch, Transformer | MLE| Adam | RNNsearch, Transformer | 42 | | PyTorch | Transformer | MLE | SGD, Adadelta, Adam | N.A. | 43 | 44 | We recommend using [THUMT-PyTorch](https://github.com/thumt/THUMT) or [THUMT-TensorFlow](https://github.com/thumt/THUMT/tree/tensorflow), which delivers better translation performance than [THUMT-Theano](https://github.com/thumt/THUMT/tree/theano). We will keep adding new features to [THUMT-PyTorch](https://github.com/thumt/THUMT) and [THUMT-TensorFlow](https://github.com/thumt/THUMT/tree/tensorflow). 45 | 46 | ## Notable Features 47 | 48 | * Transformer ([Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)) 49 | * Multi-GPU training & decoding 50 | * Multi-worker distributed training 51 | * Mixed precision training & decoding 52 | * Model ensemble & averaging 53 | * Gradient aggregation 54 | * TensorBoard for visualization 55 | 56 | ## Documentation 57 | 58 | The documentation of PyTorch implementation is avaiable at [here](docs/index.md). 59 | 60 | ## License 61 | 62 | The source code is dual licensed. Open source licensing is under the [BSD-3-Clause](https://opensource.org/licenses/BSD-3-Clause), which allows free use for research purposes. For commercial licensing, please email [thumt17@gmail.com](mailto:thumt17@gmail.com). 63 | 64 | ## Citation 65 | 66 | Please cite the following paper: 67 | 68 | > Zhixing Tan, Jiacheng Zhang, Xuancheng Huang, Gang Chen, Shuo Wang, Maosong Sun, Huanbo Luan, Yang Liu. [THUMT: An Open Source Toolkit for Neural Machine Translation](https://www.aclweb.org/anthology/2020.amta-research.11/). AMTA 2020. 69 | 70 | > Jiacheng Zhang, Yanzhuo Ding, Shiqi Shen, Yong Cheng, Maosong Sun, Huanbo Luan, Yang Liu. 2017. [THUMT: An Open Source Toolkit for Neural Machine Translation](https://arxiv.org/abs/1706.06415). arXiv:1706.06415. 71 | 72 | ## Development Team 73 | 74 | Project leaders: [Maosong Sun](http://www.thunlp.org/site2/index.php/zh/people?id=16), [Yang Liu](http://nlp.csai.tsinghua.edu.cn/~ly/), Huanbo Luan 75 | 76 | Project members: 77 | 78 | Theano: Jiacheng Zhang, Yanzhuo Ding, Shiqi Shen, Yong Cheng 79 | 80 | TensorFlow: Zhixing Tan, Jiacheng Zhang, Xuancheng Huang, Gang Chen, Shuo Wang, Zonghan Yang 81 | 82 | PyTorch: Zhixing Tan, Gang Chen 83 | 84 | ## Contact 85 | 86 | If you have questions, suggestions and bug reports, please email [thumt17@gmail.com](mailto:thumt17@gmail.com). 87 | 88 | ## Derivative Repositories 89 | 90 | * [UCE4BT](https://github.com/THUNLP-MT/UCE4BT) (Improving Back-Translation with Uncertainty-based Confidence Estimation) 91 | * [L2Copy4APE](https://github.com/THUNLP-MT/L2Copy4APE) (Learning to Copy for Automatic Post-Editing) 92 | * [Document-Transformer](https://github.com/THUNLP-MT/Document-Transformer) (Improving the Transformer Translation Model with Document-Level Context) 93 | * [PR4NMT](https://github.com/THUNLP-MT/PR4NMT) (Prior Knowledge Integration for Neural Machine Translation using Posterior Regularization) 94 | -------------------------------------------------------------------------------- /docs/benchmarks.md: -------------------------------------------------------------------------------- 1 | # THUMT Documentation 2 | 3 | [THUMT](https://github.com/thumt/THUMT/tree/pytorch) is an open-source toolkit for neural machine translation developed by the Tsinghua Natural Language Processing Group. 4 | 5 | ## Benchmarks 6 | 7 | We benchmark THUMT on the following datasets: 8 | 9 | * [WMT14 En-DE](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) 10 | * [WMT18 Zh-En](http://data.statmt.org/wmt18/translation-task/preprocessed/zh-en/) 11 | 12 | The testsets for WMT14 En-De and WMT18 Zh-En are `newstest2014` and `newstest2017` respectively. 13 | 14 | | Dataset | Model | Size | Steps | GPUs | Batch/GPU | Mode | BLEU | 15 | |:---------:|:---------:|:----:|:-----:|:----:|:---------:|:--------:|:------:| 16 | |WMT14 En-De|Transformer| Base | 100k | 4 | 2*4096 | FP16 | 26.85 | 17 | |WMT14 En-De|Transformer| Base | 100k | 4 | 2*4096 | FP32 | 26.91 | 18 | |WMT14 En-De|Transformer| Base | 100k | 8 | 4096 | FP32 | 26.95 | 19 | |WMT14 En-De|Transformer| Base | 86k | 8 | 2*4096 | FP32 | 27.21 | 20 | |WMT14 En-De|Transformer| Big | 300k | 8 | 4096 | FP16 | 28.71 | 21 | |WMT14 En-De|Transformer| Big | 20k | 16 | 8*4096 | DistFP16 | 28.68 | 22 | |WMT18 Zh-En|Transformer| Big | 300k | 8 | 2*4096 | FP16 | 24.07 | 23 | 24 | [return to index](index.md) 25 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # THUMT Documentation 2 | 3 | [THUMT](https://github.com/thumt/THUMT/tree/pytorch) is an open-source toolkit for neural machine translation developed by the Tsinghua Natural Language Processing Group. This page describes the document of [THUMT-PyTorch](https://github.com/thumt/THUMT/tree/pytorch). 4 | 5 | ## Contents 6 | 7 | * [Basics](#basics) 8 | * [Prerequisite](#prerequisite) 9 | * [Installation](#installation) 10 | * [Features](#features) 11 | * [Walkthrough](#walkthrough) 12 | * [Benchmarks](#benchmarks) 13 | 14 | ## Basics 15 | 16 | ### Prerequisites 17 | 18 | * CUDA 10.0 19 | * PyTorch 20 | * TensorFlow-2.0 (CPU version) 21 | 22 | ### Installation 23 | 24 | ```bash 25 | pip install --upgrade pip 26 | pip install thumt 27 | ``` 28 | 29 | ### Features 30 | 31 | * Multi-GPU training & decoding 32 | * Multi-worker distributed training 33 | * Mixed precision training & decoding 34 | * Model ensemble & averaging 35 | * Gradient aggregation 36 | * TensorBoard for visualization 37 | 38 | ## Walkthrough 39 | 40 | We provide a step-by-step [walkthrough](walkthrough.md) with a running example: WMT 2018 Chinese-English news translation shared task. 41 | 42 | ## Benchmarks 43 | 44 | We provide benchmarks on several datasets. See [here](benchmarks.md). 45 | -------------------------------------------------------------------------------- /docs/walkthrough.md: -------------------------------------------------------------------------------- 1 | # THUMT Documentation 2 | 3 | [THUMT](https://github.com/thumt/THUMT/tree/pytorch) is an open-source toolkit for neural machine translation developed by the Tsinghua Natural Language Processing Group. 4 | 5 | ## Walkthrough 6 | 7 | * [Data Preparation](#data-preparation) 8 | * [Obtaining the Datasets](#obtaining-the-datasets) 9 | * [Running BPE](#running-bpe) 10 | * [Shuffling Training Set](#shuffling-training-set) 11 | * [Generating Vocabularies](#generating-vocabularies) 12 | * [Training](#training) 13 | * [Decoding](#decoding) 14 | 15 | We provide a step-by-step guide with a running example: WMT 2018 Chinese-English news translation shared task. 16 | 17 | ## Data Preparation 18 | 19 | ### Obtaining the Datasets 20 | 21 | Running THUMT involves three types of datasets: 22 | 23 | * **Training set**: a set of parallel sentences used for training NMT models. 24 | * **Validation set**: a set of source sentences paired with single or multiple target translations used for model selection and hyper-parameter optimization. 25 | * **Test set**: a set of source sentences paired with single or multiple target translations used for evaluating translation performance on unseen texts. 26 | 27 | In this walkthrough, we'll use the preprocessed official [dataset](http://data.statmt.org/wmt18/translation-task/preprocessed/zh-en/). Download and unpack the files `corpus.gz`: 28 | 29 | ```bash 30 | gzip -d corpus.gz 31 | ``` 32 | 33 | The resulting file is `corpus.tsv`. Use the following command to generate source and target files: 34 | 35 | ```bash 36 | cut -f 1 corpus.tsv > corpus.tc.zh 37 | cut -f 2 corpus.tsv > corpus.tc.en 38 | ``` 39 | 40 | `corpus.tc.zh` and `corpus.tc.en` serve as the training set, which contains 24,752,392 pairs of sentences. Note that the Chinese sentences are tokenized and the English sentences are tokenized and truecased. Unpack the file `dev.tgz` using the following command: 41 | 42 | ```bash 43 | tar xvfz dev.tgz 44 | ``` 45 | 46 | After unpacking, `newdev2017.tc.zh` and `newsdev2017.tc.en` serve as the validation set, which contains 2,002 pairs of sentences. The test set we use is `newstest2017.tc.zh` and `newstest2017.tc.en`, which consists of 2,001 pairs of sentences. Note that both the validation and test sets use single references since there is only one gold-standard English translation for each Chinese sentence. 47 | 48 | ### Running BPE 49 | 50 | For efficiency reasons, only a fraction of the full vocabulary can be used in neural machine translation systems. The most widely used approach for addressing the open vocabulary problem is to use the Byte Pair Encoding (BPE). We recommend using BPE for THUMT. 51 | 52 | First, download the source code of BPE using the following command: 53 | 54 | ```bash 55 | git clone https://github.com/rsennrich/subword-nmt.git 56 | ``` 57 | 58 | To encode the training corpora using BPE, you need to generate BPE operations first. The following command will create two files named `bpe.zh` and `bpe.en`, which contain 32k BPE operations. 59 | 60 | ```bash 61 | python subword-nmt/learn_bpe.py -s 32000 -t < corpus.tc.zh > bpe.zh 62 | python subword-nmt/learn_bpe.py -s 32000 -t < corpus.tc.en > bpe.en 63 | ``` 64 | 65 | Then, the `apply_bpe.py` script runs to encode the training set using the generated BPE operations. 66 | 67 | ```bash 68 | python subword-nmt/apply_bpe.py -c bpe.zh < corpus.tc.zh > corpus.tc.32k.zh 69 | python subword-nmt/apply_bpe.py -c bpe.en < corpus.tc.en > corpus.tc.32k.en 70 | ``` 71 | 72 | The source side of the validation set and the test set also needs to be processed using the `apply_bpe.py` script. 73 | 74 | ```bash 75 | python subword-nmt/apply_bpe.py -c bpe.zh < newsdev2017.tc.zh > newsdev2017.tc.32k.zh 76 | python subword-nmt/apply_bpe.py -c bpe.zh < newstest2017.tc.zh > newstest2017.tc.32k.zh 77 | ``` 78 | 79 | Kindly note that while the source side of the validation set and test set is applied with BPE operations, the target side of them are not needed to be applied. This is because when evaluating the translation outputs, we will restore them in the normal tokenization and compare them with the original ground-truth sentences. 80 | 81 | ### Shuffling Training Set 82 | 83 | The next step is to shuffle the training set, which proves to be helpful for improving the translation quality. Simply run the following command: 84 | 85 | ```bash 86 | shuffle_corpus.py --corpus corpus.tc.32k.zh corpus.tc.32k.en 87 | ``` 88 | 89 | The resulting files `corpus.tc.32k.zh.shuf` and `corpus.tc.32k.en.shuf` rearrange the sentence pairs randomly. 90 | 91 | ### Generating Vocabularies 92 | 93 | We need to generate vocabulary from the shuffled training set. This can be done by running the `build_vocab.py` script: 94 | 95 | ```bash 96 | build_vocab.py corpus.tc.32k.zh.shuf vocab.32k.zh 97 | build_vocab.py corpus.tc.32k.en.shuf vocab.32k.en 98 | ``` 99 | 100 | The resulting files `vocab.32k.zh.txt` and `vocab.32k.en.txt` are final source and target vocabularies used for model training. 101 | 102 | ## Training 103 | 104 | We recommend using the Transformer model that delivers the best translation performance among all the three models supported by THUMT. The command for training a Transformer model is given by 105 | 106 | ```bash 107 | thumt-trainer \ 108 | --input corpus.tc.32k.zh.shuf corpus.tc.32k.en.shuf \ 109 | --vocabulary vocab.32k.zh.txt vocab.32k.en.txt \ 110 | --model transformer \ 111 | --validation newsdev2017.tc.32k.zh \ 112 | --references newsdev2017.tc.en \ 113 | --parameters=batch_size=4096,device_list=[0,1,2,3],update_cycle=2 \ 114 | --hparam_set base 115 | ``` 116 | 117 | Note that we set the `batch_size` on each device (e.g. GPU) to 4,096 words instead of 4,096 sentences. By default, the batch size for the Transformer model is defined in terms of word number rather than sentence number in THUMT. We set `update_cycle` to 2, which means the model parameters are updated every 2 batches. This effectively simulates the setting of `batch_size=32768` and requires less GPU memory. If you still run out of GPU memory, try to use smaller `batch_size` and larger `update_cycle`. For newer GPUs like Tesla V100, you can add `--half` to enable mixed-precision training, which can improves training speed and reduces memory usage. 118 | 119 | `device_list=[0,1,2,3]` suggests that `gpu0-3` is used to train the model. THUMT supports to train NMT models on multiple GPUs. If `gpu0-7` are available, simply set `device_list=[0,1,2,3,4,5,6,7]` for the same batch size of 32768 (but with the training speed doubled). You may use the `nvidia-smi` command to find unused GPUs. 120 | 121 | By setting `hparams_set=base`, we will train a base Transformer model. the training process will terminate at iteration 100,000 by default. During the training, the `thumt-trainer` command creates a `train` directory to store intermediate models called `checkpoints`, which will be evaluated on the validation set periodically. 122 | 123 | Please kindly note again that while the source side of the validation set is applied with BPE operations (`newsdev2017.tc.32k.zh`), the target side of the validation set is in the original tokenization (`newsdev2017.tc.en`). 124 | 125 | Only a small number of checkpoints that achieves highest BLEU scores on the validation set will be saved in the `train/eval` directory. This directory will be used in decoding. 126 | 127 | ## Decoding 128 | 129 | The command for translating the test set using the trained Transformer model is given by 130 | 131 | ```bash 132 | thumt-translator \ 133 | --models transformer \ 134 | --input newstest2017.tc.32k.zh \ 135 | --output newstest2017.trans \ 136 | --vocabulary vocab.32k.zh.txt vocab.32k.en.txt \ 137 | --checkpoints train/eval \ 138 | --parameters=device_list=[0],decode_alpha=1.2 139 | ``` 140 | 141 | Please kindly note that a lot of decoding techniques are actually working on the test set here, i.e. `decode_alpha`; varying `decode_alpha` during the command for training process only leads to varied translation performances on the evaluation set. 142 | 143 | The translation file output by the `thumt-translator` is `newstest2017.trans`, which needs to be restored to the normal tokenization using the following command: 144 | 145 | ```bash 146 | sed -r 's/(@@ )|(@@ ?$)//g' < newstest2017.trans > newstest2017.trans.norm 147 | ``` 148 | 149 | Finally, BLEU scores can be calculated using the [`multi-bleu.perl`](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl): 150 | 151 | ```bash 152 | multi-bleu.perl -lc newstest2017.tc.en < newstest2017.trans.norm > evalResult 153 | ``` 154 | 155 | The resulting `evalResult` stores the calculated BLEU score. 156 | 157 | [return to index](index.md) 158 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | # Copyright 2017-2020 The THUMT Authors 4 | 5 | from setuptools import find_packages 6 | from setuptools import setup 7 | 8 | setup( 9 | name="thumt", 10 | version="1.2.0", 11 | author="The THUMT Authors", 12 | author_email="thumt17@gmail.com", 13 | description="THUMT: An open-source toolkit for neural machine translation", 14 | url="http://thumt.thunlp.org", 15 | entry_points={ 16 | "console_scripts": [ 17 | "thumt-trainer = thumt.bin.trainer:cli_main", 18 | "thumt-translator = thumt.bin.translator:cli_main", 19 | "thumt-scorer=thumt.bin.scorer:cli_main" 20 | ]}, 21 | scripts=[ 22 | "thumt/scripts/average_checkpoints.py", 23 | "thumt/scripts/build_vocab.py", 24 | "thumt/scripts/convert_checkpoint.py", 25 | "thumt/scripts/shuffle_corpus.py"], 26 | packages=find_packages(), 27 | install_requires=[ 28 | "future", 29 | "pillow", 30 | "torch>=1.1.0", 31 | "regex"], 32 | classifiers=[ 33 | "Programming Language :: Python :: 3", 34 | "Topic :: Scientific/Engineering :: Artificial Intelligence"]) 35 | -------------------------------------------------------------------------------- /thumt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-MT/THUMT/4893578278c08f436fbaaf799257cea15b0c9b56/thumt/__init__.py -------------------------------------------------------------------------------- /thumt/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-MT/THUMT/4893578278c08f436fbaaf799257cea15b0c9b56/thumt/bin/__init__.py -------------------------------------------------------------------------------- /thumt/bin/scorer.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin python 2 | # coding=utf-8 3 | # Copyright 2017-2020 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import os 10 | import re 11 | import six 12 | import time 13 | import copy 14 | import torch 15 | import socket 16 | import logging 17 | import argparse 18 | import numpy as np 19 | 20 | import torch.distributed as dist 21 | 22 | import thumt.data as data 23 | import thumt.utils as utils 24 | import thumt.models as models 25 | 26 | logging.getLogger().setLevel(logging.INFO) 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser( 31 | description="Score input sentences with pre-trained checkpoints.", 32 | usage="scorer.py [] [-h | --help]" 33 | ) 34 | 35 | # input files 36 | parser.add_argument("--input", type=str, required=True, nargs=2, 37 | help="Path to input file.") 38 | parser.add_argument("--output", type=str, required=True, 39 | help="Path to output file.") 40 | parser.add_argument("--checkpoint", type=str, required=True, 41 | help="Path to trained checkpoint.") 42 | parser.add_argument("--vocabulary", type=str, nargs=2, required=True, 43 | help="Path to source and target vocabulary.") 44 | 45 | # model and configuration 46 | parser.add_argument("--model", type=str, required=True, 47 | help="Name of the model.") 48 | parser.add_argument("--parameters", type=str, default="", 49 | help="Additional hyper-parameters.") 50 | parser.add_argument("--half", action="store_true", 51 | help="Enable Half-precision for decoding.") 52 | 53 | return parser.parse_args() 54 | 55 | 56 | def default_params(): 57 | params = utils.HParams( 58 | input=None, 59 | output=None, 60 | vocabulary=None, 61 | model=None, 62 | # vocabulary specific 63 | pad="", 64 | bos="", 65 | eos="", 66 | unk="", 67 | append_eos=False, 68 | monte_carlo=False, 69 | device_list=[0], 70 | decode_batch_size=32, 71 | buffer_size=10000, 72 | level="sentence" 73 | ) 74 | 75 | return params 76 | 77 | 78 | def merge_params(params1, params2): 79 | params = utils.HParams() 80 | 81 | for (k, v) in six.iteritems(params1.values()): 82 | params.add_hparam(k, v) 83 | 84 | params_dict = params.values() 85 | 86 | for (k, v) in six.iteritems(params2.values()): 87 | if k in params_dict: 88 | # Override 89 | setattr(params, k, v) 90 | else: 91 | params.add_hparam(k, v) 92 | 93 | return params 94 | 95 | 96 | def import_params(model_dir, model_name, params): 97 | model_dir = os.path.abspath(model_dir) 98 | m_name = os.path.join(model_dir, model_name + ".json") 99 | 100 | if not os.path.exists(m_name): 101 | return params 102 | 103 | with open(m_name) as fd: 104 | logging.info("Restoring model parameters from %s" % m_name) 105 | json_str = fd.readline() 106 | params.parse_json(json_str) 107 | 108 | return params 109 | 110 | 111 | def override_params(params, args): 112 | if args.parameters: 113 | params.parse(args.parameters.lower()) 114 | 115 | params.vocabulary = { 116 | "source": data.Vocabulary(args.vocabulary[0]), 117 | "target": data.Vocabulary(args.vocabulary[1]) 118 | } 119 | 120 | return params 121 | 122 | 123 | def infer_gpu_num(param_str): 124 | result = re.match(r".*device_list=\[(.*?)\].*", param_str) 125 | 126 | if not result: 127 | return 1 128 | 129 | dev_str = result.groups()[-1] 130 | return len(dev_str.split(",")) 131 | 132 | 133 | def main(args): 134 | model_cls = models.get_model(args.model) 135 | # Import and override parameters 136 | # Priorities (low -> high): 137 | # default -> saved -> command 138 | params = default_params() 139 | params = merge_params(params, model_cls.default_params()) 140 | params = import_params(args.checkpoint, args.model, params) 141 | params = override_params(params, args) 142 | 143 | params.device = params.device_list[args.local_rank] 144 | dist.init_process_group("nccl", init_method=args.url, 145 | rank=args.local_rank, 146 | world_size=len(params.device_list)) 147 | torch.cuda.set_device(params.device_list[args.local_rank]) 148 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 149 | 150 | if args.half: 151 | torch.set_default_dtype(torch.half) 152 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 153 | 154 | def score_fn(inputs, _model, level="sentence"): 155 | _features, _labels = inputs 156 | _score = _model(_features, _labels, mode="eval", level=level) 157 | return _score 158 | 159 | # Create model 160 | with torch.no_grad(): 161 | model = model_cls(params).cuda() 162 | 163 | if args.half: 164 | model = model.half() 165 | 166 | if not params.monte_carlo: 167 | model.eval() 168 | 169 | model.load_state_dict( 170 | torch.load(utils.latest_checkpoint(args.checkpoint), 171 | map_location="cpu")["model"]) 172 | dataset = data.MTPipeline.get_eval_dataset(args.input, params) 173 | data_iter = iter(dataset) 174 | counter = 0 175 | pad_max = 1024 176 | 177 | # Buffers for synchronization 178 | size = torch.zeros([dist.get_world_size()]).long() 179 | if params.level == "sentence": 180 | t_list = [torch.empty([params.decode_batch_size]).float() 181 | for _ in range(dist.get_world_size())] 182 | else: 183 | t_list = [torch.empty([params.decode_batch_size, pad_max]).float() 184 | for _ in range(dist.get_world_size())] 185 | 186 | if dist.get_rank() == 0: 187 | fd = open(args.output, "w") 188 | else: 189 | fd = None 190 | 191 | while True: 192 | try: 193 | features = next(data_iter) 194 | batch_size = features[0]["source"].shape[0] 195 | except: 196 | features = { 197 | "source": torch.ones([1, 1]).long(), 198 | "source_mask": torch.ones([1, 1]).float(), 199 | "target": torch.ones([1, 1]).long(), 200 | "target_mask": torch.ones([1, 1]).float() 201 | }, torch.ones([1, 1]).long() 202 | batch_size = 0 203 | 204 | t = time.time() 205 | counter += 1 206 | 207 | scores = score_fn(features, model, params.level) 208 | 209 | # Padding 210 | if params.level == "sentence": 211 | pad_batch = params.decode_batch_size - scores.shape[0] 212 | scores = torch.nn.functional.pad(scores, [0, pad_batch]) 213 | else: 214 | pad_batch = params.decode_batch_size - scores.shape[0] 215 | pad_length = pad_max - scores.shape[1] 216 | scores = torch.nn.functional.pad( 217 | scores, (0, pad_length, 0, pad_batch), value=-1) 218 | 219 | # Synchronization 220 | size.zero_() 221 | size[dist.get_rank()].copy_(torch.tensor(batch_size)) 222 | dist.all_reduce(size) 223 | dist.all_gather(t_list, scores.float()) 224 | 225 | if size.sum() == 0: 226 | break 227 | 228 | if dist.get_rank() != 0: 229 | continue 230 | 231 | for i in range(params.decode_batch_size): 232 | for j in range(dist.get_world_size()): 233 | n = size[j] 234 | score = t_list[j][i] 235 | 236 | if i >= n: 237 | continue 238 | 239 | if params.level == "sentence": 240 | fd.write("{:.4f}\n".format(score)) 241 | else: 242 | s_list = score.tolist() 243 | for s in s_list: 244 | if s >= 0: 245 | fd.write("{:.8f} ".format(s)) 246 | else: 247 | fd.write("\n") 248 | break 249 | 250 | t = time.time() - t 251 | logging.info("Finished batch: %d (%.3f sec)" % (counter, t)) 252 | 253 | if dist.get_rank() == 0: 254 | fd.close() 255 | 256 | 257 | # Wrap main function 258 | def process_fn(rank, args): 259 | local_args = copy.copy(args) 260 | local_args.local_rank = rank 261 | main(local_args) 262 | 263 | 264 | def cli_main(): 265 | parsed_args = parse_args() 266 | 267 | # Pick a free port 268 | with socket.socket() as s: 269 | s.bind(("localhost", 0)) 270 | port = s.getsockname()[1] 271 | url = "tcp://localhost:" + str(port) 272 | parsed_args.url = url 273 | 274 | world_size = infer_gpu_num(parsed_args.parameters) 275 | 276 | if world_size > 1: 277 | torch.multiprocessing.spawn(process_fn, args=(parsed_args,), 278 | nprocs=world_size) 279 | else: 280 | process_fn(0, parsed_args) 281 | 282 | 283 | if __name__ == "__main__": 284 | cli_main() 285 | -------------------------------------------------------------------------------- /thumt/bin/trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import argparse 9 | import copy 10 | import glob 11 | import logging 12 | import os 13 | import re 14 | import six 15 | import socket 16 | import time 17 | import torch 18 | 19 | import thumt.data as data 20 | import torch.distributed as dist 21 | import thumt.models as models 22 | import thumt.optimizers as optimizers 23 | import thumt.utils as utils 24 | import thumt.utils.summary as summary 25 | 26 | 27 | def parse_args(args=None): 28 | parser = argparse.ArgumentParser( 29 | description="Train a neural machine translation model.", 30 | usage="trainer.py [] [-h | --help]" 31 | ) 32 | 33 | # input files 34 | parser.add_argument("--input", type=str, nargs=2, 35 | help="Path to source and target corpus.") 36 | parser.add_argument("--output", type=str, default="train", 37 | help="Path to load/store checkpoints.") 38 | parser.add_argument("--vocabulary", type=str, nargs=2, 39 | help="Path to source and target vocabulary.") 40 | parser.add_argument("--validation", type=str, 41 | help="Path to validation file.") 42 | parser.add_argument("--references", type=str, 43 | help="Pattern to reference files.") 44 | parser.add_argument("--checkpoint", type=str, 45 | help="Path to pre-trained checkpoint.") 46 | parser.add_argument("--distributed", action="store_true", 47 | help="Enable distributed training.") 48 | parser.add_argument("--local_rank", type=int, 49 | help="Local rank of this process.") 50 | parser.add_argument("--half", action="store_true", 51 | help="Enable mixed-precision training.") 52 | parser.add_argument("--hparam_set", type=str, 53 | help="Name of pre-defined hyper-parameter set.") 54 | 55 | # model and configuration 56 | parser.add_argument("--model", type=str, required=True, 57 | help="Name of the model.") 58 | parser.add_argument("--parameters", type=str, default="", 59 | help="Additional hyper-parameters.") 60 | 61 | return parser.parse_args(args) 62 | 63 | 64 | def default_params(): 65 | params = utils.HParams( 66 | input=["", ""], 67 | output="", 68 | model="transformer", 69 | vocab=["", ""], 70 | pad="", 71 | bos="", 72 | eos="", 73 | unk="", 74 | # Dataset 75 | batch_size=4096, 76 | fixed_batch_size=False, 77 | min_length=1, 78 | max_length=256, 79 | buffer_size=10000, 80 | # Initialization 81 | initializer_gain=1.0, 82 | initializer="uniform_unit_scaling", 83 | # Regularization 84 | scale_l1=0.0, 85 | scale_l2=0.0, 86 | # Training 87 | initial_step=0, 88 | warmup_steps=4000, 89 | train_steps=100000, 90 | update_cycle=1, 91 | optimizer="Adam", 92 | adam_beta1=0.9, 93 | adam_beta2=0.999, 94 | adam_epsilon=1e-8, 95 | adadelta_rho=0.95, 96 | adadelta_epsilon=1e-7, 97 | pattern="", 98 | clipping="global_norm", 99 | clip_grad_norm=5.0, 100 | learning_rate=1.0, 101 | initial_learning_rate=0.0, 102 | learning_rate_schedule="linear_warmup_rsqrt_decay", 103 | learning_rate_boundaries=[0], 104 | learning_rate_values=[0.0], 105 | device_list=[0], 106 | # Checkpoint Saving 107 | keep_checkpoint_max=20, 108 | keep_top_checkpoint_max=5, 109 | save_summary=True, 110 | save_checkpoint_secs=0, 111 | save_checkpoint_steps=1000, 112 | # Validation 113 | eval_steps=2000, 114 | eval_secs=0, 115 | top_beams=1, 116 | beam_size=4, 117 | decode_batch_size=32, 118 | decode_alpha=0.6, 119 | decode_ratio=1.0, 120 | decode_length=50, 121 | validation="", 122 | references="", 123 | ) 124 | 125 | return params 126 | 127 | 128 | def import_params(model_dir, model_name, params): 129 | model_dir = os.path.abspath(model_dir) 130 | p_name = os.path.join(model_dir, "params.json") 131 | m_name = os.path.join(model_dir, model_name + ".json") 132 | 133 | if os.path.exists(p_name): 134 | with open(p_name) as fd: 135 | logging.info("Restoring hyper parameters from %s" % p_name) 136 | json_str = fd.readline() 137 | params.parse_json(json_str) 138 | 139 | if os.path.exists(m_name): 140 | with open(m_name) as fd: 141 | logging.info("Restoring model parameters from %s" % m_name) 142 | json_str = fd.readline() 143 | params.parse_json(json_str) 144 | 145 | return params 146 | 147 | 148 | def export_params(output_dir, name, params): 149 | if not os.path.exists(output_dir): 150 | os.makedirs(output_dir) 151 | 152 | # Save params as params.json 153 | filename = os.path.join(output_dir, name) 154 | 155 | with open(filename, "w") as fd: 156 | fd.write(params.to_json()) 157 | 158 | 159 | def merge_params(params1, params2): 160 | params = utils.HParams() 161 | 162 | for (k, v) in six.iteritems(params1.values()): 163 | params.add_hparam(k, v) 164 | 165 | params_dict = params.values() 166 | 167 | for (k, v) in six.iteritems(params2.values()): 168 | if k in params_dict: 169 | # Override 170 | setattr(params, k, v) 171 | else: 172 | params.add_hparam(k, v) 173 | 174 | return params 175 | 176 | 177 | def override_params(params, args): 178 | params.model = args.model or params.model 179 | params.input = args.input or params.input 180 | params.output = args.output or params.output 181 | params.vocab = args.vocabulary or params.vocab 182 | params.validation = args.validation or params.validation 183 | params.references = args.references or params.references 184 | params.parse(args.parameters.lower()) 185 | 186 | params.vocabulary = { 187 | "source": data.Vocabulary(params.vocab[0]), 188 | "target": data.Vocabulary(params.vocab[1]) 189 | } 190 | 191 | return params 192 | 193 | 194 | def collect_params(all_params, params): 195 | collected = utils.HParams() 196 | 197 | for k in six.iterkeys(params.values()): 198 | collected.add_hparam(k, getattr(all_params, k)) 199 | 200 | return collected 201 | 202 | 203 | def print_variables(model, pattern, log=True): 204 | flags = [] 205 | 206 | for (name, var) in model.named_parameters(): 207 | if re.search(pattern, name): 208 | flags.append(True) 209 | else: 210 | flags.append(False) 211 | 212 | weights = {v[0]: v[1] for v in model.named_parameters()} 213 | total_size = 0 214 | 215 | for name in sorted(list(weights)): 216 | if re.search(pattern, name): 217 | v = weights[name] 218 | total_size += v.nelement() 219 | 220 | if log: 221 | print("%s %s" % (name.ljust(60), str(list(v.shape)).rjust(15))) 222 | 223 | if log: 224 | print("Total trainable variables size: %d" % total_size) 225 | 226 | return flags 227 | 228 | 229 | def exclude_variables(flags, grads_and_vars): 230 | idx = 0 231 | new_grads = [] 232 | new_vars = [] 233 | 234 | for grad, (name, var) in grads_and_vars: 235 | if flags[idx]: 236 | new_grads.append(grad) 237 | new_vars.append((name, var)) 238 | 239 | idx += 1 240 | 241 | return zip(new_grads, new_vars) 242 | 243 | 244 | def save_checkpoint(step, epoch, model, optimizer, params): 245 | if dist.get_rank() == 0: 246 | state = { 247 | "step": step, 248 | "epoch": epoch, 249 | "model": model.state_dict(), 250 | "optimizer": optimizer.state_dict() 251 | } 252 | utils.save(state, params.output, params.keep_checkpoint_max) 253 | 254 | 255 | def infer_gpu_num(param_str): 256 | result = re.match(r".*device_list=\[(.*?)\].*", param_str) 257 | 258 | if not result: 259 | return 1 260 | else: 261 | dev_str = result.groups()[-1] 262 | return len(dev_str.split(",")) 263 | 264 | 265 | def broadcast(model): 266 | for var in model.parameters(): 267 | dist.broadcast(var.data, 0) 268 | 269 | 270 | def get_learning_rate_schedule(params): 271 | if params.learning_rate_schedule == "linear_warmup_rsqrt_decay": 272 | schedule = optimizers.LinearWarmupRsqrtDecay( 273 | params.learning_rate, params.warmup_steps, 274 | initial_learning_rate=params.initial_learning_rate, 275 | summary=params.save_summary) 276 | elif params.learning_rate_schedule == "piecewise_constant_decay": 277 | schedule = optimizers.PiecewiseConstantDecay( 278 | params.learning_rate_boundaries, params.learning_rate_values, 279 | summary=params.save_summary) 280 | elif params.learning_rate_schedule == "linear_exponential_decay": 281 | schedule = optimizers.LinearExponentialDecay( 282 | params.learning_rate, params.warmup_steps, 283 | params.start_decay_step, params.end_decay_step, 284 | dist.get_world_size(), summary=params.save_summary) 285 | elif params.learning_rate_schedule == "constant": 286 | schedule = params.learning_rate 287 | else: 288 | raise ValueError("Unknown schedule %s" % params.learning_rate_schedule) 289 | 290 | return schedule 291 | 292 | 293 | def get_clipper(params): 294 | if params.clipping.lower() == "none": 295 | clipper = None 296 | elif params.clipping.lower() == "adaptive": 297 | clipper = optimizers.adaptive_clipper(0.95) 298 | elif params.clipping.lower() == "global_norm": 299 | clipper = optimizers.global_norm_clipper(params.clip_grad_norm) 300 | else: 301 | raise ValueError("Unknown clipper %s" % params.clipping) 302 | 303 | return clipper 304 | 305 | 306 | def get_optimizer(params, schedule, clipper): 307 | if params.optimizer.lower() == "adam": 308 | optimizer = optimizers.AdamOptimizer(learning_rate=schedule, 309 | beta_1=params.adam_beta1, 310 | beta_2=params.adam_beta2, 311 | epsilon=params.adam_epsilon, 312 | clipper=clipper, 313 | summaries=params.save_summary) 314 | elif params.optimizer.lower() == "adadelta": 315 | optimizer = optimizers.AdadeltaOptimizer( 316 | learning_rate=schedule, rho=params.adadelta_rho, 317 | epsilon=params.adadelta_epsilon, clipper=clipper, 318 | summaries=params.save_summary) 319 | elif params.optimizer.lower() == "sgd": 320 | optimizer = optimizers.SGDOptimizer( 321 | learning_rate=schedule, clipper=clipper, 322 | summaries=params.save_summary) 323 | else: 324 | raise ValueError("Unknown optimizer %s" % params.optimizer) 325 | 326 | return optimizer 327 | 328 | 329 | def load_references(pattern): 330 | if not pattern: 331 | return None 332 | 333 | files = glob.glob(pattern) 334 | references = [] 335 | 336 | for name in files: 337 | ref = [] 338 | with open(name, "rb") as fd: 339 | for line in fd: 340 | items = line.strip().split() 341 | ref.append(items) 342 | references.append(ref) 343 | 344 | return list(zip(*references)) 345 | 346 | 347 | def main(args): 348 | model_cls = models.get_model(args.model) 349 | 350 | # Import and override parameters 351 | # Priorities (low -> high): 352 | # default -> saved -> command 353 | params = default_params() 354 | params = merge_params(params, model_cls.default_params(args.hparam_set)) 355 | params = import_params(args.output, args.model, params) 356 | params = override_params(params, args) 357 | 358 | # Initialize distributed utility 359 | if args.distributed: 360 | params.device = args.local_rank 361 | dist.init_process_group("nccl") 362 | torch.cuda.set_device(args.local_rank) 363 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 364 | else: 365 | params.device = params.device_list[args.local_rank] 366 | dist.init_process_group("nccl", init_method=args.url, 367 | rank=args.local_rank, 368 | world_size=len(params.device_list)) 369 | torch.cuda.set_device(params.device_list[args.local_rank]) 370 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 371 | 372 | # Export parameters 373 | if dist.get_rank() == 0: 374 | export_params(params.output, "params.json", params) 375 | export_params(params.output, "%s.json" % params.model, 376 | collect_params(params, model_cls.default_params())) 377 | 378 | model = model_cls(params).cuda() 379 | 380 | if args.half: 381 | model = model.half() 382 | torch.set_default_dtype(torch.half) 383 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 384 | 385 | model.train() 386 | 387 | # Init tensorboard 388 | summary.init(params.output, params.save_summary) 389 | 390 | schedule = get_learning_rate_schedule(params) 391 | clipper = get_clipper(params) 392 | optimizer = get_optimizer(params, schedule, clipper) 393 | 394 | if args.half: 395 | optimizer = optimizers.LossScalingOptimizer(optimizer) 396 | 397 | optimizer = optimizers.MultiStepOptimizer(optimizer, params.update_cycle) 398 | 399 | trainable_flags = print_variables(model, params.pattern, 400 | dist.get_rank() == 0) 401 | 402 | dataset = data.MTPipeline.get_train_dataset(params.input, params) 403 | 404 | if params.validation: 405 | sorted_key, eval_dataset = data.MTPipeline.get_infer_dataset( 406 | params.validation, params) 407 | references = load_references(params.references) 408 | else: 409 | sorted_key = None 410 | eval_dataset = None 411 | references = None 412 | 413 | # Load checkpoint 414 | checkpoint = utils.latest_checkpoint(params.output) 415 | 416 | if args.checkpoint is not None: 417 | # Load pre-trained models 418 | state = torch.load(args.checkpoint, map_location="cpu") 419 | model.load_state_dict(state["model"]) 420 | step = params.initial_step 421 | epoch = 0 422 | broadcast(model) 423 | elif checkpoint is not None: 424 | state = torch.load(checkpoint, map_location="cpu") 425 | step = state["step"] 426 | epoch = state["epoch"] 427 | model.load_state_dict(state["model"]) 428 | 429 | if "optimizer" in state: 430 | optimizer.load_state_dict(state["optimizer"]) 431 | else: 432 | step = 0 433 | epoch = 0 434 | broadcast(model) 435 | 436 | def train_fn(inputs): 437 | features, labels = inputs 438 | loss = model(features, labels) 439 | return loss 440 | 441 | counter = 0 442 | 443 | while True: 444 | for features in dataset: 445 | if counter % params.update_cycle == 0: 446 | step += 1 447 | utils.set_global_step(step) 448 | 449 | counter += 1 450 | t = time.time() 451 | loss = train_fn(features) 452 | gradients = optimizer.compute_gradients(loss, 453 | list(model.parameters())) 454 | grads_and_vars = exclude_variables( 455 | trainable_flags, 456 | zip(gradients, list(model.named_parameters()))) 457 | optimizer.apply_gradients(grads_and_vars) 458 | 459 | t = time.time() - t 460 | 461 | summary.scalar("loss", loss, step, write_every_n_steps=1) 462 | summary.scalar("global_step/sec", t, step) 463 | 464 | print("epoch = %d, step = %d, loss = %.3f (%.3f sec)" % 465 | (epoch + 1, step, float(loss), t)) 466 | 467 | if counter % params.update_cycle == 0: 468 | if step >= params.train_steps: 469 | utils.evaluate(model, sorted_key, eval_dataset, 470 | params.output, references, params) 471 | save_checkpoint(step, epoch, model, optimizer, params) 472 | 473 | if dist.get_rank() == 0: 474 | summary.close() 475 | 476 | return 477 | 478 | if step % params.eval_steps == 0: 479 | utils.evaluate(model, sorted_key, eval_dataset, 480 | params.output, references, params) 481 | 482 | if step % params.save_checkpoint_steps == 0: 483 | save_checkpoint(step, epoch, model, optimizer, params) 484 | 485 | epoch += 1 486 | 487 | 488 | # Wrap main function 489 | def process_fn(rank, args): 490 | local_args = copy.copy(args) 491 | local_args.local_rank = rank 492 | main(local_args) 493 | 494 | 495 | def cli_main(): 496 | parsed_args = parse_args() 497 | 498 | if parsed_args.distributed: 499 | main(parsed_args) 500 | else: 501 | # Pick a free port 502 | with socket.socket() as s: 503 | s.bind(("localhost", 0)) 504 | port = s.getsockname()[1] 505 | url = "tcp://localhost:" + str(port) 506 | parsed_args.url = url 507 | 508 | world_size = infer_gpu_num(parsed_args.parameters) 509 | 510 | if world_size > 1: 511 | torch.multiprocessing.spawn(process_fn, args=(parsed_args,), 512 | nprocs=world_size) 513 | else: 514 | process_fn(0, parsed_args) 515 | 516 | 517 | if __name__ == "__main__": 518 | cli_main() 519 | -------------------------------------------------------------------------------- /thumt/bin/translator.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import argparse 9 | import copy 10 | import logging 11 | import os 12 | import re 13 | import six 14 | import socket 15 | import time 16 | import torch 17 | 18 | import thumt.data as data 19 | import torch.distributed as dist 20 | import thumt.models as models 21 | import thumt.utils as utils 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser( 26 | description="Decode input sentences with pre-trained checkpoints.", 27 | usage="translator.py [] [-h | --help]" 28 | ) 29 | 30 | # input files 31 | parser.add_argument("--input", type=str, required=True, nargs="+", 32 | help="Path to input file.") 33 | parser.add_argument("--output", type=str, required=True, 34 | help="Path to output file.") 35 | parser.add_argument("--checkpoints", type=str, required=True, nargs="+", 36 | help="Path to trained checkpoints.") 37 | parser.add_argument("--vocabulary", type=str, nargs=2, required=True, 38 | help="Path to source and target vocabulary.") 39 | 40 | # model and configuration 41 | parser.add_argument("--models", type=str, required=True, nargs="+", 42 | help="Name of the models.") 43 | parser.add_argument("--parameters", type=str, default="", 44 | help="Additional hyper-parameters.") 45 | 46 | # mutually exclusive parameters 47 | group = parser.add_mutually_exclusive_group() 48 | group.add_argument("--half", action="store_true", 49 | help="Enable Half-precision for decoding.") 50 | group.add_argument("--cpu", action="store_true", 51 | help="Enable CPU for decoding.") 52 | 53 | return parser.parse_args() 54 | 55 | 56 | def default_params(): 57 | params = utils.HParams( 58 | input=None, 59 | output=None, 60 | vocabulary=None, 61 | # vocabulary specific 62 | pad="", 63 | bos="", 64 | eos="", 65 | unk="", 66 | device_list=[0], 67 | # decoding 68 | top_beams=1, 69 | beam_size=4, 70 | decode_alpha=0.6, 71 | decode_ratio=1.0, 72 | decode_length=50, 73 | decode_batch_size=32, 74 | ) 75 | 76 | return params 77 | 78 | 79 | def merge_params(params1, params2): 80 | params = utils.HParams() 81 | 82 | for (k, v) in six.iteritems(params1.values()): 83 | params.add_hparam(k, v) 84 | 85 | params_dict = params.values() 86 | 87 | for (k, v) in six.iteritems(params2.values()): 88 | if k in params_dict: 89 | # Override 90 | setattr(params, k, v) 91 | else: 92 | params.add_hparam(k, v) 93 | 94 | return params 95 | 96 | 97 | def import_params(model_dir, model_name, params): 98 | model_dir = os.path.abspath(model_dir) 99 | m_name = os.path.join(model_dir, model_name + ".json") 100 | 101 | if not os.path.exists(m_name): 102 | return params 103 | 104 | with open(m_name) as fd: 105 | logging.info("Restoring model parameters from %s" % m_name) 106 | json_str = fd.readline() 107 | params.parse_json(json_str) 108 | 109 | return params 110 | 111 | 112 | def override_params(params, args): 113 | params.parse(args.parameters.lower()) 114 | 115 | params.vocabulary = { 116 | "source": data.Vocabulary(args.vocabulary[0]), 117 | "target": data.Vocabulary(args.vocabulary[1]) 118 | } 119 | 120 | return params 121 | 122 | 123 | def convert_to_string(tensor, params, direction="target"): 124 | ids = tensor.tolist() 125 | 126 | output = [] 127 | 128 | eos_id = params.vocabulary[direction][params.eos] 129 | 130 | for wid in ids: 131 | if wid == eos_id: 132 | break 133 | output.append(params.vocabulary[direction][wid]) 134 | 135 | output = b" ".join(output) 136 | 137 | return output 138 | 139 | 140 | def infer_gpu_num(param_str): 141 | result = re.match(r".*device_list=\[(.*?)\].*", param_str) 142 | 143 | if not result: 144 | return 1 145 | else: 146 | dev_str = result.groups()[-1] 147 | return len(dev_str.split(",")) 148 | 149 | 150 | def main(args): 151 | # Load configs 152 | model_cls_list = [models.get_model(model) for model in args.models] 153 | params_list = [default_params() for _ in range(len(model_cls_list))] 154 | params_list = [ 155 | merge_params(params, model_cls.default_params()) 156 | for params, model_cls in zip(params_list, model_cls_list)] 157 | params_list = [ 158 | import_params(args.checkpoints[i], args.models[i], params_list[i]) 159 | for i in range(len(args.checkpoints))] 160 | params_list = [ 161 | override_params(params_list[i], args) 162 | for i in range(len(model_cls_list))] 163 | 164 | params = params_list[0] 165 | 166 | if args.cpu: 167 | dist.init_process_group("gloo", 168 | init_method=args.url, 169 | rank=args.local_rank, 170 | world_size=1) 171 | torch.set_default_tensor_type(torch.FloatTensor) 172 | else: 173 | params.device = params.device_list[args.local_rank] 174 | dist.init_process_group("nccl", 175 | init_method=args.url, 176 | rank=args.local_rank, 177 | world_size=len(params.device_list)) 178 | torch.cuda.set_device(params.device_list[args.local_rank]) 179 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 180 | 181 | if args.half: 182 | torch.set_default_dtype(torch.half) 183 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 184 | 185 | # Create model 186 | with torch.no_grad(): 187 | model_list = [] 188 | 189 | for i in range(len(args.models)): 190 | if args.cpu: 191 | model = model_cls_list[i](params_list[i]) 192 | else: 193 | model = model_cls_list[i](params_list[i]).cuda() 194 | 195 | if args.half: 196 | model = model.half() 197 | 198 | model.eval() 199 | model.load_state_dict( 200 | torch.load(utils.latest_checkpoint(args.checkpoints[i]), 201 | map_location="cpu")["model"]) 202 | 203 | model_list.append(model) 204 | 205 | if len(args.input) == 1: 206 | mode = "infer" 207 | sorted_key, dataset = data.MTPipeline.get_infer_dataset( 208 | args.input[0], params) 209 | else: 210 | # Teacher-forcing 211 | mode = "eval" 212 | dataset = data.MTPipeline.get_eval_dataset(args.input, params) 213 | sorted_key = None 214 | 215 | iterator = iter(dataset) 216 | counter = 0 217 | pad_max = 1024 218 | top_beams = params.top_beams 219 | decode_batch_size = params.decode_batch_size 220 | 221 | # Buffers for synchronization 222 | size = torch.zeros([dist.get_world_size()]).long() 223 | t_list = [torch.empty([decode_batch_size, top_beams, pad_max]).long() 224 | for _ in range(dist.get_world_size())] 225 | 226 | all_outputs = [] 227 | 228 | while True: 229 | try: 230 | features = next(iterator) 231 | 232 | if mode == "eval": 233 | features = features[0] 234 | 235 | batch_size = features["source"].shape[0] 236 | except: 237 | features = { 238 | "source": torch.ones([1, 1]).long(), 239 | "source_mask": torch.ones([1, 1]).float() 240 | } 241 | 242 | if mode == "eval": 243 | features["target"] = torch.ones([1, 1]).long() 244 | features["target_mask"] = torch.ones([1, 1]).float() 245 | 246 | batch_size = 0 247 | 248 | t = time.time() 249 | counter += 1 250 | 251 | # Decode 252 | if mode != "eval": 253 | seqs, _ = utils.beam_search(model_list, features, params) 254 | else: 255 | seqs, _ = utils.argmax_decoding(model_list, features, params) 256 | 257 | # Padding 258 | pad_batch = decode_batch_size - seqs.shape[0] 259 | pad_beams = top_beams - seqs.shape[1] 260 | pad_length = pad_max - seqs.shape[2] 261 | seqs = torch.nn.functional.pad( 262 | seqs, (0, pad_length, 0, pad_beams, 0, pad_batch)) 263 | 264 | # Synchronization 265 | size.zero_() 266 | size[dist.get_rank()].copy_(torch.tensor(batch_size)) 267 | 268 | if args.cpu: 269 | t_list[dist.get_rank()].copy_(seqs) 270 | else: 271 | dist.all_reduce(size) 272 | dist.all_gather(t_list, seqs) 273 | 274 | if size.sum() == 0: 275 | break 276 | 277 | if dist.get_rank() != 0: 278 | continue 279 | 280 | for i in range(decode_batch_size): 281 | for j in range(dist.get_world_size()): 282 | beam_seqs = [] 283 | pad_flag = i >= size[j] 284 | for k in range(top_beams): 285 | seq = convert_to_string(t_list[j][i][k], params) 286 | 287 | if pad_flag: 288 | continue 289 | 290 | beam_seqs.append(seq) 291 | 292 | if pad_flag: 293 | continue 294 | 295 | all_outputs.append(beam_seqs) 296 | 297 | t = time.time() - t 298 | print("Finished batch: %d (%.3f sec)" % (counter, t)) 299 | 300 | if dist.get_rank() == 0: 301 | restored_outputs = [] 302 | if sorted_key is not None: 303 | for idx in range(len(all_outputs)): 304 | restored_outputs.append(all_outputs[sorted_key[idx]]) 305 | else: 306 | restored_outputs = all_outputs 307 | 308 | with open(args.output, "wb") as fd: 309 | if top_beams == 1: 310 | for seqs in restored_outputs: 311 | fd.write(seqs[0] + b"\n") 312 | else: 313 | for idx, seqs in enumerate(restored_outputs): 314 | for k, seq in enumerate(seqs): 315 | fd.write(b"%d\t%d\t" % (idx, k)) 316 | fd.write(seq + b"\n") 317 | 318 | 319 | # Wrap main function 320 | def process_fn(rank, args): 321 | local_args = copy.copy(args) 322 | local_args.local_rank = rank 323 | main(local_args) 324 | 325 | 326 | def cli_main(): 327 | parsed_args = parse_args() 328 | 329 | # Pick a free port 330 | with socket.socket() as s: 331 | s.bind(("localhost", 0)) 332 | port = s.getsockname()[1] 333 | url = "tcp://localhost:" + str(port) 334 | parsed_args.url = url 335 | 336 | if parsed_args.cpu: 337 | world_size = 1 338 | else: 339 | world_size = infer_gpu_num(parsed_args.parameters) 340 | 341 | if world_size > 1: 342 | torch.multiprocessing.spawn(process_fn, args=(parsed_args,), 343 | nprocs=world_size) 344 | else: 345 | process_fn(0, parsed_args) 346 | 347 | 348 | if __name__ == "__main__": 349 | cli_main() 350 | -------------------------------------------------------------------------------- /thumt/data/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-Present The THUMT Authors 3 | 4 | from thumt.data.dataset import Dataset, ElementSpec, MapFunc, TextLineDataset 5 | from thumt.data.pipeline import MTPipeline 6 | from thumt.data.vocab import Vocabulary 7 | -------------------------------------------------------------------------------- /thumt/data/iterator.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-Present The THUMT Authors 3 | 4 | import abc 5 | import time 6 | import queue 7 | import threading 8 | 9 | from typing import Any, Dict, List, NoReturn, Tuple, Union 10 | 11 | 12 | def _profile(msg: str, enable: bool = True): 13 | def decorator(func): 14 | def on_call(*args, **kwargs): 15 | start_time = time.perf_counter() 16 | ret = func(*args, **kwargs) 17 | 18 | if enable: 19 | print(msg, time.perf_counter() - start_time) 20 | return ret 21 | 22 | return on_call 23 | 24 | return decorator 25 | 26 | 27 | def _maybe_to_tuple(x): 28 | return x if isinstance(x, tuple) else (x,) 29 | 30 | 31 | def _unzip(x): 32 | return list(zip(*x)) 33 | 34 | 35 | class _FileWrapper(object): 36 | 37 | def __init__(self, buffer: List): 38 | self._buffer = buffer 39 | self._index = 0 40 | 41 | def __iter__(self): 42 | return self 43 | 44 | def __next__(self): 45 | if self._index >= len(self._buffer): 46 | raise StopIteration 47 | 48 | line = self._buffer[self._index] 49 | self._index += 1 50 | 51 | return line 52 | 53 | def readline(self): 54 | try: 55 | line = self._buffer[self._index] 56 | self._index += 1 57 | except: 58 | line = "" 59 | return line 60 | 61 | def readlines(self): 62 | return self._buffer 63 | 64 | def seek(self, offset: int): 65 | self._index = offset 66 | 67 | def tell(self): 68 | return self._index 69 | 70 | 71 | class _DatasetWorker(threading.Thread): 72 | 73 | def init(self, dataset: "Dataset", id: int = 0, buffer_size: int = 64): 74 | self._iterator = iter(dataset) 75 | self._buffer = queue.Queue(buffer_size) 76 | self._buffer_size = buffer_size 77 | self._empty = False 78 | self._id = id 79 | 80 | def get(self) -> Any: 81 | if self._empty and self._buffer.empty(): 82 | return None 83 | 84 | return self._buffer.get() 85 | 86 | def run(self) -> None: 87 | while True: 88 | try: 89 | self._buffer.put(next(self._iterator)) 90 | except StopIteration: 91 | break 92 | 93 | self._empty = True 94 | 95 | def is_empty(self) -> bool: 96 | return self._empty 97 | 98 | 99 | class IteratorBase(object): 100 | 101 | def __init__(self): 102 | pass 103 | 104 | def __iter__(self) -> "IteratorBase": 105 | return self 106 | 107 | def state(self) -> Dict: 108 | return {} 109 | 110 | @abc.abstractmethod 111 | def __next__(self) -> NoReturn: 112 | raise NotImplementedError("IteratorBase.__next__ not implemented.") 113 | 114 | 115 | class _BackgroundDSIter(IteratorBase): 116 | 117 | def __init__(self, dataset: "BackgroundDataset"): 118 | self._thread = _DatasetWorker(daemon=True) 119 | self._thread.init(dataset._dataset) 120 | self._thread.start() 121 | 122 | def __next__(self) -> Any: 123 | item = self._thread.get() 124 | 125 | if item is None: 126 | self._thread.join() 127 | raise StopIteration 128 | 129 | return item 130 | 131 | 132 | class _BucketDSIter(IteratorBase): 133 | 134 | def __init__(self, dataset: "BucketDataset"): 135 | self._pad = dataset.pad 136 | self._bucket_boundaries = dataset.bucket_boundaries 137 | self._batch_sizes = dataset.batch_sizes 138 | self._iterator = iter(dataset._dataset) 139 | self._spec = dataset.element_spec 140 | self._buckets = [[] for _ in dataset.batch_sizes] 141 | self._priority = [k for k in range(len(dataset.batch_sizes))] 142 | self._min_length = dataset.min_length 143 | self._max_length = dataset.max_length 144 | self._max_fill = max(dataset.batch_sizes) 145 | self._bucket_map = {} 146 | 147 | # length to bucket index 148 | max_len = max(dataset.bucket_boundaries) 149 | idx = 0 150 | 151 | # [0, max_boundary] 152 | for i in range(0, max_len + 1): 153 | for idx in range(len(self._bucket_boundaries)): 154 | if i <= self._bucket_boundaries[idx]: 155 | self._bucket_map[i] = idx 156 | break 157 | 158 | super(_BucketDSIter, self).__init__() 159 | 160 | 161 | def __iter__(self) -> "_BucketDSIter": 162 | return self 163 | 164 | @_profile("_BucketDSIter", False) 165 | def __next__(self) -> Union[List[List[int]], 166 | Tuple[List[List[int]], ...]]: 167 | try: 168 | while True: 169 | idx = self._get_bucket() 170 | 171 | if idx >= 0: 172 | return self._get_content(idx) 173 | else: 174 | self._fill() 175 | except StopIteration: 176 | idx = self._get_nonempty_bucket() 177 | 178 | if idx < 0: 179 | raise StopIteration 180 | 181 | return self._get_content(idx) 182 | 183 | @_profile("_BucketDSIter_fill", False) 184 | def _fill(self) -> None: 185 | for i in range(self._max_fill): 186 | items = next(self._iterator) 187 | 188 | if not isinstance(items, tuple): 189 | items = (items,) 190 | 191 | max_length = max([len(item) for item in items]) 192 | 193 | if max_length < self._min_length: 194 | continue 195 | 196 | if max_length > self._max_length: 197 | continue 198 | 199 | if max_length in self._bucket_map: 200 | idx = self._bucket_map[max_length] 201 | self._buckets[idx].append(items) 202 | else: 203 | self._buckets[-1].append(items) 204 | 205 | def _get_content(self, idx: int) -> List: 206 | idx = self._priority.pop(idx) 207 | self._priority.append(idx) 208 | 209 | bucket = self._buckets[idx] 210 | outs = bucket[:self._batch_sizes[idx]] 211 | self._buckets[idx] = bucket[self._batch_sizes[idx]:] 212 | 213 | content = tuple([list(item) for item in zip(*outs)]) 214 | content = self._pad_batch(content) 215 | 216 | if self._spec.elem_type is List[List[int]]: 217 | return self._pad_batch(content)[0] 218 | else: 219 | return self._pad_batch(content) 220 | 221 | def _pad_batch(self, batch: tuple) -> List: 222 | for bat in batch: 223 | max_len = max([len(item) for item in bat]) 224 | 225 | for seq in bat: 226 | for _ in range(len(seq), max_len): 227 | seq.append(self._pad) 228 | 229 | return batch 230 | 231 | def _get_bucket(self) -> int: 232 | for i, idx in enumerate(self._priority): 233 | if len(self._buckets[idx]) >= self._batch_sizes[idx]: 234 | return i 235 | 236 | return -1 237 | 238 | def _get_nonempty_bucket(self) -> int: 239 | for i, idx in enumerate(self._priority): 240 | if len(self._buckets[idx]) > 0: 241 | return i 242 | 243 | return -1 244 | 245 | 246 | class _LookupDSIter(IteratorBase): 247 | 248 | def __init__(self, dataset: "LookupDataset"): 249 | self._unk_id = dataset.unk_id 250 | self._vocabulary = dataset.vocabulary 251 | self._iterator = iter(dataset._dataset) 252 | 253 | @_profile("_LookupDSIter", False) 254 | def __next__(self) -> List[int]: 255 | outputs = [] 256 | 257 | for s in next(self._iterator): 258 | if s not in self._vocabulary: 259 | outputs.append(self._unk_id) 260 | else: 261 | outputs.append(self._vocabulary[s]) 262 | 263 | return outputs 264 | 265 | 266 | class _MapDSIter(IteratorBase): 267 | 268 | def __init__(self, dataset: "MapDataset"): 269 | self._fn = dataset._fn 270 | self._iterator = iter(dataset._dataset) 271 | 272 | @_profile("_LookupDSIter", False) 273 | def __next__(self) -> Any: 274 | item = next(self._iterator) 275 | 276 | return self._fn(item) 277 | 278 | 279 | class _PaddedBatchDSIter(IteratorBase): 280 | 281 | def __init__(self, dataset: "PaddedBatchDataset"): 282 | self._pad = dataset.pad 283 | self._batch_size = dataset.batch_size 284 | self._iterator = iter(dataset._dataset) 285 | self._spec = dataset.element_spec 286 | 287 | super(_PaddedBatchDSIter, self).__init__() 288 | 289 | 290 | def __iter__(self) -> "_PaddedBatchDSIter": 291 | return self 292 | 293 | @_profile("_PaddedBatchDSIter", False) 294 | def __next__(self) -> Union[List[List[int]], 295 | Tuple[List[List[int]], ...]]: 296 | bucket = [] 297 | 298 | try: 299 | for _ in range(self._batch_size): 300 | bucket.append(_maybe_to_tuple(next(self._iterator))) 301 | except StopIteration: 302 | if len(bucket) == 0: 303 | raise StopIteration 304 | 305 | # unzip 306 | bucket = list(map(lambda x: list(x), _unzip(bucket))) 307 | max_lens = map(lambda x: max(list(map(lambda v: len(v), x))), bucket) 308 | 309 | outputs = [] 310 | 311 | for seqs, max_len in zip(bucket, max_lens): 312 | outputs.append(self._pad_batch(seqs, max_len)) 313 | 314 | if self._spec.elem_type is List[List[int]]: 315 | return bucket[0] 316 | else: 317 | return bucket 318 | 319 | def _pad_batch(self, seqs: List, max_len: int) -> List: 320 | for seq in seqs: 321 | for _ in range(max_len - len(seq)): 322 | seq.append(self._pad) 323 | 324 | return seqs 325 | 326 | class _RepeatDSIter(IteratorBase): 327 | 328 | def __init__(self, dataset: "RepeatDataset"): 329 | self._dataset = dataset 330 | self._iterator = iter(dataset._dataset) 331 | self._n = 0 332 | self._count = dataset.count 333 | 334 | @_profile("_RepeatDSIter", False) 335 | def __next__(self) -> Any: 336 | try: 337 | return next(self._iterator) 338 | except StopIteration: 339 | self._n = self._n + 1 340 | 341 | if self._count <= 0 or self._n < self._count: 342 | self._iterator = iter(self._dataset) 343 | return next(self._iterator) 344 | 345 | raise StopIteration 346 | 347 | 348 | class _ShardDSIter(IteratorBase): 349 | 350 | def __init__(self, dataset: "ShardDataset"): 351 | self._num_shards = dataset.num_shards 352 | self._index = dataset._index 353 | self._n = 0 354 | self._iterator = iter(dataset._dataset) 355 | 356 | @_profile("_ShardDsIter", False) 357 | def __next__(self) -> Any: 358 | while self._n != self._index: 359 | next(self._iterator) 360 | self._n = (self._n + 1) % self._num_shards 361 | 362 | self._n = (self._n + 1) % self._num_shards 363 | 364 | return next(self._iterator) 365 | 366 | 367 | class _TextLineDSIter(IteratorBase): 368 | 369 | def __init__(self, dataset: "TextLineDataset"): 370 | if isinstance(dataset.input_source, str): 371 | self._file = open(dataset.input_source, "rb") 372 | else: 373 | self._file = _FileWrapper(dataset.input_source) 374 | 375 | @_profile("_TextLineDSIter", False) 376 | def __next__(self) -> bytes: 377 | return next(self._file) 378 | 379 | 380 | class _TokenizedLineDSIter(IteratorBase): 381 | 382 | def __init__(self, dataset: "Dataset"): 383 | self._bos = dataset.bos 384 | self._eos = dataset.eos 385 | self._tokenizer = dataset.tokenizer 386 | self._iterator = iter(dataset._dataset) 387 | 388 | @_profile("_TokenizedLineDSIter", False) 389 | def __next__(self) -> List[bytes]: 390 | val = self._tokenizer.encode(next(self._iterator)) 391 | 392 | if self._bos: 393 | val.insert(0, self._bos) 394 | 395 | if self._eos: 396 | val.append(self._eos) 397 | 398 | return val 399 | 400 | 401 | class _ZipDSIter(IteratorBase): 402 | 403 | def __init__(self, dataset: "ZipDataset"): 404 | self._iterators = [iter(ds) for ds in dataset._datasets] 405 | 406 | @_profile("_ZipDSIter", False) 407 | def __next__(self) -> Tuple: 408 | outputs = [] 409 | 410 | for iterator in self._iterators: 411 | outputs.append(next(iterator)) 412 | 413 | return tuple(outputs) 414 | 415 | 416 | _DATASET_TO_ITER = { 417 | "BackgroundDataset": _BackgroundDSIter, 418 | "BucketDataset": _BucketDSIter, 419 | "LookupDataset": _LookupDSIter, 420 | "MapDataset": _MapDSIter, 421 | "PaddedBatchDataset": _PaddedBatchDSIter, 422 | "RepeatDataset": _RepeatDSIter, 423 | "ShardDataset": _ShardDSIter, 424 | "TextLineDataset": _TextLineDSIter, 425 | "TokenizedLineDataset": _TokenizedLineDSIter, 426 | "ZipDataset": _ZipDSIter 427 | } 428 | 429 | 430 | class Iterator(IteratorBase): 431 | 432 | def __init__(self, dataset: "Dataset"): 433 | self._iterator = _DATASET_TO_ITER[dataset.name](dataset) 434 | 435 | def __next__(self): 436 | return next(self._iterator) 437 | -------------------------------------------------------------------------------- /thumt/data/pipeline.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-Present The THUMT Authors 3 | 4 | import torch 5 | 6 | from thumt.data.dataset import Dataset, ElementSpec, MapFunc, TextLineDataset 7 | from thumt.data.vocab import Vocabulary 8 | from thumt.tokenizers import WhiteSpaceTokenizer 9 | 10 | 11 | def _sort_input_file(filename, reverse=True): 12 | with open(filename, "rb") as fd: 13 | inputs = [line.strip() for line in fd] 14 | 15 | input_lens = [ 16 | (i, len(line.split())) for i, line in enumerate(inputs)] 17 | 18 | sorted_input_lens = sorted(input_lens, key=lambda x: x[1], 19 | reverse=reverse) 20 | sorted_keys = {} 21 | sorted_inputs = [] 22 | 23 | for i, (idx, _) in enumerate(sorted_input_lens): 24 | sorted_inputs.append(inputs[idx]) 25 | sorted_keys[idx] = i 26 | 27 | return sorted_keys, sorted_inputs 28 | 29 | 30 | class MTPipeline(object): 31 | 32 | @staticmethod 33 | def get_train_dataset(filenames, params, cpu=False): 34 | src_vocab = params.vocabulary["source"] 35 | tgt_vocab = params.vocabulary["target"] 36 | 37 | src_dataset = TextLineDataset(filenames[0]) 38 | tgt_dataset = TextLineDataset(filenames[1]) 39 | lab_dataset = TextLineDataset(filenames[1]) 40 | 41 | src_dataset = src_dataset.tokenize(WhiteSpaceTokenizer(), 42 | None, params.eos) 43 | tgt_dataset = tgt_dataset.tokenize(WhiteSpaceTokenizer(), 44 | params.bos, None) 45 | lab_dataset = lab_dataset.tokenize(WhiteSpaceTokenizer(), 46 | None, params.eos) 47 | src_dataset = Dataset.lookup(src_dataset, src_vocab, 48 | src_vocab[params.unk]) 49 | tgt_dataset = Dataset.lookup(tgt_dataset, tgt_vocab, 50 | tgt_vocab[params.unk]) 51 | lab_dataset = Dataset.lookup(lab_dataset, tgt_vocab, 52 | tgt_vocab[params.unk]) 53 | 54 | dataset = Dataset.zip((src_dataset, tgt_dataset, lab_dataset)) 55 | dataset = dataset.shard(torch.distributed.get_world_size(), 56 | torch.distributed.get_rank()) 57 | 58 | 59 | def bucket_boundaries(max_length, min_length=8, step=8): 60 | x = min_length 61 | boundaries = [] 62 | 63 | while x <= max_length: 64 | boundaries.append(x + 1) 65 | x += step 66 | 67 | return boundaries 68 | 69 | batch_size = params.batch_size 70 | max_length = (params.max_length // 8) * 8 71 | min_length = params.min_length 72 | boundaries = bucket_boundaries(max_length) 73 | batch_sizes = [max(1, batch_size // (x - 1)) 74 | if not params.fixed_batch_size else batch_size 75 | for x in boundaries] + [1] 76 | 77 | dataset = Dataset.bucket_by_sequence_length( 78 | dataset, boundaries, batch_sizes, pad=src_vocab[params.pad], 79 | min_length=params.min_length, max_length=params.max_length) 80 | 81 | def map_fn(inputs): 82 | src_seq, tgt_seq, labels = inputs 83 | src_seq = torch.tensor(src_seq) 84 | tgt_seq = torch.tensor(tgt_seq) 85 | labels = torch.tensor(labels) 86 | src_mask = src_seq != params.vocabulary["source"][params.pad] 87 | tgt_mask = tgt_seq != params.vocabulary["target"][params.pad] 88 | src_mask = src_mask.float() 89 | tgt_mask = tgt_mask.float() 90 | 91 | if not cpu: 92 | src_seq = src_seq.cuda(params.device) 93 | src_mask = src_mask.cuda(params.device) 94 | tgt_seq = tgt_seq.cuda(params.device) 95 | tgt_mask = tgt_mask.cuda(params.device) 96 | 97 | features = { 98 | "source": src_seq, 99 | "source_mask": src_mask, 100 | "target": tgt_seq, 101 | "target_mask": tgt_mask 102 | } 103 | 104 | return features, labels 105 | 106 | map_obj = MapFunc(map_fn, ElementSpec("Tensor", "{key: [None, None]}")) 107 | 108 | dataset = dataset.map(map_obj) 109 | dataset = dataset.background() 110 | 111 | return dataset 112 | 113 | @staticmethod 114 | def get_eval_dataset(filenames, params, cpu=False): 115 | src_vocab = params.vocabulary["source"] 116 | tgt_vocab = params.vocabulary["target"] 117 | 118 | src_dataset = TextLineDataset(filenames[0]) 119 | tgt_dataset = TextLineDataset(filenames[1]) 120 | lab_dataset = TextLineDataset(filenames[1]) 121 | 122 | src_dataset = src_dataset.tokenize(WhiteSpaceTokenizer(), 123 | None, params.eos) 124 | tgt_dataset = tgt_dataset.tokenize(WhiteSpaceTokenizer(), 125 | params.bos, None) 126 | lab_dataset = lab_dataset.tokenize(WhiteSpaceTokenizer(), 127 | None, params.eos) 128 | src_dataset = Dataset.lookup(src_dataset, src_vocab, 129 | src_vocab[params.unk]) 130 | tgt_dataset = Dataset.lookup(tgt_dataset, tgt_vocab, 131 | tgt_vocab[params.unk]) 132 | lab_dataset = Dataset.lookup(lab_dataset, tgt_vocab, 133 | tgt_vocab[params.unk]) 134 | 135 | dataset = Dataset.zip((src_dataset, tgt_dataset, lab_dataset)) 136 | dataset = dataset.shard(torch.distributed.get_world_size(), 137 | torch.distributed.get_rank()) 138 | 139 | dataset = dataset.padded_batch(params.decode_batch_size, 140 | pad=src_vocab[params.pad]) 141 | 142 | def map_fn(inputs): 143 | src_seq, tgt_seq, labels = inputs 144 | src_seq = torch.tensor(src_seq) 145 | tgt_seq = torch.tensor(tgt_seq) 146 | labels = torch.tensor(labels) 147 | src_mask = src_seq != params.vocabulary["source"][params.pad] 148 | tgt_mask = tgt_seq != params.vocabulary["target"][params.pad] 149 | src_mask = src_mask.float() 150 | tgt_mask = tgt_mask.float() 151 | 152 | if not cpu: 153 | src_seq = src_seq.cuda(params.device) 154 | src_mask = src_mask.cuda(params.device) 155 | tgt_seq = tgt_seq.cuda(params.device) 156 | tgt_mask = tgt_mask.cuda(params.device) 157 | 158 | features = { 159 | "source": src_seq, 160 | "source_mask": src_mask, 161 | "target": tgt_seq, 162 | "target_mask": tgt_mask 163 | } 164 | 165 | return features, labels 166 | 167 | map_obj = MapFunc(map_fn, ElementSpec("Tensor", "{key: [None, None]}")) 168 | 169 | dataset = dataset.map(map_obj) 170 | dataset = dataset.background() 171 | 172 | return dataset 173 | 174 | @staticmethod 175 | def get_infer_dataset(filename, params, cpu=False): 176 | sorted_keys, sorted_data = _sort_input_file(filename) 177 | src_vocab = params.vocabulary["source"] 178 | 179 | src_dataset = TextLineDataset(sorted_data) 180 | src_dataset = src_dataset.tokenize(WhiteSpaceTokenizer(), 181 | None, params.eos) 182 | src_dataset = Dataset.lookup(src_dataset, src_vocab, 183 | src_vocab[params.unk]) 184 | dataset = src_dataset.shard(torch.distributed.get_world_size(), 185 | torch.distributed.get_rank()) 186 | 187 | dataset = dataset.padded_batch(params.decode_batch_size, 188 | pad=src_vocab[params.pad]) 189 | 190 | def map_fn(inputs): 191 | src_seq = torch.tensor(inputs) 192 | src_mask = src_seq != params.vocabulary["source"][params.pad] 193 | src_mask = src_mask.float() 194 | 195 | if not cpu: 196 | src_seq = src_seq.cuda(params.device) 197 | src_mask = src_mask.cuda(params.device) 198 | 199 | features = { 200 | "source": src_seq, 201 | "source_mask": src_mask, 202 | } 203 | 204 | return features 205 | 206 | map_obj = MapFunc(map_fn, ElementSpec("Tensor", "{key: [None, None]}")) 207 | 208 | dataset = dataset.map(map_obj) 209 | dataset = dataset.background() 210 | 211 | return sorted_keys, dataset 212 | -------------------------------------------------------------------------------- /thumt/data/vocab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-Present The THUMT Authors 3 | 4 | import numpy as np 5 | import six 6 | import torch 7 | 8 | from typing import Union 9 | 10 | 11 | class Vocabulary(object): 12 | 13 | def __init__(self, filename): 14 | self._idx2word = {} 15 | self._word2idx = {} 16 | cnt = 0 17 | 18 | with open(filename, "rb") as fd: 19 | for line in fd: 20 | self._word2idx[line.strip()] = cnt 21 | self._idx2word[cnt] = line.strip() 22 | cnt = cnt + 1 23 | 24 | def __getitem__(self, key: Union[bytes, int]): 25 | if isinstance(key, int): 26 | return self._idx2word[key] 27 | elif isinstance(key, bytes): 28 | return self._word2idx[key] 29 | elif isinstance(key, str): 30 | key = key.encode("utf-8") 31 | return self._word2idx[key] 32 | else: 33 | raise LookupError("Cannot lookup key %s." % key) 34 | 35 | def __contains__(self, key): 36 | if isinstance(key, str): 37 | key = key.encode("utf-8") 38 | 39 | return key in self._word2idx 40 | 41 | def __iter__(self): 42 | return six.iterkeys(self._word2idx) 43 | 44 | def __len__(self): 45 | return len(self._idx2word) 46 | -------------------------------------------------------------------------------- /thumt/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import thumt.models.transformer 9 | 10 | 11 | def get_model(name): 12 | name = name.lower() 13 | 14 | if name == "transformer": 15 | return thumt.models.transformer.Transformer 16 | else: 17 | raise LookupError("Unknown model %s" % name) 18 | -------------------------------------------------------------------------------- /thumt/models/transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | 12 | import thumt.utils as utils 13 | import thumt.modules as modules 14 | 15 | 16 | class AttentionSubLayer(modules.Module): 17 | 18 | def __init__(self, params, name="attention"): 19 | super(AttentionSubLayer, self).__init__(name=name) 20 | 21 | self.dropout = params.residual_dropout 22 | self.normalization = params.normalization 23 | 24 | with utils.scope(name): 25 | self.attention = modules.MultiHeadAttention( 26 | params.hidden_size, params.num_heads, params.attention_dropout) 27 | self.layer_norm = modules.LayerNorm(params.hidden_size) 28 | 29 | def forward(self, x, bias, memory=None, state=None): 30 | if self.normalization == "before": 31 | y = self.layer_norm(x) 32 | else: 33 | y = x 34 | 35 | if self.training or state is None: 36 | y = self.attention(y, bias, memory, None) 37 | else: 38 | kv = [state["k"], state["v"]] 39 | y, k, v = self.attention(y, bias, memory, kv) 40 | state["k"], state["v"] = k, v 41 | 42 | y = nn.functional.dropout(y, self.dropout, self.training) 43 | 44 | if self.normalization == "before": 45 | return x + y 46 | else: 47 | return self.layer_norm(x + y) 48 | 49 | 50 | class FFNSubLayer(modules.Module): 51 | 52 | def __init__(self, params, dtype=None, name="ffn_layer"): 53 | super(FFNSubLayer, self).__init__(name=name) 54 | 55 | self.dropout = params.residual_dropout 56 | self.normalization = params.normalization 57 | 58 | with utils.scope(name): 59 | self.ffn_layer = modules.FeedForward(params.hidden_size, 60 | params.filter_size, 61 | dropout=params.relu_dropout) 62 | self.layer_norm = modules.LayerNorm(params.hidden_size) 63 | 64 | def forward(self, x): 65 | if self.normalization == "before": 66 | y = self.layer_norm(x) 67 | else: 68 | y = x 69 | 70 | y = self.ffn_layer(y) 71 | y = nn.functional.dropout(y, self.dropout, self.training) 72 | 73 | if self.normalization == "before": 74 | return x + y 75 | else: 76 | return self.layer_norm(x + y) 77 | 78 | 79 | class TransformerEncoderLayer(modules.Module): 80 | 81 | def __init__(self, params, name="layer"): 82 | super(TransformerEncoderLayer, self).__init__(name=name) 83 | 84 | with utils.scope(name): 85 | self.self_attention = AttentionSubLayer(params) 86 | self.feed_forward = FFNSubLayer(params) 87 | 88 | def forward(self, x, bias): 89 | x = self.self_attention(x, bias) 90 | x = self.feed_forward(x) 91 | return x 92 | 93 | 94 | class TransformerDecoderLayer(modules.Module): 95 | 96 | def __init__(self, params, name="layer"): 97 | super(TransformerDecoderLayer, self).__init__(name=name) 98 | 99 | with utils.scope(name): 100 | self.self_attention = AttentionSubLayer(params, 101 | name="self_attention") 102 | self.encdec_attention = AttentionSubLayer(params, 103 | name="encdec_attention") 104 | self.feed_forward = FFNSubLayer(params) 105 | 106 | def __call__(self, x, attn_bias, encdec_bias, memory, state=None): 107 | x = self.self_attention(x, attn_bias, state=state) 108 | x = self.encdec_attention(x, encdec_bias, memory) 109 | x = self.feed_forward(x) 110 | return x 111 | 112 | 113 | class TransformerEncoder(modules.Module): 114 | 115 | def __init__(self, params, name="encoder"): 116 | super(TransformerEncoder, self).__init__(name=name) 117 | 118 | self.normalization = params.normalization 119 | 120 | with utils.scope(name): 121 | self.layers = nn.ModuleList([ 122 | TransformerEncoderLayer(params, name="layer_%d" % i) 123 | for i in range(params.num_encoder_layers)]) 124 | if self.normalization == "before": 125 | self.layer_norm = modules.LayerNorm(params.hidden_size) 126 | else: 127 | self.layer_norm = None 128 | 129 | def forward(self, x, bias): 130 | for layer in self.layers: 131 | x = layer(x, bias) 132 | 133 | if self.normalization == "before": 134 | x = self.layer_norm(x) 135 | 136 | return x 137 | 138 | 139 | class TransformerDecoder(modules.Module): 140 | 141 | def __init__(self, params, name="decoder"): 142 | super(TransformerDecoder, self).__init__(name=name) 143 | 144 | self.normalization = params.normalization 145 | 146 | with utils.scope(name): 147 | self.layers = nn.ModuleList([ 148 | TransformerDecoderLayer(params, name="layer_%d" % i) 149 | for i in range(params.num_decoder_layers)]) 150 | 151 | if self.normalization == "before": 152 | self.layer_norm = modules.LayerNorm(params.hidden_size) 153 | else: 154 | self.layer_norm = None 155 | 156 | def forward(self, x, attn_bias, encdec_bias, memory, state=None): 157 | for i, layer in enumerate(self.layers): 158 | if state is not None: 159 | x = layer(x, attn_bias, encdec_bias, memory, 160 | state["decoder"]["layer_%d" % i]) 161 | else: 162 | x = layer(x, attn_bias, encdec_bias, memory, None) 163 | 164 | if self.normalization == "before": 165 | x = self.layer_norm(x) 166 | 167 | return x 168 | 169 | 170 | class Transformer(modules.Module): 171 | 172 | def __init__(self, params, name="transformer"): 173 | super(Transformer, self).__init__(name=name) 174 | self.params = params 175 | 176 | with utils.scope(name): 177 | self.build_embedding(params) 178 | self.encoding = modules.PositionalEmbedding() 179 | self.encoder = TransformerEncoder(params) 180 | self.decoder = TransformerDecoder(params) 181 | 182 | self.criterion = modules.SmoothedCrossEntropyLoss( 183 | params.label_smoothing) 184 | self.dropout = params.residual_dropout 185 | self.hidden_size = params.hidden_size 186 | self.num_encoder_layers = params.num_encoder_layers 187 | self.num_decoder_layers = params.num_decoder_layers 188 | self.reset_parameters() 189 | 190 | def build_embedding(self, params): 191 | svoc_size = len(params.vocabulary["source"]) 192 | tvoc_size = len(params.vocabulary["target"]) 193 | 194 | if params.shared_source_target_embedding and svoc_size != tvoc_size: 195 | raise ValueError("Cannot share source and target embedding.") 196 | 197 | if not params.shared_embedding_and_softmax_weights: 198 | self.softmax_weights = torch.nn.Parameter( 199 | torch.empty([tvoc_size, params.hidden_size])) 200 | self.add_name(self.softmax_weights, "softmax_weights") 201 | 202 | if not params.shared_source_target_embedding: 203 | self.source_embedding = torch.nn.Parameter( 204 | torch.empty([svoc_size, params.hidden_size])) 205 | self.target_embedding = torch.nn.Parameter( 206 | torch.empty([tvoc_size, params.hidden_size])) 207 | self.add_name(self.source_embedding, "source_embedding") 208 | self.add_name(self.target_embedding, "target_embedding") 209 | else: 210 | self.weights = torch.nn.Parameter( 211 | torch.empty([svoc_size, params.hidden_size])) 212 | self.add_name(self.weights, "weights") 213 | 214 | self.bias = torch.nn.Parameter(torch.zeros([params.hidden_size])) 215 | self.add_name(self.bias, "bias") 216 | 217 | @property 218 | def src_embedding(self): 219 | if self.params.shared_source_target_embedding: 220 | return self.weights 221 | else: 222 | return self.source_embedding 223 | 224 | @property 225 | def tgt_embedding(self): 226 | if self.params.shared_source_target_embedding: 227 | return self.weights 228 | else: 229 | return self.target_embedding 230 | 231 | @property 232 | def softmax_embedding(self): 233 | if not self.params.shared_embedding_and_softmax_weights: 234 | return self.softmax_weights 235 | else: 236 | return self.tgt_embedding 237 | 238 | def reset_parameters(self): 239 | nn.init.normal_(self.src_embedding, mean=0.0, 240 | std=self.params.hidden_size ** -0.5) 241 | nn.init.normal_(self.tgt_embedding, mean=0.0, 242 | std=self.params.hidden_size ** -0.5) 243 | 244 | if not self.params.shared_embedding_and_softmax_weights: 245 | nn.init.normal_(self.softmax_weights, mean=0.0, 246 | std=self.params.hidden_size ** -0.5) 247 | 248 | def encode(self, features, state): 249 | src_seq = features["source"] 250 | src_mask = features["source_mask"] 251 | enc_attn_bias = self.masking_bias(src_mask) 252 | 253 | inputs = torch.nn.functional.embedding(src_seq, self.src_embedding) 254 | inputs = inputs * (self.hidden_size ** 0.5) 255 | inputs = inputs + self.bias 256 | inputs = nn.functional.dropout(self.encoding(inputs), self.dropout, 257 | self.training) 258 | 259 | enc_attn_bias = enc_attn_bias.to(inputs) 260 | encoder_output = self.encoder(inputs, enc_attn_bias) 261 | 262 | state["encoder_output"] = encoder_output 263 | state["enc_attn_bias"] = enc_attn_bias 264 | 265 | return state 266 | 267 | def decode(self, features, state, mode="infer"): 268 | tgt_seq = features["target"] 269 | 270 | enc_attn_bias = state["enc_attn_bias"] 271 | dec_attn_bias = self.causal_bias(tgt_seq.shape[1]) 272 | 273 | targets = torch.nn.functional.embedding(tgt_seq, self.tgt_embedding) 274 | targets = targets * (self.hidden_size ** 0.5) 275 | 276 | decoder_input = torch.cat( 277 | [targets.new_zeros([targets.shape[0], 1, targets.shape[-1]]), 278 | targets[:, 1:, :]], dim=1) 279 | decoder_input = nn.functional.dropout(self.encoding(decoder_input), 280 | self.dropout, self.training) 281 | 282 | encoder_output = state["encoder_output"] 283 | dec_attn_bias = dec_attn_bias.to(targets) 284 | 285 | if mode == "infer": 286 | decoder_input = decoder_input[:, -1:, :] 287 | dec_attn_bias = dec_attn_bias[:, :, -1:, :] 288 | 289 | decoder_output = self.decoder(decoder_input, dec_attn_bias, 290 | enc_attn_bias, encoder_output, state) 291 | 292 | decoder_output = torch.reshape(decoder_output, [-1, self.hidden_size]) 293 | decoder_output = torch.transpose(decoder_output, -1, -2) 294 | logits = torch.matmul(self.softmax_embedding, decoder_output) 295 | logits = torch.transpose(logits, 0, 1) 296 | 297 | return logits, state 298 | 299 | def forward(self, features, labels, mode="train", level="sentence"): 300 | mask = features["target_mask"] 301 | 302 | state = self.empty_state(features["target"].shape[0], 303 | labels.device) 304 | state = self.encode(features, state) 305 | logits, _ = self.decode(features, state, mode=mode) 306 | loss = self.criterion(logits, labels) 307 | mask = mask.to(torch.float32) 308 | 309 | # Prevent FP16 overflow 310 | if loss.dtype == torch.float16: 311 | loss = loss.to(torch.float32) 312 | 313 | if mode == "eval": 314 | if level == "sentence": 315 | return -torch.sum(loss * mask, 1) 316 | else: 317 | return torch.exp(-loss) * mask - (1 - mask) 318 | 319 | return (torch.sum(loss * mask) / torch.sum(mask)).to(logits) 320 | 321 | def empty_state(self, batch_size, device): 322 | state = { 323 | "decoder": { 324 | "layer_%d" % i: { 325 | "k": torch.zeros([batch_size, 0, self.hidden_size], 326 | device=device), 327 | "v": torch.zeros([batch_size, 0, self.hidden_size], 328 | device=device) 329 | } for i in range(self.num_decoder_layers) 330 | } 331 | } 332 | 333 | return state 334 | 335 | @staticmethod 336 | def masking_bias(mask, inf=-1e9): 337 | ret = (1.0 - mask) * inf 338 | return torch.unsqueeze(torch.unsqueeze(ret, 1), 1) 339 | 340 | @staticmethod 341 | def causal_bias(length, inf=-1e9): 342 | ret = torch.ones([length, length]) * inf 343 | ret = torch.triu(ret, diagonal=1) 344 | return torch.reshape(ret, [1, 1, length, length]) 345 | 346 | @staticmethod 347 | def base_params(): 348 | params = utils.HParams( 349 | pad="", 350 | bos="", 351 | eos="", 352 | unk="", 353 | hidden_size=512, 354 | filter_size=2048, 355 | num_heads=8, 356 | num_encoder_layers=6, 357 | num_decoder_layers=6, 358 | attention_dropout=0.0, 359 | residual_dropout=0.1, 360 | relu_dropout=0.0, 361 | label_smoothing=0.1, 362 | normalization="after", 363 | shared_embedding_and_softmax_weights=False, 364 | shared_source_target_embedding=False, 365 | # Override default parameters 366 | warmup_steps=4000, 367 | train_steps=100000, 368 | learning_rate=7e-4, 369 | learning_rate_schedule="linear_warmup_rsqrt_decay", 370 | batch_size=4096, 371 | fixed_batch_size=False, 372 | adam_beta1=0.9, 373 | adam_beta2=0.98, 374 | adam_epsilon=1e-9, 375 | clip_grad_norm=0.0 376 | ) 377 | 378 | return params 379 | 380 | @staticmethod 381 | def base_params_v2(): 382 | params = Transformer.base_params() 383 | params.attention_dropout = 0.1 384 | params.relu_dropout = 0.1 385 | params.learning_rate = 12e-4 386 | params.warmup_steps = 8000 387 | params.normalization = "before" 388 | params.adam_beta2 = 0.997 389 | 390 | return params 391 | 392 | @staticmethod 393 | def big_params(): 394 | params = Transformer.base_params() 395 | params.hidden_size = 1024 396 | params.filter_size = 4096 397 | params.num_heads = 16 398 | params.residual_dropout = 0.3 399 | params.learning_rate = 5e-4 400 | params.train_steps = 300000 401 | 402 | return params 403 | 404 | @staticmethod 405 | def big_params_v2(): 406 | params = Transformer.base_params_v2() 407 | params.hidden_size = 1024 408 | params.filter_size = 4096 409 | params.num_heads = 16 410 | params.residual_dropout = 0.3 411 | params.learning_rate = 7e-4 412 | params.train_steps = 300000 413 | 414 | return params 415 | 416 | @staticmethod 417 | def default_params(name=None): 418 | if name == "base": 419 | return Transformer.base_params() 420 | elif name == "base_v2": 421 | return Transformer.base_params_v2() 422 | elif name == "big": 423 | return Transformer.big_params() 424 | elif name == "big_v2": 425 | return Transformer.big_params_v2() 426 | else: 427 | return Transformer.base_params() 428 | -------------------------------------------------------------------------------- /thumt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from thumt.modules.affine import Affine 2 | from thumt.modules.attention import Attention 3 | from thumt.modules.attention import MultiHeadAttention 4 | from thumt.modules.attention import MultiHeadAdditiveAttention 5 | from thumt.modules.embedding import PositionalEmbedding 6 | from thumt.modules.feed_forward import FeedForward 7 | from thumt.modules.layer_norm import LayerNorm 8 | from thumt.modules.losses import SmoothedCrossEntropyLoss 9 | from thumt.modules.module import Module 10 | from thumt.modules.recurrent import LSTMCell, GRUCell 11 | -------------------------------------------------------------------------------- /thumt/modules/affine.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | 12 | import thumt.utils as utils 13 | from thumt.modules.module import Module 14 | 15 | 16 | class Affine(Module): 17 | 18 | def __init__(self, in_features, out_features, bias=True, name="affine"): 19 | super(Affine, self).__init__(name=name) 20 | self.in_features = in_features 21 | self.out_features = out_features 22 | 23 | with utils.scope(name): 24 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 25 | self.add_name(self.weight, "weight") 26 | if bias: 27 | self.bias = nn.Parameter(torch.Tensor(out_features)) 28 | self.add_name(self.bias, "bias") 29 | else: 30 | self.register_parameter('bias', None) 31 | 32 | self.reset_parameters() 33 | 34 | def reset_parameters(self): 35 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 36 | if self.bias is not None: 37 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 38 | bound = 1 / math.sqrt(fan_in) 39 | nn.init.uniform_(self.bias, -bound, bound) 40 | 41 | def forward(self, input): 42 | return nn.functional.linear(input, self.weight, self.bias) 43 | 44 | def extra_repr(self): 45 | return 'in_features={}, out_features={}, bias={}'.format( 46 | self.in_features, self.out_features, self.bias is not None 47 | ) 48 | -------------------------------------------------------------------------------- /thumt/modules/attention.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import torch 9 | import torch.nn as nn 10 | import thumt.utils as utils 11 | 12 | from thumt.modules.module import Module 13 | from thumt.modules.affine import Affine 14 | 15 | 16 | class Attention(Module): 17 | 18 | def __init__(self, q_size, k_size, hidden_size, name="attention"): 19 | super(Attention, self).__init__(name) 20 | 21 | self._q_size = q_size 22 | self._k_size = k_size 23 | self._hidden_size = hidden_size 24 | 25 | with utils.scope(name): 26 | self.q_transform = Affine(q_size, hidden_size, name="q_transform") 27 | self.k_transform = Affine(k_size, hidden_size, name="k_transform") 28 | self.v_transform = Affine(hidden_size, 1, 29 | name="v_transform") 30 | 31 | self.reset_parameters() 32 | 33 | def compute_cache(self, memory): 34 | return self.k_transform(memory) 35 | 36 | def forward(self, query, bias, memory, cache=None): 37 | q = self.q_transform(query) 38 | 39 | if cache is None: 40 | k = self.k_transform(memory) 41 | else: 42 | k = cache 43 | 44 | # q: [batch, 1, hidden_size] 45 | # k: [batch, length, hidden_size] 46 | logits = self.v_transform(torch.tanh(q + k)) 47 | # [batch, length, 1] 48 | logits = torch.transpose(logits, 1, 2) 49 | # [batch, 1, 1, length] 50 | logits = torch.unsqueeze(logits, 2) 51 | 52 | if bias is not None: 53 | logits = logits + bias 54 | 55 | weights = torch.softmax(logits, dim=-1) 56 | 57 | # [batch, 1, length] 58 | weights = torch.squeeze(weights, 2) 59 | output = torch.matmul(weights, memory) 60 | 61 | return output 62 | 63 | def reset_parameters(self, initializer="uniform_scaling", **kwargs): 64 | if initializer == "uniform_scaling": 65 | # 6 / (4 * hidden_size) -> 6 / (2 * hidden_size) 66 | nn.init.xavier_uniform_(self.q_transform.weight) 67 | nn.init.xavier_uniform_(self.k_transform.weight) 68 | nn.init.xavier_uniform_(self.v_transform.weight) 69 | nn.init.constant_(self.q_transform.bias, 0.0) 70 | nn.init.constant_(self.k_transform.bias, 0.0) 71 | nn.init.constant_(self.v_transform.bias, 0.0) 72 | elif initializer == "uniform": 73 | nn.init.uniform_(self.q_transform.weight, -0.04, 0.04) 74 | nn.init.uniform_(self.k_transform.weight, -0.04, 0.04) 75 | nn.init.uniform_(self.v_transform.weight, -0.04, 0.04) 76 | nn.init.uniform_(self.q_transform.bias, -0.04, 0.04) 77 | nn.init.uniform_(self.k_transform.bias, -0.04, 0.04) 78 | nn.init.uniform_(self.v_transform.bias, -0.04, 0.04) 79 | else: 80 | raise ValueError("Unknown initializer %d" % initializer) 81 | 82 | 83 | class MultiHeadAttentionBase(Module): 84 | 85 | def __init__(self, name="multihead_attention_base"): 86 | super(MultiHeadAttentionBase, self).__init__(name=name) 87 | 88 | @staticmethod 89 | def split_heads(x, heads): 90 | batch = x.shape[0] 91 | length = x.shape[1] 92 | channels = x.shape[2] 93 | 94 | y = torch.reshape(x, [batch, length, heads, channels // heads]) 95 | return torch.transpose(y, 2, 1) 96 | 97 | @staticmethod 98 | def combine_heads(x): 99 | batch = x.shape[0] 100 | heads = x.shape[1] 101 | length = x.shape[2] 102 | channels = x.shape[3] 103 | 104 | y = torch.transpose(x, 2, 1) 105 | 106 | return torch.reshape(y, [batch, length, heads * channels]) 107 | 108 | 109 | class MultiHeadAttention(MultiHeadAttentionBase): 110 | 111 | def __init__(self, hidden_size, num_heads, dropout=0.0, 112 | name="multihead_attention"): 113 | super(MultiHeadAttention, self).__init__(name=name) 114 | 115 | self.num_heads = num_heads 116 | self.hidden_size = hidden_size 117 | self.dropout = dropout 118 | 119 | with utils.scope(name): 120 | self.q_transform = Affine(hidden_size, hidden_size, 121 | name="q_transform") 122 | self.k_transform = Affine(hidden_size, hidden_size, 123 | name="k_transform") 124 | self.v_transform = Affine(hidden_size, hidden_size, 125 | name="v_transform") 126 | self.o_transform = Affine(hidden_size, hidden_size, 127 | name="o_transform") 128 | 129 | self.reset_parameters() 130 | 131 | def forward(self, query, bias, memory=None, kv=None): 132 | q = self.q_transform(query) 133 | 134 | if memory is not None: 135 | if kv is not None: 136 | k, v = kv 137 | else: 138 | k, v = None, None 139 | 140 | # encoder-decoder attention 141 | k = k or self.k_transform(memory) 142 | v = v or self.v_transform(memory) 143 | else: 144 | # self-attention 145 | k = self.k_transform(query) 146 | v = self.v_transform(query) 147 | 148 | if kv is not None: 149 | k = torch.cat([kv[0], k], dim=1) 150 | v = torch.cat([kv[1], v], dim=1) 151 | 152 | # split heads 153 | qh = self.split_heads(q, self.num_heads) 154 | kh = self.split_heads(k, self.num_heads) 155 | vh = self.split_heads(v, self.num_heads) 156 | 157 | # scale query 158 | qh = qh * (self.hidden_size // self.num_heads) ** -0.5 159 | 160 | # dot-product attention 161 | kh = torch.transpose(kh, -2, -1) 162 | logits = torch.matmul(qh, kh) 163 | 164 | if bias is not None: 165 | logits = logits + bias 166 | 167 | weights = torch.nn.functional.dropout(torch.softmax(logits, dim=-1), 168 | p=self.dropout, 169 | training=self.training) 170 | 171 | x = torch.matmul(weights, vh) 172 | 173 | # combine heads 174 | output = self.o_transform(self.combine_heads(x)) 175 | 176 | if kv is not None: 177 | return output, k, v 178 | 179 | return output 180 | 181 | def reset_parameters(self, initializer="uniform_scaling", **kwargs): 182 | if initializer == "uniform_scaling": 183 | # 6 / (4 * hidden_size) -> 6 / (2 * hidden_size) 184 | nn.init.xavier_uniform_(self.q_transform.weight, 2 ** -0.5) 185 | nn.init.xavier_uniform_(self.k_transform.weight, 2 ** -0.5) 186 | nn.init.xavier_uniform_(self.v_transform.weight, 2 ** -0.5) 187 | nn.init.xavier_uniform_(self.o_transform.weight) 188 | nn.init.constant_(self.q_transform.bias, 0.0) 189 | nn.init.constant_(self.k_transform.bias, 0.0) 190 | nn.init.constant_(self.v_transform.bias, 0.0) 191 | nn.init.constant_(self.o_transform.bias, 0.0) 192 | else: 193 | raise ValueError("Unknown initializer %d" % initializer) 194 | 195 | 196 | class MultiHeadAdditiveAttention(MultiHeadAttentionBase): 197 | 198 | def __init__(self, q_size, k_size, hidden_size, num_heads, dropout=0.0, 199 | name="multihead_attention"): 200 | super(MultiHeadAdditiveAttention, self).__init__(name=name) 201 | 202 | self.num_heads = num_heads 203 | self.hidden_size = hidden_size 204 | self.dropout = dropout 205 | 206 | with utils.scope(name): 207 | self.q_transform = Affine(q_size, hidden_size, 208 | name="q_transform") 209 | self.k_transform = Affine(k_size, hidden_size, 210 | name="k_transform") 211 | self.v_transform = Affine(hidden_size, num_heads, 212 | name="v_transform") 213 | self.o_transform = Affine(k_size, k_size, 214 | name="o_transform") 215 | 216 | self.reset_parameters() 217 | 218 | def compute_cache(self, memory): 219 | return self.k_transform(memory) 220 | 221 | def forward(self, query, bias, memory, cache=None): 222 | q = self.q_transform(query) 223 | 224 | if cache is None: 225 | k = self.k_transform(memory) 226 | else: 227 | k = cache 228 | 229 | # split heads 230 | qh = self.split_heads(q, self.num_heads) 231 | kh = self.split_heads(k, self.num_heads) 232 | # q: [batch, 1, hidden_size] 233 | # k: [batch, length, hidden_size] 234 | logits = self.v_transform(torch.tanh(q + k)) 235 | # [batch, length, num_heads] 236 | logits = torch.transpose(logits, 1, 2) 237 | # [batch, num_heads, 1, length] 238 | logits = torch.unsqueeze(logits, 2) 239 | 240 | if bias is not None: 241 | logits = logits + bias 242 | 243 | weights = torch.nn.functional.dropout(torch.softmax(logits, dim=-1), 244 | p=self.dropout, 245 | training=self.training) 246 | 247 | vh = self.split_heads(memory, self.num_heads) 248 | x = torch.matmul(weights, vh) 249 | 250 | # combine heads 251 | output = self.o_transform(self.combine_heads(x)) 252 | 253 | return output 254 | 255 | def reset_parameters(self, initializer="uniform_scaling", **kwargs): 256 | if initializer == "uniform_scaling": 257 | # 6 / (4 * hidden_size) -> 6 / (2 * hidden_size) 258 | nn.init.xavier_uniform_(self.q_transform.weight, 2 ** -0.5) 259 | nn.init.xavier_uniform_(self.k_transform.weight, 2 ** -0.5) 260 | nn.init.xavier_uniform_(self.v_transform.weight, 2 ** -0.5) 261 | nn.init.xavier_uniform_(self.o_transform.weight) 262 | nn.init.constant_(self.q_transform.bias, 0.0) 263 | nn.init.constant_(self.k_transform.bias, 0.0) 264 | nn.init.constant_(self.v_transform.bias, 0.0) 265 | nn.init.constant_(self.o_transform.bias, 0.0) 266 | elif initializer == "uniform": 267 | nn.init.uniform_(self.q_transform.weight, -0.04, 0.04) 268 | nn.init.uniform_(self.k_transform.weight, -0.04, 0.04) 269 | nn.init.uniform_(self.v_transform.weight, -0.04, 0.04) 270 | nn.init.uniform_(self.o_transform.weight, -0.04, 0.04) 271 | nn.init.uniform_(self.q_transform.bias, -0.04, 0.04) 272 | nn.init.uniform_(self.k_transform.bias, -0.04, 0.04) 273 | nn.init.uniform_(self.v_transform.bias, -0.04, 0.04) 274 | nn.init.uniform_(self.o_transform.bias, -0.04, 0.04) 275 | else: 276 | raise ValueError("Unknown initializer %d" % initializer) 277 | -------------------------------------------------------------------------------- /thumt/modules/embedding.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | import torch 10 | 11 | 12 | class PositionalEmbedding(torch.nn.Module): 13 | 14 | def __init__(self): 15 | super(PositionalEmbedding, self).__init__() 16 | 17 | def forward(self, inputs): 18 | if inputs.dim() != 3: 19 | raise ValueError("The rank of input must be 3.") 20 | 21 | length = inputs.shape[1] 22 | channels = inputs.shape[2] 23 | half_dim = channels // 2 24 | 25 | positions = torch.arange(length, dtype=inputs.dtype, 26 | device=inputs.device) 27 | dimensions = torch.arange(half_dim, dtype=inputs.dtype, 28 | device=inputs.device) 29 | 30 | scale = math.log(10000.0) / float(half_dim - 1) 31 | dimensions.mul_(-scale).exp_() 32 | 33 | scaled_time = positions.unsqueeze(1) * dimensions.unsqueeze(0) 34 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 35 | dim=1) 36 | 37 | if channels % 2 == 1: 38 | pad = torch.zeros([signal.shape[0], 1], dtype=inputs.dtype, 39 | device=inputs.device) 40 | signal = torch.cat([signal, pad], axis=1) 41 | 42 | return inputs + torch.reshape(signal, [1, -1, channels]).to(inputs) 43 | -------------------------------------------------------------------------------- /thumt/modules/feed_forward.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import torch 9 | import torch.nn as nn 10 | import thumt.utils as utils 11 | 12 | from thumt.modules.module import Module 13 | from thumt.modules.affine import Affine 14 | 15 | 16 | class FeedForward(Module): 17 | 18 | def __init__(self, input_size, hidden_size, output_size=None, dropout=0.0, 19 | name="feed_forward"): 20 | super(FeedForward, self).__init__(name=name) 21 | 22 | self.input_size = input_size 23 | self.hidden_size = hidden_size 24 | self.output_size = output_size or input_size 25 | self.dropout = dropout 26 | 27 | with utils.scope(name): 28 | self.input_transform = Affine(input_size, hidden_size, 29 | name="input_transform") 30 | self.output_transform = Affine(hidden_size, self.output_size, 31 | name="output_transform") 32 | 33 | self.reset_parameters() 34 | 35 | def forward(self, x): 36 | h = nn.functional.relu(self.input_transform(x)) 37 | h = nn.functional.dropout(h, self.dropout, self.training) 38 | return self.output_transform(h) 39 | 40 | def reset_parameters(self): 41 | nn.init.xavier_uniform_(self.input_transform.weight) 42 | nn.init.xavier_uniform_(self.output_transform.weight) 43 | nn.init.constant_(self.input_transform.bias, 0.0) 44 | nn.init.constant_(self.output_transform.bias, 0.0) 45 | -------------------------------------------------------------------------------- /thumt/modules/layer_norm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import numbers 9 | import torch 10 | import torch.nn as nn 11 | import thumt.utils as utils 12 | 13 | from thumt.modules.module import Module 14 | 15 | 16 | class LayerNorm(Module): 17 | 18 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, 19 | name="layer_norm"): 20 | super(LayerNorm, self).__init__(name=name) 21 | if isinstance(normalized_shape, numbers.Integral): 22 | normalized_shape = (normalized_shape,) 23 | self.normalized_shape = tuple(normalized_shape) 24 | self.eps = eps 25 | self.elementwise_affine = elementwise_affine 26 | 27 | with utils.scope(name): 28 | if self.elementwise_affine: 29 | self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) 30 | self.bias = nn.Parameter(torch.Tensor(*normalized_shape)) 31 | self.add_name(self.weight, "weight") 32 | self.add_name(self.bias, "bias") 33 | else: 34 | self.register_parameter('weight', None) 35 | self.register_parameter('bias', None) 36 | self.reset_parameters() 37 | 38 | def reset_parameters(self): 39 | if self.elementwise_affine: 40 | nn.init.ones_(self.weight) 41 | nn.init.zeros_(self.bias) 42 | 43 | def forward(self, input): 44 | return nn.functional.layer_norm( 45 | input, self.normalized_shape, self.weight, self.bias, self.eps) 46 | 47 | def extra_repr(self): 48 | return '{normalized_shape}, eps={eps}, ' \ 49 | 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) 50 | -------------------------------------------------------------------------------- /thumt/modules/losses.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | import torch 10 | 11 | 12 | class SmoothedCrossEntropyLoss(torch.nn.Module): 13 | 14 | def __init__(self, smoothing=0.0, normalize=True): 15 | super(SmoothedCrossEntropyLoss, self).__init__() 16 | self.smoothing = smoothing 17 | self.normalize = normalize 18 | 19 | def forward(self, logits, labels): 20 | shape = labels.shape 21 | logits = torch.reshape(logits, [-1, logits.shape[-1]]) 22 | labels = torch.reshape(labels, [-1]) 23 | 24 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 25 | batch_idx = torch.arange(labels.shape[0], device=logits.device) 26 | loss = log_probs[batch_idx, labels] 27 | 28 | if not self.smoothing or not self.training: 29 | return -torch.reshape(loss, shape) 30 | 31 | n = logits.shape[-1] - 1.0 32 | p = 1.0 - self.smoothing 33 | q = self.smoothing / n 34 | 35 | if log_probs.dtype != torch.float16: 36 | sum_probs = torch.sum(log_probs, dim=-1) 37 | loss = p * loss + q * (sum_probs - loss) 38 | else: 39 | # Prevent FP16 overflow 40 | sum_probs = torch.sum(log_probs.to(torch.float32), dim=-1) 41 | loss = loss.to(torch.float32) 42 | loss = p * loss + q * (sum_probs - loss) 43 | loss = loss.to(torch.float16) 44 | 45 | loss = -torch.reshape(loss, shape) 46 | 47 | if self.normalize: 48 | normalizing = -(p * math.log(p) + n * q * math.log(q + 1e-20)) 49 | return loss - normalizing 50 | else: 51 | return loss 52 | -------------------------------------------------------------------------------- /thumt/modules/module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | import thumt.utils as utils 12 | 13 | 14 | class Module(nn.Module): 15 | 16 | def __init__(self, name=""): 17 | super(Module, self).__init__() 18 | scope = utils.get_scope() 19 | self._name = scope + "/" + name if scope else name 20 | 21 | def add_name(self, tensor, name): 22 | tensor.tensor_name = utils.unique_name(name) 23 | 24 | @property 25 | def name(self): 26 | return self._name 27 | -------------------------------------------------------------------------------- /thumt/modules/recurrent.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | import thumt.utils as utils 12 | 13 | from thumt.modules.module import Module 14 | from thumt.modules.affine import Affine 15 | from thumt.modules.layer_norm import LayerNorm 16 | 17 | 18 | class GRUCell(Module): 19 | 20 | def __init__(self, input_size, output_size, normalization=False, 21 | name="gru"): 22 | super(GRUCell, self).__init__(name=name) 23 | 24 | self.input_size = input_size 25 | self.output_size = output_size 26 | 27 | with utils.scope(name): 28 | self.reset_gate = Affine(input_size + output_size, output_size, 29 | bias=False, name="reset_gate") 30 | self.update_gate = Affine(input_size + output_size, output_size, 31 | bias=False, name="update_gate") 32 | self.transform = Affine(input_size + output_size, output_size, 33 | name="transform") 34 | 35 | def forward(self, x, h): 36 | r = torch.sigmoid(self.reset_gate(torch.cat([x, h], -1))) 37 | u = torch.sigmoid(self.update_gate(torch.cat([x, h], -1))) 38 | c = self.transform(torch.cat([x, r * h], -1)) 39 | 40 | new_h = (1.0 - u) * h + u * torch.tanh(h) 41 | 42 | return new_h, new_h 43 | 44 | def init_state(self, batch_size, dtype, device): 45 | h = torch.zeros([batch_size, self.output_size], dtype=dtype, 46 | device=device) 47 | return h 48 | 49 | def mask_state(self, h, prev_h, mask): 50 | mask = mask[:, None] 51 | new_h = mask * h + (1.0 - mask) * prev_h 52 | return new_h 53 | 54 | def reset_parameters(self, initializer="uniform"): 55 | if initializer == "uniform_scaling": 56 | nn.init.xavier_uniform_(self.gates.weight) 57 | nn.init.constant_(self.gates.bias, 0.0) 58 | elif initializer == "uniform": 59 | nn.init.uniform_(self.gates.weight, -0.08, 0.08) 60 | nn.init.uniform_(self.gates.bias, -0.08, 0.08) 61 | else: 62 | raise ValueError("Unknown initializer %d" % initializer) 63 | 64 | 65 | class LSTMCell(Module): 66 | 67 | def __init__(self, input_size, output_size, normalization=False, 68 | activation=None, name="lstm"): 69 | super(LSTMCell, self).__init__(name=name) 70 | 71 | self.input_size = input_size 72 | self.output_size = output_size 73 | self.activation = activation 74 | 75 | with utils.scope(name): 76 | self.gates = Affine(input_size + output_size, 4 * output_size, 77 | name="gates") 78 | if normalization: 79 | self.layer_norm = LayerNorm([4, output_size]) 80 | else: 81 | self.layer_norm = None 82 | 83 | self.reset_parameters() 84 | 85 | def forward(self, x, state): 86 | c, h = state 87 | 88 | gates = self.gates(torch.cat([x, h], 1)) 89 | 90 | if self.layer_norm is not None: 91 | combined = self.layer_norm( 92 | torch.reshape(gates, [-1, 4, self.output_size])) 93 | else: 94 | combined = torch.reshape(gates, [-1, 4, self.output_size]) 95 | 96 | i, j, f, o = torch.unbind(combined, 1) 97 | i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o) 98 | 99 | new_c = f * c + i * torch.tanh(j) 100 | 101 | if self.activation is None: 102 | # Do not use tanh activation 103 | new_h = o * new_c 104 | else: 105 | new_h = o * self.activation(new_c) 106 | 107 | return new_h, (new_c, new_h) 108 | 109 | def init_state(self, batch_size, dtype, device): 110 | c = torch.zeros([batch_size, self.output_size], dtype=dtype, 111 | device=device) 112 | h = torch.zeros([batch_size, self.output_size], dtype=dtype, 113 | device=device) 114 | return c, h 115 | 116 | def mask_state(self, state, prev_state, mask): 117 | c, h = state 118 | prev_c, prev_h = prev_state 119 | mask = mask[:, None] 120 | new_c = mask * c + (1.0 - mask) * prev_c 121 | new_h = mask * h + (1.0 - mask) * prev_h 122 | return new_c, new_h 123 | 124 | def reset_parameters(self, initializer="uniform"): 125 | if initializer == "uniform_scaling": 126 | nn.init.xavier_uniform_(self.gates.weight) 127 | nn.init.constant_(self.gates.bias, 0.0) 128 | elif initializer == "uniform": 129 | nn.init.uniform_(self.gates.weight, -0.04, 0.04) 130 | nn.init.uniform_(self.gates.bias, -0.04, 0.04) 131 | else: 132 | raise ValueError("Unknown initializer %d" % initializer) 133 | -------------------------------------------------------------------------------- /thumt/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from thumt.optimizers.optimizers import AdamOptimizer 2 | from thumt.optimizers.optimizers import AdadeltaOptimizer 3 | from thumt.optimizers.optimizers import SGDOptimizer 4 | from thumt.optimizers.optimizers import MultiStepOptimizer 5 | from thumt.optimizers.optimizers import LossScalingOptimizer 6 | from thumt.optimizers.schedules import LinearWarmupRsqrtDecay 7 | from thumt.optimizers.schedules import PiecewiseConstantDecay 8 | from thumt.optimizers.schedules import LinearExponentialDecay 9 | from thumt.optimizers.clipping import ( 10 | adaptive_clipper, global_norm_clipper, value_clipper) 11 | -------------------------------------------------------------------------------- /thumt/optimizers/clipping.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | 10 | 11 | def global_norm_clipper(value): 12 | def clip_fn(gradients, grad_norm): 13 | if not float(value) or grad_norm < value: 14 | return False, gradients 15 | 16 | scale = value / grad_norm 17 | 18 | gradients = [grad.data.mul_(scale) 19 | if grad is not None else None for grad in gradients] 20 | 21 | return False, gradients 22 | 23 | return clip_fn 24 | 25 | 26 | def value_clipper(clip_min, clip_max): 27 | def clip_fn(gradients, grad_norm): 28 | gradients = [ 29 | grad.data.clamp_(clip_min, clip_max) 30 | if grad is not None else None for grad in gradients] 31 | 32 | return False, None 33 | 34 | return clip_fn 35 | 36 | 37 | def adaptive_clipper(rho): 38 | norm_avg = 0.0 39 | norm_stddev = 0.0 40 | log_norm_avg = 0.0 41 | log_norm_sqr = 0.0 42 | 43 | def clip_fn(gradients, grad_norm): 44 | nonlocal norm_avg 45 | nonlocal norm_stddev 46 | nonlocal log_norm_avg 47 | nonlocal log_norm_sqr 48 | 49 | norm = grad_norm 50 | log_norm = math.log(norm) 51 | 52 | avg = rho * norm_avg + (1.0 - rho) * norm 53 | log_avg = rho * log_norm_avg + (1.0 - rho) * log_norm 54 | log_sqr = rho * log_norm_sqr + (1.0 - rho) * (log_norm ** 2) 55 | stddev = (log_sqr - (log_avg ** 2)) ** -0.5 56 | 57 | norm_avg = avg 58 | log_norm_avg = log_avg 59 | log_norm_sqr = log_sqr 60 | norm_stddev = rho * stddev + (1.0 - rho) * stddev 61 | 62 | reject = False 63 | 64 | if norm > norm_avg + 4 * math.exp(norm_stddev): 65 | reject = True 66 | 67 | return reject, gradients 68 | 69 | return clip_fn 70 | -------------------------------------------------------------------------------- /thumt/optimizers/optimizers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import re 9 | import math 10 | import torch 11 | import torch.distributed as dist 12 | import thumt.utils as utils 13 | import thumt.utils.summary as summary 14 | 15 | from thumt.optimizers.schedules import LearningRateSchedule 16 | 17 | 18 | def _save_summary(grads_and_vars): 19 | total_norm = 0.0 20 | 21 | for grad, var in grads_and_vars: 22 | if grad is None: 23 | continue 24 | 25 | _, var = var 26 | grad_norm = grad.data.norm() 27 | total_norm += grad_norm ** 2 28 | summary.histogram(var.tensor_name, var, 29 | utils.get_global_step()) 30 | summary.scalar("norm/" + var.tensor_name, var.norm(), 31 | utils.get_global_step()) 32 | summary.scalar("grad_norm/" + var.tensor_name, grad_norm, 33 | utils.get_global_step()) 34 | 35 | total_norm = total_norm ** 0.5 36 | summary.scalar("grad_norm", total_norm, utils.get_global_step()) 37 | 38 | return float(total_norm) 39 | 40 | 41 | def _compute_grad_norm(gradients): 42 | total_norm = 0.0 43 | 44 | for grad in gradients: 45 | if grad is None: 46 | continue 47 | 48 | total_norm += float(grad.data.norm() ** 2) 49 | 50 | return float(total_norm ** 0.5) 51 | 52 | 53 | class Optimizer(object): 54 | 55 | def __init__(self, name, **kwargs): 56 | self._name = name 57 | self._iterations = 0 58 | self._slots = {} 59 | 60 | def detach_gradients(self, gradients): 61 | for grad in gradients: 62 | if grad is not None: 63 | grad.detach_() 64 | 65 | def scale_gradients(self, gradients, scale): 66 | for grad in gradients: 67 | if grad is not None: 68 | grad.mul_(scale) 69 | 70 | def sync_gradients(self, gradients, compress=True): 71 | grad_vec = utils.params_to_vec(gradients) 72 | 73 | if compress: 74 | grad_vec_half = grad_vec.half() 75 | dist.all_reduce(grad_vec_half) 76 | grad_vec = grad_vec_half.to(grad_vec) 77 | else: 78 | dist.all_reduce(grad_vec) 79 | 80 | utils.vec_to_params(grad_vec, gradients) 81 | 82 | def zero_gradients(self, gradients): 83 | for grad in gradients: 84 | if grad is not None: 85 | grad.zero_() 86 | 87 | def compute_gradients(self, loss, var_list, aggregate=False): 88 | var_list = list(var_list) 89 | grads = [v.grad if v is not None else None for v in var_list] 90 | 91 | self.detach_gradients(grads) 92 | 93 | if not aggregate: 94 | self.zero_gradients(grads) 95 | 96 | loss.backward() 97 | return [v.grad if v is not None else None for v in var_list] 98 | 99 | def apply_gradients(self, grads_and_vars): 100 | raise NotImplementedError("Not implemented") 101 | 102 | @property 103 | def iterations(self): 104 | return self._iterations 105 | 106 | def state_dict(self): 107 | raise NotImplementedError("Not implemented") 108 | 109 | def load_state_dict(self): 110 | raise NotImplementedError("Not implemented") 111 | 112 | 113 | class SGDOptimizer(Optimizer): 114 | 115 | def __init__(self, learning_rate, summaries=True, name="SGD", **kwargs): 116 | super(SGDOptimizer, self).__init__(name, **kwargs) 117 | self._learning_rate = learning_rate 118 | self._summaries = summaries 119 | self._clipper = None 120 | 121 | if "clipper" in kwargs and kwargs["clipper"] is not None: 122 | self._clipper = kwargs["clipper"] 123 | 124 | def apply_gradients(self, grads_and_vars): 125 | self._iterations += 1 126 | lr = self._learning_rate 127 | grads, var_list = list(zip(*grads_and_vars)) 128 | 129 | if self._summaries: 130 | grad_norm = _save_summary(zip(grads, var_list)) 131 | else: 132 | grad_norm = _compute_grad_norm(grads) 133 | 134 | if self._clipper is not None: 135 | reject, grads = self._clipper(grads, grad_norm) 136 | 137 | if reject: 138 | return 139 | 140 | for grad, var in zip(grads, var_list): 141 | if grad is None: 142 | continue 143 | 144 | # Convert if grad is not FP32 145 | grad = grad.data.float() 146 | _, var = var 147 | 148 | if isinstance(lr, LearningRateSchedule): 149 | lr = lr(self._iterations) 150 | 151 | step_size = lr 152 | 153 | if var.dtype == torch.float32: 154 | var.data.add_(grad, alpha=-step_size) 155 | else: 156 | fp32_var = var.data.float() 157 | fp32_var.add_(grad, alpha=-step_size) 158 | var.data.copy_(fp32_var) 159 | 160 | def state_dict(self): 161 | state = { 162 | "iterations": self._iterations, 163 | } 164 | 165 | if not isinstance(self._learning_rate, LearningRateSchedule): 166 | state["learning_rate"] = self._learning_rate 167 | 168 | return state 169 | 170 | def load_state_dict(self, state): 171 | self._iterations = state.get("iterations", self._iterations) 172 | 173 | 174 | class AdamOptimizer(Optimizer): 175 | 176 | def __init__(self, learning_rate=0.01, beta_1=0.9, beta_2=0.999, 177 | epsilon=1e-7, name="Adam", **kwargs): 178 | super(AdamOptimizer, self).__init__(name, **kwargs) 179 | self._learning_rate = learning_rate 180 | self._beta_1 = beta_1 181 | self._beta_2 = beta_2 182 | self._epsilon = epsilon 183 | self._summaries = True 184 | self._clipper = None 185 | 186 | if "summaries" in kwargs and not kwargs["summaries"]: 187 | self._summaries = False 188 | 189 | if "clipper" in kwargs and kwargs["clipper"] is not None: 190 | self._clipper = kwargs["clipper"] 191 | 192 | def apply_gradients(self, grads_and_vars): 193 | self._iterations += 1 194 | lr = self._learning_rate 195 | beta_1 = self._beta_1 196 | beta_2 = self._beta_2 197 | epsilon = self._epsilon 198 | grads, var_list = list(zip(*grads_and_vars)) 199 | 200 | if self._summaries: 201 | grad_norm = _save_summary(zip(grads, var_list)) 202 | else: 203 | grad_norm = _compute_grad_norm(grads) 204 | 205 | if self._clipper is not None: 206 | reject, grads = self._clipper(grads, grad_norm) 207 | 208 | if reject: 209 | return 210 | 211 | for grad, var in zip(grads, var_list): 212 | if grad is None: 213 | continue 214 | 215 | # Convert if grad is not FP32 216 | grad = grad.data.float() 217 | name, var = var 218 | 219 | if self._slots.get(name, None) is None: 220 | self._slots[name] = {} 221 | self._slots[name]["m"] = torch.zeros_like(var.data, 222 | dtype=torch.float32) 223 | self._slots[name]["v"] = torch.zeros_like(var.data, 224 | dtype=torch.float32) 225 | 226 | m, v = self._slots[name]["m"], self._slots[name]["v"] 227 | 228 | bias_corr_1 = 1 - beta_1 ** self._iterations 229 | bias_corr_2 = 1 - beta_2 ** self._iterations 230 | 231 | m.mul_(beta_1).add_(grad, alpha=1 - beta_1) 232 | v.mul_(beta_2).addcmul_(grad, grad, value=1 - beta_2) 233 | denom = (v.sqrt() / math.sqrt(bias_corr_2)).add_(epsilon) 234 | 235 | if isinstance(lr, LearningRateSchedule): 236 | lr = lr(self._iterations) 237 | 238 | step_size = lr / bias_corr_1 239 | 240 | if var.dtype == torch.float32: 241 | var.data.addcdiv_(m, denom, value=-step_size) 242 | else: 243 | fp32_var = var.data.float() 244 | fp32_var.addcdiv_(m, denom, value=-step_size) 245 | var.data.copy_(fp32_var) 246 | 247 | def state_dict(self): 248 | state = { 249 | "beta_1": self._beta_1, 250 | "beta_2": self._beta_2, 251 | "epsilon": self._epsilon, 252 | "iterations": self._iterations, 253 | "slot": self._slots 254 | } 255 | 256 | if not isinstance(self._learning_rate, LearningRateSchedule): 257 | state["learning_rate"] = self._learning_rate 258 | 259 | return state 260 | 261 | def load_state_dict(self, state): 262 | self._iterations = state.get("iterations", self._iterations) 263 | 264 | slots = state.get("slot", {}) 265 | self._slots = {} 266 | 267 | for key in slots: 268 | m, v = slots[key]["m"], slots[key]["v"] 269 | self._slots[key] = {} 270 | self._slots[key]["m"] = torch.zeros(m.shape, dtype=torch.float32) 271 | self._slots[key]["v"] = torch.zeros(v.shape, dtype=torch.float32) 272 | self._slots[key]["m"].copy_(m) 273 | self._slots[key]["v"].copy_(v) 274 | 275 | 276 | class AdadeltaOptimizer(Optimizer): 277 | 278 | def __init__(self, learning_rate=0.001, rho=0.95, epsilon=1e-07, 279 | name="Adadelta", **kwargs): 280 | super(AdadeltaOptimizer, self).__init__(name, **kwargs) 281 | self._learning_rate = learning_rate 282 | self._rho = rho 283 | self._epsilon = epsilon 284 | self._summaries = True 285 | self._clipper = None 286 | 287 | if "summaries" in kwargs and not kwargs["summaries"]: 288 | self._summaries = False 289 | 290 | if "clipper" in kwargs and kwargs["clipper"] is not None: 291 | self._clipper = kwargs["clipper"] 292 | 293 | def apply_gradients(self, grads_and_vars): 294 | self._iterations += 1 295 | lr = self._learning_rate 296 | rho = self._rho 297 | epsilon = self._epsilon 298 | 299 | grads, var_list = list(zip(*grads_and_vars)) 300 | 301 | if self._summaries: 302 | grad_norm = _save_summary(zip(grads, var_list)) 303 | else: 304 | grad_norm = _compute_grad_norm(grads) 305 | 306 | if self._clipper is not None: 307 | reject, grads = self._clipper(grads, grad_norm) 308 | 309 | if reject: 310 | return 311 | 312 | for grad, var in zip(grads, var_list): 313 | if grad is None: 314 | continue 315 | 316 | # Convert if grad is not FP32 317 | grad = grad.data.float() 318 | name, var = var 319 | 320 | if self._slots.get(name, None) is None: 321 | self._slots[name] = {} 322 | self._slots[name]["m"] = torch.zeros_like(var.data, 323 | dtype=torch.float32) 324 | self._slots[name]["v"] = torch.zeros_like(var.data, 325 | dtype=torch.float32) 326 | 327 | square_avg = self._slots[name]["m"] 328 | acc_delta = self._slots[name]["v"] 329 | 330 | if isinstance(lr, LearningRateSchedule): 331 | lr = lr(self._iterations) 332 | 333 | square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho) 334 | std = square_avg.add(epsilon).sqrt_() 335 | delta = acc_delta.add(epsilon).sqrt_().div_(std).mul_(grad) 336 | acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho) 337 | 338 | if var.dtype == torch.float32: 339 | var.data.add_(delta, alpha=-lr) 340 | else: 341 | fp32_var = var.data.float() 342 | fp32_var.add_(delta, alpha=-lr) 343 | var.data.copy_(fp32_var) 344 | 345 | def state_dict(self): 346 | state = { 347 | "rho": self._rho, 348 | "epsilon": self._epsilon, 349 | "iterations": self._iterations, 350 | "slot": self._slots 351 | } 352 | 353 | if not isinstance(self._learning_rate, LearningRateSchedule): 354 | state["learning_rate"] = self._learning_rate 355 | 356 | return state 357 | 358 | def load_state_dict(self, state): 359 | self._iterations = state.get("iterations", self._iterations) 360 | 361 | slots = state.get("slot", {}) 362 | self._slots = {} 363 | 364 | for key in slots: 365 | m, v = slots[key]["m"], slots[key]["v"] 366 | self._slots[key] = {} 367 | self._slots[key]["m"] = torch.zeros(m.shape, dtype=torch.float32) 368 | self._slots[key]["v"] = torch.zeros(v.shape, dtype=torch.float32) 369 | self._slots[key]["m"].copy_(m) 370 | self._slots[key]["v"].copy_(v) 371 | 372 | 373 | class LossScalingOptimizer(Optimizer): 374 | 375 | def __init__(self, optimizer, scale=2.0**7, increment_period=2000, 376 | multiplier=2.0, name="LossScalingOptimizer", **kwargs): 377 | super(LossScalingOptimizer, self).__init__(name, **kwargs) 378 | self._optimizer = optimizer 379 | self._scale = scale 380 | self._increment_period = increment_period 381 | self._multiplier = multiplier 382 | self._num_good_steps = 0 383 | self._summaries = True 384 | 385 | if "summaries" in kwargs and not kwargs["summaries"]: 386 | self._summaries = False 387 | 388 | def _update_if_finite_grads(self): 389 | if self._num_good_steps + 1 > self._increment_period: 390 | self._scale *= self._multiplier 391 | self._scale = min(self._scale, 2.0**16) 392 | self._num_good_steps = 0 393 | else: 394 | self._num_good_steps += 1 395 | 396 | def _update_if_not_finite_grads(self): 397 | self._scale = max(self._scale / self._multiplier, 1) 398 | 399 | def compute_gradients(self, loss, var_list, aggregate=False): 400 | var_list = list(var_list) 401 | grads = [v.grad if v is not None else None for v in var_list] 402 | 403 | self.detach_gradients(grads) 404 | 405 | if not aggregate: 406 | self.zero_gradients(grads) 407 | 408 | loss = loss * self._scale 409 | loss.backward() 410 | 411 | return [v.grad if v is not None else None for v in var_list] 412 | 413 | def apply_gradients(self, grads_and_vars): 414 | self._iterations += 1 415 | grads, var_list = list(zip(*grads_and_vars)) 416 | new_grads = [] 417 | 418 | if self._summaries: 419 | summary.scalar("optimizer/scale", self._scale, 420 | utils.get_global_step()) 421 | 422 | for grad in grads: 423 | if grad is None: 424 | new_grads.append(None) 425 | continue 426 | 427 | norm = grad.data.norm() 428 | 429 | if not torch.isfinite(norm): 430 | self._update_if_not_finite_grads() 431 | return 432 | else: 433 | # Rescale gradients 434 | new_grads.append(grad.data.float().mul_(1.0 / self._scale)) 435 | 436 | self._update_if_finite_grads() 437 | self._optimizer.apply_gradients(zip(new_grads, var_list)) 438 | 439 | def state_dict(self): 440 | state = { 441 | "scale": self._scale, 442 | "increment_period": self._increment_period, 443 | "multiplier": self._multiplier, 444 | "num_good_steps": self._num_good_steps, 445 | "optimizer": self._optimizer.state_dict() 446 | } 447 | return state 448 | 449 | def load_state_dict(self, state): 450 | self._num_good_steps = state.get("num_good_steps", 451 | self._num_good_steps) 452 | self._optimizer.load_state_dict(state.get("optimizer", {})) 453 | 454 | 455 | class MultiStepOptimizer(Optimizer): 456 | 457 | def __init__(self, optimizer, n=1, compress=True, 458 | name="MultiStepOptimizer", **kwargs): 459 | super(MultiStepOptimizer, self).__init__(name, **kwargs) 460 | self._n = n 461 | self._optimizer = optimizer 462 | self._compress = compress 463 | 464 | def compute_gradients(self, loss, var_list, aggregate=False): 465 | if self._iterations % self._n == 0: 466 | return self._optimizer.compute_gradients(loss, var_list, aggregate) 467 | else: 468 | return self._optimizer.compute_gradients(loss, var_list, True) 469 | 470 | def apply_gradients(self, grads_and_vars): 471 | size = dist.get_world_size() 472 | grads, var_list = list(zip(*grads_and_vars)) 473 | self._iterations += 1 474 | 475 | if self._n == 1: 476 | if size > 1: 477 | self.sync_gradients(grads, compress=self._compress) 478 | self.scale_gradients(grads, 1.0 / size) 479 | 480 | self._optimizer.apply_gradients(zip(grads, var_list)) 481 | else: 482 | if self._iterations % self._n != 0: 483 | return 484 | 485 | if size > 1: 486 | self.sync_gradients(grads, compress=self._compress) 487 | 488 | self.scale_gradients(grads, 1.0 / (self._n * size)) 489 | self._optimizer.apply_gradients(zip(grads, var_list)) 490 | 491 | def state_dict(self): 492 | state = { 493 | "n": self._n, 494 | "iterations": self._iterations, 495 | "compress": self._compress, 496 | "optimizer": self._optimizer.state_dict() 497 | } 498 | return state 499 | 500 | def load_state_dict(self, state): 501 | self._iterations = state.get("iterations", self._iterations) 502 | self._optimizer.load_state_dict(state.get("optimizer", {})) 503 | -------------------------------------------------------------------------------- /thumt/optimizers/schedules.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import thumt.utils as utils 9 | import thumt.utils.summary as summary 10 | 11 | 12 | class LearningRateSchedule(object): 13 | 14 | def __call__(self, step): 15 | raise NotImplementedError("Not implemented.") 16 | 17 | def get_config(self): 18 | raise NotImplementedError("Not implemented.") 19 | 20 | @classmethod 21 | def from_config(cls, config): 22 | return cls(**config) 23 | 24 | 25 | 26 | class LinearWarmupRsqrtDecay(LearningRateSchedule): 27 | 28 | def __init__(self, learning_rate, warmup_steps, initial_learning_rate=0.0, 29 | summary=True): 30 | super(LinearWarmupRsqrtDecay, self).__init__() 31 | 32 | if initial_learning_rate <= 0: 33 | if warmup_steps > 0: 34 | initial_learning_rate = learning_rate / warmup_steps 35 | else: 36 | initial_learning_rate = 0.0 37 | elif initial_learning_rate >= learning_rate: 38 | raise ValueError("The maximum learning rate: %f must be " 39 | "higher than the initial learning rate:" 40 | " %f" % (learning_rate, initial_learning_rate)) 41 | 42 | self._initial_learning_rate = initial_learning_rate 43 | self._maximum_learning_rate = learning_rate 44 | self._warmup_steps = warmup_steps 45 | self._summary = summary 46 | 47 | def __call__(self, step): 48 | if step <= self._warmup_steps: 49 | lr_step = self._maximum_learning_rate - self._initial_learning_rate 50 | lr_step /= self._warmup_steps 51 | lr = self._initial_learning_rate + lr_step * step 52 | else: 53 | lr = self._maximum_learning_rate 54 | 55 | if self._warmup_steps != 0: 56 | # approximately hidden_size ** -0.5 57 | lr = lr * self._warmup_steps ** 0.5 58 | 59 | lr = lr * (step ** -0.5) 60 | 61 | if self._summary: 62 | summary.scalar("learning_rate", lr, utils.get_global_step()) 63 | 64 | return lr 65 | 66 | def get_config(self): 67 | return { 68 | "learning_rate": self._maximum_learning_rate, 69 | "initial_learning_rate": self._initial_learning_rate, 70 | "warmup_steps": self._warmup_steps 71 | } 72 | 73 | 74 | class PiecewiseConstantDecay(LearningRateSchedule): 75 | 76 | def __init__(self, boundaries, values, summary=True): 77 | super(PiecewiseConstantDecay, self).__init__() 78 | 79 | if len(boundaries) != len(values) - 1: 80 | raise ValueError("The length of boundaries should be 1" 81 | " less than the length of values") 82 | 83 | self._boundaries = boundaries 84 | self._values = values 85 | self._summary = summary 86 | 87 | def __call__(self, step): 88 | boundaries = self._boundaries 89 | values = self._values 90 | learning_rate = values[0] 91 | 92 | if step <= boundaries[0]: 93 | learning_rate = values[0] 94 | elif step > boundaries[-1]: 95 | learning_rate = values[-1] 96 | else: 97 | for low, high, v in zip(boundaries[:-1], boundaries[1:], 98 | values[1:-1]): 99 | 100 | if step > low and step <= high: 101 | learning_rate = v 102 | break 103 | 104 | if self._summary: 105 | summary.scalar("learning_rate", learning_rate, 106 | utils.get_global_step()) 107 | 108 | return learning_rate 109 | 110 | def get_config(self): 111 | return { 112 | "boundaries": self._boundaries, 113 | "values": self._values, 114 | } 115 | 116 | 117 | class LinearExponentialDecay(LearningRateSchedule): 118 | 119 | def __init__(self, learning_rate, warmup_steps, start_decay_step, 120 | end_decay_step, n, summary=True): 121 | super(LinearExponentialDecay, self).__init__() 122 | 123 | self._learning_rate = learning_rate 124 | self._warmup_steps = warmup_steps 125 | self._start_decay_step = start_decay_step 126 | self._end_decay_step = end_decay_step 127 | self._n = n 128 | self._summary = summary 129 | 130 | def __call__(self, step): 131 | # See reference: The Best of Both Worlds: Combining Recent Advances 132 | # in Neural Machine Translation 133 | n = self._n 134 | p = self._warmup_steps / n 135 | s = n * self._start_decay_step 136 | e = n * self._end_decay_step 137 | 138 | learning_rate = self._learning_rate 139 | 140 | learning_rate *= min( 141 | 1.0 + (n - 1) * step / float(n * p), 142 | n, 143 | n * ((2 * n) ** (float(s - n * step) / float(e - s)))) 144 | 145 | if self._summary: 146 | summary.scalar("learning_rate", learning_rate, 147 | utils.get_global_step()) 148 | 149 | return learning_rate 150 | 151 | def get_config(self): 152 | return { 153 | "learning_rate": self._learning_rate, 154 | "warmup_steps": self._warmup_steps, 155 | "start_decay_step": self._start_decay_step, 156 | "end_decay_step": self._end_decay_step, 157 | } 158 | -------------------------------------------------------------------------------- /thumt/scripts/average_checkpoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2017-2020 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import os 10 | import glob 11 | import argparse 12 | import collections 13 | import torch 14 | import shutil 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description="Create vocabulary") 19 | 20 | parser.add_argument("--path", help="checkpoint directory") 21 | parser.add_argument("--output", default="average", 22 | help="Output path") 23 | parser.add_argument("--checkpoints", default=5, type=int, 24 | help="Number of checkpoints to average") 25 | 26 | return parser.parse_args() 27 | 28 | 29 | def list_checkpoints(path): 30 | names = glob.glob(os.path.join(path, "*.pt")) 31 | 32 | if not names: 33 | return None 34 | 35 | vals = [] 36 | 37 | for name in names: 38 | counter = int(name.rstrip(".pt").split("-")[-1]) 39 | vals.append([counter, name]) 40 | 41 | return [item[1] for item in sorted(vals)] 42 | 43 | 44 | def main(args): 45 | checkpoints = list_checkpoints(args.path) 46 | 47 | if not checkpoints: 48 | raise ValueError("No checkpoint to average") 49 | 50 | checkpoints = checkpoints[-args.checkpoints:] 51 | values = collections.OrderedDict() 52 | 53 | for checkpoint in checkpoints: 54 | print("Loading checkpoint: %s" % checkpoint) 55 | state = torch.load(checkpoint, map_location="cpu")["model"] 56 | 57 | for key in state: 58 | if key not in values: 59 | values[key] = state[key].float().clone() 60 | else: 61 | values[key].add_(state[key].float()) 62 | 63 | for key in values: 64 | values[key].div_(len(checkpoints)) 65 | 66 | state = {"step": 0, "epoch": 0, "model": values} 67 | 68 | if not os.path.exists(args.output): 69 | os.makedirs(args.output) 70 | 71 | torch.save(state, os.path.join(args.output, "average-0.pt")) 72 | params_pattern = os.path.join(args.path, "*.json") 73 | params_files = glob.glob(params_pattern) 74 | 75 | for name in params_files: 76 | new_name = name.replace(args.path.rstrip("/"), args.output.rstrip("/")) 77 | shutil.copy(name, new_name) 78 | 79 | 80 | if __name__ == "__main__": 81 | main(parse_args()) 82 | -------------------------------------------------------------------------------- /thumt/scripts/build_vocab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2017-2020 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import collections 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description="Create vocabulary") 15 | 16 | parser.add_argument("corpus", help="input corpus") 17 | parser.add_argument("output", default="vocab.txt", 18 | help="Output vocabulary name") 19 | parser.add_argument("--limit", default=0, type=int, help="Vocabulary size") 20 | parser.add_argument("--control", type=str, default=",,", 21 | help="Add control symbols to vocabulary. " 22 | "Control symbols are separated by comma.") 23 | 24 | return parser.parse_args() 25 | 26 | 27 | def count_words(filename): 28 | counter = collections.Counter() 29 | 30 | with open(filename, "rb") as fd: 31 | for line in fd: 32 | words = line.strip().split() 33 | counter.update(words) 34 | 35 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) 36 | words, counts = list(zip(*count_pairs)) 37 | 38 | return words, counts 39 | 40 | 41 | def control_symbols(string): 42 | if not string: 43 | return [] 44 | else: 45 | symbs = string.strip().split(",") 46 | return [sym.encode("ascii") for sym in symbs] 47 | 48 | 49 | def save_vocab(name, vocab): 50 | if name.split(".")[-1] != "txt": 51 | name = name + ".txt" 52 | 53 | pairs = sorted(vocab.items(), key=lambda x: (x[1], x[0])) 54 | words, _ = list(zip(*pairs)) 55 | 56 | with open(name, "wb") as f: 57 | for word in words: 58 | f.write(word) 59 | f.write("\n".encode("ascii")) 60 | 61 | 62 | def main(args): 63 | vocab = {} 64 | limit = args.limit 65 | count = 0 66 | 67 | words, counts = count_words(args.corpus) 68 | ctl_symbols = control_symbols(args.control) 69 | 70 | for sym in ctl_symbols: 71 | vocab[sym] = len(vocab) 72 | 73 | for word, freq in zip(words, counts): 74 | if limit and len(vocab) >= limit: 75 | break 76 | 77 | if word in vocab: 78 | print("Warning: found duplicate token %s, ignored" % word) 79 | continue 80 | 81 | vocab[word] = len(vocab) 82 | count += freq 83 | 84 | save_vocab(args.output, vocab) 85 | 86 | print("Total words: %d" % sum(counts)) 87 | print("Unique words: %d" % len(words)) 88 | print("Vocabulary coverage: %4.2f%%" % (100.0 * count / sum(counts))) 89 | 90 | 91 | if __name__ == "__main__": 92 | main(parse_args()) 93 | -------------------------------------------------------------------------------- /thumt/scripts/convert_checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2017-2020 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import os 11 | import sys 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | import torch 16 | 17 | 18 | def convert_tensor(variables, name, tensor): 19 | # 1. replace '/' with '.' 20 | name = name.replace("/", ".") 21 | # 2. strip "transformer." 22 | if "transformer" in name: 23 | name = name[12:] 24 | # 3. layer_* -> layers.* 25 | name = name.replace("layer_", "layers.") 26 | name = name.replace("layers.norm", "layer_norm") 27 | # 4. offset -> bias 28 | name = name.replace("offset", "bias") 29 | # 5. scale -> weight 30 | name = name.replace("scale", "weight") 31 | # 6. matrix -> weight, transpose 32 | if "matrix" in name: 33 | name = name.replace("matrix", "weight") 34 | tensor = tensor.transpose() 35 | # 7. multihead_attention -> attention 36 | name = name.replace("multihead_attention", "attention") 37 | variables[name] = torch.tensor(tensor) 38 | 39 | 40 | def main(): 41 | if len(sys.argv) != 3: 42 | print("convert_checkpoint.py input output") 43 | exit(-1) 44 | 45 | var_list = tf.train.list_variables(sys.argv[1]) 46 | variables = {} 47 | reader = tf.train.load_checkpoint(sys.argv[1]) 48 | 49 | for (name, _) in var_list: 50 | tensor = reader.get_tensor(name) 51 | if not name.startswith("transformer") or "Adam" in name: 52 | continue 53 | 54 | if "qkv_transform" in name: 55 | if "matrix" in name: 56 | n1 = name.replace("qkv_transform", "q_transform") 57 | n2 = name.replace("qkv_transform", "k_transform") 58 | n3 = name.replace("qkv_transform", "v_transform") 59 | v1, v2, v3 = np.split(tensor, 3, axis=1) 60 | convert_tensor(variables, n1, v1) 61 | convert_tensor(variables, n2, v2) 62 | convert_tensor(variables, n3, v3) 63 | elif "bias" in name: 64 | n1 = name.replace("qkv_transform", "q_transform") 65 | n2 = name.replace("qkv_transform", "k_transform") 66 | n3 = name.replace("qkv_transform", "v_transform") 67 | v1, v2, v3 = np.split(tensor, 3) 68 | convert_tensor(variables, n1, v1) 69 | convert_tensor(variables, n2, v2) 70 | convert_tensor(variables, n3, v3) 71 | elif "kv_transform" in name: 72 | if "matrix" in name: 73 | n1 = name.replace("kv_transform", "k_transform") 74 | n2 = name.replace("kv_transform", "v_transform") 75 | v1, v2 = np.split(tensor, 2, axis=1) 76 | convert_tensor(variables, n1, v1) 77 | convert_tensor(variables, n2, v2) 78 | elif "bias" in name: 79 | n1 = name.replace("kv_transform", "k_transform") 80 | n2 = name.replace("kv_transform", "v_transform") 81 | v1, v2 = np.split(tensor, 2) 82 | convert_tensor(variables, n1, v1) 83 | convert_tensor(variables, n2, v2) 84 | elif "multihead_attention/output_transform" in name: 85 | name = name.replace("multihead_attention/output_transform", 86 | "multihead_attention/o_transform") 87 | convert_tensor(variables, name, tensor) 88 | elif "ffn_layer/output_layer/linear" in name: 89 | name = name.replace("ffn_layer/output_layer/linear", 90 | "ffn_layer/output_transform") 91 | convert_tensor(variables, name, tensor) 92 | elif "ffn_layer/input_layer/linear" in name: 93 | name = name.replace("ffn_layer/input_layer/linear", 94 | "ffn_layer/input_transform") 95 | convert_tensor(variables, name, tensor) 96 | else: 97 | convert_tensor(variables, name, tensor) 98 | 99 | torch.save({"model": variables}, sys.argv[2]) 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /thumt/scripts/shuffle_corpus.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2017-2020 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import numpy 11 | 12 | 13 | def parseargs(): 14 | parser = argparse.ArgumentParser(description="Shuffle corpus") 15 | 16 | parser.add_argument("--corpus", nargs="+", required=True, 17 | help="input corpora") 18 | parser.add_argument("--suffix", type=str, default="shuf", 19 | help="Suffix of output files") 20 | parser.add_argument("--seed", type=int, help="Random seed") 21 | 22 | return parser.parse_args() 23 | 24 | 25 | def main(args): 26 | name = args.corpus 27 | suffix = "." + args.suffix 28 | stream = [open(item, "rb") for item in name] 29 | data = [fd.readlines() for fd in stream] 30 | minlen = min([len(lines) for lines in data]) 31 | 32 | if args.seed: 33 | numpy.random.seed(args.seed) 34 | 35 | indices = numpy.arange(minlen) 36 | numpy.random.shuffle(indices) 37 | 38 | newstream = [open(item + suffix, "wb") for item in name] 39 | 40 | for idx in indices.tolist(): 41 | lines = [item[idx] for item in data] 42 | 43 | for line, fd in zip(lines, newstream): 44 | fd.write(line) 45 | 46 | for fdr, fdw in zip(stream, newstream): 47 | fdr.close() 48 | fdw.close() 49 | 50 | 51 | if __name__ == "__main__": 52 | main(parseargs()) 53 | -------------------------------------------------------------------------------- /thumt/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from thumt.tokenizers.tokenizer import Tokenizer, WhiteSpaceTokenizer 2 | from thumt.tokenizers.unicode_tokenizer import UnicodeTokenizer 3 | -------------------------------------------------------------------------------- /thumt/tokenizers/tokenizer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from typing import List, NoReturn 4 | 5 | 6 | class Tokenizer(object): 7 | 8 | def __init__(self, name: str): 9 | self._name = name 10 | 11 | @abc.abstractmethod 12 | def __repr__(self) -> NoReturn: 13 | raise NotImplementedError("Tokenizer.__repr__ not implemented.") 14 | 15 | @property 16 | def name(self) -> str: 17 | return self._name 18 | 19 | @abc.abstractmethod 20 | def encode(self, inp: bytes) -> NoReturn: 21 | raise NotImplementedError("Tokenizer.encode not implemented.") 22 | 23 | @abc.abstractmethod 24 | def decode(self, inp: List[bytes]) -> NoReturn: 25 | raise NotImplementedError("Tokenizer.decode not implemented.") 26 | 27 | 28 | class WhiteSpaceTokenizer(Tokenizer): 29 | 30 | def __init__(self): 31 | super(WhiteSpaceTokenizer, self).__init__("WhiteSpaceTokenizer") 32 | 33 | def __repr__(self) -> str: 34 | return "WhiteSpaceTokenizer()" 35 | 36 | def encode(self, inp: bytes) -> List[bytes]: 37 | return inp.strip().split() 38 | 39 | def decode(self, inp: List[bytes]) -> bytes: 40 | return b" ".join(inp) 41 | -------------------------------------------------------------------------------- /thumt/tokenizers/unicode_tokenizer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import json 3 | import base64 4 | import collections 5 | import regex as re 6 | 7 | from typing import List, NoReturn 8 | from thumt.tokenizers.tokenizer import Tokenizer 9 | 10 | 11 | _RULES = [ 12 | # Open/Initial puncutation 13 | [ 14 | ("([\\p{Ps}\\p{Pi}])(.)", "\\1 \\2"), 15 | ("([\\p{Ps}\\p{Pi}]) (.)", "\\1\\2") 16 | ], 17 | # Close/Final puncutation 18 | [ 19 | ("(.)([\\p{Pe}\\p{Pf}])", "\\1 \\2"), 20 | ("(.) ([\\p{Pe}\\p{Pf}])", "\\1\\2") 21 | ], 22 | # Tokenize the following symbols 23 | [ 24 | ("([|~\\\\^_`#&*+<=>@/\\-])", " \\1 "), 25 | ("[ ]?([|~\\\\^_`#&*+<=>@/\\-])[ ]?", "\\1"), 26 | ], 27 | # Tokenize colon 28 | [ 29 | ("([\\p{L}]): ", "\\1 : "), 30 | ("([\\p{L}]) : ", "\\1: ") 31 | ], 32 | # Tokenize period and comma 33 | [ 34 | ("(.)([\\.,!?;]) ", "\\1 \\2 "), 35 | ("(.) ([\\.,!?;]) ", "\\1\\2 "), 36 | ], 37 | # Tokenize period and comma at end of the input 38 | [ 39 | ("(.)([\\.,!?;])$", "\\1 \\2"), 40 | ("(.) ([\\.,!?;])$", "\\1\\2"), 41 | ], 42 | # Tokenize quotation mark 43 | [ 44 | ("([\\p{L}])\"([\\p{L}])", "\\1 \\2"), 45 | ("([\\p{L}]) ([\\p{L}])", "\\1\"\\2"), 46 | ], 47 | [ 48 | ("\"([\\p{L}\\p{N}])", " \\1"), 49 | (" ([\\p{L}\\p{N}])", "\"\\1"), 50 | ], 51 | [ 52 | ("([\\p{L}\\p{N}])\"", "\\1 "), 53 | ("([\\p{L}\\p{N}]) ", "\\1\""), 54 | ], 55 | # Tokenize Apostrophe 56 | [ 57 | ("([\\p{L}])'([\\p{L}])", "\\1 \\2"), 58 | ("([\\p{L}]) ([\\p{L}])", "\\1'\\2"), 59 | ], 60 | [ 61 | ("'([\\p{L}])", " \\1"), 62 | (" ([\\p{L}])", "\"\\1"), 63 | ], 64 | [ 65 | ("([\\p{L}])'", "\\1 "), 66 | ("([\\p{L}]) ", "\\1\""), 67 | ], 68 | # Replace control/separators with space 69 | [ 70 | ("[\\p{C}\\p{Z}]+", " ") 71 | ], 72 | # Remove starting space 73 | [ 74 | ("^ (.*)", "\\1") 75 | ], 76 | # Remove trailing space 77 | [ 78 | ("(.*) $", "\\1") 79 | ] 80 | ] 81 | 82 | _TOKEN_PATTERNS = [re.compile(rule[0][0]) for rule in _RULES] 83 | _TOKEN_REPL = [rule[0][1] for rule in _RULES] 84 | 85 | _DETOKEN_PATTERNS = [ 86 | re.compile(rule[1][0]) if len(rule) == 2 else None for rule in _RULES 87 | ][::-1] 88 | _DETOKEN_REPL = [ 89 | rule[1][1] if len(rule) == 2 else None for rule in _RULES 90 | ][::-1] 91 | 92 | 93 | class UnicodeTokenizer(Tokenizer): 94 | 95 | def __init__(self, name="unicode_tokenizer"): 96 | super(UnicodeTokenizer, self).__init__() 97 | 98 | def encode(self, inp: bytes) -> List[bytes]: 99 | inp_str = inp 100 | for pat, repl in zip(_TOKEN_PATTERNS, _TOKEN_REPL): 101 | input_str = re.sub(pat, repl, input_str) 102 | 103 | return input_str 104 | 105 | def decode(self, inp: List[bytes]) -> bytes: 106 | input_str = b" ".join(inp) 107 | 108 | for pat, repl in zip(_DETOKEN_PATTERNS, _DETOKEN_REPL): 109 | if not pat: 110 | continue 111 | input_str = re.sub(pat, repl, input_str) 112 | 113 | return input_str 114 | -------------------------------------------------------------------------------- /thumt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from thumt.utils.hparams import HParams 2 | from thumt.utils.inference import beam_search, argmax_decoding 3 | from thumt.utils.evaluation import evaluate 4 | from thumt.utils.checkpoint import save, latest_checkpoint 5 | from thumt.utils.scope import scope, get_scope, unique_name 6 | from thumt.utils.misc import get_global_step, set_global_step 7 | from thumt.utils.convert_params import params_to_vec, vec_to_params 8 | -------------------------------------------------------------------------------- /thumt/utils/bleu.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | 10 | from collections import Counter 11 | 12 | 13 | def closest_length(candidate, references): 14 | clen = len(candidate) 15 | closest_diff = 9999 16 | closest_len = 9999 17 | 18 | for reference in references: 19 | rlen = len(reference) 20 | diff = abs(rlen - clen) 21 | 22 | if diff < closest_diff: 23 | closest_diff = diff 24 | closest_len = rlen 25 | elif diff == closest_diff: 26 | closest_len = rlen if rlen < closest_len else closest_len 27 | 28 | return closest_len 29 | 30 | 31 | def shortest_length(references): 32 | return min([len(ref) for ref in references]) 33 | 34 | 35 | def modified_precision(candidate, references, n): 36 | tngrams = len(candidate) + 1 - n 37 | counts = Counter([tuple(candidate[i:i+n]) for i in range(tngrams)]) 38 | 39 | if len(counts) == 0: 40 | return 0, 0 41 | 42 | max_counts = {} 43 | for reference in references: 44 | rngrams = len(reference) + 1 - n 45 | ngrams = [tuple(reference[i:i+n]) for i in range(rngrams)] 46 | ref_counts = Counter(ngrams) 47 | for ngram in counts: 48 | mcount = 0 if ngram not in max_counts else max_counts[ngram] 49 | rcount = 0 if ngram not in ref_counts else ref_counts[ngram] 50 | max_counts[ngram] = max(mcount, rcount) 51 | 52 | clipped_counts = {} 53 | 54 | for ngram, count in counts.items(): 55 | clipped_counts[ngram] = min(count, max_counts[ngram]) 56 | 57 | return float(sum(clipped_counts.values())), float(sum(counts.values())) 58 | 59 | 60 | def brevity_penalty(trans, refs, mode="closest"): 61 | bp_c = 0.0 62 | bp_r = 0.0 63 | 64 | for candidate, references in zip(trans, refs): 65 | bp_c += len(candidate) 66 | 67 | if mode == "shortest": 68 | bp_r += shortest_length(references) 69 | else: 70 | bp_r += closest_length(candidate, references) 71 | 72 | # Prevent zero divide 73 | bp_c = bp_c or 1.0 74 | 75 | return math.exp(min(0, 1.0 - bp_r / bp_c)) 76 | 77 | 78 | def bleu(trans, refs, bp="closest", smooth=False, n=4, weights=None): 79 | p_norm = [0 for _ in range(n)] 80 | p_denorm = [0 for _ in range(n)] 81 | 82 | for candidate, references in zip(trans, refs): 83 | for i in range(n): 84 | ccount, tcount = modified_precision(candidate, references, i + 1) 85 | p_norm[i] += ccount 86 | p_denorm[i] += tcount 87 | 88 | bleu_n = [0 for _ in range(n)] 89 | 90 | for i in range(n): 91 | # add one smoothing 92 | if smooth and i > 0: 93 | p_norm[i] += 1 94 | p_denorm[i] += 1 95 | 96 | if p_norm[i] == 0 or p_denorm[i] == 0: 97 | bleu_n[i] = -9999 98 | else: 99 | bleu_n[i] = math.log(float(p_norm[i]) / float(p_denorm[i])) 100 | 101 | if weights: 102 | if len(weights) != n: 103 | raise ValueError("len(weights) != n: invalid weight number") 104 | log_precision = sum([bleu_n[i] * weights[i] for i in range(n)]) 105 | else: 106 | log_precision = sum(bleu_n) / float(n) 107 | 108 | bp = brevity_penalty(trans, refs, bp) 109 | 110 | score = bp * math.exp(log_precision) 111 | 112 | return score 113 | -------------------------------------------------------------------------------- /thumt/utils/bpe.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | # Modified from subword-nmt 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import re 10 | 11 | 12 | class BPE(object): 13 | 14 | def __init__(self, bpe_path, merges=-1, separator="@@"): 15 | with open(bpe_path, "r", encoding="utf-8") as fd: 16 | firstline = fd.readline() 17 | 18 | if not firstline.startswith("#version:"): 19 | raise ValueError("THUMT only support BPE version >= 0.2.") 20 | 21 | codes = tuple([item.strip("\r\n").split(" ") 22 | for (n, item) in enumerate(fd) 23 | if (n < merges or merges == -1)]) 24 | 25 | for _, item in enumerate(codes): 26 | if len(item) != 2: 27 | raise ValueError("Error: invalid BPE codes found.") 28 | 29 | self._codes = {} 30 | 31 | for (i, code) in enumerate(codes): 32 | if tuple(code) not in self._codes: 33 | self._codes[tuple(code)] = i 34 | 35 | self._separator = separator 36 | 37 | def _get_pairs(self, word): 38 | pairs = set() 39 | prev_char = word[0] 40 | 41 | for char in word[1:]: 42 | pairs.add((prev_char, char)) 43 | prev_char = char 44 | 45 | return pairs 46 | 47 | 48 | def _encode_word(self, orig): 49 | word = tuple(orig[:-1]) + (orig[-1] + "",) 50 | pairs = self._get_pairs(word) 51 | 52 | if not pairs: 53 | return (orig,) 54 | 55 | while True: 56 | bigram = min(pairs, key=lambda x: self._codes.get(x, float("inf"))) 57 | 58 | if bigram not in self._codes: 59 | break 60 | 61 | first, second = bigram 62 | new_word = [] 63 | 64 | i = 0 65 | 66 | while i < len(word): 67 | try: 68 | j = word.index(first, i) 69 | new_word.extend(word[i:j]) 70 | i = j 71 | except: 72 | new_word.extend(word[i:]) 73 | break 74 | 75 | if word[i] == first and word[i + 1] == second: 76 | if i < len(word) - 1: 77 | new_word.append(first + second) 78 | i += 2 79 | else: 80 | new_word.append(word[i]) 81 | i += 1 82 | else: 83 | new_word.append(word[i]) 84 | i += 1 85 | 86 | new_word = tuple(new_word) 87 | word = new_word 88 | 89 | if len(word) == 1: 90 | break 91 | else: 92 | pairs = self._get_pairs(word) 93 | 94 | if word[-1] == "": 95 | word = word[:-1] 96 | elif word[-1].endswith(""): 97 | word = word[:-1] + (word[-1].replace("", ""),) 98 | 99 | return word 100 | 101 | def encode(self, s): 102 | words = s.strip().split() 103 | output = [] 104 | 105 | for word in words: 106 | if not word: 107 | continue 108 | 109 | new_word = self._encode_word(word) 110 | 111 | for item in new_word[:-1]: 112 | output.append(item + self._separator) 113 | 114 | output.append(new_word[-1]) 115 | 116 | return output 117 | 118 | @staticmethod 119 | def decode(s): 120 | if isinstance(s, str): 121 | return re.sub("(@@ )|(@@ ?$)", "", s) 122 | else: 123 | return re.sub(b"(@@ )|(@@ ?$)", b"", s) 124 | -------------------------------------------------------------------------------- /thumt/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import glob 10 | import torch 11 | 12 | 13 | def oldest_checkpoint(path): 14 | names = glob.glob(os.path.join(path, "*.pt")) 15 | 16 | if not names: 17 | return None 18 | 19 | oldest_counter = 10000000 20 | checkpoint_name = names[0] 21 | 22 | for name in names: 23 | counter = name.rstrip(".pt").split("-")[-1] 24 | 25 | if not counter.isdigit(): 26 | continue 27 | else: 28 | counter = int(counter) 29 | 30 | if counter < oldest_counter: 31 | checkpoint_name = name 32 | oldest_counter = counter 33 | 34 | return checkpoint_name 35 | 36 | 37 | def latest_checkpoint(path): 38 | names = glob.glob(os.path.join(path, "*.pt")) 39 | 40 | if not names: 41 | return None 42 | 43 | latest_counter = 0 44 | checkpoint_name = names[0] 45 | 46 | for name in names: 47 | counter = name.rstrip(".pt").split("-")[-1] 48 | 49 | if not counter.isdigit(): 50 | continue 51 | else: 52 | counter = int(counter) 53 | 54 | if counter > latest_counter: 55 | checkpoint_name = name 56 | latest_counter = counter 57 | 58 | return checkpoint_name 59 | 60 | 61 | def save(state, path, max_to_keep=None): 62 | checkpoints = glob.glob(os.path.join(path, "*.pt")) 63 | 64 | if not checkpoints: 65 | counter = 1 66 | else: 67 | checkpoint = latest_checkpoint(path) 68 | counter = int(checkpoint.rstrip(".pt").split("-")[-1]) + 1 69 | 70 | if max_to_keep and len(checkpoints) >= max_to_keep: 71 | checkpoint = oldest_checkpoint(path) 72 | os.remove(checkpoint) 73 | 74 | checkpoint = os.path.join(path, "model-%d.pt" % counter) 75 | print("Saving checkpoint: %s" % checkpoint) 76 | torch.save(state, checkpoint) 77 | -------------------------------------------------------------------------------- /thumt/utils/convert_params.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | # Modified from torch.nn.utils.convert_parameters.py 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import torch 10 | 11 | 12 | def params_to_vec(parameters): 13 | r"""Convert parameters to one vector 14 | 15 | Arguments: 16 | parameters (Iterable[Tensor]): an iterator of Tensors that are the 17 | parameters of a model. 18 | 19 | Returns: 20 | The parameters represented by a single vector 21 | """ 22 | 23 | # Flag for the device where the parameter is located 24 | param_device = None 25 | vec = [] 26 | 27 | for param in parameters: 28 | if param is None: 29 | continue 30 | 31 | # Ensure the parameters are located in the same device 32 | param_device = _check_param_device(param, param_device) 33 | vec.append(param.view(-1)) 34 | 35 | return torch.cat(vec) 36 | 37 | 38 | def vec_to_params(vec, parameters): 39 | r"""Convert one vector to the parameters 40 | 41 | Arguments: 42 | vec (Tensor): a single vector represents the parameters of a model. 43 | parameters (Iterable[Tensor]): an iterator of Tensors that are the 44 | parameters of a model. 45 | """ 46 | 47 | # Ensure vec of type Tensor 48 | if not isinstance(vec, torch.Tensor): 49 | raise TypeError("expected torch.Tensor, but got: {}" 50 | .format(torch.typename(vec))) 51 | 52 | # Flag for the device where the parameter is located 53 | param_device = None 54 | 55 | # Pointer for slicing the vector for each parameter 56 | pointer = 0 57 | 58 | for param in parameters: 59 | if param is None: 60 | continue 61 | 62 | # Ensure the parameters are located in the same device 63 | param_device = _check_param_device(param, param_device) 64 | 65 | # The length of the parameter 66 | num_param = param.numel() 67 | 68 | # Slice the vector, reshape it, and replace the old data of the parameter 69 | param.data = vec[pointer:pointer + num_param].view_as(param).data 70 | 71 | # Increment the pointer 72 | pointer += num_param 73 | 74 | 75 | def _check_param_device(param, old_param_device): 76 | r"""This helper function is to check if the parameters are located 77 | in the same device. Currently, the conversion between model parameters 78 | and single vector form is not supported for multiple allocations, 79 | e.g. parameters in different GPUs, or mixture of CPU/GPU. 80 | 81 | Arguments: 82 | param ([Tensor]): a Tensor of a parameter of a model 83 | old_param_device (int): the device where the first parameter of a 84 | model is allocated. 85 | 86 | Returns: 87 | old_param_device (int): report device for the first time 88 | """ 89 | 90 | # Meet the first parameter 91 | if old_param_device is None: 92 | old_param_device = param.get_device() if param.is_cuda else -1 93 | else: 94 | warn = False 95 | 96 | if param.is_cuda: # Check if in same GPU 97 | warn = (param.get_device() != old_param_device) 98 | else: # Check if in CPU 99 | warn = (old_param_device != -1) 100 | 101 | if warn: 102 | raise TypeError("Found two parameters on different devices," 103 | " this is currently not supported.") 104 | 105 | return old_param_device 106 | -------------------------------------------------------------------------------- /thumt/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import datetime 9 | import glob 10 | import operator 11 | import os 12 | import shutil 13 | import time 14 | import torch 15 | 16 | import torch.distributed as dist 17 | 18 | from thumt.utils.checkpoint import save, latest_checkpoint 19 | from thumt.utils.inference import beam_search 20 | from thumt.utils.bleu import bleu 21 | from thumt.utils.bpe import BPE 22 | from thumt.utils.misc import get_global_step 23 | from thumt.utils.summary import scalar 24 | 25 | 26 | def _save_log(filename, result): 27 | metric, global_step, score = result 28 | 29 | with open(filename, "a") as fd: 30 | time = datetime.datetime.now() 31 | msg = "%s: %s at step %d: %f\n" % (time, metric, global_step, score) 32 | fd.write(msg) 33 | 34 | 35 | def _read_score_record(filename): 36 | # "checkpoint_name": score 37 | records = [] 38 | 39 | if not os.path.exists(filename): 40 | return records 41 | 42 | with open(filename) as fd: 43 | for line in fd: 44 | name, score = line.strip().split(":") 45 | name = name.strip()[1:-1] 46 | score = float(score) 47 | records.append([name, score]) 48 | 49 | return records 50 | 51 | 52 | def _save_score_record(filename, records): 53 | keys = [] 54 | 55 | for record in records: 56 | checkpoint_name = record[0] 57 | step = int(checkpoint_name.strip().split("-")[-1].rstrip(".pt")) 58 | keys.append((step, record)) 59 | 60 | sorted_keys = sorted(keys, key=operator.itemgetter(0), 61 | reverse=True) 62 | sorted_records = [item[1] for item in sorted_keys] 63 | 64 | with open(filename, "w") as fd: 65 | for record in sorted_records: 66 | checkpoint_name, score = record 67 | fd.write("\"%s\": %f\n" % (checkpoint_name, score)) 68 | 69 | 70 | def _add_to_record(records, record, max_to_keep): 71 | added = None 72 | removed = None 73 | models = {} 74 | 75 | for (name, score) in records: 76 | models[name] = score 77 | 78 | if len(records) < max_to_keep: 79 | if record[0] not in models: 80 | added = record[0] 81 | records.append(record) 82 | else: 83 | sorted_records = sorted(records, key=lambda x: -x[1]) 84 | worst_score = sorted_records[-1][1] 85 | current_score = record[1] 86 | 87 | if current_score >= worst_score: 88 | if record[0] not in models: 89 | added = record[0] 90 | removed = sorted_records[-1][0] 91 | records = sorted_records[:-1] + [record] 92 | 93 | # Sort 94 | records = sorted(records, key=lambda x: -x[1]) 95 | 96 | return added, removed, records 97 | 98 | 99 | def _convert_to_string(tensor, params, direction="target"): 100 | ids = tensor.tolist() 101 | 102 | output = [] 103 | 104 | eos_id = params.vocabulary[direction][params.eos] 105 | 106 | for wid in ids: 107 | if wid == eos_id: 108 | break 109 | output.append(params.vocabulary[direction][wid]) 110 | 111 | output = b" ".join(output) 112 | 113 | return output 114 | 115 | 116 | def _evaluate_model(model, sorted_key, dataset, references, params): 117 | # Create model 118 | with torch.no_grad(): 119 | model.eval() 120 | 121 | iterator = iter(dataset) 122 | counter = 0 123 | pad_max = 1024 124 | 125 | # Buffers for synchronization 126 | size = torch.zeros([dist.get_world_size()]).long() 127 | t_list = [torch.empty([params.decode_batch_size, pad_max]).long() 128 | for _ in range(dist.get_world_size())] 129 | results = [] 130 | 131 | while True: 132 | try: 133 | features = next(iterator) 134 | batch_size = features["source"].shape[0] 135 | except: 136 | features = { 137 | "source": torch.ones([1, 1]).long(), 138 | "source_mask": torch.ones([1, 1]).float() 139 | } 140 | batch_size = 0 141 | 142 | t = time.time() 143 | counter += 1 144 | 145 | # Decode 146 | seqs, _ = beam_search([model], features, params) 147 | 148 | # Padding 149 | seqs = torch.squeeze(seqs, dim=1) 150 | pad_batch = params.decode_batch_size - seqs.shape[0] 151 | pad_length = pad_max - seqs.shape[1] 152 | seqs = torch.nn.functional.pad(seqs, (0, pad_length, 0, pad_batch)) 153 | 154 | # Synchronization 155 | size.zero_() 156 | size[dist.get_rank()].copy_(torch.tensor(batch_size)) 157 | dist.all_reduce(size) 158 | dist.all_gather(t_list, seqs) 159 | 160 | if size.sum() == 0: 161 | break 162 | 163 | if dist.get_rank() != 0: 164 | continue 165 | 166 | for i in range(params.decode_batch_size): 167 | for j in range(dist.get_world_size()): 168 | n = size[j] 169 | seq = _convert_to_string(t_list[j][i], params) 170 | 171 | if i >= n: 172 | continue 173 | 174 | # Restore BPE segmentation 175 | seq = BPE.decode(seq) 176 | 177 | results.append(seq.split()) 178 | 179 | t = time.time() - t 180 | print("Finished batch: %d (%.3f sec)" % (counter, t)) 181 | 182 | model.train() 183 | 184 | if dist.get_rank() == 0: 185 | restored_results = [] 186 | 187 | for idx in range(len(results)): 188 | restored_results.append(results[sorted_key[idx]]) 189 | 190 | return bleu(restored_results, references) 191 | 192 | return 0.0 193 | 194 | 195 | def evaluate(model, sorted_key, dataset, base_dir, references, params): 196 | if not references: 197 | return 198 | 199 | base_dir = base_dir.rstrip("/") 200 | save_path = os.path.join(base_dir, "eval") 201 | record_name = os.path.join(save_path, "record") 202 | log_name = os.path.join(save_path, "log") 203 | max_to_keep = params.keep_top_checkpoint_max 204 | 205 | if dist.get_rank() == 0: 206 | # Create directory and copy files 207 | if not os.path.exists(save_path): 208 | print("Making dir: %s" % save_path) 209 | os.makedirs(save_path) 210 | 211 | params_pattern = os.path.join(base_dir, "*.json") 212 | params_files = glob.glob(params_pattern) 213 | 214 | for name in params_files: 215 | new_name = name.replace(base_dir, save_path) 216 | shutil.copy(name, new_name) 217 | 218 | # Do validation here 219 | global_step = get_global_step() 220 | 221 | if dist.get_rank() == 0: 222 | print("Validating model at step %d" % global_step) 223 | 224 | score = _evaluate_model(model, sorted_key, dataset, references, params) 225 | 226 | # Save records 227 | if dist.get_rank() == 0: 228 | scalar("BLEU/score", score, global_step, write_every_n_steps=1) 229 | print("BLEU at step %d: %f" % (global_step, score)) 230 | 231 | # Save checkpoint to save_path 232 | save({"model": model.state_dict(), "step": global_step}, save_path) 233 | 234 | _save_log(log_name, ("BLEU", global_step, score)) 235 | records = _read_score_record(record_name) 236 | record = [latest_checkpoint(save_path).split("/")[-1], score] 237 | 238 | added, removed, records = _add_to_record(records, record, max_to_keep) 239 | 240 | if added is None: 241 | # Remove latest checkpoint 242 | filename = latest_checkpoint(save_path) 243 | print("Removing %s" % filename) 244 | files = glob.glob(filename + "*") 245 | 246 | for name in files: 247 | os.remove(name) 248 | 249 | if removed is not None: 250 | filename = os.path.join(save_path, removed) 251 | print("Removing %s" % filename) 252 | files = glob.glob(filename + "*") 253 | 254 | for name in files: 255 | os.remove(name) 256 | 257 | _save_score_record(record_name, records) 258 | 259 | best_score = records[0][1] 260 | print("Best score at step %d: %f" % (global_step, best_score)) 261 | -------------------------------------------------------------------------------- /thumt/utils/hparams.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | # Modified from TensorFlow (tf.contrib.training.HParams) 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import json 10 | import logging 11 | import re 12 | import six 13 | 14 | 15 | def parse_values(values, type_map): 16 | ret = {} 17 | param_re = re.compile(r"(?P[a-zA-Z][\w]*)\s*=\s*" 18 | r"((?P[^,\[]*)|\[(?P[^\]]*)\])($|,)") 19 | pos = 0 20 | 21 | while pos < len(values): 22 | m = param_re.match(values, pos) 23 | 24 | if not m: 25 | raise ValueError( 26 | "Malformed hyperparameter value: %s" % values[pos:]) 27 | 28 | # Check that there is a comma between parameters and move past it. 29 | pos = m.end() 30 | # Parse the values. 31 | m_dict = m.groupdict() 32 | name = m_dict["name"] 33 | 34 | if name not in type_map: 35 | raise ValueError("Unknown hyperparameter type for %s" % name) 36 | 37 | def parse_fail(): 38 | raise ValueError("Could not parse hparam %s in %s" % (name, values)) 39 | 40 | if type_map[name] == bool: 41 | def parse_bool(value): 42 | if value == "true": 43 | return True 44 | elif value == "false": 45 | return False 46 | else: 47 | try: 48 | return bool(int(value)) 49 | except ValueError: 50 | parse_fail() 51 | parse = parse_bool 52 | else: 53 | parse = type_map[name] 54 | 55 | 56 | if m_dict["val"] is not None: 57 | try: 58 | ret[name] = parse(m_dict["val"]) 59 | except ValueError: 60 | parse_fail() 61 | elif m_dict["vals"] is not None: 62 | elements = filter(None, re.split("[ ,]", m_dict["vals"])) 63 | try: 64 | ret[name] = [parse(e) for e in elements] 65 | except ValueError: 66 | parse_fail() 67 | else: 68 | parse_fail() 69 | 70 | return ret 71 | 72 | 73 | class HParams(object): 74 | 75 | def __init__(self, **kwargs): 76 | self._hparam_types = {} 77 | 78 | for name, value in six.iteritems(kwargs): 79 | self.add_hparam(name, value) 80 | 81 | def add_hparam(self, name, value): 82 | if getattr(self, name, None) is not None: 83 | raise ValueError("Hyperparameter name is reserved: %s" % name) 84 | if isinstance(value, (list, tuple)): 85 | if not value: 86 | raise ValueError("Multi-valued hyperparameters cannot be" 87 | " empty: %s" % name) 88 | self._hparam_types[name] = (type(value[0]), True) 89 | else: 90 | self._hparam_types[name] = (type(value), False) 91 | setattr(self, name, value) 92 | 93 | def parse(self, values): 94 | type_map = dict() 95 | 96 | for name, t in six.iteritems(self._hparam_types): 97 | param_type, _ = t 98 | type_map[name] = param_type 99 | 100 | values_map = parse_values(values, type_map) 101 | return self._set_from_map(values_map) 102 | 103 | def _set_from_map(self, values_map): 104 | for name, value in six.iteritems(values_map): 105 | if name not in self._hparam_types: 106 | logging.debug("%s not found in hparams." % name) 107 | continue 108 | 109 | _, is_list = self._hparam_types[name] 110 | 111 | if isinstance(value, list): 112 | if not is_list: 113 | raise ValueError("Must not pass a list for single-valued " 114 | "parameter: %s" % name) 115 | setattr(self, name, value) 116 | else: 117 | if is_list: 118 | raise ValueError("Must pass a list for multi-valued " 119 | "parameter: %s" % name) 120 | setattr(self, name, value) 121 | return self 122 | 123 | def to_json(self): 124 | return json.dumps(self.values()) 125 | 126 | def parse_json(self, values_json): 127 | values_map = json.loads(values_json) 128 | return self._set_from_map(values_map) 129 | 130 | def values(self): 131 | return {n: getattr(self, n) for n in six.iterkeys(self._hparam_types)} 132 | 133 | def __str__(self): 134 | return str(sorted(six.iteritems(self.values()))) 135 | -------------------------------------------------------------------------------- /thumt/utils/inference.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | import torch 10 | 11 | from collections import namedtuple 12 | from thumt.utils.nest import map_structure 13 | 14 | 15 | def _merge_first_two_dims(tensor): 16 | shape = list(tensor.shape) 17 | shape[1] *= shape[0] 18 | return torch.reshape(tensor, shape[1:]) 19 | 20 | 21 | def _split_first_two_dims(tensor, dim_0, dim_1): 22 | shape = [dim_0, dim_1] + list(tensor.shape)[1:] 23 | return torch.reshape(tensor, shape) 24 | 25 | 26 | def _tile_to_beam_size(tensor, beam_size): 27 | tensor = torch.unsqueeze(tensor, 1) 28 | tile_dims = [1] * int(tensor.dim()) 29 | tile_dims[1] = beam_size 30 | 31 | return tensor.repeat(tile_dims) 32 | 33 | 34 | def _gather_2d(params, indices, name=None): 35 | batch_size = params.shape[0] 36 | range_size = indices.shape[1] 37 | batch_pos = torch.arange(batch_size * range_size, device=params.device) 38 | batch_pos = batch_pos // range_size 39 | batch_pos = torch.reshape(batch_pos, [batch_size, range_size]) 40 | output = params[batch_pos, indices] 41 | 42 | return output 43 | 44 | 45 | class BeamSearchState(namedtuple("BeamSearchState", 46 | ("inputs", "state", "finish"))): 47 | pass 48 | 49 | 50 | def _get_inference_fn(model_fns, features): 51 | def inference_fn(inputs, state): 52 | local_features = { 53 | "source": features["source"], 54 | "source_mask": features["source_mask"], 55 | "target": inputs, 56 | "target_mask": torch.ones(*inputs.shape).to(inputs).float() 57 | } 58 | 59 | outputs = [] 60 | next_state = [] 61 | 62 | for (model_fn, model_state) in zip(model_fns, state): 63 | if model_state: 64 | logits, new_state = model_fn(local_features, model_state) 65 | outputs.append(torch.nn.functional.log_softmax(logits, 66 | dim=-1)) 67 | next_state.append(new_state) 68 | else: 69 | logits = model_fn(local_features) 70 | outputs.append(torch.nn.functional.log_softmax(logits, 71 | dim=-1)) 72 | next_state.append({}) 73 | 74 | # Ensemble 75 | log_prob = sum(outputs) / float(len(outputs)) 76 | 77 | return log_prob.float(), next_state 78 | 79 | return inference_fn 80 | 81 | 82 | def _beam_search_step(time, func, state, batch_size, beam_size, alpha, 83 | pad_id, eos_id, min_length, max_length, inf=-1e9): 84 | # Compute log probabilities 85 | seqs, log_probs = state.inputs[:2] 86 | flat_seqs = _merge_first_two_dims(seqs) 87 | flat_state = map_structure(lambda x: _merge_first_two_dims(x), state.state) 88 | step_log_probs, next_state = func(flat_seqs, flat_state) 89 | step_log_probs = _split_first_two_dims(step_log_probs, batch_size, 90 | beam_size) 91 | next_state = map_structure( 92 | lambda x: _split_first_two_dims(x, batch_size, beam_size), next_state) 93 | curr_log_probs = torch.unsqueeze(log_probs, 2) + step_log_probs 94 | 95 | # Apply length penalty 96 | length_penalty = ((5.0 + float(time + 1)) / 6.0) ** alpha 97 | curr_scores = curr_log_probs / length_penalty 98 | vocab_size = curr_scores.shape[-1] 99 | 100 | # Prevent null translation 101 | min_length_flags = torch.ge(min_length, time + 1).float().mul_(inf) 102 | curr_scores[:, :, eos_id].add_(min_length_flags) 103 | 104 | # Select top-k candidates 105 | # [batch_size, beam_size * vocab_size] 106 | curr_scores = torch.reshape(curr_scores, [-1, beam_size * vocab_size]) 107 | # [batch_size, 2 * beam_size] 108 | top_scores, top_indices = torch.topk(curr_scores, k=2*beam_size) 109 | # Shape: [batch_size, 2 * beam_size] 110 | beam_indices = top_indices // vocab_size 111 | symbol_indices = top_indices % vocab_size 112 | # Expand sequences 113 | # [batch_size, 2 * beam_size, time] 114 | candidate_seqs = _gather_2d(seqs, beam_indices) 115 | candidate_seqs = torch.cat([candidate_seqs, 116 | torch.unsqueeze(symbol_indices, 2)], 2) 117 | 118 | # Expand sequences 119 | # Suppress finished sequences 120 | flags = torch.eq(symbol_indices, eos_id).to(torch.bool) 121 | # [batch, 2 * beam_size] 122 | alive_scores = top_scores + flags.to(torch.float32) * inf 123 | # [batch, beam_size] 124 | alive_scores, alive_indices = torch.topk(alive_scores, beam_size) 125 | alive_symbols = _gather_2d(symbol_indices, alive_indices) 126 | alive_indices = _gather_2d(beam_indices, alive_indices) 127 | alive_seqs = _gather_2d(seqs, alive_indices) 128 | # [batch_size, beam_size, time + 1] 129 | alive_seqs = torch.cat([alive_seqs, torch.unsqueeze(alive_symbols, 2)], 2) 130 | alive_state = map_structure( 131 | lambda x: _gather_2d(x, alive_indices), 132 | next_state) 133 | alive_log_probs = alive_scores * length_penalty 134 | # Check length constraint 135 | length_flags = torch.le(max_length, time + 1).float() 136 | alive_log_probs = alive_log_probs + length_flags * inf 137 | alive_scores = alive_scores + length_flags * inf 138 | 139 | # Select finished sequences 140 | prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish 141 | # [batch, 2 * beam_size] 142 | step_fin_scores = top_scores + (1.0 - flags.to(torch.float32)) * inf 143 | # [batch, 3 * beam_size] 144 | fin_flags = torch.cat([prev_fin_flags, flags], dim=1) 145 | fin_scores = torch.cat([prev_fin_scores, step_fin_scores], dim=1) 146 | # [batch, beam_size] 147 | fin_scores, fin_indices = torch.topk(fin_scores, beam_size) 148 | fin_flags = _gather_2d(fin_flags, fin_indices) 149 | pad_seqs = prev_fin_seqs.new_full([batch_size, beam_size, 1], pad_id) 150 | prev_fin_seqs = torch.cat([prev_fin_seqs, pad_seqs], dim=2) 151 | fin_seqs = torch.cat([prev_fin_seqs, candidate_seqs], dim=1) 152 | fin_seqs = _gather_2d(fin_seqs, fin_indices) 153 | 154 | new_state = BeamSearchState( 155 | inputs=(alive_seqs, alive_log_probs, alive_scores), 156 | state=alive_state, 157 | finish=(fin_flags, fin_seqs, fin_scores), 158 | ) 159 | 160 | return new_state 161 | 162 | 163 | def beam_search(models, features, params): 164 | if not isinstance(models, (list, tuple)): 165 | raise ValueError("'models' must be a list or tuple") 166 | 167 | beam_size = params.beam_size 168 | top_beams = params.top_beams 169 | alpha = params.decode_alpha 170 | decode_ratio = params.decode_ratio 171 | decode_length = params.decode_length 172 | 173 | pad_id = params.vocabulary["target"][params.pad] 174 | bos_id = params.vocabulary["target"][params.bos] 175 | eos_id = params.vocabulary["target"][params.eos] 176 | 177 | min_val = -1e9 178 | shape = features["source"].shape 179 | device = features["source"].device 180 | batch_size = shape[0] 181 | seq_length = shape[1] 182 | 183 | # Compute initial state if necessary 184 | states = [] 185 | funcs = [] 186 | 187 | for model in models: 188 | state = model.empty_state(batch_size, device) 189 | states.append(model.encode(features, state)) 190 | funcs.append(model.decode) 191 | 192 | # For source sequence length 193 | max_length = features["source_mask"].sum(1) * decode_ratio 194 | max_length = max_length.long() + decode_length 195 | max_step = max_length.max() 196 | # [batch, beam_size] 197 | max_length = torch.unsqueeze(max_length, 1).repeat([1, beam_size]) 198 | min_length = torch.ones_like(max_length) 199 | 200 | # Expand the inputs 201 | # [batch, length] => [batch * beam_size, length] 202 | features["source"] = torch.unsqueeze(features["source"], 1) 203 | features["source"] = features["source"].repeat([1, beam_size, 1]) 204 | features["source"] = torch.reshape(features["source"], 205 | [batch_size * beam_size, seq_length]) 206 | features["source_mask"] = torch.unsqueeze(features["source_mask"], 1) 207 | features["source_mask"] = features["source_mask"].repeat([1, beam_size, 1]) 208 | features["source_mask"] = torch.reshape(features["source_mask"], 209 | [batch_size * beam_size, seq_length]) 210 | 211 | decoding_fn = _get_inference_fn(funcs, features) 212 | 213 | states = map_structure( 214 | lambda x: _tile_to_beam_size(x, beam_size), 215 | states) 216 | 217 | # Initial beam search state 218 | init_seqs = torch.full([batch_size, beam_size, 1], bos_id, device=device) 219 | init_seqs = init_seqs.long() 220 | init_log_probs = init_seqs.new_tensor( 221 | [[0.] + [min_val] * (beam_size - 1)], dtype=torch.float32) 222 | init_log_probs = init_log_probs.repeat([batch_size, 1]) 223 | init_scores = torch.zeros_like(init_log_probs) 224 | fin_seqs = torch.zeros([batch_size, beam_size, 1], dtype=torch.int64, 225 | device=device) 226 | fin_scores = torch.full([batch_size, beam_size], min_val, 227 | dtype=torch.float32, device=device) 228 | fin_flags = torch.zeros([batch_size, beam_size], dtype=torch.bool, 229 | device=device) 230 | 231 | state = BeamSearchState( 232 | inputs=(init_seqs, init_log_probs, init_scores), 233 | state=states, 234 | finish=(fin_flags, fin_seqs, fin_scores), 235 | ) 236 | 237 | for time in range(max_step): 238 | state = _beam_search_step(time, decoding_fn, state, batch_size, 239 | beam_size, alpha, pad_id, eos_id, 240 | min_length, max_length) 241 | max_penalty = ((5.0 + max_step) / 6.0) ** alpha 242 | best_alive_score = torch.max(state.inputs[1][:, 0] / max_penalty) 243 | worst_finished_score = torch.min(state.finish[2]) 244 | cond = torch.gt(worst_finished_score, best_alive_score) 245 | is_finished = bool(cond) 246 | 247 | if is_finished: 248 | break 249 | 250 | final_state = state 251 | alive_seqs = final_state.inputs[0] 252 | alive_scores = final_state.inputs[2] 253 | final_flags = final_state.finish[0].byte() 254 | final_seqs = final_state.finish[1] 255 | final_scores = final_state.finish[2] 256 | 257 | final_seqs = torch.where(final_flags[:, :, None], final_seqs, alive_seqs) 258 | final_scores = torch.where(final_flags, final_scores, alive_scores) 259 | 260 | # Append extra 261 | final_seqs = torch.nn.functional.pad(final_seqs, (0, 1, 0, 0, 0, 0), 262 | value=eos_id) 263 | 264 | return final_seqs[:, :top_beams, 1:], final_scores[:, :top_beams] 265 | 266 | 267 | def argmax_decoding(models, features, params): 268 | if not isinstance(models, (list, tuple)): 269 | raise ValueError("'models' must be a list or tuple") 270 | 271 | # Compute initial state if necessary 272 | log_probs = [] 273 | shape = features["target"].shape 274 | device = features["target"].device 275 | batch_size = features["target"].shape[0] 276 | target_mask = features["target_mask"] 277 | target_length = target_mask.sum(1).long() 278 | eos_id = params.vocabulary["target"][params.eos] 279 | 280 | for model in models: 281 | state = model.empty_state(batch_size, device) 282 | state = model.encode(features, state) 283 | logits, _ = model.decode(features, state, "eval") 284 | log_probs.append(torch.nn.functional.log_softmax(logits, dim=-1)) 285 | 286 | log_prob = sum(log_probs) / len(models) 287 | ret = torch.max(log_prob, -1) 288 | values = torch.reshape(ret.values, shape) 289 | indices = torch.reshape(ret.indices, shape) 290 | 291 | batch_pos = torch.arange(batch_size, device=device) 292 | seq_pos = target_length - 1 293 | indices[batch_pos, seq_pos] = eos_id 294 | 295 | return indices[:, None, :], torch.sum(values * target_mask, -1)[:, None] 296 | -------------------------------------------------------------------------------- /thumt/utils/misc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | _GLOBAL_STEP = 0 9 | 10 | 11 | def get_global_step(): 12 | return _GLOBAL_STEP 13 | 14 | 15 | def set_global_step(step): 16 | global _GLOBAL_STEP 17 | _GLOBAL_STEP = step 18 | -------------------------------------------------------------------------------- /thumt/utils/nest.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | # Modified from TensorFlow 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import collections 10 | import six 11 | 12 | 13 | def _sorted(dict_): 14 | try: 15 | return sorted(six.iterkeys(dict_)) 16 | except TypeError: 17 | raise TypeError("nest only supports dicts with sortable keys.") 18 | 19 | 20 | def _sequence_like(instance, args): 21 | if isinstance(instance, dict): 22 | result = dict(zip(_sorted(instance), args)) 23 | return type(instance)((key, result[key]) 24 | for key in six.iterkeys(instance)) 25 | elif (isinstance(instance, tuple) and 26 | hasattr(instance, "_fields") and 27 | isinstance(instance._fields, collections.Sequence) and 28 | all(isinstance(f, six.string_types) for f in instance._fields)): 29 | # This is a namedtuple 30 | return type(instance)(*args) 31 | else: 32 | # Not a namedtuple 33 | return type(instance)(args) 34 | 35 | 36 | def _yield_value(iterable): 37 | if isinstance(iterable, dict): 38 | for key in _sorted(iterable): 39 | yield iterable[key] 40 | else: 41 | for value in iterable: 42 | yield value 43 | 44 | 45 | def _yield_flat_nest(nest): 46 | for n in _yield_value(nest): 47 | if is_sequence(n): 48 | for ni in _yield_flat_nest(n): 49 | yield ni 50 | else: 51 | yield n 52 | 53 | 54 | def is_sequence(seq): 55 | if isinstance(seq, dict): 56 | return True 57 | if isinstance(seq, set): 58 | print("Sets are not currently considered sequences, but this may " 59 | "change in the future, so consider avoiding using them.") 60 | return (isinstance(seq, collections.Sequence) 61 | and not isinstance(seq, six.string_types)) 62 | 63 | 64 | def flatten(nest): 65 | if is_sequence(nest): 66 | return list(_yield_flat_nest(nest)) 67 | else: 68 | return [nest] 69 | 70 | 71 | def _recursive_assert_same_structure(nest1, nest2, check_types): 72 | is_sequence_nest1 = is_sequence(nest1) 73 | if is_sequence_nest1 != is_sequence(nest2): 74 | raise ValueError( 75 | "The two structures don't have the same nested structure.\n\n" 76 | "First structure: %s\n\nSecond structure: %s." % (nest1, nest2)) 77 | 78 | if not is_sequence_nest1: 79 | return # finished checking 80 | 81 | if check_types: 82 | type_nest1 = type(nest1) 83 | type_nest2 = type(nest2) 84 | if type_nest1 != type_nest2: 85 | raise TypeError( 86 | "The two structures don't have the same sequence type. First " 87 | "structure has type %s, while second structure has type %s." 88 | % (type_nest1, type_nest2)) 89 | 90 | if isinstance(nest1, dict): 91 | keys1 = set(six.iterkeys(nest1)) 92 | keys2 = set(six.iterkeys(nest2)) 93 | if keys1 != keys2: 94 | raise ValueError( 95 | "The two dictionaries don't have the same set of keys. " 96 | "First structure has keys {}, while second structure has" 97 | " keys {}.".format(keys1, keys2)) 98 | 99 | nest1_as_sequence = [n for n in _yield_value(nest1)] 100 | nest2_as_sequence = [n for n in _yield_value(nest2)] 101 | for n1, n2 in zip(nest1_as_sequence, nest2_as_sequence): 102 | _recursive_assert_same_structure(n1, n2, check_types) 103 | 104 | 105 | def assert_same_structure(nest1, nest2, check_types=True): 106 | len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1 107 | len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1 108 | if len_nest1 != len_nest2: 109 | raise ValueError("The two structures don't have the same number of " 110 | "elements.\n\nFirst structure (%i elements): %s\n\n" 111 | "Second structure (%i elements): %s" 112 | % (len_nest1, nest1, len_nest2, nest2)) 113 | _recursive_assert_same_structure(nest1, nest2, check_types) 114 | 115 | 116 | def flatten_dict_items(dictionary): 117 | if not isinstance(dictionary, dict): 118 | raise TypeError("input must be a dictionary") 119 | flat_dictionary = {} 120 | for i, v in six.iteritems(dictionary): 121 | if not is_sequence(i): 122 | if i in flat_dictionary: 123 | raise ValueError( 124 | "Could not flatten dictionary: key %s is not unique." % i) 125 | flat_dictionary[i] = v 126 | else: 127 | flat_i = flatten(i) 128 | flat_v = flatten(v) 129 | if len(flat_i) != len(flat_v): 130 | raise ValueError( 131 | "Could not flatten dictionary. Key had %d elements, but" 132 | " value had %d elements. Key: %s, value: %s." 133 | % (len(flat_i), len(flat_v), flat_i, flat_v)) 134 | for new_i, new_v in zip(flat_i, flat_v): 135 | if new_i in flat_dictionary: 136 | raise ValueError( 137 | "Could not flatten dictionary: key %s is not unique." 138 | % (new_i)) 139 | flat_dictionary[new_i] = new_v 140 | return flat_dictionary 141 | 142 | 143 | def _packed_nest_with_indices(structure, flat, index): 144 | packed = [] 145 | for s in _yield_value(structure): 146 | if is_sequence(s): 147 | new_index, child = _packed_nest_with_indices(s, flat, index) 148 | packed.append(_sequence_like(s, child)) 149 | index = new_index 150 | else: 151 | packed.append(flat[index]) 152 | index += 1 153 | return index, packed 154 | 155 | 156 | def pack_sequence_as(structure, flat_sequence): 157 | if not is_sequence(flat_sequence): 158 | raise TypeError("flat_sequence must be a sequence") 159 | 160 | if not is_sequence(structure): 161 | if len(flat_sequence) != 1: 162 | raise ValueError("Structure is a scalar but len(flat_sequence) ==" 163 | " %d > 1" % len(flat_sequence)) 164 | return flat_sequence[0] 165 | 166 | flat_structure = flatten(structure) 167 | if len(flat_structure) != len(flat_sequence): 168 | raise ValueError( 169 | "Could not pack sequence. Structure had %d elements, but " 170 | "flat_sequence had %d elements. Structure: %s, flat_sequence: %s." 171 | % (len(flat_structure), len(flat_sequence), structure, 172 | flat_sequence)) 173 | 174 | _, packed = _packed_nest_with_indices(structure, flat_sequence, 0) 175 | return _sequence_like(structure, packed) 176 | 177 | 178 | def map_structure(func, *structure, **check_types_dict): 179 | if not callable(func): 180 | raise TypeError("func must be callable, got: %s" % func) 181 | 182 | if not structure: 183 | raise ValueError("Must provide at least one structure") 184 | 185 | if check_types_dict: 186 | if "check_types" not in check_types_dict or len(check_types_dict) > 1: 187 | raise ValueError("Only valid keyword argument is check_types") 188 | check_types = check_types_dict["check_types"] 189 | else: 190 | check_types = True 191 | 192 | for other in structure[1:]: 193 | assert_same_structure(structure[0], other, check_types=check_types) 194 | 195 | flat_structure = [flatten(s) for s in structure] 196 | entries = zip(*flat_structure) 197 | 198 | return pack_sequence_as( 199 | structure[0], [func(*x) for x in entries]) 200 | -------------------------------------------------------------------------------- /thumt/utils/scope.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | # Modified from TensorFlow (tf.name_scope) 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import re 10 | import contextlib 11 | 12 | # global variable 13 | _NAME_STACK = "" 14 | _NAMES_IN_USE = {} 15 | _VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*$") 16 | _VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/]*$") 17 | 18 | 19 | def unique_name(name, mark_as_used=True): 20 | global _NAME_STACK 21 | 22 | if _NAME_STACK: 23 | name = _NAME_STACK + "/" + name 24 | 25 | i = _NAMES_IN_USE.get(name, 0) 26 | 27 | if mark_as_used: 28 | _NAMES_IN_USE[name] = i + 1 29 | 30 | if i > 0: 31 | base_name = name 32 | 33 | while name in _NAMES_IN_USE: 34 | name = "%s_%d" % (base_name, i) 35 | i += 1 36 | 37 | if mark_as_used: 38 | _NAMES_IN_USE[name] = 1 39 | 40 | return name 41 | 42 | 43 | @contextlib.contextmanager 44 | def scope(name): 45 | global _NAME_STACK 46 | 47 | if name: 48 | if _NAME_STACK: 49 | # check name 50 | if not _VALID_SCOPE_NAME_REGEX.match(name): 51 | raise ValueError("'%s' is not a valid scope name" % name) 52 | else: 53 | # check name strictly 54 | if not _VALID_OP_NAME_REGEX.match(name): 55 | raise ValueError("'%s' is not a valid scope name" % name) 56 | 57 | try: 58 | old_stack = _NAME_STACK 59 | 60 | if not name: 61 | new_stack = None 62 | elif name and name[-1] == "/": 63 | new_stack = name[:-1] 64 | else: 65 | new_stack = unique_name(name) 66 | 67 | _NAME_STACK = new_stack 68 | 69 | yield "" if new_stack is None else new_stack + "/" 70 | finally: 71 | _NAME_STACK = old_stack 72 | 73 | 74 | def get_scope(): 75 | return _NAME_STACK 76 | -------------------------------------------------------------------------------- /thumt/utils/summary.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import queue 9 | import threading 10 | import torch 11 | 12 | import torch.distributed as dist 13 | import torch.utils.tensorboard as tensorboard 14 | 15 | _SUMMARY_WRITER = None 16 | _QUEUE = None 17 | _THREAD = None 18 | 19 | 20 | class SummaryWorker(threading.Thread): 21 | 22 | def run(self): 23 | global _QUEUE 24 | 25 | while True: 26 | item = _QUEUE.get() 27 | name, kwargs = item 28 | 29 | if name == "stop": 30 | break 31 | 32 | self.write_summary(name, **kwargs) 33 | 34 | def write_summary(self, name, **kwargs): 35 | if name == "scalar": 36 | _SUMMARY_WRITER.add_scalar(**kwargs) 37 | elif name == "histogram": 38 | _SUMMARY_WRITER.add_histogram(**kwargs) 39 | 40 | def stop(self): 41 | global _QUEUE 42 | _QUEUE.put(("stop", None)) 43 | self.join() 44 | 45 | 46 | def init(log_dir, enable=True): 47 | global _SUMMARY_WRITER 48 | global _QUEUE 49 | global _THREAD 50 | 51 | if enable and dist.get_rank() == 0: 52 | _SUMMARY_WRITER = tensorboard.SummaryWriter(log_dir) 53 | _QUEUE = queue.Queue() 54 | thread = SummaryWorker(daemon=True) 55 | thread.start() 56 | _THREAD = thread 57 | 58 | 59 | def scalar(tag, scalar_value, global_step=None, walltime=None, 60 | write_every_n_steps=100): 61 | 62 | if _SUMMARY_WRITER is not None: 63 | if global_step % write_every_n_steps == 0: 64 | scalar_value = float(scalar_value) 65 | kwargs = dict(tag=tag, scalar_value=scalar_value, 66 | global_step=global_step, walltime=walltime) 67 | _QUEUE.put(("scalar", kwargs)) 68 | 69 | 70 | def histogram(tag, values, global_step=None, bins="tensorflow", walltime=None, 71 | max_bins=None, write_every_n_steps=100): 72 | 73 | if _SUMMARY_WRITER is not None: 74 | if global_step % write_every_n_steps == 0: 75 | values = values.detach().cpu() 76 | kwargs = dict(tag=tag, values=values, global_step=global_step, 77 | bins=bins, walltime=walltime, max_bins=max_bins) 78 | _QUEUE.put(("histogram", kwargs)) 79 | 80 | 81 | def close(): 82 | if _SUMMARY_WRITER is not None: 83 | _THREAD.stop() 84 | _SUMMARY_WRITER.close() 85 | --------------------------------------------------------------------------------