├── 代码解释
├── data
├── PAGE
│ └── 111
└── LM
│ └── pre.py
├── bert
├── __init__.py
├── tokenization.py
└── bert_model_for_PAGE.py
├── results
├── PAGE
│ └── 111
└── bert
│ └── 111
├── model.jpg
├── .idea
├── encodings.xml
├── vcs.xml
├── misc.xml
├── modules.xml
├── bert-transformer-vae.iml
├── inspectionProfiles
│ └── Project_Default.xml
├── webServers.xml
├── deployment.xml
└── workspace.xml
├── 运行命令
├── README.md
├── test.py
├── beamsearch.py
├── test_PAGE.py
├── hparams.py
├── prepro.py
├── utils.py
├── train_TPAGE.py
├── model.py
├── bert_transformer_vae_for_PAGE.py
├── modules.py
└── data_load.py
/代码解释:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/PAGE/111:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/bert/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/results/PAGE/111:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/results/bert/111:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuhaiming1996/BERT-T2T/HEAD/model.jpg
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/运行命令:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | python train_TPAGE.py \
6 | --train=data/PAGE/train.txt \
7 | --eval=data/PAGE/eval.txt \
8 | --init_checkpoint_bert=results/bert/bert_model.ckpt \
9 | --batch_size=32 \
10 | --eval_batch_size=32 \
11 | --num_epochs_PAGE=10 \
12 | --maxlen_vae_Encoder=80 \
13 | --maxlen_vae_Decoder_en=40\
14 | --maxlen_vae_Decoder_de=40\
15 |
16 |
--------------------------------------------------------------------------------
/.idea/bert-transformer-vae.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
22 |
23 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/data/LM/pre.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | class PreHelper:
4 | def __init__(self,filePath_r,filePath_w_train,filePath_w_eval):
5 | self.filePath_r=filePath_r
6 | self.filePath_w_train=filePath_w_train
7 | self.filePath_w_eval = filePath_w_eval
8 |
9 | def proprocess(self):
10 | sens = []
11 | num = 0
12 | with open(self.filePath_r,mode="r",encoding="utf-8") as fr:
13 | for line in fr:
14 | num += 1
15 | line = line.strip()
16 | if line != "" and len(line)>8:
17 | sens.append(line.strip())
18 | if num%1000000 == 0:
19 | print("数据正在提取中,请耐心等待!!!")
20 |
21 | random.shuffle(sens)
22 | print("读写完成,开始写入文件")
23 | sens_eval = sens[:20000]
24 | sens_train = sens[20000:]
25 | with open(self.filePath_w_eval,mode="w",encoding="utf-8") as fw:
26 | for line in sens_eval:
27 | line = line.strip()
28 | fw.write("---xhm---".join([line,line])+"\n")
29 |
30 |
31 | with open(self.filePath_w_train,mode="w",encoding="utf-8") as fw:
32 | for line in sens_train:
33 | line = line.strip()
34 | fw.write("---xhm---".join([line,line])+"\n")
35 |
36 |
37 |
38 | if __name__=="__main__":
39 | preHelper=PreHelper(filePath_r="sens_LM_raw.txt",filePath_w_train="train.txt",filePath_w_eval="eval.txt")
40 | preHelper.proprocess()
41 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 介绍
2 | NLP中,对于生成问题如NMT,QA, Paraphrase 任务来说通常会存在生成多样性不足的问题,
3 | 通常我们会采用beamSearch来增加多样性。但是beamSeach 生成的句子还是有很大的相似度,无法满足项目落地需求。
4 | 我采用了这篇[A Deep Generative Framework for Paraphrase Generation](https://arxiv.org/abs/1709.05074)
5 | 的基于CVAE的结构思想构造了一个模型,试图解决生成任务的多样性。
6 |
7 |
8 |
9 | ## 模型结构图
10 | 提示:请先看这篇论文[A Deep Generative Framework for Paraphrase Generation](https://arxiv.org/abs/1709.05074)
11 | 的思想和结构,再看我下面的这个模型结构图
12 | 
13 |
14 | ## 文件说明
15 | ### /data/PAGE 训练语料
16 | train.txt 格式:id---xhm--src---xhm--tgt
17 |
18 | eval.txt 格式:id---xhm--src---xhm--tgt
19 |
20 | test.txt 格式:id---xhm--src---xhm--tgt
21 |
22 | ### results
23 | #### /results/bert
24 | 该文件是预训练的好中文bert模型,大家可以去[这里](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)下载,解压后放在这里
25 | #### /results/PAGE
26 | 该文件夹是复述模型保存路径
27 |
28 |
29 |
30 | ### 运行命令
31 | 模型训练使用的是tf.data.* API 从tfrecord文件中构造的迭代器(感慨一下:非常强大的API.建议大家都采用这种方式)
32 |
33 | python train_TPAGE.py \
34 | --train=data/PAGE/train.txt \
35 | --eval=data/PAGE/eval.txt \
36 | --init_checkpoint_bert=results/bert/bert_model.ckpt \
37 | --batch_size=32 \
38 | --eval_batch_size=32 \
39 | --num_epochs_PAGE=10 \
40 | --maxlen_vae_Encoder=80 \
41 | --maxlen_vae_Decoder_en=40\
42 | --maxlen_vae_Decoder_de=40\
43 |
44 | #### 温馨一刻
45 | 大家若对于KL loss的计算公式有疑问,请看这里的[公式推导](https://blog.csdn.net/qq_32806793/article/details/95652645) 你就会明白代码为啥这样写了
46 |
47 | KL_loss = tf.reduce_mean(0.5 * tf.reduce_sum(tf.square(mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, [1]))
48 |
49 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python2
3 |
4 |
5 | import os
6 |
7 | import tensorflow as tf
8 |
9 | from data_load import get_batch
10 | from model import Transformer
11 | from hparams import Hparams
12 | from utils import get_hypotheses, calc_bleu, postprocess, load_hparams
13 | import logging
14 |
15 | logging.basicConfig(level=logging.INFO)
16 |
17 | logging.info("# hparams")
18 | hparams = Hparams()
19 | parser = hparams.parser
20 | hp = parser.parse_args()
21 | load_hparams(hp, hp.ckpt)
22 |
23 | logging.info("# Prepare test batches")
24 | test_batches, num_test_batches, num_test_samples = get_batch(hp.test1, hp.test1,
25 | 100000, 100000,
26 | hp.vocab, hp.test_batch_size,
27 | shuffle=False)
28 | iter = tf.data.Iterator.from_structure(test_batches.output_types, test_batches.output_shapes)
29 | xs, ys = iter.get_next()
30 |
31 | test_init_op = iter.make_initializer(test_batches)
32 |
33 | logging.info("# Load model")
34 | m = Transformer(hp)
35 | y_hat, _ = m.eval(xs, ys)
36 |
37 | logging.info("# Session")
38 | with tf.Session() as sess:
39 | ckpt_ = tf.train.latest_checkpoint(hp.ckpt)
40 | ckpt = hp.ckpt if ckpt_ is None else ckpt_ # None: ckpt is a file. otherwise dir.
41 | saver = tf.train.Saver()
42 |
43 | saver.restore(sess, ckpt)
44 |
45 | sess.run(test_init_op)
46 |
47 | logging.info("# get hypotheses")
48 | hypotheses = get_hypotheses(num_test_batches, num_test_samples, sess, y_hat, m.idx2token)
49 |
50 | logging.info("# write results")
51 | model_output = ckpt.split("/")[-1]
52 | if not os.path.exists(hp.testdir): os.makedirs(hp.testdir)
53 | translation = os.path.join(hp.testdir, model_output)
54 | with open(translation, 'w') as fout:
55 | fout.write("\n".join(hypotheses))
56 |
57 | logging.info("# calc bleu score and append it to translation")
58 | calc_bleu(hp.test2, translation)
59 |
60 |
--------------------------------------------------------------------------------
/beamsearch.py:
--------------------------------------------------------------------------------
1 | def beam_search(x, sess, g, batch_size=hp.batch_size):
2 | inputs = np.reshape(np.transpose(np.array([x] * hp.beam_size), (1, 0, 2)),
3 | (hp.beam_size * batch_size, hp.max_len))
4 | preds = np.zeros((batch_size, hp.beam_size, hp.y_max_len), np.int32)
5 | prob_product = np.zeros((batch_size, hp.beam_size))
6 | stc_length = np.ones((batch_size, hp.beam_size))
7 |
8 | for j in range(hp.y_max_len):
9 | _probs, _preds = sess.run(
10 | g.preds, {g.x: inputs, g.y: np.reshape(preds, (hp.beam_size * batch_size, hp.y_max_len))})
11 | j_probs = np.reshape(_probs[:, j, :], (batch_size, hp.beam_size, hp.beam_size))
12 | j_preds = np.reshape(_preds[:, j, :], (batch_size, hp.beam_size, hp.beam_size))
13 | if j == 0:
14 | preds[:, :, j] = j_preds[:, 0, :]
15 | prob_product += np.log(j_probs[:, 0, :])
16 | else:
17 | add_or_not = np.asarray(np.logical_or.reduce([j_preds > hp.end_id]), dtype=np.int)
18 | tmp_stc_length = np.expand_dims(stc_length, axis=-1) + add_or_not
19 | tmp_stc_length = np.reshape(tmp_stc_length, (batch_size, hp.beam_size * hp.beam_size))
20 |
21 | this_probs = np.expand_dims(prob_product, axis=-1) + np.log(j_probs) * add_or_not
22 | this_probs = np.reshape(this_probs, (batch_size, hp.beam_size * hp.beam_size))
23 | selected = np.argsort(this_probs / tmp_stc_length, axis=1)[:, -hp.beam_size:]
24 |
25 | tmp_preds = np.concatenate([np.expand_dims(preds, axis=2)] * hp.beam_size, axis=2)
26 | tmp_preds[:, :, :, j] = j_preds[:, :, :]
27 | tmp_preds = np.reshape(tmp_preds, (batch_size, hp.beam_size * hp.beam_size, hp.y_max_len))
28 |
29 | for batch_idx in range(batch_size):
30 | prob_product[batch_idx] = this_probs[batch_idx, selected[batch_idx]]
31 | preds[batch_idx] = tmp_preds[batch_idx, selected[batch_idx]]
32 | stc_length[batch_idx] = tmp_stc_length[batch_idx, selected[batch_idx]]
33 |
34 | final_selected = np.argmax(prob_product / stc_length, axis=1)
35 | final_preds = []
36 | for batch_idx in range(batch_size):
37 | final_preds.append(preds[batch_idx, final_selected[batch_idx]])
38 |
39 | return final_preds
--------------------------------------------------------------------------------
/test_PAGE.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python3
3 |
4 | import tensorflow as tf
5 |
6 | from bert_transformer_vae_for_PAGE import VaeModel
7 |
8 | from data_load import get_batch_for_train_or_dev_or_test,saveForTfRecord
9 | from utils import save_hparams, get_hypotheses
10 | import os
11 | from hparams import Hparams
12 | import logging
13 | os.environ['CUDA_VISIBLE_DEVICES']= '5'
14 | logging.basicConfig(level=logging.INFO)
15 |
16 |
17 | logging.info("# hparams")
18 | hparams = Hparams()
19 | parser = hparams.parser
20 | hp = parser.parse_args()
21 | save_hparams(hp, hp.PAGEdir)
22 |
23 |
24 | logging.info("# 许海明提醒你: 这里需要准备tfRecord")
25 | logging.info("# 许海明提醒你: 这里需要准备tfRecord")
26 | logging.info("# 许海明提醒你: 这里需要准备tfRecord")
27 | logging.info("# 许海明提醒你: 这里需要准备tfRecord")
28 |
29 |
30 |
31 | saveForTfRecord(hp.test,
32 | hp.maxlen_vae_Encoder,
33 | hp.maxlen_vae_Decoder_en,
34 | hp.maxlen_vae_Decoder_de,
35 | hp.vocab,
36 | output_file="./data/PAGE/test.tf_record",
37 | is_test=False)
38 |
39 |
40 | test_batches, num_test_batches, num_test_samples = get_batch_for_train_or_dev_or_test(hp.test,
41 | hp.maxlen_vae_Encoder,
42 | hp.maxlen_vae_Decoder_en,
43 | hp.maxlen_vae_Decoder_de,
44 | hp.test_batch_size,
45 | input_file="./data/PAGE/test.tf_record" ,
46 | is_training=False,
47 | is_test= False)
48 |
49 | # create a iterator of the correct shape and type
50 | iter = tf.data.Iterator.from_structure(test_batches.output_types, test_batches.output_shapes)
51 | features=iter.get_next()
52 |
53 | xs = (features["input_ids_vae_encoder"], features["input_masks_vae_encoder"], features["segment_ids_vae_encoder"], features["sent1"], features["sent2"])
54 | ys = (features["input_ids_vae_decoder_enc"], features["input_ids_vae_decoder_dec"], features["output_ids_vae_decoder_dec"], features["sent1"], features["sent2"])
55 |
56 |
57 | test_init_op = iter.make_initializer(test_batches)
58 |
59 |
60 | logging.info("# Load model")
61 | m = VaeModel(hp)
62 |
63 | y_hat, _ = m.eval(xs, ys, mode="TPAGE")
64 | # y_hat = m.infer(xs, ys)
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 | with tf.Session() as sess:
79 | ckpt = tf.train.latest_checkpoint(hp.PAGEdir)
80 | saver = tf.train.Saver(max_to_keep=hp.num_epochs_LM)
81 | saver.restore(sess, ckpt)
82 | sess.run(test_init_op)
83 |
84 |
85 | logging.info("# get hypotheses")
86 | hypotheses = get_hypotheses(num_test_batches, num_test_samples, sess, y_hat, m.idx2token)
87 |
88 | logging.info("# write results")
89 | model_output = "test_2.txt"
90 | translation = os.path.join(hp.PAGEdir,model_output)
91 | with open(translation, mode='w',encoding="utf-8") as fout:
92 | fout.write("\n".join(hypotheses))
93 |
94 |
95 |
96 |
97 | logging.info("Done")
98 |
--------------------------------------------------------------------------------
/hparams.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | class Hparams:
4 | parser = argparse.ArgumentParser()
5 |
6 |
7 |
8 |
9 | ## files
10 | parser.add_argument('--train', help="训练数据")
11 | parser.add_argument('--eval', help="验证数据")
12 | parser.add_argument('--test', help="测试数据")
13 |
14 |
15 | ## vocabulary
16 | parser.add_argument('--vocab', default='results/bert/vocab.txt',
17 | help="vocabulary file path")
18 |
19 | parser.add_argument('--init_checkpoint_LM', help="语言模型的初始路径-fine_tune")
20 | parser.add_argument('--init_checkpoint_bert', help="bert模型的初始路径-fine_tune")
21 | parser.add_argument('--init_checkpoint_PAGE', help="复述模型的初始路径-fine_tune")
22 |
23 |
24 | # training scheme
25 | parser.add_argument('--batch_size', default=32, type=int)
26 | parser.add_argument('--eval_batch_size', default=32, type=int)
27 | parser.add_argument('--test_batch_size', default=128, type=int)
28 |
29 | parser.add_argument('--lr', default=5e-5, type=float, help="learning rate")
30 | parser.add_argument('--warmup_steps', default=100, type=int)
31 |
32 | parser.add_argument('--num_epochs_LM', default=20, type=int, help="语言模型训练的epoch")
33 | parser.add_argument('--num_epochs_PAGE', default=20, type=int, help="复述模型训练的epoch")
34 |
35 | parser.add_argument('--LMdir', default="results/LM", help="这是语言模型的的路径")
36 | parser.add_argument('--PAGEdir', default="results/PAGE", help="这是复述模型的的路径")
37 |
38 | # model
39 | parser.add_argument('--d_model', default=768, type=int,
40 | help="hidden dimension of encoder/decoder")
41 | parser.add_argument('--d_ff', default=2048, type=int,
42 | help="hidden dimension of feedforward layer")
43 | parser.add_argument('--num_blocks', default=4, type=int,
44 | help="number of encoder/decoder blocks")
45 | parser.add_argument('--num_heads', default=8, type=int,
46 | help="number of attention heads")
47 |
48 | parser.add_argument('--dropout_rate', default=0.1, type=float)
49 | parser.add_argument('--smoothing', default=0.1, type=float,
50 | help="label smoothing rate")
51 |
52 |
53 |
54 |
55 | parser.add_argument('--maxlen_vae_Encoder', default=80, type=int,
56 | help="VAE编码器的最大的长度 len(sen1)+len(sen2)")
57 | parser.add_argument('--maxlen_vae_Decoder_en', default=40, type=int,
58 | help="源句子的最大长度")
59 | parser.add_argument('--maxlen_vae_Decoder_de', default=40, type=int,
60 | help="复述句的最大长度")
61 |
62 |
63 |
64 | # vae Z
65 | parser.add_argument('--z_dim', default=768, type=int,
66 | help="VAE中那个Z的维度")
67 |
68 |
69 |
70 | # bert的参数
71 | parser.add_argument('--attention_probs_dropout_prob', default=0.1, type=float)
72 | parser.add_argument('--hidden_dropout_prob', default=0.1, type=float)
73 | parser.add_argument('--initializer_range', default=0.02, type=float)
74 |
75 |
76 | parser.add_argument('--hidden_size', default=768, type=int)
77 | parser.add_argument('--intermediate_size', default=3072, type=int)
78 | parser.add_argument('--max_position_embeddings', default=512, type=int)
79 | parser.add_argument('--num_attention_heads', default=12, type=int)
80 | parser.add_argument('--num_hidden_layers', default=12, type=int)
81 | parser.add_argument('--pooler_fc_size', default=768, type=int)
82 | parser.add_argument('--pooler_num_attention_heads', default=12, type=int)
83 |
84 |
85 | parser.add_argument('--pooler_num_fc_layers', default=3, type=int)
86 | parser.add_argument('--pooler_size_per_head', default=128, type=int)
87 | parser.add_argument('--type_vocab_size', default=2, type=int)
88 | parser.add_argument('--vocab_size', default=21128, type=int)
89 |
90 | parser.add_argument('--directionality', default="bidi")
91 | parser.add_argument('--hidden_act', default="gelu")
92 | parser.add_argument('--pooler_type', default="first_token_transform")
93 |
94 |
--------------------------------------------------------------------------------
/prepro.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python3
3 | '''
4 | Feb. 2019 by kyubyong park.
5 | kbpark.linguist@gmail.com.
6 | https://www.github.com/kyubyong/transformer.
7 |
8 | Preprocess the iwslt 2016 datasets.
9 | '''
10 |
11 | import os
12 | import errno
13 | import sentencepiece as spm
14 | import re
15 | from hparams import Hparams
16 | import logging
17 |
18 | logging.basicConfig(level=logging.INFO)
19 |
20 | def prepro(hp):
21 | """Load raw data -> Preprocessing -> Segmenting with sentencepice
22 | hp: hyperparams. argparse.
23 | """
24 | logging.info("# Check if raw files exist")
25 | train1 = "iwslt2016/de-en/train.tags.de-en.de"
26 | train2 = "iwslt2016/de-en/train.tags.de-en.en"
27 | eval1 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.de.xml"
28 | eval2 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.en.xml"
29 | test1 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.de.xml"
30 | test2 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.en.xml"
31 | for f in (train1, train2, eval1, eval2, test1, test2):
32 | if not os.path.isfile(f):
33 | raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), f)
34 |
35 | logging.info("# Preprocessing")
36 | # train
37 | _prepro = lambda x: [line.strip() for line in open(x, 'r').read().split("\n") \
38 | if not line.startswith("<")]
39 | prepro_train1, prepro_train2 = _prepro(train1), _prepro(train2)
40 | assert len(prepro_train1)==len(prepro_train2), "Check if train source and target files match."
41 |
42 | # eval
43 | _prepro = lambda x: [re.sub("<[^>]+>", "", line).strip() \
44 | for line in open(x, 'r').read().split("\n") \
45 | if line.startswith("ntk', dec, weights) # (N, T2, vocab_size)
117 | y_hat = tf.to_int32(tf.argmax(logits, axis=-1))
118 |
119 | return logits, y_hat, y, sents2
120 |
121 | def train(self, xs, ys):
122 | '''
123 | Returns
124 | loss: scalar.
125 | train_op: training operation
126 | global_step: scalar.
127 | summaries: training summary node
128 | '''
129 | # forward
130 | memory, sents1 = self.encode(xs)
131 | logits, preds, y, sents2 = self.decode(ys, memory)
132 |
133 | # train scheme
134 | y_ = label_smoothing(tf.one_hot(y, depth=self.hp.vocab_size))
135 | ce = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y_)
136 | nonpadding = tf.to_float(tf.not_equal(y, self.token2idx[""])) # 0:
137 | loss = tf.reduce_sum(ce * nonpadding) / (tf.reduce_sum(nonpadding) + 1e-7)
138 |
139 | global_step = tf.train.get_or_create_global_step()
140 | lr = noam_scheme(self.hp.lr, global_step, self.hp.warmup_steps)
141 | optimizer = tf.train.AdamOptimizer(lr)
142 | train_op = optimizer.minimize(loss, global_step=global_step)
143 |
144 | tf.summary.scalar('lr', lr)
145 | tf.summary.scalar("loss", loss)
146 | tf.summary.scalar("global_step", global_step)
147 |
148 | summaries = tf.summary.merge_all()
149 |
150 | return loss, train_op, global_step, summaries
151 |
152 | def eval(self, xs, ys):
153 | '''Predicts autoregressively
154 | At inference, input ys is ignored.
155 | Returns
156 | y_hat: (N, T2)
157 | '''
158 | decoder_inputs, y, y_seqlen, sents2 = ys
159 |
160 | decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.token2idx[""]
161 | ys = (decoder_inputs, y, y_seqlen, sents2)
162 |
163 | memory, sents1 = self.encode(xs, False)
164 |
165 | logging.info("Inference graph is being built. Please be patient.")
166 | for _ in tqdm(range(self.hp.maxlen2)):
167 | logits, y_hat, y, sents2 = self.decode(ys, memory, False)
168 | if tf.reduce_sum(y_hat, 1) == self.token2idx[""]: break
169 |
170 | _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1)
171 | ys = (_decoder_inputs, y, y_seqlen, sents2)
172 |
173 | # monitor a random sample
174 | n = tf.random_uniform((), 0, tf.shape(y_hat)[0]-1, tf.int32)
175 | sent1 = sents1[n]
176 | pred = convert_idx_to_token_tensor(y_hat[n], self.idx2token)
177 | sent2 = sents2[n]
178 |
179 | tf.summary.text("sent1", sent1)
180 | tf.summary.text("pred", pred)
181 | tf.summary.text("sent2", sent2)
182 | summaries = tf.summary.merge_all()
183 |
184 | return y_hat, summaries
185 |
186 |
--------------------------------------------------------------------------------
/bert_transformer_vae_for_PAGE.py:
--------------------------------------------------------------------------------
1 | from bert import bert_model_for_PAGE
2 | from bert.bert_model_for_PAGE import get_shape_list,layer_norm
3 | import tensorflow as tf
4 | from modules import ff, multihead_attention, label_smoothing, noam_scheme
5 | from utils import convert_idx_to_token_tensor
6 | from tqdm import tqdm
7 | import logging
8 |
9 | from data_load import load_vocab
10 |
11 | class VaeModel:
12 |
13 | '''
14 | 这是模型的主题框架
15 | '''
16 | def __init__(self,hp):
17 | self.hp = hp
18 | self.token2idx, self.idx2token = load_vocab(hp.vocab)
19 |
20 |
21 | # 变积分自动编码器的编码器
22 | def encoder_vae(self, xs, training, mode):
23 |
24 | '''
25 | :param xs:
26 | :param training:
27 | :param mode: TPAGE表示训练复述模型 LM 表示训练transformer的解码器 PPAGE:代表预测的时候
28 | return:
29 | '''
30 |
31 | input_ids_vae_encoder, input_masks_vae_encoder, segment_ids_vae_encoder, sents1, sents2 = xs
32 |
33 | if mode == "TPAGE" or mode == "PPAGE":
34 | # 这是整体训练 vae
35 | encoder_vae = bert_model_for_PAGE.BertModel(
36 | config=self.hp,
37 | is_training=training,
38 | input_ids=input_ids_vae_encoder,
39 | input_mask=input_masks_vae_encoder,
40 | token_type_ids=segment_ids_vae_encoder)
41 |
42 | self.embeddings = encoder_vae.get_embedding_table()
43 | self.full_position_embeddings = encoder_vae.get_full_position_embeddings()
44 | pattern_para = tf.squeeze(encoder_vae.get_sequence_output()[:, 0:1, :], axis=1)
45 | gaussian_params = tf.layers.dense(pattern_para, 2 * self.hp.z_dim)
46 | mean = gaussian_params[:, :self.hp.z_dim]
47 | stddev = 1e-6 + tf.nn.softplus(gaussian_params[:, self.hp.z_dim:])
48 | else:
49 | raise("您好,你选择的mode 出现错误")
50 |
51 | return mean, stddev
52 |
53 |
54 |
55 | # 变积分自动编码器的解码器
56 | def decoder_vae(self, ys, z, training, mode):
57 | input_ids_vae_decoder_enc, input_ids_vae_decoder_dec,output_ids_vae_decoder_dec, sents1, sents2 = ys
58 | if mode == "TPAGE" or mode == "PPAGE":
59 | with tf.variable_scope("decoder_enc_vae", reuse=tf.AUTO_REUSE):
60 | enc = tf.nn.embedding_lookup(self.embeddings, input_ids_vae_decoder_enc) # (N, T1, d_model)
61 | input_shape = get_shape_list(enc, expected_rank=3)
62 | seq_length = input_shape[1]
63 | width = input_shape[2]
64 | output = enc
65 | assert_op = tf.assert_less_equal(seq_length, self.hp.max_position_embeddings)
66 | with tf.control_dependencies([assert_op]):
67 | position_embeddings = tf.slice(self.full_position_embeddings, [0, 0],
68 | [seq_length, -1])
69 | num_dims = len(output.shape.as_list())
70 | position_broadcast_shape = []
71 | for _ in range(num_dims - 2):
72 | position_broadcast_shape.append(1)
73 | position_broadcast_shape.extend([seq_length, width])
74 | position_embeddings = tf.reshape(position_embeddings,
75 | position_broadcast_shape)
76 |
77 | output += position_embeddings # 添加位置信息
78 |
79 | enc = output
80 | ## Blocks
81 | for i in range(self.hp.num_blocks):
82 | with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
83 | # self-attention
84 | enc = multihead_attention(seq_k=input_ids_vae_decoder_enc,
85 | seq_q=input_ids_vae_decoder_enc,
86 | queries=enc,
87 | keys=enc,
88 | values=enc,
89 | num_heads=self.hp.num_heads,
90 | dropout_rate=self.hp.dropout_rate,
91 | training=training,
92 | causality=False,
93 | scope="self_attention")
94 | # feed forward
95 | enc = ff(enc, num_units=[self.hp.d_ff, self.hp.d_model])
96 |
97 |
98 | memory = enc
99 |
100 |
101 | # 下面是VAE的decoder中的 decoder
102 | with tf.variable_scope("decoder_dec_vae", reuse=tf.AUTO_REUSE):
103 | dec = tf.nn.embedding_lookup(self.embeddings, input_ids_vae_decoder_dec) # (N, T, d_model)
104 | input_shape = get_shape_list(dec, expected_rank=3)
105 | seq_length = input_shape[1]
106 | width = input_shape[2]
107 | output = dec
108 | assert_op = tf.assert_less_equal(seq_length, self.hp.max_position_embeddings)
109 | with tf.control_dependencies([assert_op]):
110 | position_embeddings = tf.slice(self.full_position_embeddings, [0, 0],
111 | [seq_length, -1])
112 | num_dims = len(output.shape.as_list())
113 | position_broadcast_shape = []
114 | for _ in range(num_dims - 2):
115 | position_broadcast_shape.append(1)
116 | position_broadcast_shape.extend([seq_length, width])
117 | position_embeddings = tf.reshape(position_embeddings,
118 | position_broadcast_shape)
119 |
120 |
121 | output += position_embeddings # 添加位置信息
122 |
123 | dec=output
124 |
125 | # 在这里加上 规则模式 采用的是concate的方式!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
126 | z = tf.expand_dims(z, axis=1)
127 | z = tf.tile(z, multiples=[1, seq_length, 1])
128 | dec = tf.concat([dec, z], axis=-1)
129 | dec = tf.layers.dense(dec, self.hp.d_model)
130 | # Blocks
131 | for i in range(self.hp.num_blocks):
132 | with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
133 | # Masked self-attention (Note that causality is True at this time)
134 | dec = multihead_attention(seq_k=input_ids_vae_decoder_dec,
135 | seq_q=input_ids_vae_decoder_dec,
136 | queries=dec,
137 | keys=dec,
138 | values=dec,
139 | num_heads=self.hp.num_heads,
140 | dropout_rate=self.hp.dropout_rate,
141 | training=training,
142 | causality=True,
143 | scope="self_attention")
144 |
145 | # Vanilla attention
146 | dec = multihead_attention(seq_k=input_ids_vae_decoder_enc,
147 | seq_q=input_ids_vae_decoder_dec,
148 | queries=dec,
149 | keys=memory,
150 | values=memory,
151 | num_heads=self.hp.num_heads,
152 | dropout_rate=self.hp.dropout_rate,
153 | training=training,
154 | causality=False,
155 | scope="vanilla_attention")
156 | ### Feed Forward
157 | dec = ff(dec, num_units=[self.hp.d_ff, self.hp.d_model])
158 |
159 |
160 |
161 | # Final linear projection (embedding weights are shared)
162 | weights = tf.transpose(self.embeddings) # (d_model, vocab_size)
163 | # 这里为了适应迎合 强制加上一个 FF
164 | logits = tf.einsum('ntd,dk->ntk', dec, weights) # (N, T2, vocab_size)
165 | y_hat = tf.to_int32(tf.argmax(logits, axis=-1))
166 | return logits, y_hat, output_ids_vae_decoder_dec, sents2
167 |
168 |
169 |
170 | def train(self, xs, ys,mode):
171 | '''
172 | Returns
173 | loss: scalar.
174 | train_op: training operation
175 | global_step: scalar.
176 | summaries: training summary node
177 | '''
178 | # forward
179 |
180 | mu, sigma = self.encoder_vae(xs, training=True, mode=mode)
181 | if mode == "TPAGE" or mode == "PPAGE":
182 | # 表示 训练VAE
183 | # 这里提醒自己一下 将embeding 全部设为True
184 | z = mu + sigma * tf.random_normal(tf.shape(mu), 0, 1, dtype=tf.float32)
185 | else:
186 | raise ("许海明在这里提醒你:出现非法mode")
187 |
188 |
189 |
190 | logits, preds, y, sents2 = self.decoder_vae(ys, z,training=True,mode=mode)
191 |
192 | # train scheme
193 |
194 | ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)
195 | nonpadding = tf.to_float(tf.not_equal(y, self.token2idx["[PAD]"])) # 0:
196 | loss_decoder = tf.reduce_sum(ce * nonpadding) / tf.to_float(get_shape_list(xs[0],expected_rank=2)[0])
197 |
198 | # 这里加上KL loss
199 | if mode == "TPAGE":
200 | KL_loss = tf.reduce_mean(0.5 * tf.reduce_sum(tf.square(mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, [1]))
201 | else:
202 | KL_loss = 0.0
203 |
204 | loss = loss_decoder + KL_loss
205 |
206 |
207 | global_step = tf.train.get_or_create_global_step()
208 | lr = noam_scheme(self.hp.lr, global_step, self.hp.warmup_steps)
209 | optimizer = tf.train.AdamOptimizer(lr)
210 | train_op = optimizer.minimize(loss, global_step=global_step)
211 |
212 | # # monitor a random sample
213 | n = tf.random_uniform((), 0, tf.shape(preds)[0] - 1, tf.int32)
214 | print_demo=(xs[0][n],
215 | y[n],
216 | preds[n])
217 |
218 |
219 | tf.summary.scalar('lr', lr)
220 | tf.summary.scalar("KL_loss", KL_loss)
221 | tf.summary.scalar("loss_decoder", loss_decoder)
222 | tf.summary.scalar("loss", loss)
223 | tf.summary.scalar("global_step", global_step)
224 |
225 | summaries = tf.summary.merge_all()
226 |
227 | return loss, train_op, global_step, summaries, print_demo
228 |
229 |
230 |
231 | def eval(self, xs,ys,mode):
232 |
233 | mu, sigma = self.encoder_vae(xs, training=False, mode=mode) #这里主要是为了获取 embeding 其他的都没用
234 |
235 | if mode == "TPAGE" or mode == "PPAGE":
236 | # 表示 训练VAE
237 | # 这里提醒自己一下 将embeding 全部设为True
238 | z = mu + sigma * tf.random_normal(tf.shape(mu), 0, 1, dtype=tf.float32)
239 |
240 | else:
241 | raise ("许海明在这里提醒你:出现非法mode")
242 |
243 |
244 |
245 | # z = tf.random_normal([get_shape_list(xs[0], expected_rank=2)[0], self.hp.z_dim]) #自动生成采样因子
246 |
247 | input_ids_vae_decoder_enc, input_ids_vae_decoder_dec, output_ids_vae_decoder_dec, sents1, sents2 = ys
248 |
249 | decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.token2idx[""]
250 | ys = (input_ids_vae_decoder_enc, decoder_inputs, output_ids_vae_decoder_dec, sents1, sents2)
251 |
252 | logging.info("Inference graph is being built. Please be patient.")
253 | for _ in tqdm(range(self.hp.maxlen_vae_Decoder_de)):
254 | logits, y_hat, y, sents2 = self.decoder_vae(ys, z, training=False, mode=mode)
255 | if tf.reduce_sum(y_hat, 1) == self.token2idx["[PAD]"]:
256 | break
257 |
258 | _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1)
259 | ys = (input_ids_vae_decoder_enc, _decoder_inputs, output_ids_vae_decoder_dec, sents1, sents2)
260 |
261 | # monitor a random sample
262 | n = tf.random_uniform((), 0, tf.shape(y_hat)[0]-1, tf.int32)
263 | sent1 = sents1[n]
264 | pred = convert_idx_to_token_tensor(y_hat[n], self.idx2token)
265 | sent2 = sents2[n]
266 |
267 | tf.summary.text("sent1", sent1)
268 | tf.summary.text("pred", pred)
269 | tf.summary.text("sent2", sent2)
270 | summaries = tf.summary.merge_all()
271 |
272 | return y_hat, summaries
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python3
3 |
4 | from bert.bert_model_for_PAGE import get_shape_list
5 |
6 | import numpy as np
7 | import tensorflow as tf
8 |
9 | def ln(inputs, epsilon = 1e-8, scope="ln"):
10 | '''Applies layer normalization. See https://arxiv.org/abs/1607.06450.
11 | inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`.
12 | epsilon: A floating number. A very small number for preventing ZeroDivision Error.
13 | scope: Optional scope for `variable_scope`.
14 |
15 | Returns:
16 | A tensor with the same shape and data dtype as `inputs`.
17 | '''
18 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
19 | inputs_shape = inputs.get_shape()
20 | params_shape = inputs_shape[-1:]
21 |
22 | mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
23 | beta= tf.get_variable("beta", params_shape, initializer=tf.zeros_initializer())
24 | gamma = tf.get_variable("gamma", params_shape, initializer=tf.ones_initializer())
25 | normalized = (inputs - mean) / ( (variance + epsilon) ** (.5) )
26 | outputs = gamma * normalized + beta
27 |
28 | return outputs
29 |
30 | def get_token_embeddings(vocab_size, num_units, zero_pad=True):
31 | '''Constructs token embedding matrix.
32 | Note that the column of index 0's are set to zeros.
33 | vocab_size: scalar. V.
34 | num_units: embedding dimensionalty. E.
35 | zero_pad: Boolean. If True, all the values of the first row (id = 0) should be constant zero
36 | To apply query/key masks easily, zero pad is turned on.
37 |
38 | Returns
39 | weight variable: (V, E)
40 | '''
41 | with tf.variable_scope("shared_weight_matrix"):
42 | embeddings = tf.get_variable('weight_mat',
43 | dtype=tf.float32,
44 | shape=(vocab_size, num_units),
45 | initializer=tf.contrib.layers.xavier_initializer())
46 | if zero_pad:
47 | embeddings = tf.concat((tf.zeros(shape=[1, num_units]),
48 | embeddings[1:, :]), 0)
49 | return embeddings
50 |
51 | def scaled_dot_product_attention(seq_k, seq_q, Q, K, V,
52 | causality=False, dropout_rate=0.,
53 | training=True,
54 | num_heads=8,
55 | scope="scaled_dot_product_attention"):
56 | '''See 3.2.1.
57 | Q: Packed queries. 3d tensor. [N, T_q, d_k].
58 | K: Packed keys. 3d tensor. [N, T_k, d_k].
59 | V: Packed values. 3d tensor. [N, T_k, d_v].
60 | causality: If True, applies masking for future blinding
61 | dropout_rate: A floating point number of [0, 1].
62 | training: boolean for controlling droput
63 | scope: Optional scope for `variable_scope`.
64 | '''
65 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
66 | d_k = Q.get_shape().as_list()[-1]
67 |
68 | # dot product
69 | outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1])) # (N, T_q, T_k)
70 |
71 | # scale
72 | outputs /= d_k ** 0.5
73 |
74 | # key masking
75 | outputs = mask(seq_k, seq_q, outputs, num_heads,type="key")
76 |
77 | # causality or future blinding masking
78 | if causality:
79 | outputs = mask(seq_k, seq_q, outputs, num_heads,type="future")
80 |
81 | # softmax
82 | outputs = tf.nn.softmax(outputs)
83 |
84 | attention = tf.transpose(outputs, [0, 2, 1])
85 | tf.summary.image("attention", tf.expand_dims(attention[:1], -1))
86 |
87 | # query masking
88 | outputs = mask(seq_k, seq_q, outputs,num_heads, type="query")
89 |
90 | # dropout
91 | outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=training)
92 |
93 | # weighted sum (context vectors)
94 | outputs = tf.matmul(outputs, V) # (N, T_q, d_v)
95 |
96 | return outputs
97 |
98 | def mask(seq_k, seq_q, inputs,num_heads, type=None):
99 |
100 | padding_num = -2 ** 32 + 1
101 | if type in ("k", "key", "keys"):
102 | # Generate masks
103 | masks = get_attn_key_pad_mask(seq_k=seq_k, seq_q=seq_q, PAD_ID=0)
104 | # Apply masks to inputs
105 | masks = tf.tile(masks,[num_heads, 1, 1])
106 | paddings = tf.ones_like(inputs) * padding_num
107 | outputs = tf.where(tf.equal(masks, 0), paddings, inputs) # (N, T_q, T_k)
108 | elif type in ("q", "query", "queries"):
109 | # Generate masks
110 | masks = get_non_pad_mask(seq_k=seq_k, seq_q=seq_q, PAD_ID=0)
111 | masks = tf.tile(masks, [num_heads, 1, 1])
112 | # Apply masks to inputs
113 | outputs = inputs*masks
114 |
115 | elif type in ("f", "future", "right"):
116 | masks = get_subsequent_mask(inputs)
117 | # masks = tf.tile(masks, [num_heads, 1, 1])
118 |
119 | paddings = tf.ones_like(masks) * padding_num
120 | outputs = tf.where(tf.equal(masks, 0), paddings, inputs)
121 | else:
122 | print("Check if you entered type correctly!")
123 |
124 | return outputs
125 |
126 |
127 |
128 |
129 |
130 |
131 | def get_non_pad_mask(seq_k, seq_q, PAD_ID):
132 | '''
133 |
134 | :param seq: [batch_size,len]
135 | :return:
136 | '''
137 | '''
138 | masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q)
139 | masks = tf.expand_dims(masks, -1) # (N, T_q, 1)
140 | masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k)
141 | '''
142 | masks = tf.cast(tf.math.not_equal(seq_q, PAD_ID), tf.float32)
143 | masks = tf.expand_dims(masks, -1) # (N, T_q, 1)
144 | masks = tf.tile(masks, [1, 1, get_shape_list(seq_k)[1]]) # (N, T_q, T_k)
145 | return masks
146 |
147 |
148 | def get_attn_key_pad_mask(seq_k,seq_q,PAD_ID):
149 | ''' For masking out the padding part of key sequence. '''
150 | # print(get_shape_list(seq_k,expected_rank=2))
151 |
152 | padding_mask = tf.cast(tf.math.not_equal(seq_k,PAD_ID),tf.float32)
153 | masks = tf.expand_dims(padding_mask, 1) # (N, 1, T_k)
154 | masks = tf.tile(masks, [1, get_shape_list(seq_q,expected_rank=2)[1], 1]) # (N, T_q, T_k)
155 | # print(get_shape_list(masks,expected_rank=3))
156 | return masks
157 |
158 |
159 |
160 |
161 | def get_subsequent_mask(inputs):
162 | ''' For masking out the subsequent info. '''
163 |
164 | diag_vals = tf.ones_like(inputs[0, :, :]) # (T_q, T_k)
165 | tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
166 | masks = tf.tile(tf.expand_dims(tril, 0), [get_shape_list(inputs)[0], 1, 1]) # (N, T_q, T_k)
167 |
168 | return masks
169 |
170 |
171 |
172 |
173 | def multihead_attention(seq_k,
174 | seq_q,
175 | queries, keys, values,
176 | num_heads=8,
177 | dropout_rate=0,
178 | training=True,
179 | causality=False,
180 | scope="multihead_attention"):
181 | '''Applies multihead attention. See 3.2.2
182 | queries: A 3d tensor with shape of [N, T_q, d_model].
183 | keys: A 3d tensor with shape of [N, T_k, d_model].
184 | values: A 3d tensor with shape of [N, T_k, d_model].
185 | num_heads: An int. Number of heads.
186 | dropout_rate: A floating point number.
187 | training: Boolean. Controller of mechanism for dropout.
188 | causality: Boolean. If true, units that reference the future are masked.
189 | scope: Optional scope for `variable_scope`.
190 |
191 | Returns
192 | A 3d tensor with shape of (N, T_q, C)
193 | '''
194 | d_model = queries.get_shape().as_list()[-1]
195 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
196 | # Linear projections
197 | Q = tf.layers.dense(queries, d_model, use_bias=False) # (N, T_q, d_model)
198 | K = tf.layers.dense(keys, d_model, use_bias=False) # (N, T_k, d_model)
199 | V = tf.layers.dense(values, d_model, use_bias=False) # (N, T_k, d_model)
200 |
201 | # Split and concat
202 | Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, d_model/h)
203 | K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h)
204 | V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h)
205 |
206 | # Attention
207 | outputs = scaled_dot_product_attention(seq_k, seq_q, Q_, K_, V_, causality, dropout_rate, training,num_heads)
208 |
209 | # Restore shape
210 | outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2 ) # (N, T_q, d_model)
211 |
212 | # Residual connection
213 | outputs += queries
214 |
215 | # Normalize
216 | outputs = ln(outputs)
217 |
218 | return outputs
219 |
220 | def ff(inputs, num_units, scope="positionwise_feedforward"):
221 | '''position-wise feed forward net. See 3.3
222 |
223 | inputs: A 3d tensor with shape of [N, T, C].
224 | num_units: A list of two integers.
225 | scope: Optional scope for `variable_scope`.
226 |
227 | Returns:
228 | A 3d tensor with the same shape and dtype as inputs
229 | '''
230 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
231 | # Inner layer
232 | outputs = tf.layers.dense(inputs, num_units[0], activation=tf.nn.relu)
233 |
234 | # Outer layer
235 | outputs = tf.layers.dense(outputs, num_units[1])
236 |
237 | # Residual connection
238 | outputs += inputs
239 |
240 | # Normalize
241 | outputs = ln(outputs)
242 |
243 | return outputs
244 |
245 | def label_smoothing(inputs, epsilon=0.1):
246 | '''Applies label smoothing. See 5.4 and https://arxiv.org/abs/1512.00567.
247 | inputs: 3d tensor. [N, T, V], where V is the number of vocabulary.
248 | epsilon: Smoothing rate.
249 |
250 | For example,
251 |
252 | ```
253 | import tensorflow as tf
254 | inputs = tf.convert_to_tensor([[[0, 0, 1],
255 | [0, 1, 0],
256 | [1, 0, 0]],
257 |
258 | [[1, 0, 0],
259 | [1, 0, 0],
260 | [0, 1, 0]]], tf.float32)
261 |
262 | outputs = label_smoothing(inputs)
263 |
264 | with tf.Session() as sess:
265 | print(sess.run([outputs]))
266 |
267 | >>
268 | [array([[[ 0.03333334, 0.03333334, 0.93333334],
269 | [ 0.03333334, 0.93333334, 0.03333334],
270 | [ 0.93333334, 0.03333334, 0.03333334]],
271 |
272 | [[ 0.93333334, 0.03333334, 0.03333334],
273 | [ 0.93333334, 0.03333334, 0.03333334],
274 | [ 0.03333334, 0.93333334, 0.03333334]]], dtype=float32)]
275 | ```
276 | '''
277 | V = inputs.get_shape().as_list()[-1] # number of channels
278 | return ((1-epsilon) * inputs) + (epsilon / V)
279 |
280 | def positional_encoding(inputs,
281 | maxlen,
282 | masking=True,
283 | scope="positional_encoding"):
284 | '''Sinusoidal Positional_Encoding. See 3.5
285 | inputs: 3d tensor. (N, T, E)
286 | maxlen: scalar. Must be >= T
287 | masking: Boolean. If True, padding positions are set to zeros.
288 | scope: Optional scope for `variable_scope`.
289 |
290 | returns
291 | 3d tensor that has the same shape as inputs.
292 | '''
293 |
294 | E = inputs.get_shape().as_list()[-1] # static
295 | N, T = tf.shape(inputs)[0], tf.shape(inputs)[1] # dynamic
296 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
297 | # position indices
298 | position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1]) # (N, T)
299 |
300 | # First part of the PE function: sin and cos argument
301 | position_enc = np.array([
302 | [pos / np.power(10000, (i-i%2)/E) for i in range(E)]
303 | for pos in range(maxlen)])
304 |
305 | # Second part, apply the cosine to even columns and sin to odds.
306 | position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i
307 | position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1
308 | position_enc = tf.convert_to_tensor(position_enc, tf.float32) # (maxlen, E)
309 |
310 | # lookup
311 | outputs = tf.nn.embedding_lookup(position_enc, position_ind)
312 |
313 | # masks
314 | if masking:
315 | outputs = tf.where(tf.equal(inputs, 0), inputs, outputs)
316 |
317 | return tf.to_float(outputs)
318 |
319 | def noam_scheme(init_lr, global_step, warmup_steps=4000.):
320 | '''Noam scheme learning rate decay
321 | init_lr: initial learning rate. scalar.
322 | global_step: scalar.
323 | warmup_steps: scalar. During warmup_steps, learning rate increases
324 | until it reaches init_lr.
325 | '''
326 | step = tf.cast(global_step + 1, dtype=tf.float32)
327 | return init_lr * warmup_steps ** 0.5 * tf.minimum(step * warmup_steps ** -1.5, step ** -0.5)
--------------------------------------------------------------------------------
/bert/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with open(vocab_file, mode="r",encoding="utf-8") as fr:
126 | for token in fr:
127 | if not token:
128 | break
129 | token = token.strip()
130 | vocab[token] = index
131 | index += 1
132 | return vocab
133 |
134 |
135 | def convert_by_vocab(vocab, items):
136 | """Converts a sequence of [tokens|ids] using the vocab."""
137 | output = []
138 | for item in items:
139 | output.append(vocab[item])
140 | return output
141 |
142 |
143 | def convert_tokens_to_ids(vocab, tokens):
144 | return convert_by_vocab(vocab, tokens)
145 |
146 |
147 | def convert_ids_to_tokens(inv_vocab, ids):
148 | return convert_by_vocab(inv_vocab, ids)
149 |
150 |
151 | def whitespace_tokenize(text):
152 | """Runs basic whitespace cleaning and splitting on a piece of text."""
153 | text = text.strip()
154 | if not text:
155 | return []
156 | tokens = text.split()
157 | return tokens
158 |
159 |
160 | class FullTokenizer(object):
161 | """Runs end-to-end tokenziation."""
162 |
163 | def __init__(self, vocab_file, do_lower_case=True):
164 | self.vocab = load_vocab(vocab_file)
165 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
166 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
167 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
168 |
169 | def tokenize(self, text):
170 | split_tokens = []
171 | for token in self.basic_tokenizer.tokenize(text):
172 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
173 | split_tokens.append(sub_token)
174 |
175 | return split_tokens
176 |
177 | def convert_tokens_to_ids(self, tokens):
178 | return convert_by_vocab(self.vocab, tokens)
179 |
180 | def convert_ids_to_tokens(self, ids):
181 | return convert_by_vocab(self.inv_vocab, ids)
182 |
183 |
184 | class BasicTokenizer(object):
185 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
186 |
187 | def __init__(self, do_lower_case=True):
188 | """Constructs a BasicTokenizer.
189 |
190 | Args:
191 | do_lower_case: Whether to lower case the input.
192 | """
193 | self.do_lower_case = do_lower_case
194 |
195 | def tokenize(self, text):
196 | """Tokenizes a piece of text."""
197 | text = convert_to_unicode(text)
198 | text = self._clean_text(text)
199 |
200 | # This was added on November 1st, 2018 for the multilingual and Chinese
201 | # models. This is also applied to the English models now, but it doesn't
202 | # matter since the English models were not trained on any Chinese data
203 | # and generally don't have any Chinese data in them (there are Chinese
204 | # characters in the vocabulary because Wikipedia does have some Chinese
205 | # words in the English Wikipedia.).
206 | text = self._tokenize_chinese_chars(text)
207 |
208 | orig_tokens = whitespace_tokenize(text)
209 | split_tokens = []
210 | for token in orig_tokens:
211 | if self.do_lower_case:
212 | token = token.lower()
213 | token = self._run_strip_accents(token)
214 | split_tokens.extend(self._run_split_on_punc(token))
215 |
216 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
217 | return output_tokens
218 |
219 | def _run_strip_accents(self, text):
220 | """Strips accents from a piece of text."""
221 | text = unicodedata.normalize("NFD", text)
222 | output = []
223 | for char in text:
224 | cat = unicodedata.category(char)
225 | if cat == "Mn":
226 | continue
227 | output.append(char)
228 | return "".join(output)
229 |
230 | def _run_split_on_punc(self, text):
231 | """Splits punctuation on a piece of text."""
232 | chars = list(text)
233 | i = 0
234 | start_new_word = True
235 | output = []
236 | while i < len(chars):
237 | char = chars[i]
238 | if _is_punctuation(char):
239 | output.append([char])
240 | start_new_word = True
241 | else:
242 | if start_new_word:
243 | output.append([])
244 | start_new_word = False
245 | output[-1].append(char)
246 | i += 1
247 |
248 | return ["".join(x) for x in output]
249 |
250 | def _tokenize_chinese_chars(self, text):
251 | """Adds whitespace around any CJK character."""
252 | output = []
253 | for char in text:
254 | cp = ord(char)
255 | if self._is_chinese_char(cp):
256 | output.append(" ")
257 | output.append(char)
258 | output.append(" ")
259 | else:
260 | output.append(char)
261 | return "".join(output)
262 |
263 | def _is_chinese_char(self, cp):
264 | """Checks whether CP is the codepoint of a CJK character."""
265 | # This defines a "chinese character" as anything in the CJK Unicode block:
266 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
267 | #
268 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
269 | # despite its name. The modern Korean Hangul alphabet is a different block,
270 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
271 | # space-separated words, so they are not treated specially and handled
272 | # like the all of the other languages.
273 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
274 | (cp >= 0x3400 and cp <= 0x4DBF) or #
275 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
276 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
277 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
278 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
279 | (cp >= 0xF900 and cp <= 0xFAFF) or #
280 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
281 | return True
282 |
283 | return False
284 |
285 | def _clean_text(self, text):
286 | """Performs invalid character removal and whitespace cleanup on text."""
287 | output = []
288 | for char in text:
289 | cp = ord(char)
290 | if cp == 0 or cp == 0xfffd or _is_control(char):
291 | continue
292 | if _is_whitespace(char):
293 | output.append(" ")
294 | else:
295 | output.append(char)
296 | return "".join(output)
297 |
298 |
299 | class WordpieceTokenizer(object):
300 | """Runs WordPiece tokenziation."""
301 |
302 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
303 | self.vocab = vocab
304 | self.unk_token = unk_token
305 | self.max_input_chars_per_word = max_input_chars_per_word
306 |
307 | def tokenize(self, text):
308 | """Tokenizes a piece of text into its word pieces.
309 |
310 | This uses a greedy longest-match-first algorithm to perform tokenization
311 | using the given vocabulary.
312 |
313 | For example:
314 | input = "unaffable"
315 | output = ["un", "##aff", "##able"]
316 |
317 | Args:
318 | text: A single token or whitespace separated tokens. This should have
319 | already been passed through `BasicTokenizer.
320 |
321 | Returns:
322 | A list of wordpiece tokens.
323 | """
324 |
325 | text = convert_to_unicode(text)
326 |
327 | output_tokens = []
328 | for token in whitespace_tokenize(text):
329 | chars = list(token)
330 | if len(chars) > self.max_input_chars_per_word:
331 | output_tokens.append(self.unk_token)
332 | continue
333 |
334 | is_bad = False
335 | start = 0
336 | sub_tokens = []
337 | while start < len(chars):
338 | end = len(chars)
339 | cur_substr = None
340 | while start < end:
341 | substr = "".join(chars[start:end])
342 | if start > 0:
343 | substr = "##" + substr
344 | if substr in self.vocab:
345 | cur_substr = substr
346 | break
347 | end -= 1
348 | if cur_substr is None:
349 | is_bad = True
350 | break
351 | sub_tokens.append(cur_substr)
352 | start = end
353 |
354 | if is_bad:
355 | output_tokens.append(self.unk_token)
356 | else:
357 | output_tokens.extend(sub_tokens)
358 | return output_tokens
359 |
360 |
361 | def _is_whitespace(char):
362 | """Checks whether `chars` is a whitespace character."""
363 | # \t, \n, and \r are technically contorl characters but we treat them
364 | # as whitespace since they are generally considered as such.
365 | if char == " " or char == "\t" or char == "\n" or char == "\r":
366 | return True
367 | cat = unicodedata.category(char)
368 | if cat == "Zs":
369 | return True
370 | return False
371 |
372 |
373 | def _is_control(char):
374 | """Checks whether `chars` is a control character."""
375 | # These are technically control characters but we count them as whitespace
376 | # characters.
377 | if char == "\t" or char == "\n" or char == "\r":
378 | return False
379 | cat = unicodedata.category(char)
380 | if cat in ("Cc", "Cf"):
381 | return True
382 | return False
383 |
384 |
385 | def _is_punctuation(char):
386 | """Checks whether `chars` is a punctuation character."""
387 | cp = ord(char)
388 | # We treat all non-letter/number ASCII as punctuation.
389 | # Characters such as "^", "$", and "`" are not in the Unicode
390 | # Punctuation class but we treat them as punctuation anyways, for
391 | # consistency.
392 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
393 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
394 | return True
395 | cat = unicodedata.category(char)
396 | if cat.startswith("P"):
397 | return True
398 | return False
399 |
--------------------------------------------------------------------------------
/data_load.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python3
3 |
4 | import tensorflow as tf
5 | from utils import calc_num_batches
6 | import copy
7 | import collections
8 | from bert import tokenization
9 |
10 |
11 |
12 | def load_vocab(vocab_fpath):
13 | '''Loads vocabulary file and returns idx<->token maps
14 | vocab_fpath: string. vocabulary file path.
15 | Note that these are reserved
16 | 0: , 1: , 2: , 3:
17 | 注意 这里
18 | 0:pad
19 | [CLS]的位置 我让他学习复述规则
20 | [UNK]
21 | [SEP] 两个句子分隔符号 同时也是句子结束符号
22 | [MASK]
23 | 一个句子的开始符号
24 | Returns
25 | two dictionaries.
26 | '''
27 | vocab = collections.OrderedDict()
28 | index = 0
29 | with open(vocab_fpath, mode="r", encoding="utf-8") as fr:
30 | for token in fr:
31 | print(token)
32 | if not token:
33 | break
34 | token = token.strip()
35 | vocab[token] = index
36 | index += 1
37 |
38 | token2idx = {token: idx for idx, token in enumerate(vocab)}
39 | idx2token = {idx: token for idx, token in enumerate(vocab)}
40 | return token2idx, idx2token
41 |
42 |
43 |
44 | def load_data_for_test(fpath):
45 | '''
46 |
47 | :param fpath:文件的路径
48 | :param sep 分隔符
49 | :return:
50 | '''
51 | sents1, sents2 = [], []
52 | with open(fpath, mode='r',encoding="utf-8") as fr:
53 | for line in fr:
54 | line = line.strip()
55 | if line is not None and line != "":
56 | assert len(line.strip())>3
57 | sents1.append(line.strip())
58 | sents2=copy.deepcopy(sents1)
59 | return sents1, sents2
60 |
61 |
62 |
63 |
64 | def load_data_for_train_or_eval(fpath, sep="---xhm--"):
65 | '''
66 |
67 | :param fpath:文件的路径
68 | :param sep 分隔符
69 | :return:
70 | '''
71 | sents1, sents2 = [], []
72 | num = 0
73 | with open(fpath, mode='r',encoding="utf-8") as fr:
74 | for line in fr:
75 | num+=1
76 | if num%1000000==0:
77 | print("数据正在预处理请你耐心等待。。。。",num)
78 | line = line.strip()
79 | if line is not None and line != "":
80 | sens = line.split(sep)
81 | if len(sens)!=3:
82 | print("语料格式错误的数据:",line)
83 | continue
84 | else:
85 | # assert len(sens[0].strip())>=3
86 | # assert len(sens[1].strip())>=3
87 | #
88 | if len(sens[1].strip())<=3 or len(sens[2].strip())<=3:
89 | print("数据长度不合格",line)
90 | continue
91 | sents1.append(sens[1].strip())
92 | sents2.append(sens[2].strip())
93 | return sents1, sents2
94 |
95 |
96 |
97 |
98 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
99 | """Truncates a sequence pair in place to the maximum length."""
100 |
101 | # This is a simple heuristic which will always truncate the longer sequence
102 | # one token at a time. This makes more sense than truncating an equal percent
103 | # of tokens from each, since if one sequence is very short then each token
104 | # that's truncated likely contains more information than a longer sequence.
105 | while True:
106 | total_length = len(tokens_a) + len(tokens_b)
107 | if total_length <= max_length:
108 | break
109 | if len(tokens_a) > len(tokens_b):
110 | tokens_a.pop()
111 | else:
112 | tokens_b.pop()
113 |
114 |
115 |
116 | # 许海明
117 | class InputFeatures(object):
118 | """A single set of features of data."""
119 | '''
120 | yield (input_ids_vae_encoder, input_masks_vae_encoder, segment_ids_vae_encoder, sent1, sent2), \
121 | (input_ids_vae_decoder_enc, input_ids_vae_decoder_dec, output_ids_vae_decoder_dec, sent1, sent2)
122 | '''
123 |
124 |
125 | def __init__(self,
126 | input_ids_vae_encoder,
127 | input_masks_vae_encoder,
128 | segment_ids_vae_encoder,
129 | input_ids_vae_decoder_enc,
130 | input_ids_vae_decoder_dec,
131 | output_ids_vae_decoder_dec,
132 | sent1,
133 | sent2,
134 | is_real_example=True):
135 |
136 | self.input_ids_vae_encoder=input_ids_vae_encoder
137 | self.input_masks_vae_encoder=input_masks_vae_encoder
138 | self.segment_ids_vae_encoder=segment_ids_vae_encoder
139 | self.input_ids_vae_decoder_enc=input_ids_vae_decoder_enc
140 | self.input_ids_vae_decoder_dec=input_ids_vae_decoder_dec
141 | self.output_ids_vae_decoder_dec=output_ids_vae_decoder_dec
142 | self.sent1=sent1
143 | self.sent2=sent2
144 | self.is_real_example = is_real_example
145 |
146 |
147 |
148 |
149 | # 许海明
150 | def convert_single_example(ex_index,sent1, sent2,
151 | maxlen_vae_Encoder,
152 | maxlen_vae_Decoder_en,
153 | maxlen_vae_Decoder_de, tokenizer):
154 |
155 |
156 |
157 | sent1_c = copy.deepcopy(sent1)
158 | sent2_c = copy.deepcopy(sent2)
159 |
160 | '''
161 | # VAE 的编码器的输入 简单就是一个bert的两句话的合在一起
162 | '''
163 | tokens_sent1 = tokenizer.tokenize(sent1)
164 | tokens_sent2 = tokenizer.tokenize(sent2)
165 | _truncate_seq_pair(tokens_sent1, tokens_sent2, maxlen_vae_Encoder - 3)
166 | tokens_vae_encoder = []
167 | segment_ids_vae_encoder = []
168 |
169 | tokens_vae_encoder.append("[CLS]")
170 | segment_ids_vae_encoder.append(0)
171 |
172 | for token in tokens_sent1:
173 | tokens_vae_encoder.append(token)
174 | segment_ids_vae_encoder.append(0)
175 |
176 | tokens_vae_encoder.append("[SEP]")
177 | segment_ids_vae_encoder.append(0)
178 |
179 | for token in tokens_sent2:
180 | tokens_vae_encoder.append(token)
181 | segment_ids_vae_encoder.append(1)
182 |
183 | tokens_vae_encoder.append("[SEP]")
184 | segment_ids_vae_encoder.append(1)
185 |
186 | input_ids_vae_encoder = tokenizer.convert_tokens_to_ids(tokens_vae_encoder)
187 | input_masks_vae_encoder = [1] * len(input_ids_vae_encoder)
188 | # print("\n\ntokens_vae_encoder",tokens_vae_encoder)
189 | # print("\n\ninput_ids_vae_encoder",input_ids_vae_encoder)
190 | while len(input_ids_vae_encoder) < maxlen_vae_Encoder:
191 | input_ids_vae_encoder.append(0)
192 | input_masks_vae_encoder.append(0)
193 | segment_ids_vae_encoder.append(0)
194 |
195 | assert len(input_ids_vae_encoder) == maxlen_vae_Encoder
196 | assert len(input_masks_vae_encoder) == maxlen_vae_Encoder
197 | assert len(segment_ids_vae_encoder) == maxlen_vae_Encoder
198 |
199 | # vae的解码器的编码器输入 用符号 3
200 | tokens_sent3 = tokenizer.tokenize(sent1_c)
201 |
202 | if len(tokens_sent3) > maxlen_vae_Decoder_en - 1:
203 | tokens_sent3 = tokens_sent3[0:(maxlen_vae_Decoder_en - 1)]
204 |
205 | tokens_vae_decoder_enc = []
206 | for token in tokens_sent3:
207 | tokens_vae_decoder_enc.append(token)
208 | # segment_ids_vae_decoder_enc.append(0)
209 | tokens_vae_decoder_enc.append("[SEP]")
210 |
211 | # segment_ids_vae_decoder_enc.append(0)
212 | input_ids_vae_decoder_enc = tokenizer.convert_tokens_to_ids(tokens_vae_decoder_enc)
213 | # print("\n\ntokens_vae_decoder_enc:",tokens_vae_decoder_enc)
214 | # print("input_ids_vae_decoder_enc:", input_ids_vae_decoder_enc)
215 |
216 | # input_mask_vae_decoder_enc = [1] * len(input_ids_vae_decoder_enc)
217 | while len(input_ids_vae_decoder_enc) < maxlen_vae_Decoder_en:
218 | input_ids_vae_decoder_enc.append(0)
219 | # input_mask_vae_decoder_enc.append(0)
220 | # segment_ids_vae_decoder_enc.append(0)
221 |
222 | # vae解码器的解码器的输入和输出 用符号 4
223 | # 这是训练 PAGE模型的代码
224 | tokens_sent4 = tokenizer.tokenize(sent2_c)
225 |
226 | if len(tokens_sent4) > maxlen_vae_Decoder_de - 1:
227 | tokens_sent4 = tokens_sent4[0:(maxlen_vae_Decoder_de - 1)]
228 |
229 | tokens_sent4_input = [""]
230 | tokens_sent4_input.extend(copy.copy(tokens_sent4))
231 | tokens_sent4_output = copy.copy(tokens_sent4)
232 | tokens_sent4_output.append("[SEP]")
233 |
234 | input_ids_vae_decoder_dec = tokenizer.convert_tokens_to_ids(tokens_sent4_input)
235 | output_ids_vae_decoder_dec = tokenizer.convert_tokens_to_ids(tokens_sent4_output)
236 |
237 |
238 | while len(input_ids_vae_decoder_dec) < maxlen_vae_Decoder_de:
239 | input_ids_vae_decoder_dec.append(0)
240 | output_ids_vae_decoder_dec.append(0)
241 | # input_mask_vae_decoder_dec.append(0)
242 |
243 | assert len(input_ids_vae_decoder_dec) == maxlen_vae_Decoder_de
244 | assert len(output_ids_vae_decoder_dec) == maxlen_vae_Decoder_de
245 | # assert len(input_mask_vae_decoder_dec) == maxlen_vae_Decoder_de
246 |
247 | # x = encode(sent1, "x", token2idx)
248 | # y = encode(sent2, "y", token2idx)
249 | # decoder_input, y = y[:-1], y[1:]
250 | #
251 | # x_seqlen, y_seqlen = len(x), len(y)
252 | if ex_index<=3:
253 | print("*** Example ***")
254 | print("guid: %s" % (ex_index))
255 | print("\n\ntokens_vae_encoder",tokens_vae_encoder)
256 | print("input_ids_vae_encoder",input_ids_vae_encoder)
257 |
258 |
259 | print("\n\ntokens_vae_decoder_enc:", tokens_vae_decoder_enc)
260 | print("input_ids_vae_decoder_enc:", input_ids_vae_decoder_enc)
261 |
262 |
263 |
264 | print("\n\n tokens_input_vae_decoder_dec", tokens_sent4_input)
265 | print("input_ids_vae_decoder_dec", input_ids_vae_decoder_dec)
266 |
267 | print("\n\n tokens_output_vae_decoder_dec", tokens_sent4_output)
268 | print("output_ids_vae_decoder_dec", output_ids_vae_decoder_dec)
269 |
270 | feature = InputFeatures( input_ids_vae_encoder=input_ids_vae_encoder,
271 | input_masks_vae_encoder=input_masks_vae_encoder,
272 | segment_ids_vae_encoder=segment_ids_vae_encoder,
273 | input_ids_vae_decoder_enc=input_ids_vae_decoder_enc,
274 | input_ids_vae_decoder_dec=input_ids_vae_decoder_dec,
275 | output_ids_vae_decoder_dec=output_ids_vae_decoder_dec,
276 | sent1=sent1,
277 | sent2=sent2,
278 | is_real_example=True)
279 | return feature
280 |
281 |
282 |
283 |
284 | # 许海明
285 | def file_based_convert_examples_to_features(sentS1,
286 | sentS2,
287 | maxlen_vae_Encoder,
288 | maxlen_vae_Decoder_en,
289 | maxlen_vae_Decoder_de,
290 | tokenizer,
291 | output_file):
292 |
293 | writer = tf.python_io.TFRecordWriter(output_file)
294 |
295 | for (ex_index, (sent1,sent2)) in enumerate(zip(sentS1,sentS2)):
296 | if ex_index % 500000 == 0:
297 | print("Writing example %d of %d" % (ex_index, len(sentS1)))
298 |
299 | feature = convert_single_example(ex_index,sent1, sent2,
300 | maxlen_vae_Encoder,
301 | maxlen_vae_Decoder_en,
302 | maxlen_vae_Decoder_de, tokenizer)
303 |
304 | def create_int_feature(values):
305 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
306 | return f
307 |
308 | def create_bytes_feature(value):
309 | """Returns a bytes_list from a string / byte."""
310 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))
311 |
312 | features = collections.OrderedDict()
313 |
314 | features["input_ids_vae_encoder"] = create_int_feature(feature.input_ids_vae_encoder)
315 | features["input_masks_vae_encoder"] = create_int_feature(feature.input_masks_vae_encoder)
316 | features["segment_ids_vae_encoder"] = create_int_feature(feature.segment_ids_vae_encoder)
317 | features["input_ids_vae_decoder_enc"] = create_int_feature(feature.input_ids_vae_decoder_enc)
318 | features["input_ids_vae_decoder_dec"] = create_int_feature(feature.input_ids_vae_decoder_dec)
319 | features["output_ids_vae_decoder_dec"] = create_int_feature(feature.output_ids_vae_decoder_dec)
320 |
321 | features["sent1"] = create_bytes_feature(feature.sent1)
322 | features["sent2"] = create_bytes_feature(feature.sent2)
323 |
324 |
325 | features["is_real_example"] = create_int_feature( [int(feature.is_real_example)] )
326 |
327 |
328 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
329 | writer.write(tf_example.SerializeToString())
330 | writer.close()
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 | def file_based_input_fn_builder(input_file,
341 | maxlen_vae_Encoder,
342 | maxlen_vae_Decoder_en,
343 | maxlen_vae_Decoder_de,
344 | is_training):
345 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
346 |
347 | name_to_features = {
348 | "input_ids_vae_encoder": tf.FixedLenFeature([maxlen_vae_Encoder], tf.int64),
349 | "input_masks_vae_encoder": tf.FixedLenFeature([maxlen_vae_Encoder], tf.int64),
350 | "segment_ids_vae_encoder": tf.FixedLenFeature([maxlen_vae_Encoder], tf.int64),
351 | "input_ids_vae_decoder_enc": tf.FixedLenFeature([maxlen_vae_Decoder_en], tf.int64),
352 | "input_ids_vae_decoder_dec": tf.FixedLenFeature([maxlen_vae_Decoder_de], tf.int64),
353 | "output_ids_vae_decoder_dec": tf.FixedLenFeature([maxlen_vae_Decoder_de], tf.int64),
354 | "sent1": tf.FixedLenFeature([], tf.string),
355 | "sent2": tf.FixedLenFeature([], tf.string),
356 | "is_real_example": tf.FixedLenFeature([], tf.int64),
357 | }
358 |
359 | def _decode_record(record, name_to_features):
360 | """Decodes a record to a TensorFlow example."""
361 | example = tf.parse_single_example(record, name_to_features)
362 |
363 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
364 | # So cast all int64 to int32.
365 | for name in list(example.keys()):
366 | t = example[name]
367 | if t.dtype == tf.int64:
368 | t = tf.to_int32(t)
369 | example[name] = t
370 |
371 | return example
372 |
373 | def input_fn(batch_size):
374 | """The actual input function."""
375 | d = tf.data.TFRecordDataset(input_file)
376 | if is_training:
377 | d = d.repeat()
378 | d = d.shuffle(buffer_size=100)
379 |
380 | d = d.apply(
381 | tf.contrib.data.map_and_batch(
382 | lambda record: _decode_record(record, name_to_features),
383 | batch_size=batch_size))
384 |
385 | return d
386 |
387 | return input_fn
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 | def saveForTfRecord(fpath,
399 | maxlen_vae_Encoder,
400 | maxlen_vae_Decoder_en,
401 | maxlen_vae_Decoder_de,
402 | vocab_fpath,
403 | output_file,
404 | is_test=False):
405 | '''
406 |
407 | :param fpath:
408 | :param maxlen_vae_Encoder:
409 | :param maxlen_vae_Decoder_en:
410 | :param maxlen_vae_Decoder_de:
411 | :param vocab_fpath:
412 | :param output_dir: 这个是输出保存路径
413 | :param is_test: 是否是测试数据
414 | :return:
415 | '''
416 | if is_test:
417 | # 是测试流程
418 | sents1, sents2 = load_data_for_test(fpath)
419 | else:
420 | sents1, sents2 = load_data_for_train_or_eval(fpath)
421 | # sents1, sents2 = load_data(fpath1, fpath2, maxlen1, maxlen2)
422 |
423 | print("读取完成")
424 |
425 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_fpath, do_lower_case=True)
426 | # output_file = os.path.join(output_dir, "train.tf_record")
427 | file_based_convert_examples_to_features(sents1,
428 | sents2,
429 | maxlen_vae_Encoder,
430 | maxlen_vae_Decoder_en,
431 | maxlen_vae_Decoder_de,
432 | tokenizer,
433 | output_file)
434 |
435 |
436 |
437 |
438 |
439 |
440 | def get_batch_for_train_or_dev_or_test( fpath,
441 | maxlen_vae_Encoder,
442 | maxlen_vae_Decoder_en,
443 | maxlen_vae_Decoder_de,
444 | batch_size,
445 | input_file,
446 | is_training,
447 | is_test):
448 | '''
449 |
450 | Returns
451 | batches
452 | num_batches: number of mini-batches
453 | num_samples
454 | '''
455 |
456 | if is_test:
457 | # 是测试流程
458 | sents1, sents2 = load_data_for_test(fpath)
459 | else:
460 | sents1, sents2 = load_data_for_train_or_eval(fpath)
461 | # sents1, sents2 = load_data(fpath1, fpath2, maxlen1, maxlen2)
462 |
463 | input_fn=file_based_input_fn_builder(input_file,
464 | maxlen_vae_Encoder,
465 | maxlen_vae_Decoder_en,
466 | maxlen_vae_Decoder_de,
467 | is_training)
468 |
469 | batches=input_fn(batch_size)
470 | num_batches = calc_num_batches(len(sents1), batch_size)
471 | return batches, num_batches, len(sents1)
472 |
473 |
474 |
475 |
476 |
477 |
478 |
--------------------------------------------------------------------------------
/bert/bert_model_for_PAGE.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """The main BERT model and related functions."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import copy
23 | import json
24 | import math
25 | import re
26 | import numpy as np
27 | import six
28 | import tensorflow as tf
29 |
30 |
31 | class BertConfig(object):
32 | """Configuration for `BertModel`."""
33 |
34 | def __init__(self,
35 | vocab_size,
36 | hidden_size=768,
37 | num_hidden_layers=12,
38 | num_attention_heads=12,
39 | intermediate_size=3072,
40 | hidden_act="gelu",
41 | hidden_dropout_prob=0.1,
42 | attention_probs_dropout_prob=0.1,
43 | max_position_embeddings=512,
44 | type_vocab_size=16,
45 | initializer_range=0.02):
46 | """Constructs BertConfig.
47 |
48 | Args:
49 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
50 | hidden_size: Size of the encoder layers and the pooler layer.
51 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
52 | num_attention_heads: Number of attention heads for each attention layer in
53 | the Transformer encoder.
54 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
55 | layer in the Transformer encoder.
56 | hidden_act: The non-linear activation function (function or string) in the
57 | encoder and pooler.
58 | hidden_dropout_prob: The dropout probability for all fully connected
59 | layers in the embeddings, encoder, and pooler.
60 | attention_probs_dropout_prob: The dropout ratio for the attention
61 | probabilities.
62 | max_position_embeddings: The maximum sequence length that this model might
63 | ever be used with. Typically set this to something large just in case
64 | (e.g., 512 or 1024 or 2048).
65 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
66 | `BertModel`.
67 | initializer_range: The stdev of the truncated_normal_initializer for
68 | initializing all weight matrices.
69 | """
70 | self.vocab_size = vocab_size
71 | self.hidden_size = hidden_size
72 | self.num_hidden_layers = num_hidden_layers
73 | self.num_attention_heads = num_attention_heads
74 | self.hidden_act = hidden_act
75 | self.intermediate_size = intermediate_size
76 | self.hidden_dropout_prob = hidden_dropout_prob
77 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
78 | self.max_position_embeddings = max_position_embeddings
79 | self.type_vocab_size = type_vocab_size
80 | self.initializer_range = initializer_range
81 |
82 | @classmethod
83 | def from_dict(cls, json_object):
84 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
85 | config = BertConfig(vocab_size=None)
86 | for (key, value) in six.iteritems(json_object):
87 | config.__dict__[key] = value
88 | return config
89 |
90 | @classmethod
91 | def from_json_file(cls, json_file):
92 | """Constructs a `BertConfig` from a json file of parameters."""
93 | with tf.gfile.GFile(json_file, "r") as reader:
94 | text = reader.read()
95 | return cls.from_dict(json.loads(text))
96 |
97 | def to_dict(self):
98 | """Serializes this instance to a Python dictionary."""
99 | output = copy.deepcopy(self.__dict__)
100 | return output
101 |
102 | def to_json_string(self):
103 | """Serializes this instance to a JSON string."""
104 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
105 |
106 |
107 | class BertModel(object):
108 | """BERT model ("Bidirectional Encoder Representations from Transformers").
109 |
110 | Example usage:
111 |
112 | ```python
113 | # Already been converted into WordPiece token ids
114 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
115 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
116 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
117 |
118 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
119 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
120 |
121 | model = modeling.BertModel(config=config, is_training=True,
122 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
123 |
124 | label_embeddings = tf.get_variable(...)
125 | pooled_output = model.get_pooled_output()
126 | logits = tf.matmul(pooled_output, label_embeddings)
127 | ...
128 | ```
129 | """
130 |
131 | def __init__(self,
132 | config,
133 | is_training,
134 | input_ids,
135 | input_mask=None,
136 | token_type_ids=None,
137 | use_one_hot_embeddings=False,
138 | scope=None):
139 | """Constructor for BertModel.
140 |
141 | Args:
142 | config: `BertConfig` instance.
143 | is_training: bool. true for training model, false for eval model. Controls
144 | whether dropout will be applied.
145 | input_ids: int32 Tensor of shape [batch_size, seq_length].
146 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
147 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
148 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
149 | embeddings or tf.embedding_lookup() for the word embeddings.
150 | scope: (optional) variable scope. Defaults to "bert".
151 |
152 | Raises:
153 | ValueError: The config is invalid or one of the input tensor shapes
154 | is invalid.
155 | """
156 | config = copy.deepcopy(config)
157 | if not is_training:
158 | config.hidden_dropout_prob = 0.0
159 | config.attention_probs_dropout_prob = 0.0
160 |
161 | input_shape = get_shape_list(input_ids, expected_rank=2)
162 | batch_size = input_shape[0]
163 | seq_length = input_shape[1]
164 |
165 | if input_mask is None:
166 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
167 |
168 | if token_type_ids is None:
169 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
170 |
171 | with tf.variable_scope(scope, default_name="bert"):
172 | with tf.variable_scope("embeddings"):
173 | # Perform embedding lookup on the word ids.
174 | (self.embedding_output, self.embedding_table) = embedding_lookup(
175 | input_ids=input_ids,
176 | vocab_size=config.vocab_size,
177 | embedding_size=config.hidden_size,
178 | initializer_range=config.initializer_range,
179 | word_embedding_name="word_embeddings",
180 | use_one_hot_embeddings=use_one_hot_embeddings)
181 |
182 | # Add positional embeddings and token type embeddings, then layer
183 | # normalize and perform dropout.
184 | self.embedding_output,self.full_position_embeddings = embedding_postprocessor(
185 | input_tensor=self.embedding_output,
186 | use_token_type=True,
187 | token_type_ids=token_type_ids,
188 | token_type_vocab_size=config.type_vocab_size,
189 | token_type_embedding_name="token_type_embeddings",
190 | use_position_embeddings=True,
191 | position_embedding_name="position_embeddings",
192 | initializer_range=config.initializer_range,
193 | max_position_embeddings=config.max_position_embeddings,
194 | dropout_prob=config.hidden_dropout_prob)
195 |
196 | with tf.variable_scope("encoder"):
197 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
198 | # mask of shape [batch_size, seq_length, seq_length] which is used
199 | # for the attention scores.
200 | attention_mask = create_attention_mask_from_input_mask(
201 | input_ids, input_mask)
202 |
203 | # Run the stacked transformer.
204 | # `sequence_output` shape = [batch_size, seq_length, hidden_size].
205 | self.all_encoder_layers = transformer_model(
206 | input_tensor=self.embedding_output,
207 | attention_mask=attention_mask,
208 | hidden_size=config.hidden_size,
209 | num_hidden_layers=config.num_hidden_layers,
210 | num_attention_heads=config.num_attention_heads,
211 | intermediate_size=config.intermediate_size,
212 | intermediate_act_fn=get_activation(config.hidden_act),
213 | hidden_dropout_prob=config.hidden_dropout_prob,
214 | attention_probs_dropout_prob=config.attention_probs_dropout_prob,
215 | initializer_range=config.initializer_range,
216 | do_return_all_layers=True)
217 |
218 | self.sequence_output = self.all_encoder_layers[-1]
219 | # The "pooler" converts the encoded sequence tensor of shape
220 | # [batch_size, seq_length, hidden_size] to a tensor of shape
221 | # [batch_size, hidden_size]. This is necessary for segment-level
222 | # (or segment-pair-level) classification tasks where we need a fixed
223 | # dimensional representation of the segment.
224 | with tf.variable_scope("pooler"):
225 | # We "pool" the model by simply taking the hidden state corresponding
226 | # to the first token. We assume that this has been pre-trained
227 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
228 | self.pooled_output = tf.layers.dense(
229 | first_token_tensor,
230 | config.hidden_size,
231 | activation=tf.tanh,
232 | kernel_initializer=create_initializer(config.initializer_range))
233 |
234 | def get_pooled_output(self):
235 | return self.pooled_output
236 |
237 | def get_sequence_output(self):
238 | """Gets final hidden layer of encoder.
239 |
240 | Returns:
241 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
242 | to the final hidden of the transformer encoder.
243 | """
244 | return self.sequence_output
245 |
246 | def get_all_encoder_layers(self):
247 | return self.all_encoder_layers
248 |
249 | def get_embedding_output(self):
250 | """Gets output of the embedding lookup (i.e., input to the transformer).
251 |
252 | Returns:
253 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
254 | to the output of the embedding layer, after summing the word
255 | embeddings with the positional embeddings and the token type embeddings,
256 | then performing layer normalization. This is the input to the transformer.
257 | """
258 | return self.embedding_output
259 |
260 | def get_embedding_table(self):
261 | return self.embedding_table
262 |
263 | def get_full_position_embeddings(self):
264 | return self.full_position_embeddings
265 |
266 | def gelu(x):
267 | """Gaussian Error Linear Unit.
268 |
269 | This is a smoother version of the RELU.
270 | Original paper: https://arxiv.org/abs/1606.08415
271 | Args:
272 | x: float Tensor to perform activation.
273 |
274 | Returns:
275 | `x` with the GELU activation applied.
276 | """
277 | cdf = 0.5 * (1.0 + tf.tanh(
278 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
279 | return x * cdf
280 |
281 |
282 | def get_activation(activation_string):
283 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
284 |
285 | Args:
286 | activation_string: String name of the activation function.
287 |
288 | Returns:
289 | A Python function corresponding to the activation function. If
290 | `activation_string` is None, empty, or "linear", this will return None.
291 | If `activation_string` is not a string, it will return `activation_string`.
292 |
293 | Raises:
294 | ValueError: The `activation_string` does not correspond to a known
295 | activation.
296 | """
297 |
298 | # We assume that anything that"s not a string is already an activation
299 | # function, so we just return it.
300 | if not isinstance(activation_string, six.string_types):
301 | return activation_string
302 |
303 | if not activation_string:
304 | return None
305 |
306 | act = activation_string.lower()
307 | if act == "linear":
308 | return None
309 | elif act == "relu":
310 | return tf.nn.relu
311 | elif act == "gelu":
312 | return gelu
313 | elif act == "tanh":
314 | return tf.tanh
315 | else:
316 | raise ValueError("Unsupported activation: %s" % act)
317 |
318 |
319 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
320 | """Compute the union of the current variables and checkpoint variables."""
321 | assignment_map = {}
322 | initialized_variable_names = {}
323 |
324 | name_to_variable = collections.OrderedDict()
325 | for var in tvars:
326 | name = var.name
327 | m = re.match("^(.*):\\d+$", name)
328 | if m is not None:
329 | name = m.group(1)
330 | name_to_variable[name] = var
331 |
332 | init_vars = tf.train.list_variables(init_checkpoint)
333 |
334 | assignment_map = collections.OrderedDict()
335 | for x in init_vars:
336 | (name, var) = (x[0], x[1])
337 | if name not in name_to_variable:
338 | continue
339 | assignment_map[name] = name
340 | initialized_variable_names[name] = 1
341 | initialized_variable_names[name + ":0"] = 1
342 |
343 | return (assignment_map, initialized_variable_names)
344 |
345 |
346 | def dropout(input_tensor, dropout_prob):
347 | """Perform dropout.
348 |
349 | Args:
350 | input_tensor: float Tensor.
351 | dropout_prob: Python float. The probability of dropping out a value (NOT of
352 | *keeping* a dimension as in `tf.nn.dropout`).
353 |
354 | Returns:
355 | A version of `input_tensor` with dropout applied.
356 | """
357 | if dropout_prob is None or dropout_prob == 0.0:
358 | return input_tensor
359 |
360 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
361 | return output
362 |
363 |
364 | def layer_norm(input_tensor, name=None):
365 | """Run layer normalization on the last dimension of the tensor."""
366 | return tf.contrib.layers.layer_norm(
367 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
368 |
369 |
370 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
371 | """Runs layer normalization followed by dropout."""
372 | output_tensor = layer_norm(input_tensor, name)
373 | output_tensor = dropout(output_tensor, dropout_prob)
374 | return output_tensor
375 |
376 |
377 | def create_initializer(initializer_range=0.02):
378 | """Creates a `truncated_normal_initializer` with the given range."""
379 | return tf.truncated_normal_initializer(stddev=initializer_range)
380 |
381 |
382 | def embedding_lookup(input_ids,
383 | vocab_size,
384 | embedding_size=128,
385 | initializer_range=0.02,
386 | word_embedding_name="word_embeddings",
387 | use_one_hot_embeddings=False):
388 | """Looks up words embeddings for id tensor.
389 |
390 | Args:
391 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
392 | ids.
393 | vocab_size: int. Size of the embedding vocabulary.
394 | embedding_size: int. Width of the word embeddings.
395 | initializer_range: float. Embedding initialization range.
396 | word_embedding_name: string. Name of the embedding table.
397 | use_one_hot_embeddings: bool. If True, use one-hot method for word
398 | embeddings. If False, use `tf.gather()`.
399 |
400 | Returns:
401 | float Tensor of shape [batch_size, seq_length, embedding_size].
402 | """
403 | # This function assumes that the input is of shape [batch_size, seq_length,
404 | # num_inputs].
405 | #
406 | # If the input is a 2D tensor of shape [batch_size, seq_length], we
407 | # reshape to [batch_size, seq_length, 1].
408 | if input_ids.shape.ndims == 2:
409 | input_ids = tf.expand_dims(input_ids, axis=[-1])
410 |
411 | embedding_table = tf.get_variable(
412 | name=word_embedding_name,
413 | shape=[vocab_size, embedding_size],
414 | initializer=create_initializer(initializer_range))
415 |
416 |
417 | flat_input_ids = tf.reshape(input_ids, [-1])
418 | if use_one_hot_embeddings:
419 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
420 | output = tf.matmul(one_hot_input_ids, embedding_table)
421 | else:
422 | output = tf.gather(embedding_table, flat_input_ids)
423 |
424 | input_shape = get_shape_list(input_ids)
425 |
426 | output = tf.reshape(output,
427 | input_shape[0:-1] + [input_shape[-1] * embedding_size])
428 | return (output, embedding_table)
429 |
430 |
431 | def embedding_postprocessor(input_tensor,
432 | use_token_type=False,
433 | token_type_ids=None,
434 | token_type_vocab_size=16,
435 | token_type_embedding_name="token_type_embeddings",
436 | use_position_embeddings=True,
437 | position_embedding_name="position_embeddings",
438 | initializer_range=0.02,
439 | max_position_embeddings=512,
440 | dropout_prob=0.1):
441 | """Performs various post-processing on a word embedding tensor.
442 |
443 | Args:
444 | input_tensor: float Tensor of shape [batch_size, seq_length,
445 | embedding_size].
446 | use_token_type: bool. Whether to add embeddings for `token_type_ids`.
447 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
448 | Must be specified if `use_token_type` is True.
449 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
450 | token_type_embedding_name: string. The name of the embedding table variable
451 | for token type ids.
452 | use_position_embeddings: bool. Whether to add position embeddings for the
453 | position of each token in the sequence.
454 | position_embedding_name: string. The name of the embedding table variable
455 | for positional embeddings.
456 | initializer_range: float. Range of the weight initialization.
457 | max_position_embeddings: int. Maximum sequence length that might ever be
458 | used with this model. This can be longer than the sequence length of
459 | input_tensor, but cannot be shorter.
460 | dropout_prob: float. Dropout probability applied to the final output tensor.
461 |
462 | Returns:
463 | float tensor with same shape as `input_tensor`.
464 |
465 | Raises:
466 | ValueError: One of the tensor shapes or input values is invalid.
467 | """
468 | input_shape = get_shape_list(input_tensor, expected_rank=3)
469 | batch_size = input_shape[0]
470 | seq_length = input_shape[1]
471 | width = input_shape[2]
472 |
473 | output = input_tensor
474 |
475 | if use_token_type:
476 | if token_type_ids is None:
477 | raise ValueError("`token_type_ids` must be specified if"
478 | "`use_token_type` is True.")
479 | token_type_table = tf.get_variable(
480 | name=token_type_embedding_name,
481 | shape=[token_type_vocab_size, width],
482 | initializer=create_initializer(initializer_range))
483 | # This vocab will be small so we always do one-hot here, since it is always
484 | # faster for a small vocabulary.
485 | flat_token_type_ids = tf.reshape(token_type_ids, [-1])
486 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
487 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
488 | token_type_embeddings = tf.reshape(token_type_embeddings,
489 | [batch_size, seq_length, width])
490 | output += token_type_embeddings
491 |
492 | if use_position_embeddings:
493 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
494 | with tf.control_dependencies([assert_op]):
495 | full_position_embeddings = tf.get_variable(
496 | name=position_embedding_name,
497 | shape=[max_position_embeddings, width],
498 | initializer=create_initializer(initializer_range))
499 |
500 | position_embeddings = tf.slice(full_position_embeddings, [0, 0],
501 | [seq_length, -1])
502 | num_dims = len(output.shape.as_list())
503 |
504 | position_broadcast_shape = []
505 | for _ in range(num_dims - 2):
506 | position_broadcast_shape.append(1)
507 | position_broadcast_shape.extend([seq_length, width])
508 | position_embeddings = tf.reshape(position_embeddings,
509 | position_broadcast_shape)
510 | output += position_embeddings
511 |
512 | output = layer_norm_and_dropout(output, dropout_prob)
513 | return output,full_position_embeddings
514 |
515 |
516 | def create_attention_mask_from_input_mask(from_tensor, to_mask):
517 | """Create 3D attention mask from a 2D tensor mask.
518 |
519 | Args:
520 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
521 | to_mask: int32 Tensor of shape [batch_size, to_seq_length].
522 |
523 | Returns:
524 | float Tensor of shape [batch_size, from_seq_length, to_seq_length].
525 | """
526 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
527 | batch_size = from_shape[0]
528 | from_seq_length = from_shape[1]
529 |
530 | to_shape = get_shape_list(to_mask, expected_rank=2)
531 | to_seq_length = to_shape[1]
532 |
533 | to_mask = tf.cast(
534 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
535 |
536 | # We don't assume that `from_tensor` is a mask (although it could be). We
537 | # don't actually care if we attend *from* padding tokens (only *to* padding)
538 | # tokens so we create a tensor of all ones.
539 | #
540 | # `broadcast_ones` = [batch_size, from_seq_length, 1]
541 | broadcast_ones = tf.ones(
542 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32)
543 |
544 | # Here we broadcast along two dimensions to create the mask.
545 | mask = broadcast_ones * to_mask
546 |
547 | return mask
548 |
549 |
550 | def attention_layer(from_tensor,
551 | to_tensor,
552 | attention_mask=None,
553 | num_attention_heads=1,
554 | size_per_head=512,
555 | query_act=None,
556 | key_act=None,
557 | value_act=None,
558 | attention_probs_dropout_prob=0.0,
559 | initializer_range=0.02,
560 | do_return_2d_tensor=False,
561 | batch_size=None,
562 | from_seq_length=None,
563 | to_seq_length=None):
564 | """Performs multi-headed attention from `from_tensor` to `to_tensor`.
565 |
566 | This is an implementation of multi-headed attention based on "Attention
567 | is all you Need". If `from_tensor` and `to_tensor` are the same, then
568 | this is self-attention. Each timestep in `from_tensor` attends to the
569 | corresponding sequence in `to_tensor`, and returns a fixed-with vector.
570 |
571 | This function first projects `from_tensor` into a "query" tensor and
572 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list
573 | of tensors of length `num_attention_heads`, where each tensor is of shape
574 | [batch_size, seq_length, size_per_head].
575 |
576 | Then, the query and key tensors are dot-producted and scaled. These are
577 | softmaxed to obtain attention probabilities. The value tensors are then
578 | interpolated by these probabilities, then concatenated back to a single
579 | tensor and returned.
580 |
581 | In practice, the multi-headed attention are done with transposes and
582 | reshapes rather than actual separate tensors.
583 |
584 | Args:
585 | from_tensor: float Tensor of shape [batch_size, from_seq_length,
586 | from_width].
587 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
588 | attention_mask: (optional) int32 Tensor of shape [batch_size,
589 | from_seq_length, to_seq_length]. The values should be 1 or 0. The
590 | attention scores will effectively be set to -infinity for any positions in
591 | the mask that are 0, and will be unchanged for positions that are 1.
592 | num_attention_heads: int. Number of attention heads.
593 | size_per_head: int. Size of each attention head.
594 | query_act: (optional) Activation function for the query transform.
595 | key_act: (optional) Activation function for the key transform.
596 | value_act: (optional) Activation function for the value transform.
597 | attention_probs_dropout_prob: (optional) float. Dropout probability of the
598 | attention probabilities.
599 | initializer_range: float. Range of the weight initializer.
600 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
601 | * from_seq_length, num_attention_heads * size_per_head]. If False, the
602 | output will be of shape [batch_size, from_seq_length, num_attention_heads
603 | * size_per_head].
604 | batch_size: (Optional) int. If the input is 2D, this might be the batch size
605 | of the 3D version of the `from_tensor` and `to_tensor`.
606 | from_seq_length: (Optional) If the input is 2D, this might be the seq length
607 | of the 3D version of the `from_tensor`.
608 | to_seq_length: (Optional) If the input is 2D, this might be the seq length
609 | of the 3D version of the `to_tensor`.
610 |
611 | Returns:
612 | float Tensor of shape [batch_size, from_seq_length,
613 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
614 | true, this will be of shape [batch_size * from_seq_length,
615 | num_attention_heads * size_per_head]).
616 |
617 | Raises:
618 | ValueError: Any of the arguments or tensor shapes are invalid.
619 | """
620 |
621 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
622 | seq_length, width):
623 | output_tensor = tf.reshape(
624 | input_tensor, [batch_size, seq_length, num_attention_heads, width])
625 |
626 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
627 | return output_tensor
628 |
629 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
630 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
631 |
632 | if len(from_shape) != len(to_shape):
633 | raise ValueError(
634 | "The rank of `from_tensor` must match the rank of `to_tensor`.")
635 |
636 | if len(from_shape) == 3:
637 | batch_size = from_shape[0]
638 | from_seq_length = from_shape[1]
639 | to_seq_length = to_shape[1]
640 | elif len(from_shape) == 2:
641 | if (batch_size is None or from_seq_length is None or to_seq_length is None):
642 | raise ValueError(
643 | "When passing in rank 2 tensors to attention_layer, the values "
644 | "for `batch_size`, `from_seq_length`, and `to_seq_length` "
645 | "must all be specified.")
646 |
647 | # Scalar dimensions referenced here:
648 | # B = batch size (number of sequences)
649 | # F = `from_tensor` sequence length
650 | # T = `to_tensor` sequence length
651 | # N = `num_attention_heads`
652 | # H = `size_per_head`
653 |
654 | from_tensor_2d = reshape_to_matrix(from_tensor)
655 | to_tensor_2d = reshape_to_matrix(to_tensor)
656 |
657 | # `query_layer` = [B*F, N*H]
658 | query_layer = tf.layers.dense(
659 | from_tensor_2d,
660 | num_attention_heads * size_per_head,
661 | activation=query_act,
662 | name="query",
663 | kernel_initializer=create_initializer(initializer_range))
664 |
665 | # `key_layer` = [B*T, N*H]
666 | key_layer = tf.layers.dense(
667 | to_tensor_2d,
668 | num_attention_heads * size_per_head,
669 | activation=key_act,
670 | name="key",
671 | kernel_initializer=create_initializer(initializer_range))
672 |
673 | # `value_layer` = [B*T, N*H]
674 | value_layer = tf.layers.dense(
675 | to_tensor_2d,
676 | num_attention_heads * size_per_head,
677 | activation=value_act,
678 | name="value",
679 | kernel_initializer=create_initializer(initializer_range))
680 |
681 | # `query_layer` = [B, N, F, H]
682 | query_layer = transpose_for_scores(query_layer, batch_size,
683 | num_attention_heads, from_seq_length,
684 | size_per_head)
685 |
686 | # `key_layer` = [B, N, T, H]
687 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
688 | to_seq_length, size_per_head)
689 |
690 | # Take the dot product between "query" and "key" to get the raw
691 | # attention scores.
692 | # `attention_scores` = [B, N, F, T]
693 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
694 | attention_scores = tf.multiply(attention_scores,
695 | 1.0 / math.sqrt(float(size_per_head)))
696 |
697 | if attention_mask is not None:
698 | # `attention_mask` = [B, 1, F, T]
699 | attention_mask = tf.expand_dims(attention_mask, axis=[1])
700 |
701 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
702 | # masked positions, this operation will create a tensor which is 0.0 for
703 | # positions we want to attend and -10000.0 for masked positions.
704 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
705 |
706 | # Since we are adding it to the raw scores before the softmax, this is
707 | # effectively the same as removing these entirely.
708 | attention_scores += adder
709 |
710 | # Normalize the attention scores to probabilities.
711 | # `attention_probs` = [B, N, F, T]
712 | attention_probs = tf.nn.softmax(attention_scores)
713 |
714 | # This is actually dropping out entire tokens to attend to, which might
715 | # seem a bit unusual, but is taken from the original Transformer paper.
716 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
717 |
718 | # `value_layer` = [B, T, N, H]
719 | value_layer = tf.reshape(
720 | value_layer,
721 | [batch_size, to_seq_length, num_attention_heads, size_per_head])
722 |
723 | # `value_layer` = [B, N, T, H]
724 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
725 |
726 | # `context_layer` = [B, N, F, H]
727 | context_layer = tf.matmul(attention_probs, value_layer)
728 |
729 | # `context_layer` = [B, F, N, H]
730 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
731 |
732 | if do_return_2d_tensor:
733 | # `context_layer` = [B*F, N*H]
734 | context_layer = tf.reshape(
735 | context_layer,
736 | [batch_size * from_seq_length, num_attention_heads * size_per_head])
737 | else:
738 | # `context_layer` = [B, F, N*H]
739 | context_layer = tf.reshape(
740 | context_layer,
741 | [batch_size, from_seq_length, num_attention_heads * size_per_head])
742 |
743 | return context_layer
744 |
745 |
746 | def transformer_model(input_tensor,
747 | attention_mask=None,
748 | hidden_size=768,
749 | num_hidden_layers=12,
750 | num_attention_heads=12,
751 | intermediate_size=3072,
752 | intermediate_act_fn=gelu,
753 | hidden_dropout_prob=0.1,
754 | attention_probs_dropout_prob=0.1,
755 | initializer_range=0.02,
756 | do_return_all_layers=False):
757 | """Multi-headed, multi-layer Transformer from "Attention is All You Need".
758 |
759 | This is almost an exact implementation of the original Transformer encoder.
760 |
761 | See the original paper:
762 | https://arxiv.org/abs/1706.03762
763 |
764 | Also see:
765 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
766 |
767 | Args:
768 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
769 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
770 | seq_length], with 1 for positions that can be attended to and 0 in
771 | positions that should not be.
772 | hidden_size: int. Hidden size of the Transformer.
773 | num_hidden_layers: int. Number of layers (blocks) in the Transformer.
774 | num_attention_heads: int. Number of attention heads in the Transformer.
775 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed
776 | forward) layer.
777 | intermediate_act_fn: function. The non-linear activation function to apply
778 | to the output of the intermediate/feed-forward layer.
779 | hidden_dropout_prob: float. Dropout probability for the hidden layers.
780 | attention_probs_dropout_prob: float. Dropout probability of the attention
781 | probabilities.
782 | initializer_range: float. Range of the initializer (stddev of truncated
783 | normal).
784 | do_return_all_layers: Whether to also return all layers or just the final
785 | layer.
786 |
787 | Returns:
788 | float Tensor of shape [batch_size, seq_length, hidden_size], the final
789 | hidden layer of the Transformer.
790 |
791 | Raises:
792 | ValueError: A Tensor shape or parameter is invalid.
793 | """
794 | if hidden_size % num_attention_heads != 0:
795 | raise ValueError(
796 | "The hidden size (%d) is not a multiple of the number of attention "
797 | "heads (%d)" % (hidden_size, num_attention_heads))
798 |
799 | attention_head_size = int(hidden_size / num_attention_heads)
800 | input_shape = get_shape_list(input_tensor, expected_rank=3)
801 | batch_size = input_shape[0]
802 | seq_length = input_shape[1]
803 | input_width = input_shape[2]
804 |
805 | # The Transformer performs sum residuals on all layers so the input needs
806 | # to be the same as the hidden size.
807 | if input_width != hidden_size:
808 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
809 | (input_width, hidden_size))
810 |
811 | # We keep the representation as a 2D tensor to avoid re-shaping it back and
812 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
813 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
814 | # help the optimizer.
815 | prev_output = reshape_to_matrix(input_tensor)
816 |
817 | all_layer_outputs = []
818 | for layer_idx in range(num_hidden_layers):
819 | with tf.variable_scope("layer_%d" % layer_idx):
820 | layer_input = prev_output
821 |
822 | with tf.variable_scope("attention"):
823 | attention_heads = []
824 | with tf.variable_scope("self"):
825 | attention_head = attention_layer(
826 | from_tensor=layer_input,
827 | to_tensor=layer_input,
828 | attention_mask=attention_mask,
829 | num_attention_heads=num_attention_heads,
830 | size_per_head=attention_head_size,
831 | attention_probs_dropout_prob=attention_probs_dropout_prob,
832 | initializer_range=initializer_range,
833 | do_return_2d_tensor=True,
834 | batch_size=batch_size,
835 | from_seq_length=seq_length,
836 | to_seq_length=seq_length)
837 | attention_heads.append(attention_head)
838 |
839 | attention_output = None
840 | if len(attention_heads) == 1:
841 | attention_output = attention_heads[0]
842 | else:
843 | # In the case where we have other sequences, we just concatenate
844 | # them to the self-attention head before the projection.
845 | attention_output = tf.concat(attention_heads, axis=-1)
846 |
847 | # Run a linear projection of `hidden_size` then add a residual
848 | # with `layer_input`.
849 | with tf.variable_scope("output"):
850 | attention_output = tf.layers.dense(
851 | attention_output,
852 | hidden_size,
853 | kernel_initializer=create_initializer(initializer_range))
854 | attention_output = dropout(attention_output, hidden_dropout_prob)
855 | attention_output = layer_norm(attention_output + layer_input)
856 |
857 | # The activation is only applied to the "intermediate" hidden layer.
858 | with tf.variable_scope("intermediate"):
859 | intermediate_output = tf.layers.dense(
860 | attention_output,
861 | intermediate_size,
862 | activation=intermediate_act_fn,
863 | kernel_initializer=create_initializer(initializer_range))
864 |
865 | # Down-project back to `hidden_size` then add the residual.
866 | with tf.variable_scope("output"):
867 | layer_output = tf.layers.dense(
868 | intermediate_output,
869 | hidden_size,
870 | kernel_initializer=create_initializer(initializer_range))
871 | layer_output = dropout(layer_output, hidden_dropout_prob)
872 | layer_output = layer_norm(layer_output + attention_output)
873 | prev_output = layer_output
874 | all_layer_outputs.append(layer_output)
875 |
876 | if do_return_all_layers:
877 | final_outputs = []
878 | for layer_output in all_layer_outputs:
879 | final_output = reshape_from_matrix(layer_output, input_shape)
880 | final_outputs.append(final_output)
881 | return final_outputs
882 | else:
883 | final_output = reshape_from_matrix(prev_output, input_shape)
884 | return final_output
885 |
886 |
887 | def get_shape_list(tensor, expected_rank=None, name=None):
888 | """Returns a list of the shape of tensor, preferring static dimensions.
889 |
890 | Args:
891 | tensor: A tf.Tensor object to find the shape of.
892 | expected_rank: (optional) int. The expected rank of `tensor`. If this is
893 | specified and the `tensor` has a different rank, and exception will be
894 | thrown.
895 | name: Optional name of the tensor for the error message.
896 |
897 | Returns:
898 | A list of dimensions of the shape of tensor. All static dimensions will
899 | be returned as python integers, and dynamic dimensions will be returned
900 | as tf.Tensor scalars.
901 | """
902 | if name is None:
903 | name = tensor.name
904 |
905 | if expected_rank is not None:
906 | assert_rank(tensor, expected_rank, name)
907 |
908 | shape = tensor.shape.as_list()
909 |
910 | non_static_indexes = []
911 | for (index, dim) in enumerate(shape):
912 | if dim is None:
913 | non_static_indexes.append(index)
914 |
915 | if not non_static_indexes:
916 | return shape
917 |
918 | dyn_shape = tf.shape(tensor)
919 | for index in non_static_indexes:
920 | shape[index] = dyn_shape[index]
921 | return shape
922 |
923 |
924 | def reshape_to_matrix(input_tensor):
925 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
926 | ndims = input_tensor.shape.ndims
927 | if ndims < 2:
928 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
929 | (input_tensor.shape))
930 | if ndims == 2:
931 | return input_tensor
932 |
933 | width = input_tensor.shape[-1]
934 | output_tensor = tf.reshape(input_tensor, [-1, width])
935 | return output_tensor
936 |
937 |
938 | def reshape_from_matrix(output_tensor, orig_shape_list):
939 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
940 | if len(orig_shape_list) == 2:
941 | return output_tensor
942 |
943 | output_shape = get_shape_list(output_tensor)
944 |
945 | orig_dims = orig_shape_list[0:-1]
946 | width = output_shape[-1]
947 |
948 | return tf.reshape(output_tensor, orig_dims + [width])
949 |
950 |
951 | def assert_rank(tensor, expected_rank, name=None):
952 | """Raises an exception if the tensor rank is not of the expected rank.
953 |
954 | Args:
955 | tensor: A tf.Tensor to check the rank of.
956 | expected_rank: Python integer or list of integers, expected rank.
957 | name: Optional name of the tensor for the error message.
958 |
959 | Raises:
960 | ValueError: If the expected shape doesn't match the actual shape.
961 | """
962 | if name is None:
963 | name = tensor.name
964 |
965 | expected_rank_dict = {}
966 | if isinstance(expected_rank, six.integer_types):
967 | expected_rank_dict[expected_rank] = True
968 | else:
969 | for x in expected_rank:
970 | expected_rank_dict[x] = True
971 |
972 | actual_rank = tensor.shape.ndims
973 | if actual_rank not in expected_rank_dict:
974 | scope_name = tf.get_variable_scope().name
975 | raise ValueError(
976 | "For the tensor `%s` in scope `%s`, the actual rank "
977 | "`%d` (shape = %s) is not equal to the expected rank `%s`" %
978 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
979 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 | 馬
129 | maxlen_vae_Encoder
130 | maxlen_vae_Decoder_de
131 | sents2
132 | sents1
133 | z
134 | conca
135 | ##um
136 | print
137 | InputExample
138 | output_ids_vae_decoder_dec
139 | _truncate_seq_pair
140 | load_data_for_train_or_eval
141 | tf.logging.info
142 | 你
143 | add_by_xhm_12341222e2
144 | trainable
145 | training
146 | init_checkpoint_PAGE
147 | PAGEdir
148 | test evaluation
149 | d_model
150 | [PAD]
151 | 768
152 | wa
153 | multi
154 | 6
155 | vocab
156 | 96
157 | em
158 |
159 |
160 | ---xhm---
161 | print
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 | 1561979138231
410 |
411 |
412 | 1561979138231
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 | 1564715273935
450 |
451 |
452 |
453 | 1564715273935
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 | file://$PROJECT_DIR$/model.py
541 | 83
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
674 |
675 |
676 |
677 |
678 |
679 |
680 |
681 |
682 |
683 |
684 |
685 |
686 |
687 |
688 |
689 |
690 |
691 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
702 |
703 |
704 |
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
715 |
716 |
717 |
718 |
719 |
720 |
721 |
722 |
723 |
724 |
725 |
726 |
727 |
728 |
729 |
730 |
731 |
732 |
733 |
734 |
735 |
736 |
737 |
738 |
739 |
740 |
741 |
742 |
743 |
744 |
745 |
746 |
747 |
748 |
749 |
750 |
751 |
752 |
753 |
754 |
755 |
756 |
757 |
758 |
759 |
760 |
761 |
762 |
763 |
764 |
765 |
766 |
767 |
768 |
769 |
770 |
771 |
772 |
773 |
774 |
775 |
776 |
777 |
778 |
779 |
780 |
781 |
782 |
783 |
784 |
785 |
786 |
787 |
788 |
789 |
790 |
791 |
792 |
793 |
794 |
795 |
796 |
797 |
798 |
799 |
800 |
801 |
802 |
803 |
804 |
805 |
806 |
807 |
808 |
809 |
810 |
811 |
812 |
813 |
814 |
--------------------------------------------------------------------------------