├── LICENSE
├── README.md
├── results
├── conll05
│ ├── ensemble
│ │ ├── conll05.brown.result
│ │ ├── conll05.dev.result
│ │ └── conll05.wsj.result
│ └── single
│ │ ├── conll05.brown.result
│ │ ├── conll05.dev.result
│ │ └── conll05.wsj.result
└── conll12
│ ├── ensemble
│ ├── conll12.dev.result
│ └── conll12.test.result
│ └── single
│ ├── conll12.devel.result
│ └── conll12.test.result
└── tagger
├── __init__.py
├── bin
├── predictor.py
└── trainer.py
├── data
├── __init__.py
├── dataset.py
├── embedding.py
└── vocab.py
├── models
├── __init__.py
└── deepatt.py
├── modules
├── __init__.py
├── affine.py
├── attention.py
├── embedding.py
├── feed_forward.py
├── layer_norm.py
├── losses.py
├── module.py
└── recurrent.py
├── optimizers
├── __init__.py
├── clipping.py
├── optimizers.py
└── schedules.py
├── scripts
├── build_vocab.py
└── convert_to_conll.py
└── utils
├── __init__.py
├── checkpoint.py
├── hparams.py
├── misc.py
├── scope.py
├── summary.py
└── validation.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018, Natural Language Processing Lab at Xiamen University
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification,
5 | are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice, this
11 | list of conditions and the following disclaimer in the documentation and/or
12 | other materials provided with the distribution.
13 |
14 | * Neither the name of the copyright holder nor the names of its
15 | contributors may be used to endorse or promote products derived from this
16 | software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tagger
2 |
3 | This is the source code for the paper "[Deep Semantic Role Labeling with Self-Attention](https://arxiv.org/abs/1712.01586)".
4 |
5 | ## Contents
6 |
7 | * [Basics](#basics)
8 | * [Notice](#notice)
9 | * [Prerequisites](#prerequisites)
10 | * [Walkthrough](#walkthrough)
11 | * [Data](#data)
12 | * [Training](#training)
13 | * [Decoding](#decoding)
14 | * [Benchmarks](#benchmarks)
15 | * [Pretrained Models](#pretrained-models)
16 | * [License](#license)
17 | * [Citation](#citation)
18 | * [Contact](#contact)
19 |
20 | ## Basics
21 |
22 | ### Notice
23 |
24 | The original code used in the paper is implemented using TensorFlow 1.0, which is obsolete now. We have re-implemented our methods using PyTorch, which is based on [THUMT](https://github.com/THUNLP-MT/THUMT). The differences are as follows:
25 |
26 | * We only implement DeepAtt-FFN model
27 | * Model ensemble are currently not available
28 |
29 | Please check the git history to use TensorFlow implementation.
30 |
31 | ### Prerequisites
32 |
33 | * Python 3
34 | * PyTorch
35 | * TensorFlow-2.0 (CPU version)
36 | * GloVe embeddings and `srlconll` scripts
37 |
38 | ## Walkthrough
39 |
40 | ### Data
41 |
42 | #### Training Data
43 |
44 | We follow the same procedures described in the [deep_srl](https://github.com/luheng/deep_srl) repository to convert the CoNLL datasets.
45 | The GloVe embeddings and `srlconll` scripts can also be found in that link.
46 |
47 | If you followed these procedures, you can find that the processed data has the following format:
48 | ```
49 | 2 My cats love hats . ||| B-A0 I-A0 B-V B-A1 O
50 | ```
51 |
52 | *The CoNLL datasets are not publicly available. We cannot provide these datasets.*
53 |
54 | #### Vocabulary
55 |
56 | You can use the `build_vocab.py` script to generate vocabularies. The command is described as follows:
57 |
58 | ```[bash]
59 | python tagger/scripts/build_vocab.py --limit LIMIT --lower TRAIN_FILE OUTPUT_DIR
60 | ```
61 |
62 | where `LIMIT` specifies the vocabulary size. This command will create two vocabularies named `vocab.txt` and `label.txt` in the `OUTPUT_DIR`.
63 |
64 | ### Training
65 |
66 | Once you finished the procedures described above, you can start the training stage.
67 |
68 | #### Preparing the validation script
69 |
70 | An external validation script is required to enable the validation functionality.
71 | Here's the validation script we used to train an FFN model on the CoNLL-2005 dataset.
72 | Please make sure that the validation script can run properly.
73 |
74 | ```[bash]
75 | #!/usr/bin/env bash
76 | SRLPATH=/PATH/TO/SRLCONLL
77 | TAGGERPATH=/PATH/TO/TAGGER
78 | DATAPATH=/PATH/TO/DATA
79 | EMBPATH=/PATH/TO/GLOVE_EMBEDDING
80 | DEVICE=0
81 |
82 | export PYTHONPATH=$TAGGERPATH:$PYTHONPATH
83 | export PERL5LIB="$SRLPATH/lib:$PERL5LIB"
84 | export PATH="$SRLPATH/bin:$PATH"
85 |
86 | python $TAGGERPATH/tagger/bin/predictor.py \
87 | --input $DATAPATH/conll05.devel.txt \
88 | --checkpoint train \
89 | --model deepatt \
90 | --vocab $DATAPATH/deep_srl/word_dict $DATAPATH/deep_srl/label_dict \
91 | --parameters=device=$DEVICE,embedding=$EMBPATH/glove.6B.100d.txt \
92 | --output tmp.txt
93 |
94 | python $TAGGERPATH/tagger/scripts/convert_to_conll.py tmp.txt $DATAPATH/conll05.devel.props.gold.txt output
95 | perl $SRLPATH/bin/srl-eval.pl $DATAPATH/conll05.devel.props.* output
96 | ```
97 |
98 | #### Training command
99 |
100 | The command below is what we used to train a model on the CoNLL-2005 dataset. The content of `run.sh` is described in the above section.
101 |
102 | ```[bash]
103 | #!/usr/bin/env bash
104 | SRLPATH=/PATH/TO/SRLCONLL
105 | TAGGERPATH=/PATH/TO/TAGGER
106 | DATAPATH=/PATH/TO/DATA
107 | EMBPATH=/PATH/TO/GLOVE_EMBEDDING
108 | DEVICE=[0]
109 |
110 | export PYTHONPATH=$TAGGERPATH:$PYTHONPATH
111 | export PERL5LIB="$SRLPATH/lib:$PERL5LIB"
112 | export PATH="$SRLPATH/bin:$PATH"
113 |
114 | python $TAGGERPATH/tagger/bin/trainer.py \
115 | --model deepatt \
116 | --input $DATAPATH/conll05.train.txt \
117 | --output train \
118 | --vocabulary $DATAPATH/deep_srl/word_dict $DATAPATH/deep_srl/label_dict \
119 | --parameters="save_summary=false,feature_size=100,hidden_size=200,filter_size=800,"`
120 | `"residual_dropout=0.2,num_hidden_layers=10,attention_dropout=0.1,"`
121 | `"relu_dropout=0.1,batch_size=4096,optimizer=adadelta,initializer=orthogonal,"`
122 | `"initializer_gain=1.0,train_steps=600000,"`
123 | `"learning_rate_schedule=piecewise_constant_decay,"`
124 | `"learning_rate_values=[1.0,0.5,0.25,],"`
125 | `"learning_rate_boundaries=[400000,50000],device_list=$DEVICE,"`
126 | `"clip_grad_norm=1.0,embedding=$EMBPATH/glove.6B.100d.txt,script=run.sh"
127 | ```
128 |
129 | ### Decoding
130 |
131 | The following is the command used to generate outputs:
132 |
133 | ```[bash]
134 | #!/usr/bin/env bash
135 | SRLPATH=/PATH/TO/SRLCONLL
136 | TAGGERPATH=/PATH/TO/TAGGER
137 | DATAPATH=/PATH/TO/DATA
138 | EMBPATH=/PATH/TO/GLOVE_EMBEDDING
139 | DEVICE=0
140 |
141 | python $TAGGERPATH/tagger/bin/predictor.py \
142 | --input $DATAPATH/conll05.test.wsj.txt \
143 | --checkpoint train/best \
144 | --model deepatt \
145 | --vocab $DATAPATH/deep_srl/word_dict $DATAPATH/deep_srl/label_dict \
146 | --parameters=device=$DEVICE,embedding=$EMBPATH/glove.6B.100d.txt \
147 | --output tmp.txt
148 |
149 | ```
150 |
151 | ## Benchmarks
152 |
153 | We've performed 4 runs on CoNLL-05 datasets. The results are shown below.
154 |
155 | | Runs | Dev-P | Dev-R | Dev-F1 | WSJ-P | WSJ-R | WSJ-F1 | BROWN-P | BROWN-R | BROWN-F1 |
156 | | :----: | :---: | :---: | :----: | :---: | :---: | :----: | :-----: | :-----: | :------: |
157 | | Paper | 82.6 | 83.6 | 83.1 | 84.5 | 85.2 | 84.8 | 73.5 | 74.6 | 74.1 |
158 | | Run0 | 82.9 | 83.7 | 83.3 | 84.6 | 85.0 | 84.8 | 73.5 | 74.0 | 73.8 |
159 | | Run1 | 82.3 | 83.4 | 82.9 | 84.4 | 85.3 | 84.8 | 72.5 | 73.9 | 73.2 |
160 | | Run2 | 82.7 | 83.6 | 83.2 | 84.8 | 85.4 | 85.1 | 73.2 | 73.9 | 73.6 |
161 | | Run3 | 82.3 | 83.6 | 82.9 | 84.3 | 84.9 | 84.6 | 72.3 | 73.6 | 72.9 |
162 |
163 | ## Pretrained Models
164 |
165 | The pretrained models of TensorFlow implementation can be downloaded at [Google Drive](https://drive.google.com/open?id=1jvBlpOmqGdZEqnFrdWJkH1xHsGU2OjiP).
166 |
167 | ## LICENSE
168 |
169 | BSD
170 |
171 | ## Citation
172 |
173 | If you use our codes, please cite our paper:
174 |
175 | ```
176 | @inproceedings{tan2018deep,
177 | title = {Deep Semantic Role Labeling with Self-Attention},
178 | author = {Tan, Zhixing and Wang, Mingxuan and Xie, Jun and Chen, Yidong and Shi, Xiaodong},
179 | booktitle = {AAAI Conference on Artificial Intelligence},
180 | year = {2018}
181 | }
182 | ```
183 |
184 | ## Contact
185 |
186 | This code is written by Zhixing Tan. If you have any problems, feel free to send an email.
187 |
--------------------------------------------------------------------------------
/tagger/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XMUNLP/Tagger/02e1fd323ac747bfe5f7b8824c6b416fd90f33a1/tagger/__init__.py
--------------------------------------------------------------------------------
/tagger/bin/predictor.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import argparse
9 | import logging
10 | import os
11 | import six
12 | import time
13 | import torch
14 |
15 | import tagger.data as data
16 | import tagger.models as models
17 | import tagger.utils as utils
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser(
22 | description="Predict using SRL models",
23 | usage="translator.py [] [-h | --help]"
24 | )
25 |
26 | # input files
27 | parser.add_argument("--input", type=str, required=True,
28 | help="Path of input file")
29 | parser.add_argument("--output", type=str, required=True,
30 | help="Path of output file")
31 | parser.add_argument("--checkpoint", type=str, required=True,
32 | help="Path of trained models")
33 | parser.add_argument("--vocabulary", type=str, nargs=2, required=True,
34 | help="Path of source and target vocabulary")
35 |
36 | # model and configuration
37 | parser.add_argument("--model", type=str, required=True,
38 | help="Name of the model")
39 | parser.add_argument("--parameters", type=str, default="",
40 | help="Additional hyper parameters")
41 | parser.add_argument("--half", action="store_true",
42 | help="Use half precision for decoding")
43 |
44 | return parser.parse_args()
45 |
46 |
47 | def default_params():
48 | params = utils.HParams(
49 | input=None,
50 | output=None,
51 | vocabulary=None,
52 | embedding="",
53 | # vocabulary specific
54 | pad="",
55 | bos="",
56 | eos="",
57 | unk="",
58 | device=0,
59 | decode_batch_size=128
60 | )
61 |
62 | return params
63 |
64 |
65 | def merge_params(params1, params2):
66 | params = utils.HParams()
67 |
68 | for (k, v) in six.iteritems(params1.values()):
69 | params.add_hparam(k, v)
70 |
71 | params_dict = params.values()
72 |
73 | for (k, v) in six.iteritems(params2.values()):
74 | if k in params_dict:
75 | # Override
76 | setattr(params, k, v)
77 | else:
78 | params.add_hparam(k, v)
79 |
80 | return params
81 |
82 |
83 | def import_params(model_dir, model_name, params):
84 | model_dir = os.path.abspath(model_dir)
85 | m_name = os.path.join(model_dir, model_name + ".json")
86 |
87 | if not os.path.exists(m_name):
88 | return params
89 |
90 | with open(m_name) as fd:
91 | logging.info("Restoring model parameters from %s" % m_name)
92 | json_str = fd.readline()
93 | params.parse_json(json_str)
94 |
95 | return params
96 |
97 |
98 | def override_params(params, args):
99 | params.parse(args.parameters)
100 |
101 | src_vocab, src_w2idx, src_idx2w = data.load_vocabulary(args.vocabulary[0])
102 | tgt_vocab, tgt_w2idx, tgt_idx2w = data.load_vocabulary(args.vocabulary[1])
103 |
104 | params.vocabulary = {
105 | "source": src_vocab, "target": tgt_vocab
106 | }
107 | params.lookup = {
108 | "source": src_w2idx, "target": tgt_w2idx
109 | }
110 | params.mapping = {
111 | "source": src_idx2w, "target": tgt_idx2w
112 | }
113 |
114 | return params
115 |
116 |
117 | def convert_to_string(inputs, tensor, params):
118 | inputs = torch.squeeze(inputs)
119 | inputs = inputs.tolist()
120 | tensor = torch.squeeze(tensor, dim=1)
121 | tensor = tensor.tolist()
122 | decoded = []
123 |
124 | for wids, lids in zip(inputs, tensor):
125 | output = []
126 | for wid, lid in zip(wids, lids):
127 | if wid == 0:
128 | break
129 | output.append(params.mapping["target"][lid])
130 | decoded.append(b" ".join(output))
131 |
132 | return decoded
133 |
134 |
135 | def main(args):
136 | # Load configs
137 | model_cls = models.get_model(args.model)
138 | params = default_params()
139 | params = merge_params(params, model_cls.default_params())
140 | params = import_params(args.checkpoint, args.model, params)
141 | params = override_params(params, args)
142 | torch.cuda.set_device(params.device)
143 | torch.set_default_tensor_type(torch.cuda.FloatTensor)
144 |
145 | # Create model
146 | with torch.no_grad():
147 | model = model_cls(params).cuda()
148 |
149 | if args.half:
150 | model = model.half()
151 | torch.set_default_tensor_type(torch.cuda.HalfTensor)
152 |
153 | model.eval()
154 | model.load_state_dict(
155 | torch.load(utils.best_checkpoint(args.checkpoint),
156 | map_location="cpu")["model"])
157 |
158 | # Decoding
159 | dataset = data.get_dataset(args.input, "infer", params)
160 | fd = open(args.output, "wb")
161 | counter = 0
162 |
163 | if params.embedding is not None:
164 | embedding = data.load_glove_embedding(params.embedding)
165 | else:
166 | embedding = None
167 |
168 | for features in dataset:
169 | t = time.time()
170 | counter += 1
171 | features = data.lookup(features, "infer", params, embedding)
172 |
173 | labels = model.argmax_decode(features)
174 | batch = convert_to_string(features["inputs"], labels, params)
175 |
176 | for seq in batch:
177 | fd.write(seq)
178 | fd.write(b"\n")
179 |
180 | t = time.time() - t
181 | print("Finished batch: %d (%.3f sec)" % (counter, t))
182 |
183 | fd.close()
184 |
185 |
186 | if __name__ == "__main__":
187 | main(parse_args())
188 |
--------------------------------------------------------------------------------
/tagger/bin/trainer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import argparse
9 | import copy
10 | import glob
11 | import logging
12 | import os
13 | import re
14 | import six
15 | import socket
16 | import threading
17 | import time
18 | import torch
19 |
20 | import tagger.data as data
21 | import torch.distributed as dist
22 | import tagger.models as models
23 | import tagger.optimizers as optimizers
24 | import tagger.utils as utils
25 | import tagger.utils.summary as summary
26 | from tagger.utils.validation import ValidationWorker
27 |
28 |
29 | def parse_args(args=None):
30 | parser = argparse.ArgumentParser(
31 | description="Training SRL tagger",
32 | usage="trainer.py [] [-h | --help]"
33 | )
34 |
35 | # input files
36 | parser.add_argument("--input", type=str,
37 | help="Path of the training corpus")
38 | parser.add_argument("--output", type=str, default="train",
39 | help="Path to saved models")
40 | parser.add_argument("--vocabulary", type=str, nargs=2,
41 | help="Path of source and target vocabulary")
42 | parser.add_argument("--checkpoint", type=str,
43 | help="Path to pre-trained checkpoint")
44 | parser.add_argument("--distributed", action="store_true",
45 | help="Enable distributed training mode")
46 | parser.add_argument("--local_rank", type=int,
47 | help="Local rank of this process")
48 | parser.add_argument("--half", action="store_true",
49 | help="Enable mixed precision training")
50 | parser.add_argument("--hparam_set", type=str,
51 | help="Name of pre-defined hyper parameter set")
52 |
53 | # model and configuration
54 | parser.add_argument("--model", type=str, required=True,
55 | help="Name of the model")
56 | parser.add_argument("--parameters", type=str, default="",
57 | help="Additional hyper parameters")
58 |
59 | return parser.parse_args(args)
60 |
61 |
62 | def default_params():
63 | params = utils.HParams(
64 | input="",
65 | output="",
66 | model="transformer",
67 | vocab=["", ""],
68 | pad="",
69 | bos="",
70 | eos="",
71 | unk="",
72 | # Dataset
73 | batch_size=4096,
74 | fixed_batch_size=False,
75 | min_length=1,
76 | max_length=256,
77 | buffer_size=10000,
78 | # Initialization
79 | initializer_gain=1.0,
80 | initializer="uniform_unit_scaling",
81 | # Regularization
82 | scale_l1=0.0,
83 | scale_l2=0.0,
84 | # Training
85 | script="",
86 | warmup_steps=4000,
87 | train_steps=100000,
88 | update_cycle=1,
89 | optimizer="Adam",
90 | adam_beta1=0.9,
91 | adam_beta2=0.999,
92 | adam_epsilon=1e-8,
93 | adadelta_rho=0.95,
94 | adadelta_epsilon=1e-6,
95 | clipping="global_norm",
96 | clip_grad_norm=5.0,
97 | learning_rate=1.0,
98 | learning_rate_schedule="linear_warmup_rsqrt_decay",
99 | learning_rate_boundaries=[0],
100 | learning_rate_values=[0.0],
101 | device_list=[0],
102 | embedding="",
103 | # Validation
104 | keep_top_k=50,
105 | frequency=10,
106 | # Checkpoint Saving
107 | keep_checkpoint_max=20,
108 | keep_top_checkpoint_max=5,
109 | save_summary=True,
110 | save_checkpoint_secs=0,
111 | save_checkpoint_steps=1000,
112 | )
113 |
114 | return params
115 |
116 |
117 | def import_params(model_dir, model_name, params):
118 | model_dir = os.path.abspath(model_dir)
119 | p_name = os.path.join(model_dir, "params.json")
120 | m_name = os.path.join(model_dir, model_name + ".json")
121 |
122 | if not os.path.exists(p_name) or not os.path.exists(m_name):
123 | return params
124 |
125 | with open(p_name) as fd:
126 | logging.info("Restoring hyper parameters from %s" % p_name)
127 | json_str = fd.readline()
128 | params.parse_json(json_str)
129 |
130 | with open(m_name) as fd:
131 | logging.info("Restoring model parameters from %s" % m_name)
132 | json_str = fd.readline()
133 | params.parse_json(json_str)
134 |
135 | return params
136 |
137 |
138 | def export_params(output_dir, name, params):
139 | if not os.path.exists(output_dir):
140 | os.makedirs(output_dir)
141 |
142 | # Save params as params.json
143 | filename = os.path.join(output_dir, name)
144 |
145 | with open(filename, "w") as fd:
146 | fd.write(params.to_json())
147 |
148 |
149 | def merge_params(params1, params2):
150 | params = utils.HParams()
151 |
152 | for (k, v) in six.iteritems(params1.values()):
153 | params.add_hparam(k, v)
154 |
155 | params_dict = params.values()
156 |
157 | for (k, v) in six.iteritems(params2.values()):
158 | if k in params_dict:
159 | # Override
160 | setattr(params, k, v)
161 | else:
162 | params.add_hparam(k, v)
163 |
164 | return params
165 |
166 |
167 | def override_params(params, args):
168 | params.model = args.model or params.model
169 | params.input = args.input or params.input
170 | params.output = args.output or params.output
171 | params.vocab = args.vocabulary or params.vocab
172 | params.parse(args.parameters)
173 |
174 | src_vocab, src_w2idx, src_idx2w = data.load_vocabulary(params.vocab[0])
175 | tgt_vocab, tgt_w2idx, tgt_idx2w = data.load_vocabulary(params.vocab[1])
176 |
177 | params.vocabulary = {
178 | "source": src_vocab, "target": tgt_vocab
179 | }
180 | params.lookup = {
181 | "source": src_w2idx, "target": tgt_w2idx
182 | }
183 | params.mapping = {
184 | "source": src_idx2w, "target": tgt_idx2w
185 | }
186 |
187 | return params
188 |
189 |
190 | def collect_params(all_params, params):
191 | collected = utils.HParams()
192 |
193 | for k in six.iterkeys(params.values()):
194 | collected.add_hparam(k, getattr(all_params, k))
195 |
196 | return collected
197 |
198 |
199 | def print_variables(model):
200 | weights = {v[0]: v[1] for v in model.named_parameters()}
201 | total_size = 0
202 |
203 | for name in sorted(list(weights)):
204 | v = weights[name]
205 | print("%s %s" % (name.ljust(60), str(list(v.shape)).rjust(15)))
206 | total_size += v.nelement()
207 |
208 | print("Total trainable variables size: %d" % total_size)
209 |
210 |
211 | def save_checkpoint(step, epoch, model, optimizer, params):
212 | if dist.get_rank() == 0:
213 | state = {
214 | "step": step,
215 | "epoch": epoch,
216 | "model": model.state_dict(),
217 | "optimizer": optimizer.state_dict()
218 | }
219 | utils.save(state, params.output, params.keep_checkpoint_max)
220 |
221 |
222 | def infer_gpu_num(param_str):
223 | result = re.match(r".*device_list=\[(.*?)\].*", param_str)
224 |
225 | if not result:
226 | return 1
227 | else:
228 | dev_str = result.groups()[-1]
229 | return len(dev_str.split(","))
230 |
231 |
232 | def get_clipper(params):
233 | if params.clipping.lower() == "none":
234 | clipper = None
235 | elif params.clipping.lower() == "adaptive":
236 | clipper = optimizers.adaptive_clipper(0.95)
237 | elif params.clipping.lower() == "global_norm":
238 | clipper = optimizers.global_norm_clipper(params.clip_grad_norm)
239 | else:
240 | raise ValueError("Unknown clipper %s" % params.clipping)
241 |
242 | return clipper
243 |
244 |
245 | def get_learning_rate_schedule(params):
246 | if params.learning_rate_schedule == "linear_warmup_rsqrt_decay":
247 | schedule = optimizers.LinearWarmupRsqrtDecay(params.learning_rate,
248 | params.warmup_steps)
249 | elif params.learning_rate_schedule == "piecewise_constant_decay":
250 | schedule = optimizers.PiecewiseConstantDecay(
251 | params.learning_rate_boundaries, params.learning_rate_values)
252 | elif params.learning_rate_schedule == "linear_exponential_decay":
253 | schedule = optimizers.LinearExponentialDecay(params.learning_rate,
254 | params.warmup_steps, params.start_decay_step,
255 | params.end_decay_step,
256 | dist.get_world_size())
257 | else:
258 | raise ValueError("Unknown schedule %s" % params.learning_rate_schedule)
259 |
260 | return schedule
261 |
262 |
263 | def broadcast(model):
264 | for var in model.parameters():
265 | dist.broadcast(var.data, 0)
266 |
267 |
268 | def main(args):
269 | model_cls = models.get_model(args.model)
270 |
271 | # Import and override parameters
272 | # Priorities (low -> high):
273 | # default -> saved -> command
274 | params = default_params()
275 | params = merge_params(params, model_cls.default_params(args.hparam_set))
276 | params = import_params(args.output, args.model, params)
277 | params = override_params(params, args)
278 |
279 | # Initialize distributed utility
280 | if args.distributed:
281 | dist.init_process_group("nccl")
282 | torch.cuda.set_device(args.local_rank)
283 | else:
284 | dist.init_process_group("nccl", init_method=args.url,
285 | rank=args.local_rank,
286 | world_size=len(params.device_list))
287 | torch.cuda.set_device(params.device_list[args.local_rank])
288 | torch.set_default_tensor_type(torch.cuda.FloatTensor)
289 |
290 | # Export parameters
291 | if dist.get_rank() == 0:
292 | export_params(params.output, "params.json", params)
293 | export_params(params.output, "%s.json" % params.model,
294 | collect_params(params, model_cls.default_params()))
295 |
296 | model = model_cls(params).cuda()
297 | model.load_embedding(params.embedding)
298 |
299 | if args.half:
300 | model = model.half()
301 | torch.set_default_dtype(torch.half)
302 | torch.set_default_tensor_type(torch.cuda.HalfTensor)
303 |
304 | model.train()
305 |
306 | # Init tensorboard
307 | summary.init(params.output, params.save_summary)
308 | schedule = get_learning_rate_schedule(params)
309 | clipper = get_clipper(params)
310 |
311 | if params.optimizer.lower() == "adam":
312 | optimizer = optimizers.AdamOptimizer(learning_rate=schedule,
313 | beta_1=params.adam_beta1,
314 | beta_2=params.adam_beta2,
315 | epsilon=params.adam_epsilon,
316 | clipper=clipper)
317 | elif params.optimizer.lower() == "adadelta":
318 | optimizer = optimizers.AdadeltaOptimizer(
319 | learning_rate=schedule, rho=params.adadelta_rho,
320 | epsilon=params.adadelta_epsilon, clipper=clipper)
321 | else:
322 | raise ValueError("Unknown optimizer %s" % params.optimizer)
323 |
324 | if args.half:
325 | optimizer = optimizers.LossScalingOptimizer(optimizer)
326 |
327 | optimizer = optimizers.MultiStepOptimizer(optimizer, params.update_cycle)
328 |
329 | if dist.get_rank() == 0:
330 | print_variables(model)
331 |
332 | dataset = data.get_dataset(params.input, "train", params)
333 |
334 | # Load checkpoint
335 | checkpoint = utils.latest_checkpoint(params.output)
336 |
337 | if checkpoint is not None:
338 | state = torch.load(checkpoint, map_location="cpu")
339 | step = state["step"]
340 | epoch = state["epoch"]
341 | model.load_state_dict(state["model"])
342 |
343 | if "optimizer" in state:
344 | optimizer.load_state_dict(state["optimizer"])
345 | else:
346 | step = 0
347 | epoch = 0
348 | broadcast(model)
349 |
350 | def train_fn(inputs):
351 | features, labels = inputs
352 | loss = model(features, labels)
353 | return loss
354 |
355 | counter = 0
356 | should_save = False
357 |
358 | if params.script:
359 | thread = ValidationWorker(daemon=True)
360 | thread.init(params)
361 | thread.start()
362 | else:
363 | thread = None
364 |
365 | def step_fn(features, step):
366 | t = time.time()
367 | features = data.lookup(features, "train", params)
368 | loss = train_fn(features)
369 | gradients = optimizer.compute_gradients(loss,
370 | list(model.parameters()))
371 | if params.clip_grad_norm:
372 | torch.nn.utils.clip_grad_norm_(model.parameters(),
373 | params.clip_grad_norm)
374 |
375 | optimizer.apply_gradients(zip(gradients,
376 | list(model.named_parameters())))
377 |
378 | t = time.time() - t
379 |
380 | summary.scalar("loss", loss, step, write_every_n_steps=1)
381 | summary.scalar("global_step/sec", t, step)
382 |
383 | print("epoch = %d, step = %d, loss = %.3f (%.3f sec)" %
384 | (epoch + 1, step, float(loss), t))
385 |
386 | try:
387 | while True:
388 | for features in dataset:
389 | if counter % params.update_cycle == 0:
390 | step += 1
391 | utils.set_global_step(step)
392 | should_save = True
393 |
394 | counter += 1
395 | step_fn(features, step)
396 |
397 | if step % params.save_checkpoint_steps == 0:
398 | if should_save:
399 | save_checkpoint(step, epoch, model, optimizer, params)
400 | should_save = False
401 |
402 | if step >= params.train_steps:
403 | if should_save:
404 | save_checkpoint(step, epoch, model, optimizer, params)
405 |
406 | if dist.get_rank() == 0:
407 | summary.close()
408 |
409 | return
410 |
411 | epoch += 1
412 | finally:
413 | if thread is not None:
414 | thread.stop()
415 | thread.join()
416 |
417 |
418 | # Wrap main function
419 | def process_fn(rank, args):
420 | local_args = copy.copy(args)
421 | local_args.local_rank = rank
422 | main(local_args)
423 |
424 |
425 | if __name__ == "__main__":
426 | parsed_args = parse_args()
427 |
428 | if parsed_args.distributed:
429 | main(parsed_args)
430 | else:
431 | # Pick a free port
432 | with socket.socket() as s:
433 | s.bind(("localhost", 0))
434 | port = s.getsockname()[1]
435 | url = "tcp://localhost:" + str(port)
436 | parsed_args.url = url
437 |
438 | world_size = infer_gpu_num(parsed_args.parameters)
439 |
440 | if world_size > 1:
441 | torch.multiprocessing.spawn(process_fn, args=(parsed_args,),
442 | nprocs=world_size)
443 | else:
444 | process_fn(0, parsed_args)
445 |
--------------------------------------------------------------------------------
/tagger/data/__init__.py:
--------------------------------------------------------------------------------
1 | from tagger.data.dataset import get_dataset
2 | from tagger.data.vocab import load_vocabulary, lookup
3 | from tagger.data.embedding import load_glove_embedding
4 |
--------------------------------------------------------------------------------
/tagger/data/dataset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import queue
9 | import torch
10 | import threading
11 | import tensorflow as tf
12 |
13 |
14 | _QUEUE = None
15 | _THREAD = None
16 | _LOCK = threading.Lock()
17 |
18 |
19 | def build_input_fn(filename, mode, params):
20 | def train_input_fn():
21 | dataset = tf.data.TextLineDataset(filename)
22 | dataset = dataset.prefetch(params.buffer_size)
23 | dataset = dataset.shuffle(params.buffer_size)
24 |
25 | # Split "|||"
26 | dataset = dataset.map(
27 | lambda x: tf.strings.split([x], sep="|||", maxsplit=2),
28 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
29 | dataset = dataset.map(
30 | lambda x: (x.values[0], x.values[1]),
31 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
32 | dataset = dataset.map(
33 | lambda x, y: (tf.strings.split([x]).values,
34 | tf.strings.split([y]).values),
35 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
36 | dataset = dataset.map(
37 | lambda x, y: ({
38 | "preds": tf.strings.to_number(x[0], tf.int32),
39 | "inputs": tf.strings.lower(x[1:])
40 | }, y),
41 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
42 | dataset = dataset.map(
43 | lambda x, y: ({
44 | "preds": tf.one_hot(x["preds"], tf.shape(x["inputs"])[0],
45 | dtype=tf.int32),
46 | "inputs": x["inputs"]
47 | }, y),
48 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
49 |
50 | def bucket_boundaries(max_length, min_length=8, step=8):
51 | x = min_length
52 | boundaries = []
53 |
54 | while x <= max_length:
55 | boundaries.append(x + 1)
56 | x += step
57 |
58 | return boundaries
59 |
60 | batch_size = params.batch_size
61 | max_length = (params.max_length // 8) * 8
62 | min_length = params.min_length
63 | boundaries = bucket_boundaries(max_length)
64 | batch_sizes = [max(1, batch_size // (x - 1))
65 | if not params.fixed_batch_size else batch_size
66 | for x in boundaries] + [1]
67 |
68 | def element_length_func(x, y):
69 | return tf.shape(x["inputs"])[0]
70 |
71 | def valid_size(x, y):
72 | size = element_length_func(x, y)
73 | return tf.logical_and(size >= min_length, size <= max_length)
74 |
75 | transformation_fn = tf.data.experimental.bucket_by_sequence_length(
76 | element_length_func,
77 | boundaries,
78 | batch_sizes,
79 | padded_shapes=({
80 | "inputs": tf.TensorShape([None]),
81 | "preds": tf.TensorShape([None]),
82 | }, tf.TensorShape([None])),
83 | padding_values=({
84 | "inputs": params.pad,
85 | "preds": 0,
86 | }, params.pad),
87 | pad_to_bucket_boundary=True)
88 |
89 | dataset = dataset.filter(valid_size)
90 | dataset = dataset.apply(transformation_fn)
91 |
92 | return dataset
93 |
94 |
95 | def infer_input_fn():
96 | dataset = tf.data.TextLineDataset(filename)
97 |
98 | # Split "|||"
99 | dataset = dataset.map(
100 | lambda x: tf.strings.split([x], sep="|||", maxsplit=2),
101 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
102 | dataset = dataset.map(
103 | lambda x: (x.values[0], x.values[1]),
104 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
105 | dataset = dataset.map(
106 | lambda x, y: (tf.strings.split([x]).values,
107 | tf.strings.split([y]).values),
108 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
109 | dataset = dataset.map(
110 | lambda x, y: ({
111 | "preds": tf.strings.to_number(x[0], tf.int32),
112 | "inputs": tf.strings.lower(x[1:])
113 | }, y),
114 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
115 | dataset = dataset.map(
116 | lambda x, y: ({
117 | "preds": tf.one_hot(x["preds"], tf.shape(x["inputs"])[0],
118 | dtype=tf.int32),
119 | "inputs": x["inputs"]
120 | }, y),
121 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
122 |
123 | dataset = dataset.padded_batch(
124 | params.decode_batch_size,
125 | padded_shapes=({
126 | "inputs": tf.TensorShape([None]),
127 | "preds": tf.TensorShape([None]),
128 | }, tf.TensorShape([None])),
129 | padding_values=({
130 | "inputs": params.pad,
131 | "preds": 0,
132 | }, params.pad),
133 | )
134 |
135 | return dataset
136 |
137 | if mode == "train":
138 | return train_input_fn
139 | else:
140 | return infer_input_fn
141 |
142 |
143 | class DatasetWorker(threading.Thread):
144 |
145 | def init(self, dataset):
146 | self._dataset = dataset
147 | self._stop = False
148 |
149 | def run(self):
150 | global _QUEUE
151 | global _LOCK
152 |
153 | while not self._stop:
154 | for feature in self._dataset:
155 | _QUEUE.put(feature)
156 |
157 | def stop(self):
158 | self._stop = True
159 |
160 |
161 | class Dataset(object):
162 |
163 | def __iter__(self):
164 | return self
165 |
166 | def __next__(self):
167 | global _QUEUE
168 | return _QUEUE.get()
169 |
170 | def stop(self):
171 | global _THREAD
172 | _THREAD.stop()
173 | _THREAD.join()
174 |
175 |
176 | def get_dataset(filenames, mode, params):
177 | global _QUEUE
178 | global _THREAD
179 |
180 | input_fn = build_input_fn(filenames, mode, params)
181 |
182 | with tf.device("/cpu:0"):
183 | dataset = input_fn()
184 |
185 | if mode != "train":
186 | return dataset
187 | else:
188 | _QUEUE = queue.Queue(100)
189 | thread = DatasetWorker(daemon=True)
190 | thread.init(dataset)
191 | thread.start()
192 | _THREAD = thread
193 | return Dataset()
194 |
--------------------------------------------------------------------------------
/tagger/data/embedding.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import numpy as np
9 |
10 |
11 | def load_glove_embedding(filename, vocab=None):
12 | fd = open(filename, "r")
13 | emb = {}
14 | fan_out = 0
15 |
16 | for line in fd:
17 | items = line.strip().split()
18 | word = items[0].encode("utf-8")
19 | value = [float(item) for item in items[1:]]
20 | fan_out = len(value)
21 | emb[word] = np.array(value, "float32")
22 |
23 | if not vocab:
24 | return emb
25 |
26 | ivoc = {}
27 |
28 | for item in vocab:
29 | ivoc[vocab[item]] = item
30 |
31 | new_emb = np.zeros([len(ivoc), fan_out], "float32")
32 |
33 | for i in ivoc:
34 | word = ivoc[i]
35 | if word not in emb:
36 | fan_in = len(ivoc)
37 | scale = 3.0 / max(1.0, (fan_in + fan_out) / 2.0)
38 | new_emb[i] = np.random.uniform(-scale, scale, [fan_out])
39 | else:
40 | new_emb[i] = emb[word]
41 |
42 | return new_emb
43 |
--------------------------------------------------------------------------------
/tagger/data/vocab.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import torch
9 | import numpy as np
10 |
11 |
12 | def _lookup(x, vocab, embedding=None, feature_size=0):
13 | x = x.tolist()
14 | y = []
15 | unk_mask = []
16 | embeddings = []
17 |
18 | for _, batch in enumerate(x):
19 | ids = []
20 | mask = []
21 | emb = []
22 |
23 | for _, v in enumerate(batch):
24 | if v in vocab:
25 | ids.append(vocab[v])
26 | mask.append(1.0)
27 |
28 | if embedding is not None:
29 | emb.append(np.zeros([feature_size]))
30 | else:
31 | ids.append(2)
32 |
33 | if embedding is not None and v in embedding:
34 | mask.append(0.0)
35 | emb.append(embedding[v])
36 | else:
37 | mask.append(1.0)
38 | emb.append(np.zeros([feature_size]))
39 |
40 | y.append(ids)
41 | unk_mask.append(mask)
42 | embeddings.append(emb)
43 |
44 | ids = torch.LongTensor(np.array(y, dtype="int32")).cuda()
45 | mask = torch.Tensor(np.array(unk_mask, dtype="float32")).cuda()
46 |
47 | if embedding is not None:
48 | emb = torch.Tensor(np.array(embeddings, dtype="float32")).cuda()
49 | else:
50 | emb = None
51 |
52 | return ids, mask, emb
53 |
54 |
55 | def load_vocabulary(filename):
56 | vocab = []
57 | with open(filename, "rb") as fd:
58 | for line in fd:
59 | vocab.append(line.strip())
60 |
61 | word2idx = {}
62 | idx2word = {}
63 |
64 | for idx, word in enumerate(vocab):
65 | word2idx[word] = idx
66 | idx2word[idx] = word
67 |
68 | return vocab, word2idx, idx2word
69 |
70 |
71 | def lookup(inputs, mode, params, embedding=None):
72 | if mode == "train":
73 | features, labels = inputs
74 | preds, seqs = features["preds"], features["inputs"]
75 | preds = torch.LongTensor(preds.numpy()).cuda()
76 | seqs = seqs.numpy()
77 | labels = labels.numpy()
78 |
79 | seqs, _, _ = _lookup(seqs, params.lookup["source"])
80 | labels, _, _ = _lookup(labels, params.lookup["target"])
81 |
82 | features = {
83 | "preds": preds,
84 | "inputs": seqs
85 | }
86 |
87 | return features, labels
88 | else:
89 | features, _ = inputs
90 | preds, seqs = features["preds"], features["inputs"]
91 | preds = torch.LongTensor(preds.numpy()).cuda()
92 | seqs = seqs.numpy()
93 |
94 | seqs, unk_mask, emb = _lookup(seqs, params.lookup["source"], embedding,
95 | params.feature_size)
96 |
97 | features = {
98 | "preds": preds,
99 | "inputs": seqs,
100 | "mask": unk_mask
101 | }
102 |
103 | if emb is not None:
104 | features["embedding"] = emb
105 |
106 | return features
107 |
--------------------------------------------------------------------------------
/tagger/models/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import tagger.models.deepatt
9 |
10 |
11 | def get_model(name):
12 | name = name.lower()
13 |
14 | if name == "deepatt":
15 | return tagger.models.deepatt.DeepAtt
16 | else:
17 | raise LookupError("Unknown model %s" % name)
18 |
--------------------------------------------------------------------------------
/tagger/models/deepatt.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import math
9 | import torch
10 | import torch.nn as nn
11 |
12 | import tagger.utils as utils
13 | import tagger.modules as modules
14 |
15 | from tagger.data import load_glove_embedding
16 |
17 |
18 | class AttentionSubLayer(modules.Module):
19 |
20 | def __init__(self, params, name="attention"):
21 | super(AttentionSubLayer, self).__init__(name=name)
22 |
23 | with utils.scope(name):
24 | self.attention = modules.MultiHeadAttention(
25 | params.hidden_size, params.num_heads, params.attention_dropout)
26 | self.layer_norm = modules.LayerNorm(params.hidden_size)
27 |
28 | self.dropout = params.residual_dropout
29 |
30 | def forward(self, x, bias):
31 | y = self.attention(x, bias)
32 | y = nn.functional.dropout(y, self.dropout, self.training)
33 |
34 | return self.layer_norm(x + y)
35 |
36 |
37 | class FFNSubLayer(modules.Module):
38 |
39 | def __init__(self, params, dtype=None, name="ffn_layer"):
40 | super(FFNSubLayer, self).__init__(name=name)
41 |
42 | with utils.scope(name):
43 | self.ffn_layer = modules.FeedForward(params.hidden_size,
44 | params.filter_size,
45 | dropout=params.relu_dropout)
46 | self.layer_norm = modules.LayerNorm(params.hidden_size)
47 | self.dropout = params.residual_dropout
48 |
49 | def forward(self, x):
50 | y = self.ffn_layer(x)
51 | y = nn.functional.dropout(y, self.dropout, self.training)
52 |
53 | return self.layer_norm(x + y)
54 |
55 |
56 | class DeepAttEncoderLayer(modules.Module):
57 |
58 | def __init__(self, params, name="layer"):
59 | super(DeepAttEncoderLayer, self).__init__(name=name)
60 |
61 | with utils.scope(name):
62 | self.self_attention = AttentionSubLayer(params)
63 | self.feed_forward = FFNSubLayer(params)
64 |
65 | def forward(self, x, bias):
66 | x = self.feed_forward(x)
67 | x = self.self_attention(x, bias)
68 | return x
69 |
70 |
71 | class DeepAttEncoder(modules.Module):
72 |
73 | def __init__(self, params, name="encoder"):
74 | super(DeepAttEncoder, self).__init__(name=name)
75 |
76 | with utils.scope(name):
77 | self.layers = nn.ModuleList([
78 | DeepAttEncoderLayer(params, name="layer_%d" % i)
79 | for i in range(params.num_hidden_layers)])
80 |
81 | def forward(self, x, bias):
82 | for layer in self.layers:
83 | x = layer(x, bias)
84 | return x
85 |
86 |
87 | class DeepAtt(modules.Module):
88 |
89 | def __init__(self, params, name="deepatt"):
90 | super(DeepAtt, self).__init__(name=name)
91 | self.params = params
92 |
93 | with utils.scope(name):
94 | self.build_embedding(params)
95 | self.encoding = modules.PositionalEmbedding()
96 | self.encoder = DeepAttEncoder(params)
97 | self.classifier = modules.Affine(params.hidden_size,
98 | len(params.vocabulary["target"]),
99 | name="softmax")
100 |
101 | self.criterion = modules.SmoothedCrossEntropyLoss(
102 | params.label_smoothing)
103 | self.dropout = params.residual_dropout
104 | self.hidden_size = params.hidden_size
105 | self.reset_parameters()
106 |
107 | def build_embedding(self, params):
108 | vocab_size = len(params.vocabulary["source"])
109 |
110 | self.embedding = torch.nn.Parameter(
111 | torch.empty([vocab_size, params.feature_size]))
112 | self.weights = torch.nn.Parameter(
113 | torch.empty([2, params.feature_size]))
114 | self.bias = torch.nn.Parameter(torch.zeros([params.hidden_size]))
115 | self.add_name(self.embedding, "embedding")
116 | self.add_name(self.weights, "weights")
117 | self.add_name(self.bias, "bias")
118 |
119 | def reset_parameters(self):
120 | nn.init.normal_(self.embedding, mean=0.0,
121 | std=self.params.feature_size ** -0.5)
122 | nn.init.normal_(self.weights, mean=0.0,
123 | std=self.params.feature_size ** -0.5)
124 | nn.init.normal_(self.classifier.weight, mean=0.0,
125 | std=self.params.hidden_size ** -0.5)
126 | nn.init.zeros_(self.classifier.bias)
127 |
128 | def encode(self, features):
129 | seq = features["inputs"]
130 | pred = features["preds"]
131 | mask = torch.ne(seq, 0).float().cuda()
132 | enc_attn_bias = self.masking_bias(mask)
133 |
134 | inputs = torch.nn.functional.embedding(seq, self.embedding)
135 |
136 | if "embedding" in features and not self.training:
137 | embedding = features["embedding"]
138 | unk_mask = features["mask"].to(mask)[:, :, None]
139 | inputs = inputs * unk_mask + (1.0 - unk_mask) * embedding
140 |
141 | preds = torch.nn.functional.embedding(pred, self.weights)
142 | inputs = torch.cat([inputs, preds], axis=-1)
143 | inputs = inputs * (self.hidden_size ** 0.5)
144 | inputs = inputs + self.bias
145 |
146 | inputs = nn.functional.dropout(self.encoding(inputs), self.dropout,
147 | self.training)
148 |
149 | enc_attn_bias = enc_attn_bias.to(inputs)
150 | encoder_output = self.encoder(inputs, enc_attn_bias)
151 | logits = self.classifier(encoder_output)
152 |
153 | return logits
154 |
155 | def argmax_decode(self, features):
156 | logits = self.encode(features)
157 | return torch.argmax(logits, -1)
158 |
159 | def forward(self, features, labels):
160 | mask = torch.ne(features["inputs"], 0).float().cuda()
161 | logits = self.encode(features)
162 | loss = self.criterion(logits, labels)
163 | mask = mask.to(logits)
164 |
165 | return torch.sum(loss * mask) / torch.sum(mask)
166 |
167 | def load_embedding(self, path):
168 | if not path:
169 | return
170 | emb = load_glove_embedding(path, self.params.lookup["source"])
171 |
172 | with torch.no_grad():
173 | self.embedding.copy_(torch.tensor(emb))
174 |
175 | @staticmethod
176 | def masking_bias(mask, inf=-1e9):
177 | ret = (1.0 - mask) * inf
178 | return torch.unsqueeze(torch.unsqueeze(ret, 1), 1)
179 |
180 | @staticmethod
181 | def base_params():
182 | params = utils.HParams(
183 | pad="",
184 | bos="",
185 | eos="",
186 | unk="",
187 | feature_size=100,
188 | hidden_size=200,
189 | filter_size=800,
190 | num_heads=8,
191 | num_hidden_layers=10,
192 | attention_dropout=0.0,
193 | residual_dropout=0.1,
194 | relu_dropout=0.0,
195 | label_smoothing=0.1,
196 | clip_grad_norm=0.0
197 | )
198 |
199 | return params
200 |
201 | @staticmethod
202 | def default_params(name=None):
203 | return DeepAtt.base_params()
204 |
--------------------------------------------------------------------------------
/tagger/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from tagger.modules.attention import MultiHeadAttention
2 | from tagger.modules.embedding import PositionalEmbedding
3 | from tagger.modules.feed_forward import FeedForward
4 | from tagger.modules.layer_norm import LayerNorm
5 | from tagger.modules.losses import SmoothedCrossEntropyLoss
6 | from tagger.modules.module import Module
7 | from tagger.modules.affine import Affine
8 | from tagger.modules.recurrent import LSTMCell, GRUCell, HighwayLSTMCell, DynamicLSTMCell
9 |
--------------------------------------------------------------------------------
/tagger/modules/affine.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import math
9 | import torch
10 | import torch.nn as nn
11 |
12 | import tagger.utils as utils
13 | from tagger.modules.module import Module
14 |
15 |
16 | class Affine(Module):
17 |
18 | def __init__(self, in_features, out_features, bias=True, name="affine"):
19 | super(Affine, self).__init__(name=name)
20 | self.in_features = in_features
21 | self.out_features = out_features
22 |
23 | with utils.scope(name):
24 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
25 | self.add_name(self.weight, "weight")
26 | if bias:
27 | self.bias = nn.Parameter(torch.Tensor(out_features))
28 | self.add_name(self.bias, "bias")
29 | else:
30 | self.register_parameter('bias', None)
31 |
32 | self.reset_parameters()
33 |
34 | def reset_parameters(self):
35 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
36 | if self.bias is not None:
37 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
38 | bound = 1 / math.sqrt(fan_in)
39 | nn.init.uniform_(self.bias, -bound, bound)
40 |
41 | def orthogonal_initialize(self, gain=1.0):
42 | nn.init.orthogonal_(self.weight, gain)
43 | nn.init.zeros_(self.bias)
44 |
45 | def forward(self, input):
46 | return nn.functional.linear(input, self.weight, self.bias)
47 |
48 | def extra_repr(self):
49 | return 'in_features={}, out_features={}, bias={}'.format(
50 | self.in_features, self.out_features, self.bias is not None
51 | )
52 |
--------------------------------------------------------------------------------
/tagger/modules/attention.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import torch
9 | import torch.nn as nn
10 | import tagger.utils as utils
11 |
12 | from tagger.modules.module import Module
13 | from tagger.modules.affine import Affine
14 |
15 |
16 | class MultiHeadAttention(Module):
17 |
18 | def __init__(self, hidden_size, num_heads, dropout=0.0,
19 | name="multihead_attention"):
20 | super(MultiHeadAttention, self).__init__(name=name)
21 |
22 | self.num_heads = num_heads
23 | self.hidden_size = hidden_size
24 | self.dropout = dropout
25 |
26 | with utils.scope(name):
27 | self.qkv_transform = Affine(hidden_size, 3 * hidden_size,
28 | name="qkv_transform")
29 | self.o_transform = Affine(hidden_size, hidden_size,
30 | name="o_transform")
31 |
32 | self.reset_parameters()
33 |
34 | def forward(self, query, bias):
35 | qkv = self.qkv_transform(query)
36 | q, k, v = torch.split(qkv, self.hidden_size, dim=-1)
37 |
38 | # split heads
39 | qh = self.split_heads(q, self.num_heads)
40 | kh = self.split_heads(k, self.num_heads)
41 | vh = self.split_heads(v, self.num_heads)
42 |
43 | # scale query
44 | qh = qh * (self.hidden_size // self.num_heads) ** -0.5
45 |
46 | # dot-product attention
47 | kh = torch.transpose(kh, -2, -1)
48 | logits = torch.matmul(qh, kh)
49 |
50 | if bias is not None:
51 | logits = logits + bias
52 |
53 | weights = torch.nn.functional.dropout(torch.softmax(logits, dim=-1),
54 | p=self.dropout,
55 | training=self.training)
56 |
57 | x = torch.matmul(weights, vh)
58 |
59 | # combine heads
60 | output = self.o_transform(self.combine_heads(x))
61 |
62 | return output
63 |
64 | def reset_parameters(self, initializer="orthogonal"):
65 | if initializer == "orthogonal":
66 | self.qkv_transform.orthogonal_initialize()
67 | self.o_transform.orthogonal_initialize()
68 | else:
69 | # 6 / (4 * hidden_size) -> 6 / (2 * hidden_size)
70 | nn.init.xavier_uniform_(self.qkv_transform.weight)
71 | nn.init.xavier_uniform_(self.o_transform.weight)
72 | nn.init.constant_(self.qkv_transform.bias, 0.0)
73 | nn.init.constant_(self.o_transform.bias, 0.0)
74 |
75 | @staticmethod
76 | def split_heads(x, heads):
77 | batch = x.shape[0]
78 | length = x.shape[1]
79 | channels = x.shape[2]
80 |
81 | y = torch.reshape(x, [batch, length, heads, channels // heads])
82 | return torch.transpose(y, 2, 1)
83 |
84 | @staticmethod
85 | def combine_heads(x):
86 | batch = x.shape[0]
87 | heads = x.shape[1]
88 | length = x.shape[2]
89 | channels = x.shape[3]
90 |
91 | y = torch.transpose(x, 2, 1)
92 |
93 | return torch.reshape(y, [batch, length, heads * channels])
94 |
--------------------------------------------------------------------------------
/tagger/modules/embedding.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import math
9 | import torch
10 |
11 |
12 | class PositionalEmbedding(torch.nn.Module):
13 |
14 | def __init__(self):
15 | super(PositionalEmbedding, self).__init__()
16 |
17 | def forward(self, inputs):
18 | if inputs.dim() != 3:
19 | raise ValueError("The rank of input must be 3.")
20 |
21 | length = inputs.shape[1]
22 | channels = inputs.shape[2]
23 | half_dim = channels // 2
24 |
25 | positions = torch.arange(length, dtype=inputs.dtype,
26 | device=inputs.device)
27 | dimensions = torch.arange(half_dim, dtype=inputs.dtype,
28 | device=inputs.device)
29 |
30 | scale = math.log(10000.0) / float(half_dim - 1)
31 | dimensions.mul_(-scale).exp_()
32 |
33 | scaled_time = positions.unsqueeze(1) * dimensions.unsqueeze(0)
34 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],
35 | dim=1)
36 |
37 | if channels % 2 == 1:
38 | pad = torch.zeros([signal.shape[0], 1], dtype=inputs.dtype,
39 | device=inputs.device)
40 | signal = torch.cat([signal, pad], axis=1)
41 |
42 | return inputs + torch.reshape(signal, [1, -1, channels]).to(inputs)
43 |
--------------------------------------------------------------------------------
/tagger/modules/feed_forward.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import torch
9 | import torch.nn as nn
10 | import tagger.utils as utils
11 |
12 | from tagger.modules.module import Module
13 | from tagger.modules.affine import Affine
14 |
15 |
16 | class FeedForward(Module):
17 |
18 | def __init__(self, input_size, hidden_size, output_size=None, dropout=0.0,
19 | name="feed_forward"):
20 | super(FeedForward, self).__init__(name=name)
21 |
22 | self.input_size = input_size
23 | self.hidden_size = hidden_size
24 | self.output_size = output_size or input_size
25 | self.dropout = dropout
26 |
27 | with utils.scope(name):
28 | self.input_transform = Affine(input_size, hidden_size,
29 | name="input_transform")
30 | self.output_transform = Affine(hidden_size, self.output_size,
31 | name="output_transform")
32 |
33 | self.reset_parameters()
34 |
35 | def forward(self, x):
36 | h = nn.functional.relu(self.input_transform(x))
37 | h = nn.functional.dropout(h, self.dropout, self.training)
38 | return self.output_transform(h)
39 |
40 | def reset_parameters(self, initializer="orthogonal"):
41 | if initializer == "orthogonal":
42 | self.input_transform.orthogonal_initialize()
43 | self.output_transform.orthogonal_initialize()
44 | else:
45 | nn.init.xavier_uniform_(self.input_transform.weight)
46 | nn.init.xavier_uniform_(self.output_transform.weight)
47 | nn.init.constant_(self.input_transform.bias, 0.0)
48 | nn.init.constant_(self.output_transform.bias, 0.0)
49 |
--------------------------------------------------------------------------------
/tagger/modules/layer_norm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import numbers
9 | import torch
10 | import torch.nn as nn
11 | import tagger.utils as utils
12 |
13 | from tagger.modules.module import Module
14 |
15 |
16 | class LayerNorm(Module):
17 |
18 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True,
19 | name="layer_norm"):
20 | super(LayerNorm, self).__init__(name=name)
21 | if isinstance(normalized_shape, numbers.Integral):
22 | normalized_shape = (normalized_shape,)
23 | self.normalized_shape = tuple(normalized_shape)
24 | self.eps = eps
25 | self.elementwise_affine = elementwise_affine
26 |
27 | with utils.scope(name):
28 | if self.elementwise_affine:
29 | self.weight = nn.Parameter(torch.Tensor(*normalized_shape))
30 | self.bias = nn.Parameter(torch.Tensor(*normalized_shape))
31 | self.add_name(self.weight, "weight")
32 | self.add_name(self.bias, "bias")
33 | else:
34 | self.register_parameter('weight', None)
35 | self.register_parameter('bias', None)
36 | self.reset_parameters()
37 |
38 | def reset_parameters(self):
39 | if self.elementwise_affine:
40 | nn.init.ones_(self.weight)
41 | nn.init.zeros_(self.bias)
42 |
43 | def forward(self, input):
44 | return nn.functional.layer_norm(
45 | input, self.normalized_shape, self.weight, self.bias, self.eps)
46 |
47 | def extra_repr(self):
48 | return '{normalized_shape}, eps={eps}, ' \
49 | 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
50 |
--------------------------------------------------------------------------------
/tagger/modules/losses.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import math
9 | import torch
10 |
11 |
12 | class SmoothedCrossEntropyLoss(torch.nn.Module):
13 |
14 | def __init__(self, smoothing=0.0, normalize=True):
15 | super(SmoothedCrossEntropyLoss, self).__init__()
16 | self.smoothing = smoothing
17 | self.normalize = normalize
18 |
19 | def forward(self, logits, labels):
20 | shape = labels.shape
21 | logits = torch.reshape(logits, [-1, logits.shape[-1]])
22 | labels = torch.reshape(labels, [-1])
23 |
24 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
25 | batch_idx = torch.arange(labels.shape[0], device=logits.device)
26 | loss = log_probs[batch_idx, labels]
27 |
28 | if not self.smoothing:
29 | return -torch.reshape(loss, shape)
30 |
31 | n = logits.shape[-1] - 1.0
32 | p = 1.0 - self.smoothing
33 | q = self.smoothing / n
34 |
35 | if log_probs.dtype != torch.float16:
36 | sum_probs = torch.sum(log_probs, dim=-1)
37 | loss = p * loss + q * (sum_probs - loss)
38 | else:
39 | # Prevent FP16 overflow
40 | sum_probs = torch.sum(log_probs.to(torch.float32), dim=-1)
41 | loss = loss.to(torch.float32)
42 | loss = p * loss + q * (sum_probs - loss)
43 | loss = loss.to(torch.float16)
44 |
45 | loss = -torch.reshape(loss, shape)
46 |
47 | if self.normalize:
48 | normalizing = -(p * math.log(p) + n * q * math.log(q + 1e-20))
49 | return loss - normalizing
50 | else:
51 | return loss
52 |
--------------------------------------------------------------------------------
/tagger/modules/module.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | import tagger.utils as utils
12 |
13 |
14 | class Module(nn.Module):
15 |
16 | def __init__(self, name=""):
17 | super(Module, self).__init__()
18 | scope = utils.get_scope()
19 | self._name = scope + "/" + name if scope else name
20 |
21 | def add_name(self, tensor, name):
22 | tensor.tensor_name = utils.unique_name(name)
23 |
24 | @property
25 | def name(self):
26 | return self._name
27 |
--------------------------------------------------------------------------------
/tagger/modules/recurrent.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2020 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | import tagger.utils as utils
12 |
13 | from tagger.modules.module import Module
14 | from tagger.modules.affine import Affine
15 | from tagger.modules.layer_norm import LayerNorm
16 |
17 |
18 | class GRUCell(Module):
19 |
20 | def __init__(self, input_size, output_size, normalization=False,
21 | name="gru"):
22 | super(GRUCell, self).__init__(name=name)
23 |
24 | self.input_size = input_size
25 | self.output_size = output_size
26 |
27 | with utils.scope(name):
28 | self.reset_gate = Affine(input_size + output_size, output_size,
29 | bias=False, name="reset_gate")
30 | self.update_gate = Affine(input_size + output_size, output_size,
31 | bias=False, name="update_gate")
32 | self.transform = Affine(input_size + output_size, output_size,
33 | name="transform")
34 |
35 | def forward(self, x, h):
36 | r = torch.sigmoid(self.reset_gate(torch.cat([x, h], -1)))
37 | u = torch.sigmoid(self.update_gate(torch.cat([x, h], -1)))
38 | c = self.transform(torch.cat([x, r * h], -1))
39 |
40 | new_h = (1.0 - u) * h + u * torch.tanh(h)
41 |
42 | return new_h, new_h
43 |
44 | def init_state(self, batch_size, dtype, device):
45 | h = torch.zeros([batch_size, self.output_size], dtype=dtype,
46 | device=device)
47 | return h
48 |
49 | def mask_state(self, h, prev_h, mask):
50 | mask = mask[:, None]
51 | new_h = mask * h + (1.0 - mask) * prev_h
52 | return new_h
53 |
54 | def reset_parameters(self, initializer="uniform"):
55 | if initializer == "uniform_scaling":
56 | nn.init.xavier_uniform_(self.gates.weight)
57 | nn.init.constant_(self.gates.bias, 0.0)
58 | elif initializer == "uniform":
59 | nn.init.uniform_(self.gates.weight, -0.08, 0.08)
60 | nn.init.uniform_(self.gates.bias, -0.08, 0.08)
61 | else:
62 | raise ValueError("Unknown initializer %d" % initializer)
63 |
64 |
65 | class LSTMCell(Module):
66 |
67 | def __init__(self, input_size, output_size, normalization=False,
68 | activation=torch.tanh, name="lstm"):
69 | super(LSTMCell, self).__init__(name=name)
70 |
71 | self.input_size = input_size
72 | self.output_size = output_size
73 | self.activation = activation
74 |
75 | with utils.scope(name):
76 | self.gates = Affine(input_size + output_size, 4 * output_size,
77 | name="gates")
78 | if normalization:
79 | self.layer_norm = LayerNorm([4, output_size])
80 | else:
81 | self.layer_norm = None
82 |
83 | self.reset_parameters()
84 |
85 | def forward(self, x, state):
86 | c, h = state
87 |
88 | gates = self.gates(torch.cat([x, h], 1))
89 |
90 | if self.layer_norm is not None:
91 | combined = self.layer_norm(
92 | torch.reshape(gates, [-1, 4, self.output_size]))
93 | else:
94 | combined = torch.reshape(gates, [-1, 4, self.output_size])
95 |
96 | i, j, f, o = torch.unbind(combined, 1)
97 | i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)
98 |
99 | new_c = f * c + i * torch.tanh(j)
100 |
101 | if self.activation is None:
102 | # Do not use tanh activation
103 | new_h = o * new_c
104 | else:
105 | new_h = o * self.activation(new_c)
106 |
107 | return new_h, (new_c, new_h)
108 |
109 | def init_state(self, batch_size, dtype, device):
110 | c = torch.zeros([batch_size, self.output_size], dtype=dtype,
111 | device=device)
112 | h = torch.zeros([batch_size, self.output_size], dtype=dtype,
113 | device=device)
114 | return c, h
115 |
116 | def mask_state(self, state, prev_state, mask):
117 | c, h = state
118 | prev_c, prev_h = prev_state
119 | mask = mask[:, None]
120 | new_c = mask * c + (1.0 - mask) * prev_c
121 | new_h = mask * h + (1.0 - mask) * prev_h
122 | return new_c, new_h
123 |
124 | def reset_parameters(self, initializer="orthogonal"):
125 | if initializer == "uniform_scaling":
126 | nn.init.xavier_uniform_(self.gates.weight)
127 | nn.init.constant_(self.gates.bias, 0.0)
128 | elif initializer == "uniform":
129 | nn.init.uniform_(self.gates.weight, -0.04, 0.04)
130 | nn.init.uniform_(self.gates.bias, -0.04, 0.04)
131 | elif initializer == "orthogonal":
132 | self.gates.orthogonal_initialize()
133 | else:
134 | raise ValueError("Unknown initializer %d" % initializer)
135 |
136 |
137 |
138 | class HighwayLSTMCell(Module):
139 |
140 | def __init__(self, input_size, output_size, name="lstm"):
141 | super(HighwayLSTMCell, self).__init__(name=name)
142 |
143 | self.input_size = input_size
144 | self.output_size = output_size
145 |
146 | with utils.scope(name):
147 | self.gates = Affine(input_size + output_size, 5 * output_size,
148 | name="gates")
149 | self.trans = Affine(input_size, output_size, name="trans")
150 |
151 | self.reset_parameters()
152 |
153 | def forward(self, x, state):
154 | c, h = state
155 |
156 | gates = self.gates(torch.cat([x, h], 1))
157 | combined = torch.reshape(gates, [-1, 5, self.output_size])
158 | i, j, f, o, t = torch.unbind(combined, 1)
159 | i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)
160 | t = torch.sigmoid(t)
161 |
162 | new_c = f * c + i * torch.tanh(j)
163 | tmp_h = o * torch.tanh(new_c)
164 | new_h = t * tmp_h + (1.0 - t) * self.trans(x)
165 |
166 | return new_h, (new_c, new_h)
167 |
168 | def init_state(self, batch_size, dtype, device):
169 | c = torch.zeros([batch_size, self.output_size], dtype=dtype,
170 | device=device)
171 | h = torch.zeros([batch_size, self.output_size], dtype=dtype,
172 | device=device)
173 | return c, h
174 |
175 | def mask_state(self, state, prev_state, mask):
176 | c, h = state
177 | prev_c, prev_h = prev_state
178 | mask = mask[:, None]
179 | new_c = mask * c + (1.0 - mask) * prev_c
180 | new_h = mask * h + (1.0 - mask) * prev_h
181 | return new_c, new_h
182 |
183 | def reset_parameters(self, initializer="orthogonal"):
184 | if initializer == "uniform_scaling":
185 | nn.init.xavier_uniform_(self.gates.weight)
186 | nn.init.constant_(self.gates.bias, 0.0)
187 | elif initializer == "uniform":
188 | nn.init.uniform_(self.gates.weight, -0.04, 0.04)
189 | nn.init.uniform_(self.gates.bias, -0.04, 0.04)
190 | elif initializer == "orthogonal":
191 | self.gates.orthogonal_initialize()
192 | self.trans.orthogonal_initialize()
193 | else:
194 | raise ValueError("Unknown initializer %d" % initializer)
195 |
196 |
197 | class DynamicLSTMCell(Module):
198 |
199 | def __init__(self, input_size, output_size, k=2, num_cells=4, name="lstm"):
200 | super(DynamicLSTMCell, self).__init__(name=name)
201 |
202 | self.input_size = input_size
203 | self.output_size = output_size
204 | self.num_cells = num_cells
205 | self.k = k
206 |
207 | with utils.scope(name):
208 | self.gates = Affine(input_size + output_size,
209 | 4 * output_size * num_cells,
210 | name="gates")
211 | self.topk_gate = Affine(input_size + output_size,
212 | num_cells, name="controller")
213 |
214 |
215 | self.reset_parameters()
216 |
217 | @staticmethod
218 | def top_k_softmax(logits, k, n):
219 | top_logits, top_indices = torch.topk(logits, k=min(k + 1, n))
220 |
221 | top_k_logits = top_logits[:, :k]
222 | top_k_indices = top_indices[:, :k]
223 |
224 | probs = torch.softmax(top_k_logits, dim=-1)
225 | batch = top_k_logits.shape[0]
226 | k = top_k_logits.shape[1]
227 |
228 | # Flat to 1D
229 | indices_flat = torch.reshape(top_k_indices, [-1])
230 | indices_flat = indices_flat + torch.div(
231 | torch.arange(batch * k, device=logits.device), k) * n
232 |
233 | tensor = torch.zeros([batch * n], dtype=logits.dtype,
234 | device=logits.device)
235 | tensor = tensor.scatter_add(0, indices_flat.long(),
236 | torch.reshape(probs, [-1]))
237 |
238 | return torch.reshape(tensor, [batch, n])
239 |
240 | def forward(self, x, state):
241 | c, h = state
242 | feats = torch.cat([x, h], dim=-1)
243 |
244 | logits = self.topk_gate(feats)
245 | # [batch, num_cells]
246 | gate = self.top_k_softmax(logits, self.k, self.num_cells)
247 |
248 | # [batch, 4 * num_cells * dim]
249 | combined = self.gates(feats)
250 | combined = torch.reshape(combined,
251 | [-1, self.num_cells, 4, self.output_size])
252 |
253 | i, j, f, o = torch.unbind(combined, 2)
254 | i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)
255 |
256 | # [batch, num_cells, dim]
257 | new_c = f * c[:, None, :] + i * torch.tanh(j)
258 | new_h = o * torch.tanh(new_c)
259 |
260 | gate = gate[:, None, :]
261 | new_c = torch.matmul(gate, new_c)
262 | new_h = torch.matmul(gate, new_h)
263 |
264 | new_c = torch.squeeze(new_c, 1)
265 | new_h = torch.squeeze(new_h, 1)
266 |
267 | return new_h, (new_c, new_h)
268 |
269 | def init_state(self, batch_size, dtype, device):
270 | c = torch.zeros([batch_size, self.output_size], dtype=dtype,
271 | device=device)
272 | h = torch.zeros([batch_size, self.output_size], dtype=dtype,
273 | device=device)
274 | return c, h
275 |
276 | def mask_state(self, state, prev_state, mask):
277 | c, h = state
278 | prev_c, prev_h = prev_state
279 | mask = mask[:, None]
280 | new_c = mask * c + (1.0 - mask) * prev_c
281 | new_h = mask * h + (1.0 - mask) * prev_h
282 | return new_c, new_h
283 |
284 | def reset_parameters(self, initializer="orthogonal"):
285 | if initializer == "uniform_scaling":
286 | nn.init.xavier_uniform_(self.gates.weight)
287 | nn.init.constant_(self.gates.bias, 0.0)
288 | elif initializer == "uniform":
289 | nn.init.uniform_(self.gates.weight, -0.04, 0.04)
290 | nn.init.uniform_(self.gates.bias, -0.04, 0.04)
291 | elif initializer == "orthogonal":
292 | weight = self.gates.weight.view(
293 | [self.input_size + self.output_size, self.num_cells,
294 | 4 * self.output_size])
295 | nn.init.orthogonal_(weight, 1.0)
296 | nn.init.constant_(self.gates.bias, 0.0)
297 | else:
298 | raise ValueError("Unknown initializer %d" % initializer)
299 |
--------------------------------------------------------------------------------
/tagger/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from tagger.optimizers.optimizers import AdamOptimizer
2 | from tagger.optimizers.optimizers import AdadeltaOptimizer
3 | from tagger.optimizers.optimizers import MultiStepOptimizer
4 | from tagger.optimizers.optimizers import LossScalingOptimizer
5 | from tagger.optimizers.schedules import LinearWarmupRsqrtDecay
6 | from tagger.optimizers.schedules import PiecewiseConstantDecay
7 | from tagger.optimizers.clipping import (
8 | adaptive_clipper, global_norm_clipper, value_clipper)
9 |
--------------------------------------------------------------------------------
/tagger/optimizers/clipping.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2020 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import math
9 |
10 |
11 | def global_norm_clipper(value):
12 | def clip_fn(gradients, grad_norm):
13 | if not float(value) or grad_norm < value:
14 | return False, gradients
15 |
16 | scale = value / grad_norm
17 |
18 | gradients = [grad.data.mul_(scale)
19 | if grad is not None else None for grad in gradients]
20 |
21 | return False, gradients
22 |
23 | return clip_fn
24 |
25 |
26 | def value_clipper(clip_min, clip_max):
27 | def clip_fn(gradients, grad_norm):
28 | gradients = [
29 | grad.data.clamp_(clip_min, clip_max)
30 | if grad is not None else None for grad in gradients]
31 |
32 | return False, None
33 |
34 | return clip_fn
35 |
36 |
37 | def adaptive_clipper(rho):
38 | norm_avg = 0.0
39 | norm_stddev = 0.0
40 | log_norm_avg = 0.0
41 | log_norm_sqr = 0.0
42 |
43 | def clip_fn(gradients, grad_norm):
44 | nonlocal norm_avg
45 | nonlocal norm_stddev
46 | nonlocal log_norm_avg
47 | nonlocal log_norm_sqr
48 |
49 | norm = grad_norm
50 | log_norm = math.log(norm)
51 |
52 | avg = rho * norm_avg + (1.0 - rho) * norm
53 | log_avg = rho * log_norm_avg + (1.0 - rho) * log_norm
54 | log_sqr = rho * log_norm_sqr + (1.0 - rho) * (log_norm ** 2)
55 | stddev = (log_sqr - (log_avg ** 2)) ** -0.5
56 |
57 | norm_avg = avg
58 | log_norm_avg = log_avg
59 | log_norm_sqr = log_sqr
60 | norm_stddev = rho * stddev + (1.0 - rho) * stddev
61 |
62 | reject = False
63 |
64 | if norm > norm_avg + 4 * math.exp(norm_stddev):
65 | reject = True
66 |
67 | return reject, gradients
68 |
69 | return clip_fn
70 |
--------------------------------------------------------------------------------
/tagger/optimizers/optimizers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import math
9 | import torch
10 | import torch.distributed as dist
11 | import tagger.utils as utils
12 | import tagger.utils.summary as summary
13 |
14 | from tagger.optimizers.schedules import LearningRateSchedule
15 |
16 |
17 | def _save_summary(grads_and_vars):
18 | total_norm = 0.0
19 |
20 | for grad, var in grads_and_vars:
21 | if grad is None:
22 | continue
23 |
24 | _, var = var
25 | grad_norm = grad.data.norm()
26 | total_norm += grad_norm ** 2
27 | summary.histogram(var.tensor_name, var,
28 | utils.get_global_step())
29 | summary.scalar("norm/" + var.tensor_name, var.norm(),
30 | utils.get_global_step())
31 | summary.scalar("grad_norm/" + var.tensor_name, grad_norm,
32 | utils.get_global_step())
33 |
34 | total_norm = total_norm ** 0.5
35 | summary.scalar("grad_norm", total_norm, utils.get_global_step())
36 |
37 | return float(total_norm)
38 |
39 |
40 | def _compute_grad_norm(gradients):
41 | total_norm = 0.0
42 |
43 | for grad in gradients:
44 | total_norm += float(grad.data.norm() ** 2)
45 |
46 | return float(total_norm ** 0.5)
47 |
48 |
49 | class Optimizer(object):
50 |
51 | def __init__(self, name, **kwargs):
52 | self._name = name
53 | self._iterations = 0
54 | self._slots = {}
55 |
56 | def detach_gradients(self, gradients):
57 | for grad in gradients:
58 | if grad is not None:
59 | grad.detach_()
60 |
61 | def scale_gradients(self, gradients, scale):
62 | for grad in gradients:
63 | if grad is not None:
64 | grad.mul_(scale)
65 |
66 | def sync_gradients(self, gradients, compress=True):
67 | grad_vec = torch.nn.utils.parameters_to_vector(gradients)
68 |
69 | if compress:
70 | grad_vec_half = grad_vec.half()
71 | dist.all_reduce(grad_vec_half)
72 | grad_vec = grad_vec_half.to(grad_vec)
73 | else:
74 | dist.all_reduce(grad_vec)
75 |
76 | torch.nn.utils.vector_to_parameters(grad_vec, gradients)
77 |
78 | def zero_gradients(self, gradients):
79 | for grad in gradients:
80 | if grad is not None:
81 | grad.zero_()
82 |
83 | def compute_gradients(self, loss, var_list, aggregate=False):
84 | var_list = list(var_list)
85 | grads = [v.grad if v is not None else None for v in var_list]
86 |
87 | self.detach_gradients(grads)
88 |
89 | if not aggregate:
90 | self.zero_gradients(grads)
91 |
92 | loss.backward()
93 | return [v.grad if v is not None else None for v in var_list]
94 |
95 | def apply_gradients(self, grads_and_vars):
96 | raise NotImplementedError("Not implemented")
97 |
98 | @property
99 | def iterations(self):
100 | return self._iterations
101 |
102 | def state_dict(self):
103 | raise NotImplementedError("Not implemented")
104 |
105 | def load_state_dict(self):
106 | raise NotImplementedError("Not implemented")
107 |
108 |
109 | class SGDOptimizer(Optimizer):
110 |
111 | def __init__(self, learning_rate, summaries=True, name="SGD", **kwargs):
112 | super(SGDOptimizer, self).__init__(name, **kwargs)
113 | self._learning_rate = learning_rate
114 | self._summaries = summaries
115 | self._clipper = None
116 |
117 | if "clipper" in kwargs and kwargs["clipper"] is not None:
118 | self._clipper = kwargs["clipper"]
119 |
120 | def apply_gradients(self, grads_and_vars):
121 | self._iterations += 1
122 | lr = self._learning_rate
123 | grads, var_list = list(zip(*grads_and_vars))
124 |
125 | if self._summaries:
126 | grad_norm = _save_summary(zip(grads, var_list))
127 | else:
128 | grad_norm = _compute_grad_norm(grads)
129 |
130 | if self._clipper is not None:
131 | reject, grads = self._clipper(grads, grad_norm)
132 |
133 | if reject:
134 | return
135 |
136 | for grad, var in zip(grads, var_list):
137 | if grad is None:
138 | continue
139 |
140 | # Convert if grad is not FP32
141 | grad = grad.data.float()
142 | _, var = var
143 | step_size = lr
144 |
145 | if var.dtype == torch.float32:
146 | var.data.add_(-step_size, grad)
147 | else:
148 | fp32_var = var.data.float()
149 | fp32_var.add_(-step_size, grad)
150 | var.data.copy_(fp32_var)
151 |
152 | def state_dict(self):
153 | state = {
154 | "iterations": self._iterations,
155 | }
156 |
157 | if not isinstance(self._learning_rate, LearningRateSchedule):
158 | state["learning_rate"] = self._learning_rate
159 |
160 | return state
161 |
162 | def load_state_dict(self, state):
163 | self._learning_rate = state.get("learning_rate", self._learning_rate)
164 | self._iterations = state.get("iterations", self._iterations)
165 |
166 |
167 | class AdamOptimizer(Optimizer):
168 |
169 | def __init__(self, learning_rate=0.01, beta_1=0.9, beta_2=0.999,
170 | epsilon=1e-7, name="Adam", **kwargs):
171 | super(AdamOptimizer, self).__init__(name, **kwargs)
172 | self._learning_rate = learning_rate
173 | self._beta_1 = beta_1
174 | self._beta_2 = beta_2
175 | self._epsilon = epsilon
176 | self._summaries = True
177 | self._clipper = None
178 |
179 | if "summaries" in kwargs and not kwargs["summaries"]:
180 | self._summaries = False
181 |
182 | if "clipper" in kwargs and kwargs["clipper"] is not None:
183 | self._clipper = kwargs["clipper"]
184 |
185 | def apply_gradients(self, grads_and_vars):
186 | self._iterations += 1
187 | lr = self._learning_rate
188 | beta_1 = self._beta_1
189 | beta_2 = self._beta_2
190 | epsilon = self._epsilon
191 | grads, var_list = list(zip(*grads_and_vars))
192 |
193 | if self._summaries:
194 | grad_norm = _save_summary(zip(grads, var_list))
195 | else:
196 | grad_norm = _compute_grad_norm(grads)
197 |
198 | if self._clipper is not None:
199 | reject, grads = self._clipper(grads, grad_norm)
200 |
201 | if reject:
202 | return
203 |
204 | for grad, var in zip(grads, var_list):
205 | if grad is None:
206 | continue
207 |
208 | # Convert if grad is not FP32
209 | grad = grad.data.float()
210 | name, var = var
211 |
212 | if self._slots.get(name, None) is None:
213 | self._slots[name] = {}
214 | self._slots[name]["m"] = torch.zeros_like(var.data,
215 | dtype=torch.float32)
216 | self._slots[name]["v"] = torch.zeros_like(var.data,
217 | dtype=torch.float32)
218 |
219 | m, v = self._slots[name]["m"], self._slots[name]["v"]
220 |
221 | bias_corr_1 = 1 - beta_1 ** self._iterations
222 | bias_corr_2 = 1 - beta_2 ** self._iterations
223 |
224 | m.mul_(beta_1).add_(1 - beta_1, grad)
225 | v.mul_(beta_2).addcmul_(1 - beta_2, grad, grad)
226 | denom = (v.sqrt() / math.sqrt(bias_corr_2)).add_(epsilon)
227 |
228 | if isinstance(lr, LearningRateSchedule):
229 | lr = lr(self._iterations)
230 |
231 | step_size = lr / bias_corr_1
232 |
233 | if var.dtype == torch.float32:
234 | var.data.addcdiv_(-step_size, m, denom)
235 | else:
236 | fp32_var = var.data.float()
237 | fp32_var.addcdiv_(-step_size, m, denom)
238 | var.data.copy_(fp32_var)
239 |
240 | def state_dict(self):
241 | state = {
242 | "beta_1": self._beta_1,
243 | "beta_2": self._beta_2,
244 | "epsilon": self._epsilon,
245 | "iterations": self._iterations,
246 | "slot": self._slots
247 | }
248 |
249 | if not isinstance(self._learning_rate, LearningRateSchedule):
250 | state["learning_rate"] = self._learning_rate
251 |
252 | return state
253 |
254 | def load_state_dict(self, state):
255 | self._learning_rate = state.get("learning_rate", self._learning_rate)
256 | self._beta_1 = state.get("beta_1", self._beta_1)
257 | self._beta_2 = state.get("beta_2", self._beta_2)
258 | self._epsilon = state.get("epsilon", self._epsilon)
259 | self._iterations = state.get("iterations", self._iterations)
260 |
261 | slots = state.get("slot", {})
262 | self._slots = {}
263 |
264 | for key in slots:
265 | m, v = slots[key]["m"], slots[key]["v"]
266 | self._slots[key] = {}
267 | self._slots[key]["m"] = torch.zeros(m.shape, dtype=torch.float32)
268 | self._slots[key]["v"] = torch.zeros(v.shape, dtype=torch.float32)
269 | self._slots[key]["m"].copy_(m)
270 | self._slots[key]["v"].copy_(v)
271 |
272 |
273 | class AdadeltaOptimizer(Optimizer):
274 |
275 | def __init__(self, learning_rate=0.001, rho=0.95, epsilon=1e-07,
276 | name="Adadelta", **kwargs):
277 | super(AdadeltaOptimizer, self).__init__(name, **kwargs)
278 | self._learning_rate = learning_rate
279 | self._rho = rho
280 | self._epsilon = epsilon
281 | self._summaries = True
282 |
283 | if "summaries" in kwargs and not kwargs["summaries"]:
284 | self._summaries = False
285 |
286 | if "clipper" in kwargs and kwargs["clipper"] is not None:
287 | self._clipper = kwargs["clipper"]
288 |
289 | def apply_gradients(self, grads_and_vars):
290 | self._iterations += 1
291 | lr = self._learning_rate
292 | rho = self._rho
293 | epsilon = self._epsilon
294 |
295 | grads, var_list = list(zip(*grads_and_vars))
296 |
297 | if self._summaries:
298 | grad_norm = _save_summary(zip(grads, var_list))
299 | else:
300 | grad_norm = _compute_grad_norm(grads)
301 |
302 | if self._clipper is not None:
303 | reject, grads = self._clipper(grads, grad_norm)
304 |
305 | if reject:
306 | return
307 |
308 | for grad, var in zip(grads, var_list):
309 | if grad is None:
310 | continue
311 |
312 | # Convert if grad is not FP32
313 | grad = grad.data.float()
314 | name, var = var
315 |
316 | if self._slots.get(name, None) is None:
317 | self._slots[name] = {}
318 | self._slots[name]["m"] = torch.zeros_like(var.data,
319 | dtype=torch.float32)
320 | self._slots[name]["v"] = torch.zeros_like(var.data,
321 | dtype=torch.float32)
322 |
323 | square_avg = self._slots[name]["m"]
324 | acc_delta = self._slots[name]["v"]
325 |
326 | if isinstance(lr, LearningRateSchedule):
327 | lr = lr(self._iterations)
328 |
329 | square_avg.mul_(rho).addcmul_(1 - rho, grad, grad)
330 | std = square_avg.add(epsilon).sqrt_()
331 | delta = acc_delta.add(epsilon).sqrt_().div_(std).mul_(grad)
332 | acc_delta.mul_(rho).addcmul_(1 - rho, delta, delta)
333 |
334 | if var.dtype == torch.float32:
335 | var.data.add_(-lr, delta)
336 | else:
337 | fp32_var = var.data.float()
338 | fp32_var.add_(-lr, delta)
339 | var.data.copy_(fp32_var)
340 |
341 | def state_dict(self):
342 | state = {
343 | "rho": self._rho,
344 | "epsilon": self._epsilon,
345 | "iterations": self._iterations,
346 | "slot": self._slots
347 | }
348 |
349 | if not isinstance(self._learning_rate, LearningRateSchedule):
350 | state["learning_rate"] = self._learning_rate
351 |
352 | return state
353 |
354 | def load_state_dict(self, state):
355 | self._learning_rate = state.get("learning_rate", self._learning_rate)
356 | self._rho = state.get("rho", self._rho)
357 | self._epsilon = state.get("epsilon", self._epsilon)
358 | self._iterations = state.get("iterations", self._iterations)
359 |
360 | slots = state.get("slot", {})
361 | self._slots = {}
362 |
363 | for key in slots:
364 | m, v = slots[key]["m"], slots[key]["v"]
365 | self._slots[key] = {}
366 | self._slots[key]["m"] = torch.zeros(m.shape, dtype=torch.float32)
367 | self._slots[key]["v"] = torch.zeros(v.shape, dtype=torch.float32)
368 | self._slots[key]["m"].copy_(m)
369 | self._slots[key]["v"].copy_(v)
370 |
371 |
372 | class LossScalingOptimizer(Optimizer):
373 |
374 | def __init__(self, optimizer, scale=2.0**7, increment_period=2000,
375 | multiplier=2.0, name="LossScalingOptimizer", **kwargs):
376 | super(LossScalingOptimizer, self).__init__(name, **kwargs)
377 | self._optimizer = optimizer
378 | self._scale = scale
379 | self._increment_period = increment_period
380 | self._multiplier = multiplier
381 | self._num_good_steps = 0
382 | self._summaries = True
383 |
384 | if "summaries" in kwargs and not kwargs["summaries"]:
385 | self._summaries = False
386 |
387 | def _update_if_finite_grads(self):
388 | if self._num_good_steps + 1 > self._increment_period:
389 | self._scale *= self._multiplier
390 | self._scale = min(self._scale, 2.0**16)
391 | self._num_good_steps = 0
392 | else:
393 | self._num_good_steps += 1
394 |
395 | def _update_if_not_finite_grads(self):
396 | self._scale = max(self._scale / self._multiplier, 1)
397 |
398 | def compute_gradients(self, loss, var_list, aggregate=False):
399 | var_list = list(var_list)
400 | grads = [v.grad if v is not None else None for v in var_list]
401 |
402 | self.detach_gradients(grads)
403 |
404 | if not aggregate:
405 | self.zero_gradients(grads)
406 |
407 | loss = loss * self._scale
408 | loss.backward()
409 |
410 | return [v.grad if v is not None else None for v in var_list]
411 |
412 | def apply_gradients(self, grads_and_vars):
413 | self._iterations += 1
414 | grads, var_list = list(zip(*grads_and_vars))
415 | new_grads = []
416 |
417 | if self._summaries:
418 | summary.scalar("optimizer/scale", self._scale,
419 | utils.get_global_step())
420 |
421 | for grad in grads:
422 | if grad is None:
423 | new_grads.append(None)
424 | continue
425 |
426 | norm = grad.data.norm()
427 |
428 | if not torch.isfinite(norm):
429 | self._update_if_not_finite_grads()
430 | return
431 | else:
432 | # Rescale gradients
433 | new_grads.append(grad.data.float().mul_(1.0 / self._scale))
434 |
435 | self._update_if_finite_grads()
436 | self._optimizer.apply_gradients(zip(new_grads, var_list))
437 |
438 | def state_dict(self):
439 | state = {
440 | "scale": self._scale,
441 | "increment_period": self._increment_period,
442 | "multiplier": self._multiplier,
443 | "num_good_steps": self._num_good_steps,
444 | "optimizer": self._optimizer.state_dict()
445 | }
446 | return state
447 |
448 | def load_state_dict(self, state):
449 | self._scale = state.get("scale", self._scale)
450 | self._increment_period = state.get("increment_period",
451 | self._increment_period)
452 | self._multiplier = state.get("multiplier", self._multiplier)
453 | self._num_good_steps = state.get("num_good_steps",
454 | self._num_good_steps)
455 | self._optimizer.load_state_dict(state.get("optimizer", {}))
456 |
457 |
458 | class MultiStepOptimizer(Optimizer):
459 |
460 | def __init__(self, optimizer, n=1, compress=True,
461 | name="MultiStepOptimizer", **kwargs):
462 | super(MultiStepOptimizer, self).__init__(name, **kwargs)
463 | self._n = n
464 | self._optimizer = optimizer
465 | self._compress = compress
466 |
467 | def compute_gradients(self, loss, var_list, aggregate=False):
468 | if self._iterations % self._n == 0:
469 | return self._optimizer.compute_gradients(loss, var_list, aggregate)
470 | else:
471 | return self._optimizer.compute_gradients(loss, var_list, True)
472 |
473 | def apply_gradients(self, grads_and_vars):
474 | size = dist.get_world_size()
475 | grads, var_list = list(zip(*grads_and_vars))
476 | self._iterations += 1
477 |
478 | if self._n == 1:
479 | if size > 1:
480 | self.sync_gradients(grads, compress=self._compress)
481 | self.scale_gradients(grads, 1.0 / size)
482 |
483 | self._optimizer.apply_gradients(zip(grads, var_list))
484 | else:
485 | if self._iterations % self._n != 0:
486 | return
487 |
488 | if size > 1:
489 | self.sync_gradients(grads, compress=self._compress)
490 |
491 | self.scale_gradients(grads, 1.0 / (self._n * size))
492 | self._optimizer.apply_gradients(zip(grads, var_list))
493 |
494 | def state_dict(self):
495 | state = {
496 | "n": self._n,
497 | "iterations": self._iterations,
498 | "compress": self._compress,
499 | "optimizer": self._optimizer.state_dict()
500 | }
501 | return state
502 |
503 | def load_state_dict(self, state):
504 | self._n = state.get("n", self._n)
505 | self._iterations = state.get("iterations", self._iterations)
506 | self._compress = state.get("compress", self._iterations)
507 | self._optimizer.load_state_dict(state.get("optimizer", {}))
508 |
--------------------------------------------------------------------------------
/tagger/optimizers/schedules.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 |
9 | import tagger.utils as utils
10 | import tagger.utils.summary as summary
11 |
12 |
13 | class LearningRateSchedule(object):
14 |
15 | def __call__(self, step):
16 | raise NotImplementedError("Not implemented.")
17 |
18 | def get_config(self):
19 | raise NotImplementedError("Not implemented.")
20 |
21 | @classmethod
22 | def from_config(cls, config):
23 | return cls(**config)
24 |
25 |
26 |
27 | class LinearWarmupRsqrtDecay(LearningRateSchedule):
28 |
29 | def __init__(self, learning_rate, warmup_steps, initial_learning_rate=0.0,
30 | summary=True):
31 | super(LinearWarmupRsqrtDecay, self).__init__()
32 |
33 | if not initial_learning_rate:
34 | initial_learning_rate = learning_rate / warmup_steps
35 |
36 | self._initial_learning_rate = initial_learning_rate
37 | self._maximum_learning_rate = learning_rate
38 | self._warmup_steps = warmup_steps
39 | self._summary = summary
40 |
41 | def __call__(self, step):
42 | if step <= self._warmup_steps:
43 | lr_step = self._maximum_learning_rate - self._initial_learning_rate
44 | lr_step /= self._warmup_steps
45 | lr = self._initial_learning_rate + lr_step * step
46 | else:
47 | step = step / self._warmup_steps
48 | lr = self._maximum_learning_rate * (step ** -0.5)
49 |
50 | if self._summary:
51 | summary.scalar("learning_rate", lr, utils.get_global_step())
52 |
53 | return lr
54 |
55 | def get_config(self):
56 | return {
57 | "learning_rate": self._maximum_learning_rate,
58 | "initial_learning_rate": self._initial_learning_rate,
59 | "warmup_steps": self._warmup_steps
60 | }
61 |
62 |
63 | class PiecewiseConstantDecay(LearningRateSchedule):
64 |
65 | def __init__(self, boundaries, values, summary=True, name=None):
66 | super(PiecewiseConstantDecay, self).__init__()
67 |
68 | if len(boundaries) != len(values) - 1:
69 | raise ValueError("The length of boundaries should be 1"
70 | " less than the length of values")
71 |
72 | self._boundaries = boundaries
73 | self._values = values
74 | self._summary = summary
75 |
76 | def __call__(self, step):
77 | boundaries = self._boundaries
78 | values = self._values
79 | learning_rate = values[0]
80 |
81 | if step <= boundaries[0]:
82 | learning_rate = values[0]
83 | elif step > boundaries[-1]:
84 | learning_rate = values[-1]
85 | else:
86 | for low, high, v in zip(boundaries[:-1], boundaries[1:],
87 | values[1:-1]):
88 |
89 | if step > low and step <= high:
90 | learning_rate = v
91 | break
92 |
93 | if self._summary:
94 | summary.scalar("learning_rate", learning_rate,
95 | utils.get_global_step())
96 |
97 | return learning_rate
98 |
99 | def get_config(self):
100 | return {
101 | "boundaries": self._boundaries,
102 | "values": self._values,
103 | }
104 |
105 |
106 | class LinearExponentialDecay(LearningRateSchedule):
107 |
108 | def __init__(self, learning_rate, warmup_steps, start_decay_step,
109 | end_decay_step, n, summary=True):
110 | super(LinearExponentialDecay, self).__init__()
111 |
112 | self._learning_rate = learning_rate
113 | self._warmup_steps = warmup_steps
114 | self._start_decay_step = start_decay_step
115 | self._end_decay_step = end_decay_step
116 | self._n = n
117 | self._summary = summary
118 |
119 | def __call__(self, step):
120 | # See reference: The Best of Both Worlds: Combining Recent Advances
121 | # in Neural Machine Translation
122 | n = self._n
123 | p = self._warmup_steps / n
124 | s = n * self._start_decay_step
125 | e = n * self._end_decay_step
126 |
127 | learning_rate = self._learning_rate
128 |
129 | learning_rate *= min(
130 | 1.0 + (n - 1) * step / float(n * p),
131 | n,
132 | n * ((2 * n) ** (float(s - n * step) / float(e - s))))
133 |
134 | if self._summary:
135 | summary.scalar("learning_rate", learning_rate,
136 | utils.get_global_step())
137 |
138 | return learning_rate
139 |
140 | def get_config(self):
141 | return {
142 | "learning_rate": self._learning_rate,
143 | "warmup_steps": self._warmup_steps,
144 | "start_decay_step": self._start_decay_step,
145 | "end_decay_step": self._end_decay_step,
146 | }
147 | class LearningRateSchedule(object):
148 |
149 | def __call__(self, step):
150 | raise NotImplementedError("Not implemented.")
151 |
152 | def get_config(self):
153 | raise NotImplementedError("Not implemented.")
154 |
155 | @classmethod
156 | def from_config(cls, config):
157 | return cls(**config)
158 |
159 |
160 |
161 | class LinearWarmupRsqrtDecay(LearningRateSchedule):
162 |
163 | def __init__(self, learning_rate, warmup_steps, initial_learning_rate=0.0,
164 | summary=True):
165 | super(LinearWarmupRsqrtDecay, self).__init__()
166 |
167 | if not initial_learning_rate:
168 | initial_learning_rate = learning_rate / warmup_steps
169 |
170 | self._initial_learning_rate = initial_learning_rate
171 | self._maximum_learning_rate = learning_rate
172 | self._warmup_steps = warmup_steps
173 | self._summary = summary
174 |
175 | def __call__(self, step):
176 | if step <= self._warmup_steps:
177 | lr_step = self._maximum_learning_rate - self._initial_learning_rate
178 | lr_step /= self._warmup_steps
179 | lr = self._initial_learning_rate + lr_step * step
180 | else:
181 | step = step / self._warmup_steps
182 | lr = self._maximum_learning_rate * (step ** -0.5)
183 |
184 | if self._summary:
185 | summary.scalar("learning_rate", lr, utils.get_global_step())
186 |
187 | return lr
188 |
189 | def get_config(self):
190 | return {
191 | "learning_rate": self._maximum_learning_rate,
192 | "initial_learning_rate": self._initial_learning_rate,
193 | "warmup_steps": self._warmup_steps
194 | }
195 |
196 |
197 | class PiecewiseConstantDecay(LearningRateSchedule):
198 |
199 | def __init__(self, boundaries, values, summary=True, name=None):
200 | super(PiecewiseConstantDecay, self).__init__()
201 |
202 | if len(boundaries) != len(values) - 1:
203 | raise ValueError("The length of boundaries should be 1"
204 | " less than the length of values")
205 |
206 | self._boundaries = boundaries
207 | self._values = values
208 | self._summary = summary
209 |
210 | def __call__(self, step):
211 | boundaries = self._boundaries
212 | values = self._values
213 | learning_rate = values[0]
214 |
215 | if step <= boundaries[0]:
216 | learning_rate = values[0]
217 | elif step > boundaries[-1]:
218 | learning_rate = values[-1]
219 | else:
220 | for low, high, v in zip(boundaries[:-1], boundaries[1:],
221 | values[1:-1]):
222 |
223 | if step > low and step <= high:
224 | learning_rate = v
225 | break
226 |
227 | if self._summary:
228 | summary.scalar("learning_rate", learning_rate,
229 | utils.get_global_step())
230 |
231 | return learning_rate
232 |
233 | def get_config(self):
234 | return {
235 | "boundaries": self._boundaries,
236 | "values": self._values,
237 | }
238 |
239 |
240 | class LinearExponentialDecay(LearningRateSchedule):
241 |
242 | def __init__(self, learning_rate, warmup_steps, start_decay_step,
243 | end_decay_step, n, summary=True):
244 | super(LinearExponentialDecay, self).__init__()
245 |
246 | self._learning_rate = learning_rate
247 | self._warmup_steps = warmup_steps
248 | self._start_decay_step = start_decay_step
249 | self._end_decay_step = end_decay_step
250 | self._n = n
251 | self._summary = summary
252 |
253 | def __call__(self, step):
254 | # See reference: The Best of Both Worlds: Combining Recent Advances
255 | # in Neural Machine Translation
256 | n = self._n
257 | p = self._warmup_steps / n
258 | s = n * self._start_decay_step
259 | e = n * self._end_decay_step
260 |
261 | learning_rate = self._learning_rate
262 |
263 | learning_rate *= min(
264 | 1.0 + (n - 1) * step / float(n * p),
265 | n,
266 | n * ((2 * n) ** (float(s - n * step) / float(e - s))))
267 |
268 | if self._summary:
269 | summary.scalar("learning_rate", learning_rate,
270 | utils.get_global_step())
271 |
272 | return learning_rate
273 |
274 | def get_config(self):
275 | return {
276 | "learning_rate": self._learning_rate,
277 | "warmup_steps": self._warmup_steps,
278 | "start_decay_step": self._start_decay_step,
279 | "end_decay_step": self._end_decay_step,
280 | }
281 |
--------------------------------------------------------------------------------
/tagger/scripts/build_vocab.py:
--------------------------------------------------------------------------------
1 | # build_vocab.py
2 | # author: Playinf
3 | # email: playinf@stu.xmu.edu.cn
4 |
5 | import argparse
6 | import collections
7 |
8 |
9 | def count_items(filename, lower=False):
10 | counter = collections.Counter()
11 | label_counter = collections.Counter()
12 |
13 | with open(filename, "r") as fd:
14 | for line in fd:
15 | words, labels = line.strip().split("|||")
16 | words = words.strip().split()
17 | labels = labels.strip().split()
18 |
19 | if lower:
20 | words = [item.lower() for item in words[1:]]
21 | else:
22 | words = words[1:]
23 |
24 | counter.update(words)
25 | label_counter.update(labels)
26 |
27 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
28 | words, counts = list(zip(*count_pairs))
29 | count_pairs = sorted(label_counter.items(), key=lambda x: (-x[1], x[0]))
30 | labels, _ = list(zip(*count_pairs))
31 |
32 | return words, labels, counts
33 |
34 |
35 | def special_tokens(string):
36 | if not string:
37 | return []
38 | else:
39 | return string.strip().split(":")
40 |
41 |
42 | def save_vocab(name, vocab):
43 | if name.split(".")[-1] != "txt":
44 | name = name + ".txt"
45 |
46 | pairs = sorted(vocab.items(), key=lambda x: (x[1], x[0]))
47 | words, ids = list(zip(*pairs))
48 |
49 | with open(name, "w") as f:
50 | for word in words:
51 | f.write(word + "\n")
52 |
53 |
54 | def write_vocab(name, vocab):
55 | with open(name, "w") as f:
56 | for word in vocab:
57 | f.write(word + "\n")
58 |
59 |
60 | def parse_args():
61 | msg = "build vocabulary"
62 | parser = argparse.ArgumentParser(description=msg)
63 |
64 | msg = "input corpus"
65 | parser.add_argument("corpus", help=msg)
66 | msg = "output vocabulary name"
67 | parser.add_argument("output", default="vocab.txt", help=msg)
68 | msg = "limit"
69 | parser.add_argument("--limit", default=0, type=int, help=msg)
70 | msg = "add special token, separated by colon"
71 | parser.add_argument("--special", type=str, default=":",
72 | help=msg)
73 | msg = "use lowercase"
74 | parser.add_argument("--lower", action="store_true", help=msg)
75 |
76 | return parser.parse_args()
77 |
78 |
79 | def main(args):
80 | vocab = {}
81 | limit = args.limit
82 | count = 0
83 |
84 | words, labels, counts = count_items(args.corpus, args.lower)
85 | special = special_tokens(args.special)
86 |
87 | for token in special:
88 | vocab[token] = len(vocab)
89 |
90 | for word, freq in zip(words, counts):
91 | if limit and len(vocab) >= limit:
92 | break
93 |
94 | if word in vocab:
95 | print("warning: found duplicate token %s, ignored" % word)
96 | continue
97 |
98 | vocab[word] = len(vocab)
99 | count += freq
100 |
101 | save_vocab(args.output + "/vocab.txt", vocab)
102 | write_vocab(args.output + "/label.txt", labels)
103 |
104 | print("total words: %d" % sum(counts))
105 | print("unique words: %d" % len(words))
106 | print("vocabulary coverage: %4.2f%%" % (100.0 * count / sum(counts)))
107 |
108 |
109 | if __name__ == "__main__":
110 | main(parse_args())
111 |
--------------------------------------------------------------------------------
/tagger/scripts/convert_to_conll.py:
--------------------------------------------------------------------------------
1 | # convert_to_conll.py
2 | # author: Playinf
3 | # email: playinf@stu.xmu.edu.cn
4 |
5 | import sys
6 |
7 |
8 | def convert_bio(labels):
9 | n = len(labels)
10 | tags = []
11 |
12 | tag = []
13 | count = 0
14 |
15 | # B I*
16 | for label in labels:
17 | count += 1
18 |
19 | if count == n:
20 | next_l = None
21 | else:
22 | next_l = labels[count]
23 |
24 | if label == "O":
25 | if tag:
26 | tags.append(tag)
27 | tag = []
28 | tags.append([label])
29 | continue
30 |
31 | tag.append(label[2:])
32 |
33 | if not next_l or next_l[0] == "B":
34 | tags.append(tag)
35 | tag = []
36 |
37 | new_tag = []
38 |
39 | for tag in tags:
40 | if len(tag) == 1:
41 | if tag[0] == "O":
42 | new_tag.append("*")
43 | else:
44 | new_tag.append("(" + tag[0] + "*)")
45 | continue
46 |
47 | label = tag[0]
48 | n = len(tag)
49 |
50 | for i in range(n):
51 | if i == 0:
52 | new_tag.append("(" + label + "*")
53 | elif i == n - 1:
54 | new_tag.append("*)")
55 | else:
56 | new_tag.append("*")
57 |
58 | return new_tag
59 |
60 |
61 | def print_sentence_to_conll(fout, tokens, labels):
62 | for label_column in labels:
63 | assert len(label_column) == len(tokens)
64 | for i in range(len(tokens)):
65 | fout.write(tokens[i].ljust(15))
66 | for label_column in labels:
67 | fout.write(label_column[i].rjust(15))
68 | fout.write("\n")
69 | fout.write("\n")
70 |
71 |
72 | def print_to_conll(pred_labels, gold_props_file, output_filename):
73 | fout = open(output_filename, 'w')
74 | seq_ptr = 0
75 | num_props_for_sentence = 0
76 | tokens_buf = []
77 |
78 | for line in open(gold_props_file, 'r'):
79 | line = line.strip()
80 | if line == "" and len(tokens_buf) > 0:
81 | print_sentence_to_conll(
82 | fout,
83 | tokens_buf,
84 | pred_labels[seq_ptr:seq_ptr+num_props_for_sentence]
85 | )
86 | seq_ptr += num_props_for_sentence
87 | tokens_buf = []
88 | num_props_for_sentence = 0
89 | else:
90 | info = line.split()
91 | num_props_for_sentence = len(info) - 1
92 | tokens_buf.append(info[0])
93 |
94 | # Output last sentence.
95 | if len(tokens_buf) > 0:
96 | print_sentence_to_conll(
97 | fout,
98 | tokens_buf,
99 | pred_labels[seq_ptr:seq_ptr+num_props_for_sentence]
100 | )
101 |
102 | fout.close()
103 |
104 |
105 | if __name__ == "__main__":
106 | all_labels = []
107 | with open(sys.argv[1]) as fd:
108 | for text_line in fd:
109 | labs = text_line.strip().split()
110 | labs = convert_bio(labs)
111 | all_labels.append(labs)
112 |
113 | print_to_conll(all_labels, sys.argv[2], sys.argv[3])
114 |
--------------------------------------------------------------------------------
/tagger/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from tagger.utils.hparams import HParams
2 | from tagger.utils.checkpoint import save, latest_checkpoint, best_checkpoint
3 | from tagger.utils.scope import scope, get_scope, unique_name
4 | from tagger.utils.misc import get_global_step, set_global_step
5 |
--------------------------------------------------------------------------------
/tagger/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import glob
10 | import torch
11 |
12 |
13 | def oldest_checkpoint(path):
14 | names = glob.glob(os.path.join(path, "*.pt"))
15 |
16 | if not names:
17 | return None
18 |
19 | oldest_counter = 10000000
20 | checkpoint_name = names[0]
21 |
22 | for name in names:
23 | counter = name.rstrip(".pt").split("-")[-1]
24 |
25 | if not counter.isdigit():
26 | continue
27 | else:
28 | counter = int(counter)
29 |
30 | if counter < oldest_counter:
31 | checkpoint_name = name
32 | oldest_counter = counter
33 |
34 | return checkpoint_name
35 |
36 |
37 | def best_checkpoint(path):
38 | if not os.path.exists(os.path.join(path, "checkpoint")):
39 | return latest_checkpoint(path)
40 |
41 | with open(os.path.join(path, "checkpoint")) as fd:
42 | line = fd.readline()
43 | name = line.strip().split()[-1][1:-1]
44 |
45 | return os.path.join(path, name)
46 |
47 |
48 | def latest_checkpoint(path):
49 | names = glob.glob(os.path.join(path, "*.pt"))
50 |
51 | if not names:
52 | return None
53 |
54 | latest_counter = 0
55 | checkpoint_name = names[0]
56 |
57 | for name in names:
58 | counter = name.rstrip(".pt").split("-")[-1]
59 |
60 | if not counter.isdigit():
61 | continue
62 | else:
63 | counter = int(counter)
64 |
65 | if counter > latest_counter:
66 | checkpoint_name = name
67 | latest_counter = counter
68 |
69 | return checkpoint_name
70 |
71 |
72 | def save(state, path, max_to_keep=None):
73 | checkpoints = glob.glob(os.path.join(path, "*.pt"))
74 |
75 | if max_to_keep and len(checkpoints) >= max_to_keep:
76 | checkpoint = oldest_checkpoint(path)
77 | os.remove(checkpoint)
78 |
79 | if not checkpoints:
80 | counter = 1
81 | else:
82 | checkpoint = latest_checkpoint(path)
83 | counter = int(checkpoint.rstrip(".pt").split("-")[-1]) + 1
84 |
85 | checkpoint = os.path.join(path, "model-%d.pt" % counter)
86 | print("Saving checkpoint: %s" % checkpoint)
87 | torch.save(state, checkpoint)
88 |
--------------------------------------------------------------------------------
/tagger/utils/hparams.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 | # Modified from TensorFlow (tf.contrib.training.HParams)
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import json
10 | import logging
11 | import re
12 | import six
13 |
14 |
15 | def parse_values(values, type_map):
16 | ret = {}
17 | param_re = re.compile(r"(?P[a-zA-Z][\w]*)\s*=\s*"
18 | r"((?P[^,\[]*)|\[(?P[^\]]*)\])($|,)")
19 | pos = 0
20 |
21 | while pos < len(values):
22 | m = param_re.match(values, pos)
23 |
24 | if not m:
25 | raise ValueError(
26 | "Malformed hyperparameter value: %s" % values[pos:])
27 |
28 | # Check that there is a comma between parameters and move past it.
29 | pos = m.end()
30 | # Parse the values.
31 | m_dict = m.groupdict()
32 | name = m_dict["name"]
33 |
34 | if name not in type_map:
35 | raise ValueError("Unknown hyperparameter type for %s" % name)
36 |
37 | def parse_fail():
38 | raise ValueError("Could not parse hparam %s in %s" % (name, values))
39 |
40 | if type_map[name] == bool:
41 | def parse_bool(value):
42 | if value == "true":
43 | return True
44 | elif value == "false":
45 | return False
46 | else:
47 | try:
48 | return bool(int(value))
49 | except ValueError:
50 | parse_fail()
51 | parse = parse_bool
52 | else:
53 | parse = type_map[name]
54 |
55 |
56 | if m_dict["val"] is not None:
57 | try:
58 | ret[name] = parse(m_dict["val"])
59 | except ValueError:
60 | parse_fail()
61 | elif m_dict["vals"] is not None:
62 | elements = filter(None, re.split("[ ,]", m_dict["vals"]))
63 | try:
64 | ret[name] = [parse(e) for e in elements]
65 | except ValueError:
66 | parse_fail()
67 | else:
68 | parse_fail()
69 |
70 | return ret
71 |
72 |
73 | class HParams(object):
74 |
75 | def __init__(self, **kwargs):
76 | self._hparam_types = {}
77 |
78 | for name, value in six.iteritems(kwargs):
79 | self.add_hparam(name, value)
80 |
81 | def add_hparam(self, name, value):
82 | if getattr(self, name, None) is not None:
83 | raise ValueError("Hyperparameter name is reserved: %s" % name)
84 | if isinstance(value, (list, tuple)):
85 | if not value:
86 | raise ValueError("Multi-valued hyperparameters cannot be"
87 | " empty: %s" % name)
88 | self._hparam_types[name] = (type(value[0]), True)
89 | else:
90 | self._hparam_types[name] = (type(value), False)
91 | setattr(self, name, value)
92 |
93 | def parse(self, values):
94 | type_map = dict()
95 |
96 | for name, t in six.iteritems(self._hparam_types):
97 | param_type, _ = t
98 | type_map[name] = param_type
99 |
100 | values_map = parse_values(values, type_map)
101 | return self._set_from_map(values_map)
102 |
103 | def _set_from_map(self, values_map):
104 | for name, value in six.iteritems(values_map):
105 | if name not in self._hparam_types:
106 | logging.debug("%s not found in hparams." % name)
107 | continue
108 |
109 | _, is_list = self._hparam_types[name]
110 |
111 | if isinstance(value, list):
112 | if not is_list:
113 | raise ValueError("Must not pass a list for single-valued "
114 | "parameter: %s" % name)
115 | setattr(self, name, value)
116 | else:
117 | if is_list:
118 | raise ValueError("Must pass a list for multi-valued "
119 | "parameter: %s" % name)
120 | setattr(self, name, value)
121 | return self
122 |
123 | def to_json(self):
124 | return json.dumps(self.values())
125 |
126 | def parse_json(self, values_json):
127 | values_map = json.loads(values_json)
128 | return self._set_from_map(values_map)
129 |
130 | def values(self):
131 | return {n: getattr(self, n) for n in six.iterkeys(self._hparam_types)}
132 |
133 | def __str__(self):
134 | return str(sorted(six.iteritems(self.values)))
135 |
--------------------------------------------------------------------------------
/tagger/utils/misc.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | _GLOBAL_STEP = 0
9 |
10 |
11 | def get_global_step():
12 | return _GLOBAL_STEP
13 |
14 |
15 | def set_global_step(step):
16 | global _GLOBAL_STEP
17 | _GLOBAL_STEP = step
18 |
--------------------------------------------------------------------------------
/tagger/utils/scope.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2019 The THUMT Authors
3 | # Modified from TensorFlow (tf.name_scope)
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import re
10 | import contextlib
11 |
12 | # global variable
13 | _NAME_STACK = ""
14 | _NAMES_IN_USE = {}
15 | _VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*$")
16 | _VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/]*$")
17 |
18 |
19 | def unique_name(name, mark_as_used=True):
20 | global _NAME_STACK
21 |
22 | if _NAME_STACK:
23 | name = _NAME_STACK + "/" + name
24 |
25 | i = _NAMES_IN_USE.get(name, 0)
26 |
27 | if mark_as_used:
28 | _NAMES_IN_USE[name] = i + 1
29 |
30 | if i > 0:
31 | base_name = name
32 |
33 | while name in _NAMES_IN_USE:
34 | name = "%s_%d" % (base_name, i)
35 | i += 1
36 |
37 | if mark_as_used:
38 | _NAMES_IN_USE[name] = 1
39 |
40 | return name
41 |
42 |
43 | @contextlib.contextmanager
44 | def scope(name):
45 | global _NAME_STACK
46 |
47 | if name:
48 | if _NAME_STACK:
49 | # check name
50 | if not _VALID_SCOPE_NAME_REGEX.match(name):
51 | raise ValueError("'%s' is not a valid scope name" % name)
52 | else:
53 | # check name strictly
54 | if not _VALID_OP_NAME_REGEX.match(name):
55 | raise ValueError("'%s' is not a valid scope name" % name)
56 |
57 | try:
58 | old_stack = _NAME_STACK
59 |
60 | if not name:
61 | new_stack = None
62 | elif name and name[-1] == "/":
63 | new_stack = name[:-1]
64 | else:
65 | new_stack = unique_name(name)
66 |
67 | _NAME_STACK = new_stack
68 |
69 | yield "" if new_stack is None else new_stack + "/"
70 | finally:
71 | _NAME_STACK = old_stack
72 |
73 |
74 | def get_scope():
75 | return _NAME_STACK
76 |
--------------------------------------------------------------------------------
/tagger/utils/summary.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2017-2020 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import queue
9 | import threading
10 | import torch
11 |
12 | import torch.distributed as dist
13 | import torch.utils.tensorboard as tensorboard
14 |
15 | _SUMMARY_WRITER = None
16 | _QUEUE = None
17 | _THREAD = None
18 |
19 |
20 | class SummaryWorker(threading.Thread):
21 |
22 | def run(self):
23 | global _QUEUE
24 |
25 | while True:
26 | item = _QUEUE.get()
27 | name, kwargs = item
28 |
29 | if name == "stop":
30 | break
31 |
32 | self.write_summary(name, **kwargs)
33 |
34 | def write_summary(self, name, **kwargs):
35 | if name == "scalar":
36 | _SUMMARY_WRITER.add_scalar(**kwargs)
37 | elif name == "histogram":
38 | _SUMMARY_WRITER.add_histogram(**kwargs)
39 |
40 | def stop(self):
41 | global _QUEUE
42 | _QUEUE.put(("stop", None))
43 | self.join()
44 |
45 |
46 | def init(log_dir, enable=True):
47 | global _SUMMARY_WRITER
48 | global _QUEUE
49 | global _THREAD
50 |
51 | if enable and dist.get_rank() == 0:
52 | _SUMMARY_WRITER = tensorboard.SummaryWriter(log_dir)
53 | _QUEUE = queue.Queue()
54 | thread = SummaryWorker(daemon=True)
55 | thread.start()
56 | _THREAD = thread
57 |
58 |
59 | def scalar(tag, scalar_value, global_step=None, walltime=None,
60 | write_every_n_steps=100):
61 |
62 | if _SUMMARY_WRITER is not None:
63 | if global_step % write_every_n_steps == 0:
64 | scalar_value = float(scalar_value)
65 | kwargs = dict(tag=tag, scalar_value=scalar_value,
66 | global_step=global_step, walltime=walltime)
67 | _QUEUE.put(("scalar", kwargs))
68 |
69 |
70 | def histogram(tag, values, global_step=None, bins="tensorflow", walltime=None,
71 | max_bins=None, write_every_n_steps=100):
72 |
73 | if _SUMMARY_WRITER is not None:
74 | if global_step % write_every_n_steps == 0:
75 | values = values.detach().cpu()
76 | kwargs = dict(tag=tag, values=values, global_step=global_step,
77 | bins=bins, walltime=walltime, max_bins=max_bins)
78 | _QUEUE.put(("histogram", kwargs))
79 |
80 |
81 | def close():
82 | if _SUMMARY_WRITER is not None:
83 | _THREAD.stop()
84 | _SUMMARY_WRITER.close()
85 |
--------------------------------------------------------------------------------
/tagger/utils/validation.py:
--------------------------------------------------------------------------------
1 | # validation.py
2 | # author: Playinf
3 | # email: playinf@stu.xmu.edu.cn
4 |
5 | import os
6 | import time
7 | import threading
8 | import subprocess
9 |
10 | from tagger.utils.checkpoint import latest_checkpoint
11 |
12 |
13 | def get_current_model(filename):
14 | try:
15 | with open(filename) as fd:
16 | line = fd.readline()
17 | if not line:
18 | return None
19 |
20 | name = line.strip().split(":")[1]
21 | return name.strip()[1:-1]
22 | except:
23 | return None
24 |
25 |
26 | def read_record(filename):
27 | record = []
28 |
29 | try:
30 | with open(filename) as fd:
31 | for line in fd:
32 | line = line.strip().split(":")
33 | val = float(line[0])
34 | name = line[1].strip()[1:-1]
35 | record.append((val, name))
36 | except:
37 | pass
38 |
39 | return record
40 |
41 |
42 | def write_record(filename, record):
43 | # sort
44 | sorted_record = sorted(record, key=lambda x: -x[0])
45 |
46 | with open(filename, "w") as fd:
47 | for item in sorted_record:
48 | val, name = item
49 | fd.write("%f: \"%s\"\n" % (val, name))
50 |
51 |
52 | def write_checkpoint(filename, record):
53 | # sort
54 | sorted_record = sorted(record, key=lambda x: -x[0])
55 |
56 | with open(filename, "w") as fd:
57 | fd.write("model_checkpoint_path: \"%s\"\n" % sorted_record[0][1])
58 | for item in sorted_record:
59 | val, name = item
60 | fd.write("all_model_checkpoint_paths: \"%s\"\n" % name)
61 |
62 |
63 | def add_to_record(record, item, capacity):
64 | added = None
65 | removed = None
66 | models = {}
67 |
68 | for (val, name) in record:
69 | models[name] = val
70 |
71 | if len(record) < capacity:
72 | if item[1] not in models:
73 | added = item[1]
74 | record.append(item)
75 | else:
76 | sorted_record = sorted(record, key=lambda x: -x[0])
77 | worst_score = sorted_record[-1][0]
78 | current_score = item[0]
79 |
80 | if current_score >= worst_score:
81 | if item[1] not in models:
82 | added = item[1]
83 | removed = sorted_record[-1][1]
84 | record = sorted_record[:-1] + [item]
85 |
86 | return added, removed, record
87 |
88 |
89 | class ValidationWorker(threading.Thread):
90 |
91 | def init(self, params):
92 | self._params = params
93 | self._stop = False
94 |
95 | def run(self):
96 | params = self._params
97 | best_dir = params.output + "/best"
98 | last_checkpoint = None
99 |
100 | # create directory
101 | if not os.path.exists(best_dir):
102 | os.mkdir(best_dir)
103 | record = []
104 | else:
105 | record = read_record(best_dir + "/top")
106 |
107 | while not self._stop:
108 | try:
109 | time.sleep(params.frequency)
110 | model_name = latest_checkpoint(params.output)
111 |
112 | if model_name is None:
113 | continue
114 |
115 | if model_name == last_checkpoint:
116 | continue
117 |
118 | last_checkpoint = model_name
119 |
120 | model_name = model_name.split("/")[-1]
121 | # prediction and evaluation
122 | child = subprocess.Popen("bash %s" % params.script,
123 | shell=True, stdout=subprocess.PIPE,
124 | stderr=subprocess.PIPE)
125 | info = child.communicate()[0]
126 |
127 | if not info:
128 | continue
129 |
130 | info = info.strip().split(b"\n")
131 | overall = None
132 |
133 | for line in info[::-1]:
134 | if line.find(b"Overall") > 0:
135 | overall = line
136 | break
137 |
138 | if not overall:
139 | continue
140 |
141 | f_score = float(overall.strip().split()[-1])
142 |
143 | # save best model
144 | item = (f_score, model_name)
145 | added, removed, record = add_to_record(record, item,
146 | params.keep_top_k)
147 | log_fd = open(best_dir + "/log", "a")
148 | log_fd.write("%s: %f\n" % (model_name, f_score))
149 | log_fd.close()
150 |
151 | if added is not None:
152 | model_path = params.output + "/" + model_name + "*"
153 | # copy model
154 | os.system("cp %s %s" % (model_path, best_dir))
155 | # update checkpoint
156 | write_record(best_dir + "/top", record)
157 | write_checkpoint(best_dir + "/checkpoint", record)
158 |
159 | if removed is not None:
160 | # remove old model
161 | model_name = params.output + "/best/" + removed + "*"
162 | os.system("rm %s" % model_name)
163 | except Exception as e:
164 | print(e)
165 |
166 | def stop(self):
167 | self._stop = True
168 |
--------------------------------------------------------------------------------