├── LICENSE ├── README.md ├── results ├── conll05 │ ├── ensemble │ │ ├── conll05.brown.result │ │ ├── conll05.dev.result │ │ └── conll05.wsj.result │ └── single │ │ ├── conll05.brown.result │ │ ├── conll05.dev.result │ │ └── conll05.wsj.result └── conll12 │ ├── ensemble │ ├── conll12.dev.result │ └── conll12.test.result │ └── single │ ├── conll12.devel.result │ └── conll12.test.result └── tagger ├── __init__.py ├── bin ├── predictor.py └── trainer.py ├── data ├── __init__.py ├── dataset.py ├── embedding.py └── vocab.py ├── models ├── __init__.py └── deepatt.py ├── modules ├── __init__.py ├── affine.py ├── attention.py ├── embedding.py ├── feed_forward.py ├── layer_norm.py ├── losses.py ├── module.py └── recurrent.py ├── optimizers ├── __init__.py ├── clipping.py ├── optimizers.py └── schedules.py ├── scripts ├── build_vocab.py └── convert_to_conll.py └── utils ├── __init__.py ├── checkpoint.py ├── hparams.py ├── misc.py ├── scope.py ├── summary.py └── validation.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Natural Language Processing Lab at Xiamen University 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, this 11 | list of conditions and the following disclaimer in the documentation and/or 12 | other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from this 16 | software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tagger 2 | 3 | This is the source code for the paper "[Deep Semantic Role Labeling with Self-Attention](https://arxiv.org/abs/1712.01586)". 4 | 5 | ## Contents 6 | 7 | * [Basics](#basics) 8 | * [Notice](#notice) 9 | * [Prerequisites](#prerequisites) 10 | * [Walkthrough](#walkthrough) 11 | * [Data](#data) 12 | * [Training](#training) 13 | * [Decoding](#decoding) 14 | * [Benchmarks](#benchmarks) 15 | * [Pretrained Models](#pretrained-models) 16 | * [License](#license) 17 | * [Citation](#citation) 18 | * [Contact](#contact) 19 | 20 | ## Basics 21 | 22 | ### Notice 23 | 24 | The original code used in the paper is implemented using TensorFlow 1.0, which is obsolete now. We have re-implemented our methods using PyTorch, which is based on [THUMT](https://github.com/THUNLP-MT/THUMT). The differences are as follows: 25 | 26 | * We only implement DeepAtt-FFN model 27 | * Model ensemble are currently not available 28 | 29 | Please check the git history to use TensorFlow implementation. 30 | 31 | ### Prerequisites 32 | 33 | * Python 3 34 | * PyTorch 35 | * TensorFlow-2.0 (CPU version) 36 | * GloVe embeddings and `srlconll` scripts 37 | 38 | ## Walkthrough 39 | 40 | ### Data 41 | 42 | #### Training Data 43 | 44 | We follow the same procedures described in the [deep_srl](https://github.com/luheng/deep_srl) repository to convert the CoNLL datasets. 45 | The GloVe embeddings and `srlconll` scripts can also be found in that link. 46 | 47 | If you followed these procedures, you can find that the processed data has the following format: 48 | ``` 49 | 2 My cats love hats . ||| B-A0 I-A0 B-V B-A1 O 50 | ``` 51 | 52 | *The CoNLL datasets are not publicly available. We cannot provide these datasets.* 53 | 54 | #### Vocabulary 55 | 56 | You can use the `build_vocab.py` script to generate vocabularies. The command is described as follows: 57 | 58 | ```[bash] 59 | python tagger/scripts/build_vocab.py --limit LIMIT --lower TRAIN_FILE OUTPUT_DIR 60 | ``` 61 | 62 | where `LIMIT` specifies the vocabulary size. This command will create two vocabularies named `vocab.txt` and `label.txt` in the `OUTPUT_DIR`. 63 | 64 | ### Training 65 | 66 | Once you finished the procedures described above, you can start the training stage. 67 | 68 | #### Preparing the validation script 69 | 70 | An external validation script is required to enable the validation functionality. 71 | Here's the validation script we used to train an FFN model on the CoNLL-2005 dataset. 72 | Please make sure that the validation script can run properly. 73 | 74 | ```[bash] 75 | #!/usr/bin/env bash 76 | SRLPATH=/PATH/TO/SRLCONLL 77 | TAGGERPATH=/PATH/TO/TAGGER 78 | DATAPATH=/PATH/TO/DATA 79 | EMBPATH=/PATH/TO/GLOVE_EMBEDDING 80 | DEVICE=0 81 | 82 | export PYTHONPATH=$TAGGERPATH:$PYTHONPATH 83 | export PERL5LIB="$SRLPATH/lib:$PERL5LIB" 84 | export PATH="$SRLPATH/bin:$PATH" 85 | 86 | python $TAGGERPATH/tagger/bin/predictor.py \ 87 | --input $DATAPATH/conll05.devel.txt \ 88 | --checkpoint train \ 89 | --model deepatt \ 90 | --vocab $DATAPATH/deep_srl/word_dict $DATAPATH/deep_srl/label_dict \ 91 | --parameters=device=$DEVICE,embedding=$EMBPATH/glove.6B.100d.txt \ 92 | --output tmp.txt 93 | 94 | python $TAGGERPATH/tagger/scripts/convert_to_conll.py tmp.txt $DATAPATH/conll05.devel.props.gold.txt output 95 | perl $SRLPATH/bin/srl-eval.pl $DATAPATH/conll05.devel.props.* output 96 | ``` 97 | 98 | #### Training command 99 | 100 | The command below is what we used to train a model on the CoNLL-2005 dataset. The content of `run.sh` is described in the above section. 101 | 102 | ```[bash] 103 | #!/usr/bin/env bash 104 | SRLPATH=/PATH/TO/SRLCONLL 105 | TAGGERPATH=/PATH/TO/TAGGER 106 | DATAPATH=/PATH/TO/DATA 107 | EMBPATH=/PATH/TO/GLOVE_EMBEDDING 108 | DEVICE=[0] 109 | 110 | export PYTHONPATH=$TAGGERPATH:$PYTHONPATH 111 | export PERL5LIB="$SRLPATH/lib:$PERL5LIB" 112 | export PATH="$SRLPATH/bin:$PATH" 113 | 114 | python $TAGGERPATH/tagger/bin/trainer.py \ 115 | --model deepatt \ 116 | --input $DATAPATH/conll05.train.txt \ 117 | --output train \ 118 | --vocabulary $DATAPATH/deep_srl/word_dict $DATAPATH/deep_srl/label_dict \ 119 | --parameters="save_summary=false,feature_size=100,hidden_size=200,filter_size=800,"` 120 | `"residual_dropout=0.2,num_hidden_layers=10,attention_dropout=0.1,"` 121 | `"relu_dropout=0.1,batch_size=4096,optimizer=adadelta,initializer=orthogonal,"` 122 | `"initializer_gain=1.0,train_steps=600000,"` 123 | `"learning_rate_schedule=piecewise_constant_decay,"` 124 | `"learning_rate_values=[1.0,0.5,0.25,],"` 125 | `"learning_rate_boundaries=[400000,50000],device_list=$DEVICE,"` 126 | `"clip_grad_norm=1.0,embedding=$EMBPATH/glove.6B.100d.txt,script=run.sh" 127 | ``` 128 | 129 | ### Decoding 130 | 131 | The following is the command used to generate outputs: 132 | 133 | ```[bash] 134 | #!/usr/bin/env bash 135 | SRLPATH=/PATH/TO/SRLCONLL 136 | TAGGERPATH=/PATH/TO/TAGGER 137 | DATAPATH=/PATH/TO/DATA 138 | EMBPATH=/PATH/TO/GLOVE_EMBEDDING 139 | DEVICE=0 140 | 141 | python $TAGGERPATH/tagger/bin/predictor.py \ 142 | --input $DATAPATH/conll05.test.wsj.txt \ 143 | --checkpoint train/best \ 144 | --model deepatt \ 145 | --vocab $DATAPATH/deep_srl/word_dict $DATAPATH/deep_srl/label_dict \ 146 | --parameters=device=$DEVICE,embedding=$EMBPATH/glove.6B.100d.txt \ 147 | --output tmp.txt 148 | 149 | ``` 150 | 151 | ## Benchmarks 152 | 153 | We've performed 4 runs on CoNLL-05 datasets. The results are shown below. 154 | 155 | | Runs | Dev-P | Dev-R | Dev-F1 | WSJ-P | WSJ-R | WSJ-F1 | BROWN-P | BROWN-R | BROWN-F1 | 156 | | :----: | :---: | :---: | :----: | :---: | :---: | :----: | :-----: | :-----: | :------: | 157 | | Paper | 82.6 | 83.6 | 83.1 | 84.5 | 85.2 | 84.8 | 73.5 | 74.6 | 74.1 | 158 | | Run0 | 82.9 | 83.7 | 83.3 | 84.6 | 85.0 | 84.8 | 73.5 | 74.0 | 73.8 | 159 | | Run1 | 82.3 | 83.4 | 82.9 | 84.4 | 85.3 | 84.8 | 72.5 | 73.9 | 73.2 | 160 | | Run2 | 82.7 | 83.6 | 83.2 | 84.8 | 85.4 | 85.1 | 73.2 | 73.9 | 73.6 | 161 | | Run3 | 82.3 | 83.6 | 82.9 | 84.3 | 84.9 | 84.6 | 72.3 | 73.6 | 72.9 | 162 | 163 | ## Pretrained Models 164 | 165 | The pretrained models of TensorFlow implementation can be downloaded at [Google Drive](https://drive.google.com/open?id=1jvBlpOmqGdZEqnFrdWJkH1xHsGU2OjiP). 166 | 167 | ## LICENSE 168 | 169 | BSD 170 | 171 | ## Citation 172 | 173 | If you use our codes, please cite our paper: 174 | 175 | ``` 176 | @inproceedings{tan2018deep, 177 | title = {Deep Semantic Role Labeling with Self-Attention}, 178 | author = {Tan, Zhixing and Wang, Mingxuan and Xie, Jun and Chen, Yidong and Shi, Xiaodong}, 179 | booktitle = {AAAI Conference on Artificial Intelligence}, 180 | year = {2018} 181 | } 182 | ``` 183 | 184 | ## Contact 185 | 186 | This code is written by Zhixing Tan. If you have any problems, feel free to send an email. 187 | -------------------------------------------------------------------------------- /tagger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XMUNLP/Tagger/02e1fd323ac747bfe5f7b8824c6b416fd90f33a1/tagger/__init__.py -------------------------------------------------------------------------------- /tagger/bin/predictor.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import argparse 9 | import logging 10 | import os 11 | import six 12 | import time 13 | import torch 14 | 15 | import tagger.data as data 16 | import tagger.models as models 17 | import tagger.utils as utils 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser( 22 | description="Predict using SRL models", 23 | usage="translator.py [] [-h | --help]" 24 | ) 25 | 26 | # input files 27 | parser.add_argument("--input", type=str, required=True, 28 | help="Path of input file") 29 | parser.add_argument("--output", type=str, required=True, 30 | help="Path of output file") 31 | parser.add_argument("--checkpoint", type=str, required=True, 32 | help="Path of trained models") 33 | parser.add_argument("--vocabulary", type=str, nargs=2, required=True, 34 | help="Path of source and target vocabulary") 35 | 36 | # model and configuration 37 | parser.add_argument("--model", type=str, required=True, 38 | help="Name of the model") 39 | parser.add_argument("--parameters", type=str, default="", 40 | help="Additional hyper parameters") 41 | parser.add_argument("--half", action="store_true", 42 | help="Use half precision for decoding") 43 | 44 | return parser.parse_args() 45 | 46 | 47 | def default_params(): 48 | params = utils.HParams( 49 | input=None, 50 | output=None, 51 | vocabulary=None, 52 | embedding="", 53 | # vocabulary specific 54 | pad="", 55 | bos="", 56 | eos="", 57 | unk="", 58 | device=0, 59 | decode_batch_size=128 60 | ) 61 | 62 | return params 63 | 64 | 65 | def merge_params(params1, params2): 66 | params = utils.HParams() 67 | 68 | for (k, v) in six.iteritems(params1.values()): 69 | params.add_hparam(k, v) 70 | 71 | params_dict = params.values() 72 | 73 | for (k, v) in six.iteritems(params2.values()): 74 | if k in params_dict: 75 | # Override 76 | setattr(params, k, v) 77 | else: 78 | params.add_hparam(k, v) 79 | 80 | return params 81 | 82 | 83 | def import_params(model_dir, model_name, params): 84 | model_dir = os.path.abspath(model_dir) 85 | m_name = os.path.join(model_dir, model_name + ".json") 86 | 87 | if not os.path.exists(m_name): 88 | return params 89 | 90 | with open(m_name) as fd: 91 | logging.info("Restoring model parameters from %s" % m_name) 92 | json_str = fd.readline() 93 | params.parse_json(json_str) 94 | 95 | return params 96 | 97 | 98 | def override_params(params, args): 99 | params.parse(args.parameters) 100 | 101 | src_vocab, src_w2idx, src_idx2w = data.load_vocabulary(args.vocabulary[0]) 102 | tgt_vocab, tgt_w2idx, tgt_idx2w = data.load_vocabulary(args.vocabulary[1]) 103 | 104 | params.vocabulary = { 105 | "source": src_vocab, "target": tgt_vocab 106 | } 107 | params.lookup = { 108 | "source": src_w2idx, "target": tgt_w2idx 109 | } 110 | params.mapping = { 111 | "source": src_idx2w, "target": tgt_idx2w 112 | } 113 | 114 | return params 115 | 116 | 117 | def convert_to_string(inputs, tensor, params): 118 | inputs = torch.squeeze(inputs) 119 | inputs = inputs.tolist() 120 | tensor = torch.squeeze(tensor, dim=1) 121 | tensor = tensor.tolist() 122 | decoded = [] 123 | 124 | for wids, lids in zip(inputs, tensor): 125 | output = [] 126 | for wid, lid in zip(wids, lids): 127 | if wid == 0: 128 | break 129 | output.append(params.mapping["target"][lid]) 130 | decoded.append(b" ".join(output)) 131 | 132 | return decoded 133 | 134 | 135 | def main(args): 136 | # Load configs 137 | model_cls = models.get_model(args.model) 138 | params = default_params() 139 | params = merge_params(params, model_cls.default_params()) 140 | params = import_params(args.checkpoint, args.model, params) 141 | params = override_params(params, args) 142 | torch.cuda.set_device(params.device) 143 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 144 | 145 | # Create model 146 | with torch.no_grad(): 147 | model = model_cls(params).cuda() 148 | 149 | if args.half: 150 | model = model.half() 151 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 152 | 153 | model.eval() 154 | model.load_state_dict( 155 | torch.load(utils.best_checkpoint(args.checkpoint), 156 | map_location="cpu")["model"]) 157 | 158 | # Decoding 159 | dataset = data.get_dataset(args.input, "infer", params) 160 | fd = open(args.output, "wb") 161 | counter = 0 162 | 163 | if params.embedding is not None: 164 | embedding = data.load_glove_embedding(params.embedding) 165 | else: 166 | embedding = None 167 | 168 | for features in dataset: 169 | t = time.time() 170 | counter += 1 171 | features = data.lookup(features, "infer", params, embedding) 172 | 173 | labels = model.argmax_decode(features) 174 | batch = convert_to_string(features["inputs"], labels, params) 175 | 176 | for seq in batch: 177 | fd.write(seq) 178 | fd.write(b"\n") 179 | 180 | t = time.time() - t 181 | print("Finished batch: %d (%.3f sec)" % (counter, t)) 182 | 183 | fd.close() 184 | 185 | 186 | if __name__ == "__main__": 187 | main(parse_args()) 188 | -------------------------------------------------------------------------------- /tagger/bin/trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import argparse 9 | import copy 10 | import glob 11 | import logging 12 | import os 13 | import re 14 | import six 15 | import socket 16 | import threading 17 | import time 18 | import torch 19 | 20 | import tagger.data as data 21 | import torch.distributed as dist 22 | import tagger.models as models 23 | import tagger.optimizers as optimizers 24 | import tagger.utils as utils 25 | import tagger.utils.summary as summary 26 | from tagger.utils.validation import ValidationWorker 27 | 28 | 29 | def parse_args(args=None): 30 | parser = argparse.ArgumentParser( 31 | description="Training SRL tagger", 32 | usage="trainer.py [] [-h | --help]" 33 | ) 34 | 35 | # input files 36 | parser.add_argument("--input", type=str, 37 | help="Path of the training corpus") 38 | parser.add_argument("--output", type=str, default="train", 39 | help="Path to saved models") 40 | parser.add_argument("--vocabulary", type=str, nargs=2, 41 | help="Path of source and target vocabulary") 42 | parser.add_argument("--checkpoint", type=str, 43 | help="Path to pre-trained checkpoint") 44 | parser.add_argument("--distributed", action="store_true", 45 | help="Enable distributed training mode") 46 | parser.add_argument("--local_rank", type=int, 47 | help="Local rank of this process") 48 | parser.add_argument("--half", action="store_true", 49 | help="Enable mixed precision training") 50 | parser.add_argument("--hparam_set", type=str, 51 | help="Name of pre-defined hyper parameter set") 52 | 53 | # model and configuration 54 | parser.add_argument("--model", type=str, required=True, 55 | help="Name of the model") 56 | parser.add_argument("--parameters", type=str, default="", 57 | help="Additional hyper parameters") 58 | 59 | return parser.parse_args(args) 60 | 61 | 62 | def default_params(): 63 | params = utils.HParams( 64 | input="", 65 | output="", 66 | model="transformer", 67 | vocab=["", ""], 68 | pad="", 69 | bos="", 70 | eos="", 71 | unk="", 72 | # Dataset 73 | batch_size=4096, 74 | fixed_batch_size=False, 75 | min_length=1, 76 | max_length=256, 77 | buffer_size=10000, 78 | # Initialization 79 | initializer_gain=1.0, 80 | initializer="uniform_unit_scaling", 81 | # Regularization 82 | scale_l1=0.0, 83 | scale_l2=0.0, 84 | # Training 85 | script="", 86 | warmup_steps=4000, 87 | train_steps=100000, 88 | update_cycle=1, 89 | optimizer="Adam", 90 | adam_beta1=0.9, 91 | adam_beta2=0.999, 92 | adam_epsilon=1e-8, 93 | adadelta_rho=0.95, 94 | adadelta_epsilon=1e-6, 95 | clipping="global_norm", 96 | clip_grad_norm=5.0, 97 | learning_rate=1.0, 98 | learning_rate_schedule="linear_warmup_rsqrt_decay", 99 | learning_rate_boundaries=[0], 100 | learning_rate_values=[0.0], 101 | device_list=[0], 102 | embedding="", 103 | # Validation 104 | keep_top_k=50, 105 | frequency=10, 106 | # Checkpoint Saving 107 | keep_checkpoint_max=20, 108 | keep_top_checkpoint_max=5, 109 | save_summary=True, 110 | save_checkpoint_secs=0, 111 | save_checkpoint_steps=1000, 112 | ) 113 | 114 | return params 115 | 116 | 117 | def import_params(model_dir, model_name, params): 118 | model_dir = os.path.abspath(model_dir) 119 | p_name = os.path.join(model_dir, "params.json") 120 | m_name = os.path.join(model_dir, model_name + ".json") 121 | 122 | if not os.path.exists(p_name) or not os.path.exists(m_name): 123 | return params 124 | 125 | with open(p_name) as fd: 126 | logging.info("Restoring hyper parameters from %s" % p_name) 127 | json_str = fd.readline() 128 | params.parse_json(json_str) 129 | 130 | with open(m_name) as fd: 131 | logging.info("Restoring model parameters from %s" % m_name) 132 | json_str = fd.readline() 133 | params.parse_json(json_str) 134 | 135 | return params 136 | 137 | 138 | def export_params(output_dir, name, params): 139 | if not os.path.exists(output_dir): 140 | os.makedirs(output_dir) 141 | 142 | # Save params as params.json 143 | filename = os.path.join(output_dir, name) 144 | 145 | with open(filename, "w") as fd: 146 | fd.write(params.to_json()) 147 | 148 | 149 | def merge_params(params1, params2): 150 | params = utils.HParams() 151 | 152 | for (k, v) in six.iteritems(params1.values()): 153 | params.add_hparam(k, v) 154 | 155 | params_dict = params.values() 156 | 157 | for (k, v) in six.iteritems(params2.values()): 158 | if k in params_dict: 159 | # Override 160 | setattr(params, k, v) 161 | else: 162 | params.add_hparam(k, v) 163 | 164 | return params 165 | 166 | 167 | def override_params(params, args): 168 | params.model = args.model or params.model 169 | params.input = args.input or params.input 170 | params.output = args.output or params.output 171 | params.vocab = args.vocabulary or params.vocab 172 | params.parse(args.parameters) 173 | 174 | src_vocab, src_w2idx, src_idx2w = data.load_vocabulary(params.vocab[0]) 175 | tgt_vocab, tgt_w2idx, tgt_idx2w = data.load_vocabulary(params.vocab[1]) 176 | 177 | params.vocabulary = { 178 | "source": src_vocab, "target": tgt_vocab 179 | } 180 | params.lookup = { 181 | "source": src_w2idx, "target": tgt_w2idx 182 | } 183 | params.mapping = { 184 | "source": src_idx2w, "target": tgt_idx2w 185 | } 186 | 187 | return params 188 | 189 | 190 | def collect_params(all_params, params): 191 | collected = utils.HParams() 192 | 193 | for k in six.iterkeys(params.values()): 194 | collected.add_hparam(k, getattr(all_params, k)) 195 | 196 | return collected 197 | 198 | 199 | def print_variables(model): 200 | weights = {v[0]: v[1] for v in model.named_parameters()} 201 | total_size = 0 202 | 203 | for name in sorted(list(weights)): 204 | v = weights[name] 205 | print("%s %s" % (name.ljust(60), str(list(v.shape)).rjust(15))) 206 | total_size += v.nelement() 207 | 208 | print("Total trainable variables size: %d" % total_size) 209 | 210 | 211 | def save_checkpoint(step, epoch, model, optimizer, params): 212 | if dist.get_rank() == 0: 213 | state = { 214 | "step": step, 215 | "epoch": epoch, 216 | "model": model.state_dict(), 217 | "optimizer": optimizer.state_dict() 218 | } 219 | utils.save(state, params.output, params.keep_checkpoint_max) 220 | 221 | 222 | def infer_gpu_num(param_str): 223 | result = re.match(r".*device_list=\[(.*?)\].*", param_str) 224 | 225 | if not result: 226 | return 1 227 | else: 228 | dev_str = result.groups()[-1] 229 | return len(dev_str.split(",")) 230 | 231 | 232 | def get_clipper(params): 233 | if params.clipping.lower() == "none": 234 | clipper = None 235 | elif params.clipping.lower() == "adaptive": 236 | clipper = optimizers.adaptive_clipper(0.95) 237 | elif params.clipping.lower() == "global_norm": 238 | clipper = optimizers.global_norm_clipper(params.clip_grad_norm) 239 | else: 240 | raise ValueError("Unknown clipper %s" % params.clipping) 241 | 242 | return clipper 243 | 244 | 245 | def get_learning_rate_schedule(params): 246 | if params.learning_rate_schedule == "linear_warmup_rsqrt_decay": 247 | schedule = optimizers.LinearWarmupRsqrtDecay(params.learning_rate, 248 | params.warmup_steps) 249 | elif params.learning_rate_schedule == "piecewise_constant_decay": 250 | schedule = optimizers.PiecewiseConstantDecay( 251 | params.learning_rate_boundaries, params.learning_rate_values) 252 | elif params.learning_rate_schedule == "linear_exponential_decay": 253 | schedule = optimizers.LinearExponentialDecay(params.learning_rate, 254 | params.warmup_steps, params.start_decay_step, 255 | params.end_decay_step, 256 | dist.get_world_size()) 257 | else: 258 | raise ValueError("Unknown schedule %s" % params.learning_rate_schedule) 259 | 260 | return schedule 261 | 262 | 263 | def broadcast(model): 264 | for var in model.parameters(): 265 | dist.broadcast(var.data, 0) 266 | 267 | 268 | def main(args): 269 | model_cls = models.get_model(args.model) 270 | 271 | # Import and override parameters 272 | # Priorities (low -> high): 273 | # default -> saved -> command 274 | params = default_params() 275 | params = merge_params(params, model_cls.default_params(args.hparam_set)) 276 | params = import_params(args.output, args.model, params) 277 | params = override_params(params, args) 278 | 279 | # Initialize distributed utility 280 | if args.distributed: 281 | dist.init_process_group("nccl") 282 | torch.cuda.set_device(args.local_rank) 283 | else: 284 | dist.init_process_group("nccl", init_method=args.url, 285 | rank=args.local_rank, 286 | world_size=len(params.device_list)) 287 | torch.cuda.set_device(params.device_list[args.local_rank]) 288 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 289 | 290 | # Export parameters 291 | if dist.get_rank() == 0: 292 | export_params(params.output, "params.json", params) 293 | export_params(params.output, "%s.json" % params.model, 294 | collect_params(params, model_cls.default_params())) 295 | 296 | model = model_cls(params).cuda() 297 | model.load_embedding(params.embedding) 298 | 299 | if args.half: 300 | model = model.half() 301 | torch.set_default_dtype(torch.half) 302 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 303 | 304 | model.train() 305 | 306 | # Init tensorboard 307 | summary.init(params.output, params.save_summary) 308 | schedule = get_learning_rate_schedule(params) 309 | clipper = get_clipper(params) 310 | 311 | if params.optimizer.lower() == "adam": 312 | optimizer = optimizers.AdamOptimizer(learning_rate=schedule, 313 | beta_1=params.adam_beta1, 314 | beta_2=params.adam_beta2, 315 | epsilon=params.adam_epsilon, 316 | clipper=clipper) 317 | elif params.optimizer.lower() == "adadelta": 318 | optimizer = optimizers.AdadeltaOptimizer( 319 | learning_rate=schedule, rho=params.adadelta_rho, 320 | epsilon=params.adadelta_epsilon, clipper=clipper) 321 | else: 322 | raise ValueError("Unknown optimizer %s" % params.optimizer) 323 | 324 | if args.half: 325 | optimizer = optimizers.LossScalingOptimizer(optimizer) 326 | 327 | optimizer = optimizers.MultiStepOptimizer(optimizer, params.update_cycle) 328 | 329 | if dist.get_rank() == 0: 330 | print_variables(model) 331 | 332 | dataset = data.get_dataset(params.input, "train", params) 333 | 334 | # Load checkpoint 335 | checkpoint = utils.latest_checkpoint(params.output) 336 | 337 | if checkpoint is not None: 338 | state = torch.load(checkpoint, map_location="cpu") 339 | step = state["step"] 340 | epoch = state["epoch"] 341 | model.load_state_dict(state["model"]) 342 | 343 | if "optimizer" in state: 344 | optimizer.load_state_dict(state["optimizer"]) 345 | else: 346 | step = 0 347 | epoch = 0 348 | broadcast(model) 349 | 350 | def train_fn(inputs): 351 | features, labels = inputs 352 | loss = model(features, labels) 353 | return loss 354 | 355 | counter = 0 356 | should_save = False 357 | 358 | if params.script: 359 | thread = ValidationWorker(daemon=True) 360 | thread.init(params) 361 | thread.start() 362 | else: 363 | thread = None 364 | 365 | def step_fn(features, step): 366 | t = time.time() 367 | features = data.lookup(features, "train", params) 368 | loss = train_fn(features) 369 | gradients = optimizer.compute_gradients(loss, 370 | list(model.parameters())) 371 | if params.clip_grad_norm: 372 | torch.nn.utils.clip_grad_norm_(model.parameters(), 373 | params.clip_grad_norm) 374 | 375 | optimizer.apply_gradients(zip(gradients, 376 | list(model.named_parameters()))) 377 | 378 | t = time.time() - t 379 | 380 | summary.scalar("loss", loss, step, write_every_n_steps=1) 381 | summary.scalar("global_step/sec", t, step) 382 | 383 | print("epoch = %d, step = %d, loss = %.3f (%.3f sec)" % 384 | (epoch + 1, step, float(loss), t)) 385 | 386 | try: 387 | while True: 388 | for features in dataset: 389 | if counter % params.update_cycle == 0: 390 | step += 1 391 | utils.set_global_step(step) 392 | should_save = True 393 | 394 | counter += 1 395 | step_fn(features, step) 396 | 397 | if step % params.save_checkpoint_steps == 0: 398 | if should_save: 399 | save_checkpoint(step, epoch, model, optimizer, params) 400 | should_save = False 401 | 402 | if step >= params.train_steps: 403 | if should_save: 404 | save_checkpoint(step, epoch, model, optimizer, params) 405 | 406 | if dist.get_rank() == 0: 407 | summary.close() 408 | 409 | return 410 | 411 | epoch += 1 412 | finally: 413 | if thread is not None: 414 | thread.stop() 415 | thread.join() 416 | 417 | 418 | # Wrap main function 419 | def process_fn(rank, args): 420 | local_args = copy.copy(args) 421 | local_args.local_rank = rank 422 | main(local_args) 423 | 424 | 425 | if __name__ == "__main__": 426 | parsed_args = parse_args() 427 | 428 | if parsed_args.distributed: 429 | main(parsed_args) 430 | else: 431 | # Pick a free port 432 | with socket.socket() as s: 433 | s.bind(("localhost", 0)) 434 | port = s.getsockname()[1] 435 | url = "tcp://localhost:" + str(port) 436 | parsed_args.url = url 437 | 438 | world_size = infer_gpu_num(parsed_args.parameters) 439 | 440 | if world_size > 1: 441 | torch.multiprocessing.spawn(process_fn, args=(parsed_args,), 442 | nprocs=world_size) 443 | else: 444 | process_fn(0, parsed_args) 445 | -------------------------------------------------------------------------------- /tagger/data/__init__.py: -------------------------------------------------------------------------------- 1 | from tagger.data.dataset import get_dataset 2 | from tagger.data.vocab import load_vocabulary, lookup 3 | from tagger.data.embedding import load_glove_embedding 4 | -------------------------------------------------------------------------------- /tagger/data/dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import queue 9 | import torch 10 | import threading 11 | import tensorflow as tf 12 | 13 | 14 | _QUEUE = None 15 | _THREAD = None 16 | _LOCK = threading.Lock() 17 | 18 | 19 | def build_input_fn(filename, mode, params): 20 | def train_input_fn(): 21 | dataset = tf.data.TextLineDataset(filename) 22 | dataset = dataset.prefetch(params.buffer_size) 23 | dataset = dataset.shuffle(params.buffer_size) 24 | 25 | # Split "|||" 26 | dataset = dataset.map( 27 | lambda x: tf.strings.split([x], sep="|||", maxsplit=2), 28 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 29 | dataset = dataset.map( 30 | lambda x: (x.values[0], x.values[1]), 31 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 32 | dataset = dataset.map( 33 | lambda x, y: (tf.strings.split([x]).values, 34 | tf.strings.split([y]).values), 35 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 36 | dataset = dataset.map( 37 | lambda x, y: ({ 38 | "preds": tf.strings.to_number(x[0], tf.int32), 39 | "inputs": tf.strings.lower(x[1:]) 40 | }, y), 41 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 42 | dataset = dataset.map( 43 | lambda x, y: ({ 44 | "preds": tf.one_hot(x["preds"], tf.shape(x["inputs"])[0], 45 | dtype=tf.int32), 46 | "inputs": x["inputs"] 47 | }, y), 48 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 49 | 50 | def bucket_boundaries(max_length, min_length=8, step=8): 51 | x = min_length 52 | boundaries = [] 53 | 54 | while x <= max_length: 55 | boundaries.append(x + 1) 56 | x += step 57 | 58 | return boundaries 59 | 60 | batch_size = params.batch_size 61 | max_length = (params.max_length // 8) * 8 62 | min_length = params.min_length 63 | boundaries = bucket_boundaries(max_length) 64 | batch_sizes = [max(1, batch_size // (x - 1)) 65 | if not params.fixed_batch_size else batch_size 66 | for x in boundaries] + [1] 67 | 68 | def element_length_func(x, y): 69 | return tf.shape(x["inputs"])[0] 70 | 71 | def valid_size(x, y): 72 | size = element_length_func(x, y) 73 | return tf.logical_and(size >= min_length, size <= max_length) 74 | 75 | transformation_fn = tf.data.experimental.bucket_by_sequence_length( 76 | element_length_func, 77 | boundaries, 78 | batch_sizes, 79 | padded_shapes=({ 80 | "inputs": tf.TensorShape([None]), 81 | "preds": tf.TensorShape([None]), 82 | }, tf.TensorShape([None])), 83 | padding_values=({ 84 | "inputs": params.pad, 85 | "preds": 0, 86 | }, params.pad), 87 | pad_to_bucket_boundary=True) 88 | 89 | dataset = dataset.filter(valid_size) 90 | dataset = dataset.apply(transformation_fn) 91 | 92 | return dataset 93 | 94 | 95 | def infer_input_fn(): 96 | dataset = tf.data.TextLineDataset(filename) 97 | 98 | # Split "|||" 99 | dataset = dataset.map( 100 | lambda x: tf.strings.split([x], sep="|||", maxsplit=2), 101 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 102 | dataset = dataset.map( 103 | lambda x: (x.values[0], x.values[1]), 104 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 105 | dataset = dataset.map( 106 | lambda x, y: (tf.strings.split([x]).values, 107 | tf.strings.split([y]).values), 108 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 109 | dataset = dataset.map( 110 | lambda x, y: ({ 111 | "preds": tf.strings.to_number(x[0], tf.int32), 112 | "inputs": tf.strings.lower(x[1:]) 113 | }, y), 114 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 115 | dataset = dataset.map( 116 | lambda x, y: ({ 117 | "preds": tf.one_hot(x["preds"], tf.shape(x["inputs"])[0], 118 | dtype=tf.int32), 119 | "inputs": x["inputs"] 120 | }, y), 121 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 122 | 123 | dataset = dataset.padded_batch( 124 | params.decode_batch_size, 125 | padded_shapes=({ 126 | "inputs": tf.TensorShape([None]), 127 | "preds": tf.TensorShape([None]), 128 | }, tf.TensorShape([None])), 129 | padding_values=({ 130 | "inputs": params.pad, 131 | "preds": 0, 132 | }, params.pad), 133 | ) 134 | 135 | return dataset 136 | 137 | if mode == "train": 138 | return train_input_fn 139 | else: 140 | return infer_input_fn 141 | 142 | 143 | class DatasetWorker(threading.Thread): 144 | 145 | def init(self, dataset): 146 | self._dataset = dataset 147 | self._stop = False 148 | 149 | def run(self): 150 | global _QUEUE 151 | global _LOCK 152 | 153 | while not self._stop: 154 | for feature in self._dataset: 155 | _QUEUE.put(feature) 156 | 157 | def stop(self): 158 | self._stop = True 159 | 160 | 161 | class Dataset(object): 162 | 163 | def __iter__(self): 164 | return self 165 | 166 | def __next__(self): 167 | global _QUEUE 168 | return _QUEUE.get() 169 | 170 | def stop(self): 171 | global _THREAD 172 | _THREAD.stop() 173 | _THREAD.join() 174 | 175 | 176 | def get_dataset(filenames, mode, params): 177 | global _QUEUE 178 | global _THREAD 179 | 180 | input_fn = build_input_fn(filenames, mode, params) 181 | 182 | with tf.device("/cpu:0"): 183 | dataset = input_fn() 184 | 185 | if mode != "train": 186 | return dataset 187 | else: 188 | _QUEUE = queue.Queue(100) 189 | thread = DatasetWorker(daemon=True) 190 | thread.init(dataset) 191 | thread.start() 192 | _THREAD = thread 193 | return Dataset() 194 | -------------------------------------------------------------------------------- /tagger/data/embedding.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import numpy as np 9 | 10 | 11 | def load_glove_embedding(filename, vocab=None): 12 | fd = open(filename, "r") 13 | emb = {} 14 | fan_out = 0 15 | 16 | for line in fd: 17 | items = line.strip().split() 18 | word = items[0].encode("utf-8") 19 | value = [float(item) for item in items[1:]] 20 | fan_out = len(value) 21 | emb[word] = np.array(value, "float32") 22 | 23 | if not vocab: 24 | return emb 25 | 26 | ivoc = {} 27 | 28 | for item in vocab: 29 | ivoc[vocab[item]] = item 30 | 31 | new_emb = np.zeros([len(ivoc), fan_out], "float32") 32 | 33 | for i in ivoc: 34 | word = ivoc[i] 35 | if word not in emb: 36 | fan_in = len(ivoc) 37 | scale = 3.0 / max(1.0, (fan_in + fan_out) / 2.0) 38 | new_emb[i] = np.random.uniform(-scale, scale, [fan_out]) 39 | else: 40 | new_emb[i] = emb[word] 41 | 42 | return new_emb 43 | -------------------------------------------------------------------------------- /tagger/data/vocab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import torch 9 | import numpy as np 10 | 11 | 12 | def _lookup(x, vocab, embedding=None, feature_size=0): 13 | x = x.tolist() 14 | y = [] 15 | unk_mask = [] 16 | embeddings = [] 17 | 18 | for _, batch in enumerate(x): 19 | ids = [] 20 | mask = [] 21 | emb = [] 22 | 23 | for _, v in enumerate(batch): 24 | if v in vocab: 25 | ids.append(vocab[v]) 26 | mask.append(1.0) 27 | 28 | if embedding is not None: 29 | emb.append(np.zeros([feature_size])) 30 | else: 31 | ids.append(2) 32 | 33 | if embedding is not None and v in embedding: 34 | mask.append(0.0) 35 | emb.append(embedding[v]) 36 | else: 37 | mask.append(1.0) 38 | emb.append(np.zeros([feature_size])) 39 | 40 | y.append(ids) 41 | unk_mask.append(mask) 42 | embeddings.append(emb) 43 | 44 | ids = torch.LongTensor(np.array(y, dtype="int32")).cuda() 45 | mask = torch.Tensor(np.array(unk_mask, dtype="float32")).cuda() 46 | 47 | if embedding is not None: 48 | emb = torch.Tensor(np.array(embeddings, dtype="float32")).cuda() 49 | else: 50 | emb = None 51 | 52 | return ids, mask, emb 53 | 54 | 55 | def load_vocabulary(filename): 56 | vocab = [] 57 | with open(filename, "rb") as fd: 58 | for line in fd: 59 | vocab.append(line.strip()) 60 | 61 | word2idx = {} 62 | idx2word = {} 63 | 64 | for idx, word in enumerate(vocab): 65 | word2idx[word] = idx 66 | idx2word[idx] = word 67 | 68 | return vocab, word2idx, idx2word 69 | 70 | 71 | def lookup(inputs, mode, params, embedding=None): 72 | if mode == "train": 73 | features, labels = inputs 74 | preds, seqs = features["preds"], features["inputs"] 75 | preds = torch.LongTensor(preds.numpy()).cuda() 76 | seqs = seqs.numpy() 77 | labels = labels.numpy() 78 | 79 | seqs, _, _ = _lookup(seqs, params.lookup["source"]) 80 | labels, _, _ = _lookup(labels, params.lookup["target"]) 81 | 82 | features = { 83 | "preds": preds, 84 | "inputs": seqs 85 | } 86 | 87 | return features, labels 88 | else: 89 | features, _ = inputs 90 | preds, seqs = features["preds"], features["inputs"] 91 | preds = torch.LongTensor(preds.numpy()).cuda() 92 | seqs = seqs.numpy() 93 | 94 | seqs, unk_mask, emb = _lookup(seqs, params.lookup["source"], embedding, 95 | params.feature_size) 96 | 97 | features = { 98 | "preds": preds, 99 | "inputs": seqs, 100 | "mask": unk_mask 101 | } 102 | 103 | if emb is not None: 104 | features["embedding"] = emb 105 | 106 | return features 107 | -------------------------------------------------------------------------------- /tagger/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import tagger.models.deepatt 9 | 10 | 11 | def get_model(name): 12 | name = name.lower() 13 | 14 | if name == "deepatt": 15 | return tagger.models.deepatt.DeepAtt 16 | else: 17 | raise LookupError("Unknown model %s" % name) 18 | -------------------------------------------------------------------------------- /tagger/models/deepatt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | 12 | import tagger.utils as utils 13 | import tagger.modules as modules 14 | 15 | from tagger.data import load_glove_embedding 16 | 17 | 18 | class AttentionSubLayer(modules.Module): 19 | 20 | def __init__(self, params, name="attention"): 21 | super(AttentionSubLayer, self).__init__(name=name) 22 | 23 | with utils.scope(name): 24 | self.attention = modules.MultiHeadAttention( 25 | params.hidden_size, params.num_heads, params.attention_dropout) 26 | self.layer_norm = modules.LayerNorm(params.hidden_size) 27 | 28 | self.dropout = params.residual_dropout 29 | 30 | def forward(self, x, bias): 31 | y = self.attention(x, bias) 32 | y = nn.functional.dropout(y, self.dropout, self.training) 33 | 34 | return self.layer_norm(x + y) 35 | 36 | 37 | class FFNSubLayer(modules.Module): 38 | 39 | def __init__(self, params, dtype=None, name="ffn_layer"): 40 | super(FFNSubLayer, self).__init__(name=name) 41 | 42 | with utils.scope(name): 43 | self.ffn_layer = modules.FeedForward(params.hidden_size, 44 | params.filter_size, 45 | dropout=params.relu_dropout) 46 | self.layer_norm = modules.LayerNorm(params.hidden_size) 47 | self.dropout = params.residual_dropout 48 | 49 | def forward(self, x): 50 | y = self.ffn_layer(x) 51 | y = nn.functional.dropout(y, self.dropout, self.training) 52 | 53 | return self.layer_norm(x + y) 54 | 55 | 56 | class DeepAttEncoderLayer(modules.Module): 57 | 58 | def __init__(self, params, name="layer"): 59 | super(DeepAttEncoderLayer, self).__init__(name=name) 60 | 61 | with utils.scope(name): 62 | self.self_attention = AttentionSubLayer(params) 63 | self.feed_forward = FFNSubLayer(params) 64 | 65 | def forward(self, x, bias): 66 | x = self.feed_forward(x) 67 | x = self.self_attention(x, bias) 68 | return x 69 | 70 | 71 | class DeepAttEncoder(modules.Module): 72 | 73 | def __init__(self, params, name="encoder"): 74 | super(DeepAttEncoder, self).__init__(name=name) 75 | 76 | with utils.scope(name): 77 | self.layers = nn.ModuleList([ 78 | DeepAttEncoderLayer(params, name="layer_%d" % i) 79 | for i in range(params.num_hidden_layers)]) 80 | 81 | def forward(self, x, bias): 82 | for layer in self.layers: 83 | x = layer(x, bias) 84 | return x 85 | 86 | 87 | class DeepAtt(modules.Module): 88 | 89 | def __init__(self, params, name="deepatt"): 90 | super(DeepAtt, self).__init__(name=name) 91 | self.params = params 92 | 93 | with utils.scope(name): 94 | self.build_embedding(params) 95 | self.encoding = modules.PositionalEmbedding() 96 | self.encoder = DeepAttEncoder(params) 97 | self.classifier = modules.Affine(params.hidden_size, 98 | len(params.vocabulary["target"]), 99 | name="softmax") 100 | 101 | self.criterion = modules.SmoothedCrossEntropyLoss( 102 | params.label_smoothing) 103 | self.dropout = params.residual_dropout 104 | self.hidden_size = params.hidden_size 105 | self.reset_parameters() 106 | 107 | def build_embedding(self, params): 108 | vocab_size = len(params.vocabulary["source"]) 109 | 110 | self.embedding = torch.nn.Parameter( 111 | torch.empty([vocab_size, params.feature_size])) 112 | self.weights = torch.nn.Parameter( 113 | torch.empty([2, params.feature_size])) 114 | self.bias = torch.nn.Parameter(torch.zeros([params.hidden_size])) 115 | self.add_name(self.embedding, "embedding") 116 | self.add_name(self.weights, "weights") 117 | self.add_name(self.bias, "bias") 118 | 119 | def reset_parameters(self): 120 | nn.init.normal_(self.embedding, mean=0.0, 121 | std=self.params.feature_size ** -0.5) 122 | nn.init.normal_(self.weights, mean=0.0, 123 | std=self.params.feature_size ** -0.5) 124 | nn.init.normal_(self.classifier.weight, mean=0.0, 125 | std=self.params.hidden_size ** -0.5) 126 | nn.init.zeros_(self.classifier.bias) 127 | 128 | def encode(self, features): 129 | seq = features["inputs"] 130 | pred = features["preds"] 131 | mask = torch.ne(seq, 0).float().cuda() 132 | enc_attn_bias = self.masking_bias(mask) 133 | 134 | inputs = torch.nn.functional.embedding(seq, self.embedding) 135 | 136 | if "embedding" in features and not self.training: 137 | embedding = features["embedding"] 138 | unk_mask = features["mask"].to(mask)[:, :, None] 139 | inputs = inputs * unk_mask + (1.0 - unk_mask) * embedding 140 | 141 | preds = torch.nn.functional.embedding(pred, self.weights) 142 | inputs = torch.cat([inputs, preds], axis=-1) 143 | inputs = inputs * (self.hidden_size ** 0.5) 144 | inputs = inputs + self.bias 145 | 146 | inputs = nn.functional.dropout(self.encoding(inputs), self.dropout, 147 | self.training) 148 | 149 | enc_attn_bias = enc_attn_bias.to(inputs) 150 | encoder_output = self.encoder(inputs, enc_attn_bias) 151 | logits = self.classifier(encoder_output) 152 | 153 | return logits 154 | 155 | def argmax_decode(self, features): 156 | logits = self.encode(features) 157 | return torch.argmax(logits, -1) 158 | 159 | def forward(self, features, labels): 160 | mask = torch.ne(features["inputs"], 0).float().cuda() 161 | logits = self.encode(features) 162 | loss = self.criterion(logits, labels) 163 | mask = mask.to(logits) 164 | 165 | return torch.sum(loss * mask) / torch.sum(mask) 166 | 167 | def load_embedding(self, path): 168 | if not path: 169 | return 170 | emb = load_glove_embedding(path, self.params.lookup["source"]) 171 | 172 | with torch.no_grad(): 173 | self.embedding.copy_(torch.tensor(emb)) 174 | 175 | @staticmethod 176 | def masking_bias(mask, inf=-1e9): 177 | ret = (1.0 - mask) * inf 178 | return torch.unsqueeze(torch.unsqueeze(ret, 1), 1) 179 | 180 | @staticmethod 181 | def base_params(): 182 | params = utils.HParams( 183 | pad="", 184 | bos="", 185 | eos="", 186 | unk="", 187 | feature_size=100, 188 | hidden_size=200, 189 | filter_size=800, 190 | num_heads=8, 191 | num_hidden_layers=10, 192 | attention_dropout=0.0, 193 | residual_dropout=0.1, 194 | relu_dropout=0.0, 195 | label_smoothing=0.1, 196 | clip_grad_norm=0.0 197 | ) 198 | 199 | return params 200 | 201 | @staticmethod 202 | def default_params(name=None): 203 | return DeepAtt.base_params() 204 | -------------------------------------------------------------------------------- /tagger/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from tagger.modules.attention import MultiHeadAttention 2 | from tagger.modules.embedding import PositionalEmbedding 3 | from tagger.modules.feed_forward import FeedForward 4 | from tagger.modules.layer_norm import LayerNorm 5 | from tagger.modules.losses import SmoothedCrossEntropyLoss 6 | from tagger.modules.module import Module 7 | from tagger.modules.affine import Affine 8 | from tagger.modules.recurrent import LSTMCell, GRUCell, HighwayLSTMCell, DynamicLSTMCell 9 | -------------------------------------------------------------------------------- /tagger/modules/affine.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | 12 | import tagger.utils as utils 13 | from tagger.modules.module import Module 14 | 15 | 16 | class Affine(Module): 17 | 18 | def __init__(self, in_features, out_features, bias=True, name="affine"): 19 | super(Affine, self).__init__(name=name) 20 | self.in_features = in_features 21 | self.out_features = out_features 22 | 23 | with utils.scope(name): 24 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 25 | self.add_name(self.weight, "weight") 26 | if bias: 27 | self.bias = nn.Parameter(torch.Tensor(out_features)) 28 | self.add_name(self.bias, "bias") 29 | else: 30 | self.register_parameter('bias', None) 31 | 32 | self.reset_parameters() 33 | 34 | def reset_parameters(self): 35 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 36 | if self.bias is not None: 37 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 38 | bound = 1 / math.sqrt(fan_in) 39 | nn.init.uniform_(self.bias, -bound, bound) 40 | 41 | def orthogonal_initialize(self, gain=1.0): 42 | nn.init.orthogonal_(self.weight, gain) 43 | nn.init.zeros_(self.bias) 44 | 45 | def forward(self, input): 46 | return nn.functional.linear(input, self.weight, self.bias) 47 | 48 | def extra_repr(self): 49 | return 'in_features={}, out_features={}, bias={}'.format( 50 | self.in_features, self.out_features, self.bias is not None 51 | ) 52 | -------------------------------------------------------------------------------- /tagger/modules/attention.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import torch 9 | import torch.nn as nn 10 | import tagger.utils as utils 11 | 12 | from tagger.modules.module import Module 13 | from tagger.modules.affine import Affine 14 | 15 | 16 | class MultiHeadAttention(Module): 17 | 18 | def __init__(self, hidden_size, num_heads, dropout=0.0, 19 | name="multihead_attention"): 20 | super(MultiHeadAttention, self).__init__(name=name) 21 | 22 | self.num_heads = num_heads 23 | self.hidden_size = hidden_size 24 | self.dropout = dropout 25 | 26 | with utils.scope(name): 27 | self.qkv_transform = Affine(hidden_size, 3 * hidden_size, 28 | name="qkv_transform") 29 | self.o_transform = Affine(hidden_size, hidden_size, 30 | name="o_transform") 31 | 32 | self.reset_parameters() 33 | 34 | def forward(self, query, bias): 35 | qkv = self.qkv_transform(query) 36 | q, k, v = torch.split(qkv, self.hidden_size, dim=-1) 37 | 38 | # split heads 39 | qh = self.split_heads(q, self.num_heads) 40 | kh = self.split_heads(k, self.num_heads) 41 | vh = self.split_heads(v, self.num_heads) 42 | 43 | # scale query 44 | qh = qh * (self.hidden_size // self.num_heads) ** -0.5 45 | 46 | # dot-product attention 47 | kh = torch.transpose(kh, -2, -1) 48 | logits = torch.matmul(qh, kh) 49 | 50 | if bias is not None: 51 | logits = logits + bias 52 | 53 | weights = torch.nn.functional.dropout(torch.softmax(logits, dim=-1), 54 | p=self.dropout, 55 | training=self.training) 56 | 57 | x = torch.matmul(weights, vh) 58 | 59 | # combine heads 60 | output = self.o_transform(self.combine_heads(x)) 61 | 62 | return output 63 | 64 | def reset_parameters(self, initializer="orthogonal"): 65 | if initializer == "orthogonal": 66 | self.qkv_transform.orthogonal_initialize() 67 | self.o_transform.orthogonal_initialize() 68 | else: 69 | # 6 / (4 * hidden_size) -> 6 / (2 * hidden_size) 70 | nn.init.xavier_uniform_(self.qkv_transform.weight) 71 | nn.init.xavier_uniform_(self.o_transform.weight) 72 | nn.init.constant_(self.qkv_transform.bias, 0.0) 73 | nn.init.constant_(self.o_transform.bias, 0.0) 74 | 75 | @staticmethod 76 | def split_heads(x, heads): 77 | batch = x.shape[0] 78 | length = x.shape[1] 79 | channels = x.shape[2] 80 | 81 | y = torch.reshape(x, [batch, length, heads, channels // heads]) 82 | return torch.transpose(y, 2, 1) 83 | 84 | @staticmethod 85 | def combine_heads(x): 86 | batch = x.shape[0] 87 | heads = x.shape[1] 88 | length = x.shape[2] 89 | channels = x.shape[3] 90 | 91 | y = torch.transpose(x, 2, 1) 92 | 93 | return torch.reshape(y, [batch, length, heads * channels]) 94 | -------------------------------------------------------------------------------- /tagger/modules/embedding.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | import torch 10 | 11 | 12 | class PositionalEmbedding(torch.nn.Module): 13 | 14 | def __init__(self): 15 | super(PositionalEmbedding, self).__init__() 16 | 17 | def forward(self, inputs): 18 | if inputs.dim() != 3: 19 | raise ValueError("The rank of input must be 3.") 20 | 21 | length = inputs.shape[1] 22 | channels = inputs.shape[2] 23 | half_dim = channels // 2 24 | 25 | positions = torch.arange(length, dtype=inputs.dtype, 26 | device=inputs.device) 27 | dimensions = torch.arange(half_dim, dtype=inputs.dtype, 28 | device=inputs.device) 29 | 30 | scale = math.log(10000.0) / float(half_dim - 1) 31 | dimensions.mul_(-scale).exp_() 32 | 33 | scaled_time = positions.unsqueeze(1) * dimensions.unsqueeze(0) 34 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 35 | dim=1) 36 | 37 | if channels % 2 == 1: 38 | pad = torch.zeros([signal.shape[0], 1], dtype=inputs.dtype, 39 | device=inputs.device) 40 | signal = torch.cat([signal, pad], axis=1) 41 | 42 | return inputs + torch.reshape(signal, [1, -1, channels]).to(inputs) 43 | -------------------------------------------------------------------------------- /tagger/modules/feed_forward.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import torch 9 | import torch.nn as nn 10 | import tagger.utils as utils 11 | 12 | from tagger.modules.module import Module 13 | from tagger.modules.affine import Affine 14 | 15 | 16 | class FeedForward(Module): 17 | 18 | def __init__(self, input_size, hidden_size, output_size=None, dropout=0.0, 19 | name="feed_forward"): 20 | super(FeedForward, self).__init__(name=name) 21 | 22 | self.input_size = input_size 23 | self.hidden_size = hidden_size 24 | self.output_size = output_size or input_size 25 | self.dropout = dropout 26 | 27 | with utils.scope(name): 28 | self.input_transform = Affine(input_size, hidden_size, 29 | name="input_transform") 30 | self.output_transform = Affine(hidden_size, self.output_size, 31 | name="output_transform") 32 | 33 | self.reset_parameters() 34 | 35 | def forward(self, x): 36 | h = nn.functional.relu(self.input_transform(x)) 37 | h = nn.functional.dropout(h, self.dropout, self.training) 38 | return self.output_transform(h) 39 | 40 | def reset_parameters(self, initializer="orthogonal"): 41 | if initializer == "orthogonal": 42 | self.input_transform.orthogonal_initialize() 43 | self.output_transform.orthogonal_initialize() 44 | else: 45 | nn.init.xavier_uniform_(self.input_transform.weight) 46 | nn.init.xavier_uniform_(self.output_transform.weight) 47 | nn.init.constant_(self.input_transform.bias, 0.0) 48 | nn.init.constant_(self.output_transform.bias, 0.0) 49 | -------------------------------------------------------------------------------- /tagger/modules/layer_norm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import numbers 9 | import torch 10 | import torch.nn as nn 11 | import tagger.utils as utils 12 | 13 | from tagger.modules.module import Module 14 | 15 | 16 | class LayerNorm(Module): 17 | 18 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, 19 | name="layer_norm"): 20 | super(LayerNorm, self).__init__(name=name) 21 | if isinstance(normalized_shape, numbers.Integral): 22 | normalized_shape = (normalized_shape,) 23 | self.normalized_shape = tuple(normalized_shape) 24 | self.eps = eps 25 | self.elementwise_affine = elementwise_affine 26 | 27 | with utils.scope(name): 28 | if self.elementwise_affine: 29 | self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) 30 | self.bias = nn.Parameter(torch.Tensor(*normalized_shape)) 31 | self.add_name(self.weight, "weight") 32 | self.add_name(self.bias, "bias") 33 | else: 34 | self.register_parameter('weight', None) 35 | self.register_parameter('bias', None) 36 | self.reset_parameters() 37 | 38 | def reset_parameters(self): 39 | if self.elementwise_affine: 40 | nn.init.ones_(self.weight) 41 | nn.init.zeros_(self.bias) 42 | 43 | def forward(self, input): 44 | return nn.functional.layer_norm( 45 | input, self.normalized_shape, self.weight, self.bias, self.eps) 46 | 47 | def extra_repr(self): 48 | return '{normalized_shape}, eps={eps}, ' \ 49 | 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) 50 | -------------------------------------------------------------------------------- /tagger/modules/losses.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | import torch 10 | 11 | 12 | class SmoothedCrossEntropyLoss(torch.nn.Module): 13 | 14 | def __init__(self, smoothing=0.0, normalize=True): 15 | super(SmoothedCrossEntropyLoss, self).__init__() 16 | self.smoothing = smoothing 17 | self.normalize = normalize 18 | 19 | def forward(self, logits, labels): 20 | shape = labels.shape 21 | logits = torch.reshape(logits, [-1, logits.shape[-1]]) 22 | labels = torch.reshape(labels, [-1]) 23 | 24 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 25 | batch_idx = torch.arange(labels.shape[0], device=logits.device) 26 | loss = log_probs[batch_idx, labels] 27 | 28 | if not self.smoothing: 29 | return -torch.reshape(loss, shape) 30 | 31 | n = logits.shape[-1] - 1.0 32 | p = 1.0 - self.smoothing 33 | q = self.smoothing / n 34 | 35 | if log_probs.dtype != torch.float16: 36 | sum_probs = torch.sum(log_probs, dim=-1) 37 | loss = p * loss + q * (sum_probs - loss) 38 | else: 39 | # Prevent FP16 overflow 40 | sum_probs = torch.sum(log_probs.to(torch.float32), dim=-1) 41 | loss = loss.to(torch.float32) 42 | loss = p * loss + q * (sum_probs - loss) 43 | loss = loss.to(torch.float16) 44 | 45 | loss = -torch.reshape(loss, shape) 46 | 47 | if self.normalize: 48 | normalizing = -(p * math.log(p) + n * q * math.log(q + 1e-20)) 49 | return loss - normalizing 50 | else: 51 | return loss 52 | -------------------------------------------------------------------------------- /tagger/modules/module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | import tagger.utils as utils 12 | 13 | 14 | class Module(nn.Module): 15 | 16 | def __init__(self, name=""): 17 | super(Module, self).__init__() 18 | scope = utils.get_scope() 19 | self._name = scope + "/" + name if scope else name 20 | 21 | def add_name(self, tensor, name): 22 | tensor.tensor_name = utils.unique_name(name) 23 | 24 | @property 25 | def name(self): 26 | return self._name 27 | -------------------------------------------------------------------------------- /tagger/modules/recurrent.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | import tagger.utils as utils 12 | 13 | from tagger.modules.module import Module 14 | from tagger.modules.affine import Affine 15 | from tagger.modules.layer_norm import LayerNorm 16 | 17 | 18 | class GRUCell(Module): 19 | 20 | def __init__(self, input_size, output_size, normalization=False, 21 | name="gru"): 22 | super(GRUCell, self).__init__(name=name) 23 | 24 | self.input_size = input_size 25 | self.output_size = output_size 26 | 27 | with utils.scope(name): 28 | self.reset_gate = Affine(input_size + output_size, output_size, 29 | bias=False, name="reset_gate") 30 | self.update_gate = Affine(input_size + output_size, output_size, 31 | bias=False, name="update_gate") 32 | self.transform = Affine(input_size + output_size, output_size, 33 | name="transform") 34 | 35 | def forward(self, x, h): 36 | r = torch.sigmoid(self.reset_gate(torch.cat([x, h], -1))) 37 | u = torch.sigmoid(self.update_gate(torch.cat([x, h], -1))) 38 | c = self.transform(torch.cat([x, r * h], -1)) 39 | 40 | new_h = (1.0 - u) * h + u * torch.tanh(h) 41 | 42 | return new_h, new_h 43 | 44 | def init_state(self, batch_size, dtype, device): 45 | h = torch.zeros([batch_size, self.output_size], dtype=dtype, 46 | device=device) 47 | return h 48 | 49 | def mask_state(self, h, prev_h, mask): 50 | mask = mask[:, None] 51 | new_h = mask * h + (1.0 - mask) * prev_h 52 | return new_h 53 | 54 | def reset_parameters(self, initializer="uniform"): 55 | if initializer == "uniform_scaling": 56 | nn.init.xavier_uniform_(self.gates.weight) 57 | nn.init.constant_(self.gates.bias, 0.0) 58 | elif initializer == "uniform": 59 | nn.init.uniform_(self.gates.weight, -0.08, 0.08) 60 | nn.init.uniform_(self.gates.bias, -0.08, 0.08) 61 | else: 62 | raise ValueError("Unknown initializer %d" % initializer) 63 | 64 | 65 | class LSTMCell(Module): 66 | 67 | def __init__(self, input_size, output_size, normalization=False, 68 | activation=torch.tanh, name="lstm"): 69 | super(LSTMCell, self).__init__(name=name) 70 | 71 | self.input_size = input_size 72 | self.output_size = output_size 73 | self.activation = activation 74 | 75 | with utils.scope(name): 76 | self.gates = Affine(input_size + output_size, 4 * output_size, 77 | name="gates") 78 | if normalization: 79 | self.layer_norm = LayerNorm([4, output_size]) 80 | else: 81 | self.layer_norm = None 82 | 83 | self.reset_parameters() 84 | 85 | def forward(self, x, state): 86 | c, h = state 87 | 88 | gates = self.gates(torch.cat([x, h], 1)) 89 | 90 | if self.layer_norm is not None: 91 | combined = self.layer_norm( 92 | torch.reshape(gates, [-1, 4, self.output_size])) 93 | else: 94 | combined = torch.reshape(gates, [-1, 4, self.output_size]) 95 | 96 | i, j, f, o = torch.unbind(combined, 1) 97 | i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o) 98 | 99 | new_c = f * c + i * torch.tanh(j) 100 | 101 | if self.activation is None: 102 | # Do not use tanh activation 103 | new_h = o * new_c 104 | else: 105 | new_h = o * self.activation(new_c) 106 | 107 | return new_h, (new_c, new_h) 108 | 109 | def init_state(self, batch_size, dtype, device): 110 | c = torch.zeros([batch_size, self.output_size], dtype=dtype, 111 | device=device) 112 | h = torch.zeros([batch_size, self.output_size], dtype=dtype, 113 | device=device) 114 | return c, h 115 | 116 | def mask_state(self, state, prev_state, mask): 117 | c, h = state 118 | prev_c, prev_h = prev_state 119 | mask = mask[:, None] 120 | new_c = mask * c + (1.0 - mask) * prev_c 121 | new_h = mask * h + (1.0 - mask) * prev_h 122 | return new_c, new_h 123 | 124 | def reset_parameters(self, initializer="orthogonal"): 125 | if initializer == "uniform_scaling": 126 | nn.init.xavier_uniform_(self.gates.weight) 127 | nn.init.constant_(self.gates.bias, 0.0) 128 | elif initializer == "uniform": 129 | nn.init.uniform_(self.gates.weight, -0.04, 0.04) 130 | nn.init.uniform_(self.gates.bias, -0.04, 0.04) 131 | elif initializer == "orthogonal": 132 | self.gates.orthogonal_initialize() 133 | else: 134 | raise ValueError("Unknown initializer %d" % initializer) 135 | 136 | 137 | 138 | class HighwayLSTMCell(Module): 139 | 140 | def __init__(self, input_size, output_size, name="lstm"): 141 | super(HighwayLSTMCell, self).__init__(name=name) 142 | 143 | self.input_size = input_size 144 | self.output_size = output_size 145 | 146 | with utils.scope(name): 147 | self.gates = Affine(input_size + output_size, 5 * output_size, 148 | name="gates") 149 | self.trans = Affine(input_size, output_size, name="trans") 150 | 151 | self.reset_parameters() 152 | 153 | def forward(self, x, state): 154 | c, h = state 155 | 156 | gates = self.gates(torch.cat([x, h], 1)) 157 | combined = torch.reshape(gates, [-1, 5, self.output_size]) 158 | i, j, f, o, t = torch.unbind(combined, 1) 159 | i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o) 160 | t = torch.sigmoid(t) 161 | 162 | new_c = f * c + i * torch.tanh(j) 163 | tmp_h = o * torch.tanh(new_c) 164 | new_h = t * tmp_h + (1.0 - t) * self.trans(x) 165 | 166 | return new_h, (new_c, new_h) 167 | 168 | def init_state(self, batch_size, dtype, device): 169 | c = torch.zeros([batch_size, self.output_size], dtype=dtype, 170 | device=device) 171 | h = torch.zeros([batch_size, self.output_size], dtype=dtype, 172 | device=device) 173 | return c, h 174 | 175 | def mask_state(self, state, prev_state, mask): 176 | c, h = state 177 | prev_c, prev_h = prev_state 178 | mask = mask[:, None] 179 | new_c = mask * c + (1.0 - mask) * prev_c 180 | new_h = mask * h + (1.0 - mask) * prev_h 181 | return new_c, new_h 182 | 183 | def reset_parameters(self, initializer="orthogonal"): 184 | if initializer == "uniform_scaling": 185 | nn.init.xavier_uniform_(self.gates.weight) 186 | nn.init.constant_(self.gates.bias, 0.0) 187 | elif initializer == "uniform": 188 | nn.init.uniform_(self.gates.weight, -0.04, 0.04) 189 | nn.init.uniform_(self.gates.bias, -0.04, 0.04) 190 | elif initializer == "orthogonal": 191 | self.gates.orthogonal_initialize() 192 | self.trans.orthogonal_initialize() 193 | else: 194 | raise ValueError("Unknown initializer %d" % initializer) 195 | 196 | 197 | class DynamicLSTMCell(Module): 198 | 199 | def __init__(self, input_size, output_size, k=2, num_cells=4, name="lstm"): 200 | super(DynamicLSTMCell, self).__init__(name=name) 201 | 202 | self.input_size = input_size 203 | self.output_size = output_size 204 | self.num_cells = num_cells 205 | self.k = k 206 | 207 | with utils.scope(name): 208 | self.gates = Affine(input_size + output_size, 209 | 4 * output_size * num_cells, 210 | name="gates") 211 | self.topk_gate = Affine(input_size + output_size, 212 | num_cells, name="controller") 213 | 214 | 215 | self.reset_parameters() 216 | 217 | @staticmethod 218 | def top_k_softmax(logits, k, n): 219 | top_logits, top_indices = torch.topk(logits, k=min(k + 1, n)) 220 | 221 | top_k_logits = top_logits[:, :k] 222 | top_k_indices = top_indices[:, :k] 223 | 224 | probs = torch.softmax(top_k_logits, dim=-1) 225 | batch = top_k_logits.shape[0] 226 | k = top_k_logits.shape[1] 227 | 228 | # Flat to 1D 229 | indices_flat = torch.reshape(top_k_indices, [-1]) 230 | indices_flat = indices_flat + torch.div( 231 | torch.arange(batch * k, device=logits.device), k) * n 232 | 233 | tensor = torch.zeros([batch * n], dtype=logits.dtype, 234 | device=logits.device) 235 | tensor = tensor.scatter_add(0, indices_flat.long(), 236 | torch.reshape(probs, [-1])) 237 | 238 | return torch.reshape(tensor, [batch, n]) 239 | 240 | def forward(self, x, state): 241 | c, h = state 242 | feats = torch.cat([x, h], dim=-1) 243 | 244 | logits = self.topk_gate(feats) 245 | # [batch, num_cells] 246 | gate = self.top_k_softmax(logits, self.k, self.num_cells) 247 | 248 | # [batch, 4 * num_cells * dim] 249 | combined = self.gates(feats) 250 | combined = torch.reshape(combined, 251 | [-1, self.num_cells, 4, self.output_size]) 252 | 253 | i, j, f, o = torch.unbind(combined, 2) 254 | i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o) 255 | 256 | # [batch, num_cells, dim] 257 | new_c = f * c[:, None, :] + i * torch.tanh(j) 258 | new_h = o * torch.tanh(new_c) 259 | 260 | gate = gate[:, None, :] 261 | new_c = torch.matmul(gate, new_c) 262 | new_h = torch.matmul(gate, new_h) 263 | 264 | new_c = torch.squeeze(new_c, 1) 265 | new_h = torch.squeeze(new_h, 1) 266 | 267 | return new_h, (new_c, new_h) 268 | 269 | def init_state(self, batch_size, dtype, device): 270 | c = torch.zeros([batch_size, self.output_size], dtype=dtype, 271 | device=device) 272 | h = torch.zeros([batch_size, self.output_size], dtype=dtype, 273 | device=device) 274 | return c, h 275 | 276 | def mask_state(self, state, prev_state, mask): 277 | c, h = state 278 | prev_c, prev_h = prev_state 279 | mask = mask[:, None] 280 | new_c = mask * c + (1.0 - mask) * prev_c 281 | new_h = mask * h + (1.0 - mask) * prev_h 282 | return new_c, new_h 283 | 284 | def reset_parameters(self, initializer="orthogonal"): 285 | if initializer == "uniform_scaling": 286 | nn.init.xavier_uniform_(self.gates.weight) 287 | nn.init.constant_(self.gates.bias, 0.0) 288 | elif initializer == "uniform": 289 | nn.init.uniform_(self.gates.weight, -0.04, 0.04) 290 | nn.init.uniform_(self.gates.bias, -0.04, 0.04) 291 | elif initializer == "orthogonal": 292 | weight = self.gates.weight.view( 293 | [self.input_size + self.output_size, self.num_cells, 294 | 4 * self.output_size]) 295 | nn.init.orthogonal_(weight, 1.0) 296 | nn.init.constant_(self.gates.bias, 0.0) 297 | else: 298 | raise ValueError("Unknown initializer %d" % initializer) 299 | -------------------------------------------------------------------------------- /tagger/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from tagger.optimizers.optimizers import AdamOptimizer 2 | from tagger.optimizers.optimizers import AdadeltaOptimizer 3 | from tagger.optimizers.optimizers import MultiStepOptimizer 4 | from tagger.optimizers.optimizers import LossScalingOptimizer 5 | from tagger.optimizers.schedules import LinearWarmupRsqrtDecay 6 | from tagger.optimizers.schedules import PiecewiseConstantDecay 7 | from tagger.optimizers.clipping import ( 8 | adaptive_clipper, global_norm_clipper, value_clipper) 9 | -------------------------------------------------------------------------------- /tagger/optimizers/clipping.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | 10 | 11 | def global_norm_clipper(value): 12 | def clip_fn(gradients, grad_norm): 13 | if not float(value) or grad_norm < value: 14 | return False, gradients 15 | 16 | scale = value / grad_norm 17 | 18 | gradients = [grad.data.mul_(scale) 19 | if grad is not None else None for grad in gradients] 20 | 21 | return False, gradients 22 | 23 | return clip_fn 24 | 25 | 26 | def value_clipper(clip_min, clip_max): 27 | def clip_fn(gradients, grad_norm): 28 | gradients = [ 29 | grad.data.clamp_(clip_min, clip_max) 30 | if grad is not None else None for grad in gradients] 31 | 32 | return False, None 33 | 34 | return clip_fn 35 | 36 | 37 | def adaptive_clipper(rho): 38 | norm_avg = 0.0 39 | norm_stddev = 0.0 40 | log_norm_avg = 0.0 41 | log_norm_sqr = 0.0 42 | 43 | def clip_fn(gradients, grad_norm): 44 | nonlocal norm_avg 45 | nonlocal norm_stddev 46 | nonlocal log_norm_avg 47 | nonlocal log_norm_sqr 48 | 49 | norm = grad_norm 50 | log_norm = math.log(norm) 51 | 52 | avg = rho * norm_avg + (1.0 - rho) * norm 53 | log_avg = rho * log_norm_avg + (1.0 - rho) * log_norm 54 | log_sqr = rho * log_norm_sqr + (1.0 - rho) * (log_norm ** 2) 55 | stddev = (log_sqr - (log_avg ** 2)) ** -0.5 56 | 57 | norm_avg = avg 58 | log_norm_avg = log_avg 59 | log_norm_sqr = log_sqr 60 | norm_stddev = rho * stddev + (1.0 - rho) * stddev 61 | 62 | reject = False 63 | 64 | if norm > norm_avg + 4 * math.exp(norm_stddev): 65 | reject = True 66 | 67 | return reject, gradients 68 | 69 | return clip_fn 70 | -------------------------------------------------------------------------------- /tagger/optimizers/optimizers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | import torch 10 | import torch.distributed as dist 11 | import tagger.utils as utils 12 | import tagger.utils.summary as summary 13 | 14 | from tagger.optimizers.schedules import LearningRateSchedule 15 | 16 | 17 | def _save_summary(grads_and_vars): 18 | total_norm = 0.0 19 | 20 | for grad, var in grads_and_vars: 21 | if grad is None: 22 | continue 23 | 24 | _, var = var 25 | grad_norm = grad.data.norm() 26 | total_norm += grad_norm ** 2 27 | summary.histogram(var.tensor_name, var, 28 | utils.get_global_step()) 29 | summary.scalar("norm/" + var.tensor_name, var.norm(), 30 | utils.get_global_step()) 31 | summary.scalar("grad_norm/" + var.tensor_name, grad_norm, 32 | utils.get_global_step()) 33 | 34 | total_norm = total_norm ** 0.5 35 | summary.scalar("grad_norm", total_norm, utils.get_global_step()) 36 | 37 | return float(total_norm) 38 | 39 | 40 | def _compute_grad_norm(gradients): 41 | total_norm = 0.0 42 | 43 | for grad in gradients: 44 | total_norm += float(grad.data.norm() ** 2) 45 | 46 | return float(total_norm ** 0.5) 47 | 48 | 49 | class Optimizer(object): 50 | 51 | def __init__(self, name, **kwargs): 52 | self._name = name 53 | self._iterations = 0 54 | self._slots = {} 55 | 56 | def detach_gradients(self, gradients): 57 | for grad in gradients: 58 | if grad is not None: 59 | grad.detach_() 60 | 61 | def scale_gradients(self, gradients, scale): 62 | for grad in gradients: 63 | if grad is not None: 64 | grad.mul_(scale) 65 | 66 | def sync_gradients(self, gradients, compress=True): 67 | grad_vec = torch.nn.utils.parameters_to_vector(gradients) 68 | 69 | if compress: 70 | grad_vec_half = grad_vec.half() 71 | dist.all_reduce(grad_vec_half) 72 | grad_vec = grad_vec_half.to(grad_vec) 73 | else: 74 | dist.all_reduce(grad_vec) 75 | 76 | torch.nn.utils.vector_to_parameters(grad_vec, gradients) 77 | 78 | def zero_gradients(self, gradients): 79 | for grad in gradients: 80 | if grad is not None: 81 | grad.zero_() 82 | 83 | def compute_gradients(self, loss, var_list, aggregate=False): 84 | var_list = list(var_list) 85 | grads = [v.grad if v is not None else None for v in var_list] 86 | 87 | self.detach_gradients(grads) 88 | 89 | if not aggregate: 90 | self.zero_gradients(grads) 91 | 92 | loss.backward() 93 | return [v.grad if v is not None else None for v in var_list] 94 | 95 | def apply_gradients(self, grads_and_vars): 96 | raise NotImplementedError("Not implemented") 97 | 98 | @property 99 | def iterations(self): 100 | return self._iterations 101 | 102 | def state_dict(self): 103 | raise NotImplementedError("Not implemented") 104 | 105 | def load_state_dict(self): 106 | raise NotImplementedError("Not implemented") 107 | 108 | 109 | class SGDOptimizer(Optimizer): 110 | 111 | def __init__(self, learning_rate, summaries=True, name="SGD", **kwargs): 112 | super(SGDOptimizer, self).__init__(name, **kwargs) 113 | self._learning_rate = learning_rate 114 | self._summaries = summaries 115 | self._clipper = None 116 | 117 | if "clipper" in kwargs and kwargs["clipper"] is not None: 118 | self._clipper = kwargs["clipper"] 119 | 120 | def apply_gradients(self, grads_and_vars): 121 | self._iterations += 1 122 | lr = self._learning_rate 123 | grads, var_list = list(zip(*grads_and_vars)) 124 | 125 | if self._summaries: 126 | grad_norm = _save_summary(zip(grads, var_list)) 127 | else: 128 | grad_norm = _compute_grad_norm(grads) 129 | 130 | if self._clipper is not None: 131 | reject, grads = self._clipper(grads, grad_norm) 132 | 133 | if reject: 134 | return 135 | 136 | for grad, var in zip(grads, var_list): 137 | if grad is None: 138 | continue 139 | 140 | # Convert if grad is not FP32 141 | grad = grad.data.float() 142 | _, var = var 143 | step_size = lr 144 | 145 | if var.dtype == torch.float32: 146 | var.data.add_(-step_size, grad) 147 | else: 148 | fp32_var = var.data.float() 149 | fp32_var.add_(-step_size, grad) 150 | var.data.copy_(fp32_var) 151 | 152 | def state_dict(self): 153 | state = { 154 | "iterations": self._iterations, 155 | } 156 | 157 | if not isinstance(self._learning_rate, LearningRateSchedule): 158 | state["learning_rate"] = self._learning_rate 159 | 160 | return state 161 | 162 | def load_state_dict(self, state): 163 | self._learning_rate = state.get("learning_rate", self._learning_rate) 164 | self._iterations = state.get("iterations", self._iterations) 165 | 166 | 167 | class AdamOptimizer(Optimizer): 168 | 169 | def __init__(self, learning_rate=0.01, beta_1=0.9, beta_2=0.999, 170 | epsilon=1e-7, name="Adam", **kwargs): 171 | super(AdamOptimizer, self).__init__(name, **kwargs) 172 | self._learning_rate = learning_rate 173 | self._beta_1 = beta_1 174 | self._beta_2 = beta_2 175 | self._epsilon = epsilon 176 | self._summaries = True 177 | self._clipper = None 178 | 179 | if "summaries" in kwargs and not kwargs["summaries"]: 180 | self._summaries = False 181 | 182 | if "clipper" in kwargs and kwargs["clipper"] is not None: 183 | self._clipper = kwargs["clipper"] 184 | 185 | def apply_gradients(self, grads_and_vars): 186 | self._iterations += 1 187 | lr = self._learning_rate 188 | beta_1 = self._beta_1 189 | beta_2 = self._beta_2 190 | epsilon = self._epsilon 191 | grads, var_list = list(zip(*grads_and_vars)) 192 | 193 | if self._summaries: 194 | grad_norm = _save_summary(zip(grads, var_list)) 195 | else: 196 | grad_norm = _compute_grad_norm(grads) 197 | 198 | if self._clipper is not None: 199 | reject, grads = self._clipper(grads, grad_norm) 200 | 201 | if reject: 202 | return 203 | 204 | for grad, var in zip(grads, var_list): 205 | if grad is None: 206 | continue 207 | 208 | # Convert if grad is not FP32 209 | grad = grad.data.float() 210 | name, var = var 211 | 212 | if self._slots.get(name, None) is None: 213 | self._slots[name] = {} 214 | self._slots[name]["m"] = torch.zeros_like(var.data, 215 | dtype=torch.float32) 216 | self._slots[name]["v"] = torch.zeros_like(var.data, 217 | dtype=torch.float32) 218 | 219 | m, v = self._slots[name]["m"], self._slots[name]["v"] 220 | 221 | bias_corr_1 = 1 - beta_1 ** self._iterations 222 | bias_corr_2 = 1 - beta_2 ** self._iterations 223 | 224 | m.mul_(beta_1).add_(1 - beta_1, grad) 225 | v.mul_(beta_2).addcmul_(1 - beta_2, grad, grad) 226 | denom = (v.sqrt() / math.sqrt(bias_corr_2)).add_(epsilon) 227 | 228 | if isinstance(lr, LearningRateSchedule): 229 | lr = lr(self._iterations) 230 | 231 | step_size = lr / bias_corr_1 232 | 233 | if var.dtype == torch.float32: 234 | var.data.addcdiv_(-step_size, m, denom) 235 | else: 236 | fp32_var = var.data.float() 237 | fp32_var.addcdiv_(-step_size, m, denom) 238 | var.data.copy_(fp32_var) 239 | 240 | def state_dict(self): 241 | state = { 242 | "beta_1": self._beta_1, 243 | "beta_2": self._beta_2, 244 | "epsilon": self._epsilon, 245 | "iterations": self._iterations, 246 | "slot": self._slots 247 | } 248 | 249 | if not isinstance(self._learning_rate, LearningRateSchedule): 250 | state["learning_rate"] = self._learning_rate 251 | 252 | return state 253 | 254 | def load_state_dict(self, state): 255 | self._learning_rate = state.get("learning_rate", self._learning_rate) 256 | self._beta_1 = state.get("beta_1", self._beta_1) 257 | self._beta_2 = state.get("beta_2", self._beta_2) 258 | self._epsilon = state.get("epsilon", self._epsilon) 259 | self._iterations = state.get("iterations", self._iterations) 260 | 261 | slots = state.get("slot", {}) 262 | self._slots = {} 263 | 264 | for key in slots: 265 | m, v = slots[key]["m"], slots[key]["v"] 266 | self._slots[key] = {} 267 | self._slots[key]["m"] = torch.zeros(m.shape, dtype=torch.float32) 268 | self._slots[key]["v"] = torch.zeros(v.shape, dtype=torch.float32) 269 | self._slots[key]["m"].copy_(m) 270 | self._slots[key]["v"].copy_(v) 271 | 272 | 273 | class AdadeltaOptimizer(Optimizer): 274 | 275 | def __init__(self, learning_rate=0.001, rho=0.95, epsilon=1e-07, 276 | name="Adadelta", **kwargs): 277 | super(AdadeltaOptimizer, self).__init__(name, **kwargs) 278 | self._learning_rate = learning_rate 279 | self._rho = rho 280 | self._epsilon = epsilon 281 | self._summaries = True 282 | 283 | if "summaries" in kwargs and not kwargs["summaries"]: 284 | self._summaries = False 285 | 286 | if "clipper" in kwargs and kwargs["clipper"] is not None: 287 | self._clipper = kwargs["clipper"] 288 | 289 | def apply_gradients(self, grads_and_vars): 290 | self._iterations += 1 291 | lr = self._learning_rate 292 | rho = self._rho 293 | epsilon = self._epsilon 294 | 295 | grads, var_list = list(zip(*grads_and_vars)) 296 | 297 | if self._summaries: 298 | grad_norm = _save_summary(zip(grads, var_list)) 299 | else: 300 | grad_norm = _compute_grad_norm(grads) 301 | 302 | if self._clipper is not None: 303 | reject, grads = self._clipper(grads, grad_norm) 304 | 305 | if reject: 306 | return 307 | 308 | for grad, var in zip(grads, var_list): 309 | if grad is None: 310 | continue 311 | 312 | # Convert if grad is not FP32 313 | grad = grad.data.float() 314 | name, var = var 315 | 316 | if self._slots.get(name, None) is None: 317 | self._slots[name] = {} 318 | self._slots[name]["m"] = torch.zeros_like(var.data, 319 | dtype=torch.float32) 320 | self._slots[name]["v"] = torch.zeros_like(var.data, 321 | dtype=torch.float32) 322 | 323 | square_avg = self._slots[name]["m"] 324 | acc_delta = self._slots[name]["v"] 325 | 326 | if isinstance(lr, LearningRateSchedule): 327 | lr = lr(self._iterations) 328 | 329 | square_avg.mul_(rho).addcmul_(1 - rho, grad, grad) 330 | std = square_avg.add(epsilon).sqrt_() 331 | delta = acc_delta.add(epsilon).sqrt_().div_(std).mul_(grad) 332 | acc_delta.mul_(rho).addcmul_(1 - rho, delta, delta) 333 | 334 | if var.dtype == torch.float32: 335 | var.data.add_(-lr, delta) 336 | else: 337 | fp32_var = var.data.float() 338 | fp32_var.add_(-lr, delta) 339 | var.data.copy_(fp32_var) 340 | 341 | def state_dict(self): 342 | state = { 343 | "rho": self._rho, 344 | "epsilon": self._epsilon, 345 | "iterations": self._iterations, 346 | "slot": self._slots 347 | } 348 | 349 | if not isinstance(self._learning_rate, LearningRateSchedule): 350 | state["learning_rate"] = self._learning_rate 351 | 352 | return state 353 | 354 | def load_state_dict(self, state): 355 | self._learning_rate = state.get("learning_rate", self._learning_rate) 356 | self._rho = state.get("rho", self._rho) 357 | self._epsilon = state.get("epsilon", self._epsilon) 358 | self._iterations = state.get("iterations", self._iterations) 359 | 360 | slots = state.get("slot", {}) 361 | self._slots = {} 362 | 363 | for key in slots: 364 | m, v = slots[key]["m"], slots[key]["v"] 365 | self._slots[key] = {} 366 | self._slots[key]["m"] = torch.zeros(m.shape, dtype=torch.float32) 367 | self._slots[key]["v"] = torch.zeros(v.shape, dtype=torch.float32) 368 | self._slots[key]["m"].copy_(m) 369 | self._slots[key]["v"].copy_(v) 370 | 371 | 372 | class LossScalingOptimizer(Optimizer): 373 | 374 | def __init__(self, optimizer, scale=2.0**7, increment_period=2000, 375 | multiplier=2.0, name="LossScalingOptimizer", **kwargs): 376 | super(LossScalingOptimizer, self).__init__(name, **kwargs) 377 | self._optimizer = optimizer 378 | self._scale = scale 379 | self._increment_period = increment_period 380 | self._multiplier = multiplier 381 | self._num_good_steps = 0 382 | self._summaries = True 383 | 384 | if "summaries" in kwargs and not kwargs["summaries"]: 385 | self._summaries = False 386 | 387 | def _update_if_finite_grads(self): 388 | if self._num_good_steps + 1 > self._increment_period: 389 | self._scale *= self._multiplier 390 | self._scale = min(self._scale, 2.0**16) 391 | self._num_good_steps = 0 392 | else: 393 | self._num_good_steps += 1 394 | 395 | def _update_if_not_finite_grads(self): 396 | self._scale = max(self._scale / self._multiplier, 1) 397 | 398 | def compute_gradients(self, loss, var_list, aggregate=False): 399 | var_list = list(var_list) 400 | grads = [v.grad if v is not None else None for v in var_list] 401 | 402 | self.detach_gradients(grads) 403 | 404 | if not aggregate: 405 | self.zero_gradients(grads) 406 | 407 | loss = loss * self._scale 408 | loss.backward() 409 | 410 | return [v.grad if v is not None else None for v in var_list] 411 | 412 | def apply_gradients(self, grads_and_vars): 413 | self._iterations += 1 414 | grads, var_list = list(zip(*grads_and_vars)) 415 | new_grads = [] 416 | 417 | if self._summaries: 418 | summary.scalar("optimizer/scale", self._scale, 419 | utils.get_global_step()) 420 | 421 | for grad in grads: 422 | if grad is None: 423 | new_grads.append(None) 424 | continue 425 | 426 | norm = grad.data.norm() 427 | 428 | if not torch.isfinite(norm): 429 | self._update_if_not_finite_grads() 430 | return 431 | else: 432 | # Rescale gradients 433 | new_grads.append(grad.data.float().mul_(1.0 / self._scale)) 434 | 435 | self._update_if_finite_grads() 436 | self._optimizer.apply_gradients(zip(new_grads, var_list)) 437 | 438 | def state_dict(self): 439 | state = { 440 | "scale": self._scale, 441 | "increment_period": self._increment_period, 442 | "multiplier": self._multiplier, 443 | "num_good_steps": self._num_good_steps, 444 | "optimizer": self._optimizer.state_dict() 445 | } 446 | return state 447 | 448 | def load_state_dict(self, state): 449 | self._scale = state.get("scale", self._scale) 450 | self._increment_period = state.get("increment_period", 451 | self._increment_period) 452 | self._multiplier = state.get("multiplier", self._multiplier) 453 | self._num_good_steps = state.get("num_good_steps", 454 | self._num_good_steps) 455 | self._optimizer.load_state_dict(state.get("optimizer", {})) 456 | 457 | 458 | class MultiStepOptimizer(Optimizer): 459 | 460 | def __init__(self, optimizer, n=1, compress=True, 461 | name="MultiStepOptimizer", **kwargs): 462 | super(MultiStepOptimizer, self).__init__(name, **kwargs) 463 | self._n = n 464 | self._optimizer = optimizer 465 | self._compress = compress 466 | 467 | def compute_gradients(self, loss, var_list, aggregate=False): 468 | if self._iterations % self._n == 0: 469 | return self._optimizer.compute_gradients(loss, var_list, aggregate) 470 | else: 471 | return self._optimizer.compute_gradients(loss, var_list, True) 472 | 473 | def apply_gradients(self, grads_and_vars): 474 | size = dist.get_world_size() 475 | grads, var_list = list(zip(*grads_and_vars)) 476 | self._iterations += 1 477 | 478 | if self._n == 1: 479 | if size > 1: 480 | self.sync_gradients(grads, compress=self._compress) 481 | self.scale_gradients(grads, 1.0 / size) 482 | 483 | self._optimizer.apply_gradients(zip(grads, var_list)) 484 | else: 485 | if self._iterations % self._n != 0: 486 | return 487 | 488 | if size > 1: 489 | self.sync_gradients(grads, compress=self._compress) 490 | 491 | self.scale_gradients(grads, 1.0 / (self._n * size)) 492 | self._optimizer.apply_gradients(zip(grads, var_list)) 493 | 494 | def state_dict(self): 495 | state = { 496 | "n": self._n, 497 | "iterations": self._iterations, 498 | "compress": self._compress, 499 | "optimizer": self._optimizer.state_dict() 500 | } 501 | return state 502 | 503 | def load_state_dict(self, state): 504 | self._n = state.get("n", self._n) 505 | self._iterations = state.get("iterations", self._iterations) 506 | self._compress = state.get("compress", self._iterations) 507 | self._optimizer.load_state_dict(state.get("optimizer", {})) 508 | -------------------------------------------------------------------------------- /tagger/optimizers/schedules.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | 9 | import tagger.utils as utils 10 | import tagger.utils.summary as summary 11 | 12 | 13 | class LearningRateSchedule(object): 14 | 15 | def __call__(self, step): 16 | raise NotImplementedError("Not implemented.") 17 | 18 | def get_config(self): 19 | raise NotImplementedError("Not implemented.") 20 | 21 | @classmethod 22 | def from_config(cls, config): 23 | return cls(**config) 24 | 25 | 26 | 27 | class LinearWarmupRsqrtDecay(LearningRateSchedule): 28 | 29 | def __init__(self, learning_rate, warmup_steps, initial_learning_rate=0.0, 30 | summary=True): 31 | super(LinearWarmupRsqrtDecay, self).__init__() 32 | 33 | if not initial_learning_rate: 34 | initial_learning_rate = learning_rate / warmup_steps 35 | 36 | self._initial_learning_rate = initial_learning_rate 37 | self._maximum_learning_rate = learning_rate 38 | self._warmup_steps = warmup_steps 39 | self._summary = summary 40 | 41 | def __call__(self, step): 42 | if step <= self._warmup_steps: 43 | lr_step = self._maximum_learning_rate - self._initial_learning_rate 44 | lr_step /= self._warmup_steps 45 | lr = self._initial_learning_rate + lr_step * step 46 | else: 47 | step = step / self._warmup_steps 48 | lr = self._maximum_learning_rate * (step ** -0.5) 49 | 50 | if self._summary: 51 | summary.scalar("learning_rate", lr, utils.get_global_step()) 52 | 53 | return lr 54 | 55 | def get_config(self): 56 | return { 57 | "learning_rate": self._maximum_learning_rate, 58 | "initial_learning_rate": self._initial_learning_rate, 59 | "warmup_steps": self._warmup_steps 60 | } 61 | 62 | 63 | class PiecewiseConstantDecay(LearningRateSchedule): 64 | 65 | def __init__(self, boundaries, values, summary=True, name=None): 66 | super(PiecewiseConstantDecay, self).__init__() 67 | 68 | if len(boundaries) != len(values) - 1: 69 | raise ValueError("The length of boundaries should be 1" 70 | " less than the length of values") 71 | 72 | self._boundaries = boundaries 73 | self._values = values 74 | self._summary = summary 75 | 76 | def __call__(self, step): 77 | boundaries = self._boundaries 78 | values = self._values 79 | learning_rate = values[0] 80 | 81 | if step <= boundaries[0]: 82 | learning_rate = values[0] 83 | elif step > boundaries[-1]: 84 | learning_rate = values[-1] 85 | else: 86 | for low, high, v in zip(boundaries[:-1], boundaries[1:], 87 | values[1:-1]): 88 | 89 | if step > low and step <= high: 90 | learning_rate = v 91 | break 92 | 93 | if self._summary: 94 | summary.scalar("learning_rate", learning_rate, 95 | utils.get_global_step()) 96 | 97 | return learning_rate 98 | 99 | def get_config(self): 100 | return { 101 | "boundaries": self._boundaries, 102 | "values": self._values, 103 | } 104 | 105 | 106 | class LinearExponentialDecay(LearningRateSchedule): 107 | 108 | def __init__(self, learning_rate, warmup_steps, start_decay_step, 109 | end_decay_step, n, summary=True): 110 | super(LinearExponentialDecay, self).__init__() 111 | 112 | self._learning_rate = learning_rate 113 | self._warmup_steps = warmup_steps 114 | self._start_decay_step = start_decay_step 115 | self._end_decay_step = end_decay_step 116 | self._n = n 117 | self._summary = summary 118 | 119 | def __call__(self, step): 120 | # See reference: The Best of Both Worlds: Combining Recent Advances 121 | # in Neural Machine Translation 122 | n = self._n 123 | p = self._warmup_steps / n 124 | s = n * self._start_decay_step 125 | e = n * self._end_decay_step 126 | 127 | learning_rate = self._learning_rate 128 | 129 | learning_rate *= min( 130 | 1.0 + (n - 1) * step / float(n * p), 131 | n, 132 | n * ((2 * n) ** (float(s - n * step) / float(e - s)))) 133 | 134 | if self._summary: 135 | summary.scalar("learning_rate", learning_rate, 136 | utils.get_global_step()) 137 | 138 | return learning_rate 139 | 140 | def get_config(self): 141 | return { 142 | "learning_rate": self._learning_rate, 143 | "warmup_steps": self._warmup_steps, 144 | "start_decay_step": self._start_decay_step, 145 | "end_decay_step": self._end_decay_step, 146 | } 147 | class LearningRateSchedule(object): 148 | 149 | def __call__(self, step): 150 | raise NotImplementedError("Not implemented.") 151 | 152 | def get_config(self): 153 | raise NotImplementedError("Not implemented.") 154 | 155 | @classmethod 156 | def from_config(cls, config): 157 | return cls(**config) 158 | 159 | 160 | 161 | class LinearWarmupRsqrtDecay(LearningRateSchedule): 162 | 163 | def __init__(self, learning_rate, warmup_steps, initial_learning_rate=0.0, 164 | summary=True): 165 | super(LinearWarmupRsqrtDecay, self).__init__() 166 | 167 | if not initial_learning_rate: 168 | initial_learning_rate = learning_rate / warmup_steps 169 | 170 | self._initial_learning_rate = initial_learning_rate 171 | self._maximum_learning_rate = learning_rate 172 | self._warmup_steps = warmup_steps 173 | self._summary = summary 174 | 175 | def __call__(self, step): 176 | if step <= self._warmup_steps: 177 | lr_step = self._maximum_learning_rate - self._initial_learning_rate 178 | lr_step /= self._warmup_steps 179 | lr = self._initial_learning_rate + lr_step * step 180 | else: 181 | step = step / self._warmup_steps 182 | lr = self._maximum_learning_rate * (step ** -0.5) 183 | 184 | if self._summary: 185 | summary.scalar("learning_rate", lr, utils.get_global_step()) 186 | 187 | return lr 188 | 189 | def get_config(self): 190 | return { 191 | "learning_rate": self._maximum_learning_rate, 192 | "initial_learning_rate": self._initial_learning_rate, 193 | "warmup_steps": self._warmup_steps 194 | } 195 | 196 | 197 | class PiecewiseConstantDecay(LearningRateSchedule): 198 | 199 | def __init__(self, boundaries, values, summary=True, name=None): 200 | super(PiecewiseConstantDecay, self).__init__() 201 | 202 | if len(boundaries) != len(values) - 1: 203 | raise ValueError("The length of boundaries should be 1" 204 | " less than the length of values") 205 | 206 | self._boundaries = boundaries 207 | self._values = values 208 | self._summary = summary 209 | 210 | def __call__(self, step): 211 | boundaries = self._boundaries 212 | values = self._values 213 | learning_rate = values[0] 214 | 215 | if step <= boundaries[0]: 216 | learning_rate = values[0] 217 | elif step > boundaries[-1]: 218 | learning_rate = values[-1] 219 | else: 220 | for low, high, v in zip(boundaries[:-1], boundaries[1:], 221 | values[1:-1]): 222 | 223 | if step > low and step <= high: 224 | learning_rate = v 225 | break 226 | 227 | if self._summary: 228 | summary.scalar("learning_rate", learning_rate, 229 | utils.get_global_step()) 230 | 231 | return learning_rate 232 | 233 | def get_config(self): 234 | return { 235 | "boundaries": self._boundaries, 236 | "values": self._values, 237 | } 238 | 239 | 240 | class LinearExponentialDecay(LearningRateSchedule): 241 | 242 | def __init__(self, learning_rate, warmup_steps, start_decay_step, 243 | end_decay_step, n, summary=True): 244 | super(LinearExponentialDecay, self).__init__() 245 | 246 | self._learning_rate = learning_rate 247 | self._warmup_steps = warmup_steps 248 | self._start_decay_step = start_decay_step 249 | self._end_decay_step = end_decay_step 250 | self._n = n 251 | self._summary = summary 252 | 253 | def __call__(self, step): 254 | # See reference: The Best of Both Worlds: Combining Recent Advances 255 | # in Neural Machine Translation 256 | n = self._n 257 | p = self._warmup_steps / n 258 | s = n * self._start_decay_step 259 | e = n * self._end_decay_step 260 | 261 | learning_rate = self._learning_rate 262 | 263 | learning_rate *= min( 264 | 1.0 + (n - 1) * step / float(n * p), 265 | n, 266 | n * ((2 * n) ** (float(s - n * step) / float(e - s)))) 267 | 268 | if self._summary: 269 | summary.scalar("learning_rate", learning_rate, 270 | utils.get_global_step()) 271 | 272 | return learning_rate 273 | 274 | def get_config(self): 275 | return { 276 | "learning_rate": self._learning_rate, 277 | "warmup_steps": self._warmup_steps, 278 | "start_decay_step": self._start_decay_step, 279 | "end_decay_step": self._end_decay_step, 280 | } 281 | -------------------------------------------------------------------------------- /tagger/scripts/build_vocab.py: -------------------------------------------------------------------------------- 1 | # build_vocab.py 2 | # author: Playinf 3 | # email: playinf@stu.xmu.edu.cn 4 | 5 | import argparse 6 | import collections 7 | 8 | 9 | def count_items(filename, lower=False): 10 | counter = collections.Counter() 11 | label_counter = collections.Counter() 12 | 13 | with open(filename, "r") as fd: 14 | for line in fd: 15 | words, labels = line.strip().split("|||") 16 | words = words.strip().split() 17 | labels = labels.strip().split() 18 | 19 | if lower: 20 | words = [item.lower() for item in words[1:]] 21 | else: 22 | words = words[1:] 23 | 24 | counter.update(words) 25 | label_counter.update(labels) 26 | 27 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) 28 | words, counts = list(zip(*count_pairs)) 29 | count_pairs = sorted(label_counter.items(), key=lambda x: (-x[1], x[0])) 30 | labels, _ = list(zip(*count_pairs)) 31 | 32 | return words, labels, counts 33 | 34 | 35 | def special_tokens(string): 36 | if not string: 37 | return [] 38 | else: 39 | return string.strip().split(":") 40 | 41 | 42 | def save_vocab(name, vocab): 43 | if name.split(".")[-1] != "txt": 44 | name = name + ".txt" 45 | 46 | pairs = sorted(vocab.items(), key=lambda x: (x[1], x[0])) 47 | words, ids = list(zip(*pairs)) 48 | 49 | with open(name, "w") as f: 50 | for word in words: 51 | f.write(word + "\n") 52 | 53 | 54 | def write_vocab(name, vocab): 55 | with open(name, "w") as f: 56 | for word in vocab: 57 | f.write(word + "\n") 58 | 59 | 60 | def parse_args(): 61 | msg = "build vocabulary" 62 | parser = argparse.ArgumentParser(description=msg) 63 | 64 | msg = "input corpus" 65 | parser.add_argument("corpus", help=msg) 66 | msg = "output vocabulary name" 67 | parser.add_argument("output", default="vocab.txt", help=msg) 68 | msg = "limit" 69 | parser.add_argument("--limit", default=0, type=int, help=msg) 70 | msg = "add special token, separated by colon" 71 | parser.add_argument("--special", type=str, default=":", 72 | help=msg) 73 | msg = "use lowercase" 74 | parser.add_argument("--lower", action="store_true", help=msg) 75 | 76 | return parser.parse_args() 77 | 78 | 79 | def main(args): 80 | vocab = {} 81 | limit = args.limit 82 | count = 0 83 | 84 | words, labels, counts = count_items(args.corpus, args.lower) 85 | special = special_tokens(args.special) 86 | 87 | for token in special: 88 | vocab[token] = len(vocab) 89 | 90 | for word, freq in zip(words, counts): 91 | if limit and len(vocab) >= limit: 92 | break 93 | 94 | if word in vocab: 95 | print("warning: found duplicate token %s, ignored" % word) 96 | continue 97 | 98 | vocab[word] = len(vocab) 99 | count += freq 100 | 101 | save_vocab(args.output + "/vocab.txt", vocab) 102 | write_vocab(args.output + "/label.txt", labels) 103 | 104 | print("total words: %d" % sum(counts)) 105 | print("unique words: %d" % len(words)) 106 | print("vocabulary coverage: %4.2f%%" % (100.0 * count / sum(counts))) 107 | 108 | 109 | if __name__ == "__main__": 110 | main(parse_args()) 111 | -------------------------------------------------------------------------------- /tagger/scripts/convert_to_conll.py: -------------------------------------------------------------------------------- 1 | # convert_to_conll.py 2 | # author: Playinf 3 | # email: playinf@stu.xmu.edu.cn 4 | 5 | import sys 6 | 7 | 8 | def convert_bio(labels): 9 | n = len(labels) 10 | tags = [] 11 | 12 | tag = [] 13 | count = 0 14 | 15 | # B I* 16 | for label in labels: 17 | count += 1 18 | 19 | if count == n: 20 | next_l = None 21 | else: 22 | next_l = labels[count] 23 | 24 | if label == "O": 25 | if tag: 26 | tags.append(tag) 27 | tag = [] 28 | tags.append([label]) 29 | continue 30 | 31 | tag.append(label[2:]) 32 | 33 | if not next_l or next_l[0] == "B": 34 | tags.append(tag) 35 | tag = [] 36 | 37 | new_tag = [] 38 | 39 | for tag in tags: 40 | if len(tag) == 1: 41 | if tag[0] == "O": 42 | new_tag.append("*") 43 | else: 44 | new_tag.append("(" + tag[0] + "*)") 45 | continue 46 | 47 | label = tag[0] 48 | n = len(tag) 49 | 50 | for i in range(n): 51 | if i == 0: 52 | new_tag.append("(" + label + "*") 53 | elif i == n - 1: 54 | new_tag.append("*)") 55 | else: 56 | new_tag.append("*") 57 | 58 | return new_tag 59 | 60 | 61 | def print_sentence_to_conll(fout, tokens, labels): 62 | for label_column in labels: 63 | assert len(label_column) == len(tokens) 64 | for i in range(len(tokens)): 65 | fout.write(tokens[i].ljust(15)) 66 | for label_column in labels: 67 | fout.write(label_column[i].rjust(15)) 68 | fout.write("\n") 69 | fout.write("\n") 70 | 71 | 72 | def print_to_conll(pred_labels, gold_props_file, output_filename): 73 | fout = open(output_filename, 'w') 74 | seq_ptr = 0 75 | num_props_for_sentence = 0 76 | tokens_buf = [] 77 | 78 | for line in open(gold_props_file, 'r'): 79 | line = line.strip() 80 | if line == "" and len(tokens_buf) > 0: 81 | print_sentence_to_conll( 82 | fout, 83 | tokens_buf, 84 | pred_labels[seq_ptr:seq_ptr+num_props_for_sentence] 85 | ) 86 | seq_ptr += num_props_for_sentence 87 | tokens_buf = [] 88 | num_props_for_sentence = 0 89 | else: 90 | info = line.split() 91 | num_props_for_sentence = len(info) - 1 92 | tokens_buf.append(info[0]) 93 | 94 | # Output last sentence. 95 | if len(tokens_buf) > 0: 96 | print_sentence_to_conll( 97 | fout, 98 | tokens_buf, 99 | pred_labels[seq_ptr:seq_ptr+num_props_for_sentence] 100 | ) 101 | 102 | fout.close() 103 | 104 | 105 | if __name__ == "__main__": 106 | all_labels = [] 107 | with open(sys.argv[1]) as fd: 108 | for text_line in fd: 109 | labs = text_line.strip().split() 110 | labs = convert_bio(labs) 111 | all_labels.append(labs) 112 | 113 | print_to_conll(all_labels, sys.argv[2], sys.argv[3]) 114 | -------------------------------------------------------------------------------- /tagger/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from tagger.utils.hparams import HParams 2 | from tagger.utils.checkpoint import save, latest_checkpoint, best_checkpoint 3 | from tagger.utils.scope import scope, get_scope, unique_name 4 | from tagger.utils.misc import get_global_step, set_global_step 5 | -------------------------------------------------------------------------------- /tagger/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import glob 10 | import torch 11 | 12 | 13 | def oldest_checkpoint(path): 14 | names = glob.glob(os.path.join(path, "*.pt")) 15 | 16 | if not names: 17 | return None 18 | 19 | oldest_counter = 10000000 20 | checkpoint_name = names[0] 21 | 22 | for name in names: 23 | counter = name.rstrip(".pt").split("-")[-1] 24 | 25 | if not counter.isdigit(): 26 | continue 27 | else: 28 | counter = int(counter) 29 | 30 | if counter < oldest_counter: 31 | checkpoint_name = name 32 | oldest_counter = counter 33 | 34 | return checkpoint_name 35 | 36 | 37 | def best_checkpoint(path): 38 | if not os.path.exists(os.path.join(path, "checkpoint")): 39 | return latest_checkpoint(path) 40 | 41 | with open(os.path.join(path, "checkpoint")) as fd: 42 | line = fd.readline() 43 | name = line.strip().split()[-1][1:-1] 44 | 45 | return os.path.join(path, name) 46 | 47 | 48 | def latest_checkpoint(path): 49 | names = glob.glob(os.path.join(path, "*.pt")) 50 | 51 | if not names: 52 | return None 53 | 54 | latest_counter = 0 55 | checkpoint_name = names[0] 56 | 57 | for name in names: 58 | counter = name.rstrip(".pt").split("-")[-1] 59 | 60 | if not counter.isdigit(): 61 | continue 62 | else: 63 | counter = int(counter) 64 | 65 | if counter > latest_counter: 66 | checkpoint_name = name 67 | latest_counter = counter 68 | 69 | return checkpoint_name 70 | 71 | 72 | def save(state, path, max_to_keep=None): 73 | checkpoints = glob.glob(os.path.join(path, "*.pt")) 74 | 75 | if max_to_keep and len(checkpoints) >= max_to_keep: 76 | checkpoint = oldest_checkpoint(path) 77 | os.remove(checkpoint) 78 | 79 | if not checkpoints: 80 | counter = 1 81 | else: 82 | checkpoint = latest_checkpoint(path) 83 | counter = int(checkpoint.rstrip(".pt").split("-")[-1]) + 1 84 | 85 | checkpoint = os.path.join(path, "model-%d.pt" % counter) 86 | print("Saving checkpoint: %s" % checkpoint) 87 | torch.save(state, checkpoint) 88 | -------------------------------------------------------------------------------- /tagger/utils/hparams.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | # Modified from TensorFlow (tf.contrib.training.HParams) 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import json 10 | import logging 11 | import re 12 | import six 13 | 14 | 15 | def parse_values(values, type_map): 16 | ret = {} 17 | param_re = re.compile(r"(?P[a-zA-Z][\w]*)\s*=\s*" 18 | r"((?P[^,\[]*)|\[(?P[^\]]*)\])($|,)") 19 | pos = 0 20 | 21 | while pos < len(values): 22 | m = param_re.match(values, pos) 23 | 24 | if not m: 25 | raise ValueError( 26 | "Malformed hyperparameter value: %s" % values[pos:]) 27 | 28 | # Check that there is a comma between parameters and move past it. 29 | pos = m.end() 30 | # Parse the values. 31 | m_dict = m.groupdict() 32 | name = m_dict["name"] 33 | 34 | if name not in type_map: 35 | raise ValueError("Unknown hyperparameter type for %s" % name) 36 | 37 | def parse_fail(): 38 | raise ValueError("Could not parse hparam %s in %s" % (name, values)) 39 | 40 | if type_map[name] == bool: 41 | def parse_bool(value): 42 | if value == "true": 43 | return True 44 | elif value == "false": 45 | return False 46 | else: 47 | try: 48 | return bool(int(value)) 49 | except ValueError: 50 | parse_fail() 51 | parse = parse_bool 52 | else: 53 | parse = type_map[name] 54 | 55 | 56 | if m_dict["val"] is not None: 57 | try: 58 | ret[name] = parse(m_dict["val"]) 59 | except ValueError: 60 | parse_fail() 61 | elif m_dict["vals"] is not None: 62 | elements = filter(None, re.split("[ ,]", m_dict["vals"])) 63 | try: 64 | ret[name] = [parse(e) for e in elements] 65 | except ValueError: 66 | parse_fail() 67 | else: 68 | parse_fail() 69 | 70 | return ret 71 | 72 | 73 | class HParams(object): 74 | 75 | def __init__(self, **kwargs): 76 | self._hparam_types = {} 77 | 78 | for name, value in six.iteritems(kwargs): 79 | self.add_hparam(name, value) 80 | 81 | def add_hparam(self, name, value): 82 | if getattr(self, name, None) is not None: 83 | raise ValueError("Hyperparameter name is reserved: %s" % name) 84 | if isinstance(value, (list, tuple)): 85 | if not value: 86 | raise ValueError("Multi-valued hyperparameters cannot be" 87 | " empty: %s" % name) 88 | self._hparam_types[name] = (type(value[0]), True) 89 | else: 90 | self._hparam_types[name] = (type(value), False) 91 | setattr(self, name, value) 92 | 93 | def parse(self, values): 94 | type_map = dict() 95 | 96 | for name, t in six.iteritems(self._hparam_types): 97 | param_type, _ = t 98 | type_map[name] = param_type 99 | 100 | values_map = parse_values(values, type_map) 101 | return self._set_from_map(values_map) 102 | 103 | def _set_from_map(self, values_map): 104 | for name, value in six.iteritems(values_map): 105 | if name not in self._hparam_types: 106 | logging.debug("%s not found in hparams." % name) 107 | continue 108 | 109 | _, is_list = self._hparam_types[name] 110 | 111 | if isinstance(value, list): 112 | if not is_list: 113 | raise ValueError("Must not pass a list for single-valued " 114 | "parameter: %s" % name) 115 | setattr(self, name, value) 116 | else: 117 | if is_list: 118 | raise ValueError("Must pass a list for multi-valued " 119 | "parameter: %s" % name) 120 | setattr(self, name, value) 121 | return self 122 | 123 | def to_json(self): 124 | return json.dumps(self.values()) 125 | 126 | def parse_json(self, values_json): 127 | values_map = json.loads(values_json) 128 | return self._set_from_map(values_map) 129 | 130 | def values(self): 131 | return {n: getattr(self, n) for n in six.iterkeys(self._hparam_types)} 132 | 133 | def __str__(self): 134 | return str(sorted(six.iteritems(self.values))) 135 | -------------------------------------------------------------------------------- /tagger/utils/misc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | _GLOBAL_STEP = 0 9 | 10 | 11 | def get_global_step(): 12 | return _GLOBAL_STEP 13 | 14 | 15 | def set_global_step(step): 16 | global _GLOBAL_STEP 17 | _GLOBAL_STEP = step 18 | -------------------------------------------------------------------------------- /tagger/utils/scope.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2019 The THUMT Authors 3 | # Modified from TensorFlow (tf.name_scope) 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import re 10 | import contextlib 11 | 12 | # global variable 13 | _NAME_STACK = "" 14 | _NAMES_IN_USE = {} 15 | _VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*$") 16 | _VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/]*$") 17 | 18 | 19 | def unique_name(name, mark_as_used=True): 20 | global _NAME_STACK 21 | 22 | if _NAME_STACK: 23 | name = _NAME_STACK + "/" + name 24 | 25 | i = _NAMES_IN_USE.get(name, 0) 26 | 27 | if mark_as_used: 28 | _NAMES_IN_USE[name] = i + 1 29 | 30 | if i > 0: 31 | base_name = name 32 | 33 | while name in _NAMES_IN_USE: 34 | name = "%s_%d" % (base_name, i) 35 | i += 1 36 | 37 | if mark_as_used: 38 | _NAMES_IN_USE[name] = 1 39 | 40 | return name 41 | 42 | 43 | @contextlib.contextmanager 44 | def scope(name): 45 | global _NAME_STACK 46 | 47 | if name: 48 | if _NAME_STACK: 49 | # check name 50 | if not _VALID_SCOPE_NAME_REGEX.match(name): 51 | raise ValueError("'%s' is not a valid scope name" % name) 52 | else: 53 | # check name strictly 54 | if not _VALID_OP_NAME_REGEX.match(name): 55 | raise ValueError("'%s' is not a valid scope name" % name) 56 | 57 | try: 58 | old_stack = _NAME_STACK 59 | 60 | if not name: 61 | new_stack = None 62 | elif name and name[-1] == "/": 63 | new_stack = name[:-1] 64 | else: 65 | new_stack = unique_name(name) 66 | 67 | _NAME_STACK = new_stack 68 | 69 | yield "" if new_stack is None else new_stack + "/" 70 | finally: 71 | _NAME_STACK = old_stack 72 | 73 | 74 | def get_scope(): 75 | return _NAME_STACK 76 | -------------------------------------------------------------------------------- /tagger/utils/summary.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017-2020 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import queue 9 | import threading 10 | import torch 11 | 12 | import torch.distributed as dist 13 | import torch.utils.tensorboard as tensorboard 14 | 15 | _SUMMARY_WRITER = None 16 | _QUEUE = None 17 | _THREAD = None 18 | 19 | 20 | class SummaryWorker(threading.Thread): 21 | 22 | def run(self): 23 | global _QUEUE 24 | 25 | while True: 26 | item = _QUEUE.get() 27 | name, kwargs = item 28 | 29 | if name == "stop": 30 | break 31 | 32 | self.write_summary(name, **kwargs) 33 | 34 | def write_summary(self, name, **kwargs): 35 | if name == "scalar": 36 | _SUMMARY_WRITER.add_scalar(**kwargs) 37 | elif name == "histogram": 38 | _SUMMARY_WRITER.add_histogram(**kwargs) 39 | 40 | def stop(self): 41 | global _QUEUE 42 | _QUEUE.put(("stop", None)) 43 | self.join() 44 | 45 | 46 | def init(log_dir, enable=True): 47 | global _SUMMARY_WRITER 48 | global _QUEUE 49 | global _THREAD 50 | 51 | if enable and dist.get_rank() == 0: 52 | _SUMMARY_WRITER = tensorboard.SummaryWriter(log_dir) 53 | _QUEUE = queue.Queue() 54 | thread = SummaryWorker(daemon=True) 55 | thread.start() 56 | _THREAD = thread 57 | 58 | 59 | def scalar(tag, scalar_value, global_step=None, walltime=None, 60 | write_every_n_steps=100): 61 | 62 | if _SUMMARY_WRITER is not None: 63 | if global_step % write_every_n_steps == 0: 64 | scalar_value = float(scalar_value) 65 | kwargs = dict(tag=tag, scalar_value=scalar_value, 66 | global_step=global_step, walltime=walltime) 67 | _QUEUE.put(("scalar", kwargs)) 68 | 69 | 70 | def histogram(tag, values, global_step=None, bins="tensorflow", walltime=None, 71 | max_bins=None, write_every_n_steps=100): 72 | 73 | if _SUMMARY_WRITER is not None: 74 | if global_step % write_every_n_steps == 0: 75 | values = values.detach().cpu() 76 | kwargs = dict(tag=tag, values=values, global_step=global_step, 77 | bins=bins, walltime=walltime, max_bins=max_bins) 78 | _QUEUE.put(("histogram", kwargs)) 79 | 80 | 81 | def close(): 82 | if _SUMMARY_WRITER is not None: 83 | _THREAD.stop() 84 | _SUMMARY_WRITER.close() 85 | -------------------------------------------------------------------------------- /tagger/utils/validation.py: -------------------------------------------------------------------------------- 1 | # validation.py 2 | # author: Playinf 3 | # email: playinf@stu.xmu.edu.cn 4 | 5 | import os 6 | import time 7 | import threading 8 | import subprocess 9 | 10 | from tagger.utils.checkpoint import latest_checkpoint 11 | 12 | 13 | def get_current_model(filename): 14 | try: 15 | with open(filename) as fd: 16 | line = fd.readline() 17 | if not line: 18 | return None 19 | 20 | name = line.strip().split(":")[1] 21 | return name.strip()[1:-1] 22 | except: 23 | return None 24 | 25 | 26 | def read_record(filename): 27 | record = [] 28 | 29 | try: 30 | with open(filename) as fd: 31 | for line in fd: 32 | line = line.strip().split(":") 33 | val = float(line[0]) 34 | name = line[1].strip()[1:-1] 35 | record.append((val, name)) 36 | except: 37 | pass 38 | 39 | return record 40 | 41 | 42 | def write_record(filename, record): 43 | # sort 44 | sorted_record = sorted(record, key=lambda x: -x[0]) 45 | 46 | with open(filename, "w") as fd: 47 | for item in sorted_record: 48 | val, name = item 49 | fd.write("%f: \"%s\"\n" % (val, name)) 50 | 51 | 52 | def write_checkpoint(filename, record): 53 | # sort 54 | sorted_record = sorted(record, key=lambda x: -x[0]) 55 | 56 | with open(filename, "w") as fd: 57 | fd.write("model_checkpoint_path: \"%s\"\n" % sorted_record[0][1]) 58 | for item in sorted_record: 59 | val, name = item 60 | fd.write("all_model_checkpoint_paths: \"%s\"\n" % name) 61 | 62 | 63 | def add_to_record(record, item, capacity): 64 | added = None 65 | removed = None 66 | models = {} 67 | 68 | for (val, name) in record: 69 | models[name] = val 70 | 71 | if len(record) < capacity: 72 | if item[1] not in models: 73 | added = item[1] 74 | record.append(item) 75 | else: 76 | sorted_record = sorted(record, key=lambda x: -x[0]) 77 | worst_score = sorted_record[-1][0] 78 | current_score = item[0] 79 | 80 | if current_score >= worst_score: 81 | if item[1] not in models: 82 | added = item[1] 83 | removed = sorted_record[-1][1] 84 | record = sorted_record[:-1] + [item] 85 | 86 | return added, removed, record 87 | 88 | 89 | class ValidationWorker(threading.Thread): 90 | 91 | def init(self, params): 92 | self._params = params 93 | self._stop = False 94 | 95 | def run(self): 96 | params = self._params 97 | best_dir = params.output + "/best" 98 | last_checkpoint = None 99 | 100 | # create directory 101 | if not os.path.exists(best_dir): 102 | os.mkdir(best_dir) 103 | record = [] 104 | else: 105 | record = read_record(best_dir + "/top") 106 | 107 | while not self._stop: 108 | try: 109 | time.sleep(params.frequency) 110 | model_name = latest_checkpoint(params.output) 111 | 112 | if model_name is None: 113 | continue 114 | 115 | if model_name == last_checkpoint: 116 | continue 117 | 118 | last_checkpoint = model_name 119 | 120 | model_name = model_name.split("/")[-1] 121 | # prediction and evaluation 122 | child = subprocess.Popen("bash %s" % params.script, 123 | shell=True, stdout=subprocess.PIPE, 124 | stderr=subprocess.PIPE) 125 | info = child.communicate()[0] 126 | 127 | if not info: 128 | continue 129 | 130 | info = info.strip().split(b"\n") 131 | overall = None 132 | 133 | for line in info[::-1]: 134 | if line.find(b"Overall") > 0: 135 | overall = line 136 | break 137 | 138 | if not overall: 139 | continue 140 | 141 | f_score = float(overall.strip().split()[-1]) 142 | 143 | # save best model 144 | item = (f_score, model_name) 145 | added, removed, record = add_to_record(record, item, 146 | params.keep_top_k) 147 | log_fd = open(best_dir + "/log", "a") 148 | log_fd.write("%s: %f\n" % (model_name, f_score)) 149 | log_fd.close() 150 | 151 | if added is not None: 152 | model_path = params.output + "/" + model_name + "*" 153 | # copy model 154 | os.system("cp %s %s" % (model_path, best_dir)) 155 | # update checkpoint 156 | write_record(best_dir + "/top", record) 157 | write_checkpoint(best_dir + "/checkpoint", record) 158 | 159 | if removed is not None: 160 | # remove old model 161 | model_name = params.output + "/best/" + removed + "*" 162 | os.system("rm %s" % model_name) 163 | except Exception as e: 164 | print(e) 165 | 166 | def stop(self): 167 | self._stop = True 168 | --------------------------------------------------------------------------------