├── 代码解释 ├── data ├── PAGE │ └── 111 └── LM │ └── pre.py ├── bert ├── __init__.py ├── tokenization.py └── bert_model_for_PAGE.py ├── results ├── PAGE │ └── 111 └── bert │ └── 111 ├── model.jpg ├── .idea ├── encodings.xml ├── vcs.xml ├── misc.xml ├── modules.xml ├── bert-transformer-vae.iml ├── inspectionProfiles │ └── Project_Default.xml ├── webServers.xml ├── deployment.xml └── workspace.xml ├── 运行命令 ├── README.md ├── test.py ├── beamsearch.py ├── test_PAGE.py ├── hparams.py ├── prepro.py ├── utils.py ├── train_TPAGE.py ├── model.py ├── bert_transformer_vae_for_PAGE.py ├── modules.py └── data_load.py /代码解释: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/PAGE/111: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bert/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/PAGE/111: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/bert/111: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuhaiming1996/BERT-T2T/HEAD/model.jpg -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /运行命令: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | python train_TPAGE.py \ 6 | --train=data/PAGE/train.txt \ 7 | --eval=data/PAGE/eval.txt \ 8 | --init_checkpoint_bert=results/bert/bert_model.ckpt \ 9 | --batch_size=32 \ 10 | --eval_batch_size=32 \ 11 | --num_epochs_PAGE=10 \ 12 | --maxlen_vae_Encoder=80 \ 13 | --maxlen_vae_Decoder_en=40\ 14 | --maxlen_vae_Decoder_de=40\ 15 | 16 | -------------------------------------------------------------------------------- /.idea/bert-transformer-vae.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 14 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 22 | 23 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /data/LM/pre.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | class PreHelper: 4 | def __init__(self,filePath_r,filePath_w_train,filePath_w_eval): 5 | self.filePath_r=filePath_r 6 | self.filePath_w_train=filePath_w_train 7 | self.filePath_w_eval = filePath_w_eval 8 | 9 | def proprocess(self): 10 | sens = [] 11 | num = 0 12 | with open(self.filePath_r,mode="r",encoding="utf-8") as fr: 13 | for line in fr: 14 | num += 1 15 | line = line.strip() 16 | if line != "" and len(line)>8: 17 | sens.append(line.strip()) 18 | if num%1000000 == 0: 19 | print("数据正在提取中,请耐心等待!!!") 20 | 21 | random.shuffle(sens) 22 | print("读写完成,开始写入文件") 23 | sens_eval = sens[:20000] 24 | sens_train = sens[20000:] 25 | with open(self.filePath_w_eval,mode="w",encoding="utf-8") as fw: 26 | for line in sens_eval: 27 | line = line.strip() 28 | fw.write("---xhm---".join([line,line])+"\n") 29 | 30 | 31 | with open(self.filePath_w_train,mode="w",encoding="utf-8") as fw: 32 | for line in sens_train: 33 | line = line.strip() 34 | fw.write("---xhm---".join([line,line])+"\n") 35 | 36 | 37 | 38 | if __name__=="__main__": 39 | preHelper=PreHelper(filePath_r="sens_LM_raw.txt",filePath_w_train="train.txt",filePath_w_eval="eval.txt") 40 | preHelper.proprocess() 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 介绍 2 | NLP中,对于生成问题如NMT,QA, Paraphrase 任务来说通常会存在生成多样性不足的问题, 3 | 通常我们会采用beamSearch来增加多样性。但是beamSeach 生成的句子还是有很大的相似度,无法满足项目落地需求。 4 | 我采用了这篇[A Deep Generative Framework for Paraphrase Generation](https://arxiv.org/abs/1709.05074) 5 | 的基于CVAE的结构思想构造了一个模型,试图解决生成任务的多样性。 6 | 7 | 8 | 9 | ## 模型结构图 10 | 提示:请先看这篇论文[A Deep Generative Framework for Paraphrase Generation](https://arxiv.org/abs/1709.05074) 11 | 的思想和结构,再看我下面的这个模型结构图 12 | ![image](https://github.com/xuhaiming1996/BERT-T2T/blob/master/model.jpg) 13 | 14 | ## 文件说明 15 | ### /data/PAGE 训练语料 16 | train.txt 格式:id---xhm--src---xhm--tgt 17 | 18 | eval.txt 格式:id---xhm--src---xhm--tgt 19 | 20 | test.txt 格式:id---xhm--src---xhm--tgt 21 | 22 | ### results 23 | #### /results/bert 24 | 该文件是预训练的好中文bert模型,大家可以去[这里](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)下载,解压后放在这里 25 | #### /results/PAGE 26 | 该文件夹是复述模型保存路径 27 | 28 | 29 | 30 | ### 运行命令 31 | 模型训练使用的是tf.data.* API 从tfrecord文件中构造的迭代器(感慨一下:非常强大的API.建议大家都采用这种方式) 32 | 33 | python train_TPAGE.py \ 34 | --train=data/PAGE/train.txt \ 35 | --eval=data/PAGE/eval.txt \ 36 | --init_checkpoint_bert=results/bert/bert_model.ckpt \ 37 | --batch_size=32 \ 38 | --eval_batch_size=32 \ 39 | --num_epochs_PAGE=10 \ 40 | --maxlen_vae_Encoder=80 \ 41 | --maxlen_vae_Decoder_en=40\ 42 | --maxlen_vae_Decoder_de=40\ 43 | 44 | #### 温馨一刻 45 | 大家若对于KL loss的计算公式有疑问,请看这里的[公式推导](https://blog.csdn.net/qq_32806793/article/details/95652645) 你就会明白代码为啥这样写了 46 | 47 | KL_loss = tf.reduce_mean(0.5 * tf.reduce_sum(tf.square(mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, [1])) 48 | 49 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | 4 | 5 | import os 6 | 7 | import tensorflow as tf 8 | 9 | from data_load import get_batch 10 | from model import Transformer 11 | from hparams import Hparams 12 | from utils import get_hypotheses, calc_bleu, postprocess, load_hparams 13 | import logging 14 | 15 | logging.basicConfig(level=logging.INFO) 16 | 17 | logging.info("# hparams") 18 | hparams = Hparams() 19 | parser = hparams.parser 20 | hp = parser.parse_args() 21 | load_hparams(hp, hp.ckpt) 22 | 23 | logging.info("# Prepare test batches") 24 | test_batches, num_test_batches, num_test_samples = get_batch(hp.test1, hp.test1, 25 | 100000, 100000, 26 | hp.vocab, hp.test_batch_size, 27 | shuffle=False) 28 | iter = tf.data.Iterator.from_structure(test_batches.output_types, test_batches.output_shapes) 29 | xs, ys = iter.get_next() 30 | 31 | test_init_op = iter.make_initializer(test_batches) 32 | 33 | logging.info("# Load model") 34 | m = Transformer(hp) 35 | y_hat, _ = m.eval(xs, ys) 36 | 37 | logging.info("# Session") 38 | with tf.Session() as sess: 39 | ckpt_ = tf.train.latest_checkpoint(hp.ckpt) 40 | ckpt = hp.ckpt if ckpt_ is None else ckpt_ # None: ckpt is a file. otherwise dir. 41 | saver = tf.train.Saver() 42 | 43 | saver.restore(sess, ckpt) 44 | 45 | sess.run(test_init_op) 46 | 47 | logging.info("# get hypotheses") 48 | hypotheses = get_hypotheses(num_test_batches, num_test_samples, sess, y_hat, m.idx2token) 49 | 50 | logging.info("# write results") 51 | model_output = ckpt.split("/")[-1] 52 | if not os.path.exists(hp.testdir): os.makedirs(hp.testdir) 53 | translation = os.path.join(hp.testdir, model_output) 54 | with open(translation, 'w') as fout: 55 | fout.write("\n".join(hypotheses)) 56 | 57 | logging.info("# calc bleu score and append it to translation") 58 | calc_bleu(hp.test2, translation) 59 | 60 | -------------------------------------------------------------------------------- /beamsearch.py: -------------------------------------------------------------------------------- 1 | def beam_search(x, sess, g, batch_size=hp.batch_size): 2 | inputs = np.reshape(np.transpose(np.array([x] * hp.beam_size), (1, 0, 2)), 3 | (hp.beam_size * batch_size, hp.max_len)) 4 | preds = np.zeros((batch_size, hp.beam_size, hp.y_max_len), np.int32) 5 | prob_product = np.zeros((batch_size, hp.beam_size)) 6 | stc_length = np.ones((batch_size, hp.beam_size)) 7 | 8 | for j in range(hp.y_max_len): 9 | _probs, _preds = sess.run( 10 | g.preds, {g.x: inputs, g.y: np.reshape(preds, (hp.beam_size * batch_size, hp.y_max_len))}) 11 | j_probs = np.reshape(_probs[:, j, :], (batch_size, hp.beam_size, hp.beam_size)) 12 | j_preds = np.reshape(_preds[:, j, :], (batch_size, hp.beam_size, hp.beam_size)) 13 | if j == 0: 14 | preds[:, :, j] = j_preds[:, 0, :] 15 | prob_product += np.log(j_probs[:, 0, :]) 16 | else: 17 | add_or_not = np.asarray(np.logical_or.reduce([j_preds > hp.end_id]), dtype=np.int) 18 | tmp_stc_length = np.expand_dims(stc_length, axis=-1) + add_or_not 19 | tmp_stc_length = np.reshape(tmp_stc_length, (batch_size, hp.beam_size * hp.beam_size)) 20 | 21 | this_probs = np.expand_dims(prob_product, axis=-1) + np.log(j_probs) * add_or_not 22 | this_probs = np.reshape(this_probs, (batch_size, hp.beam_size * hp.beam_size)) 23 | selected = np.argsort(this_probs / tmp_stc_length, axis=1)[:, -hp.beam_size:] 24 | 25 | tmp_preds = np.concatenate([np.expand_dims(preds, axis=2)] * hp.beam_size, axis=2) 26 | tmp_preds[:, :, :, j] = j_preds[:, :, :] 27 | tmp_preds = np.reshape(tmp_preds, (batch_size, hp.beam_size * hp.beam_size, hp.y_max_len)) 28 | 29 | for batch_idx in range(batch_size): 30 | prob_product[batch_idx] = this_probs[batch_idx, selected[batch_idx]] 31 | preds[batch_idx] = tmp_preds[batch_idx, selected[batch_idx]] 32 | stc_length[batch_idx] = tmp_stc_length[batch_idx, selected[batch_idx]] 33 | 34 | final_selected = np.argmax(prob_product / stc_length, axis=1) 35 | final_preds = [] 36 | for batch_idx in range(batch_size): 37 | final_preds.append(preds[batch_idx, final_selected[batch_idx]]) 38 | 39 | return final_preds -------------------------------------------------------------------------------- /test_PAGE.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python3 3 | 4 | import tensorflow as tf 5 | 6 | from bert_transformer_vae_for_PAGE import VaeModel 7 | 8 | from data_load import get_batch_for_train_or_dev_or_test,saveForTfRecord 9 | from utils import save_hparams, get_hypotheses 10 | import os 11 | from hparams import Hparams 12 | import logging 13 | os.environ['CUDA_VISIBLE_DEVICES']= '5' 14 | logging.basicConfig(level=logging.INFO) 15 | 16 | 17 | logging.info("# hparams") 18 | hparams = Hparams() 19 | parser = hparams.parser 20 | hp = parser.parse_args() 21 | save_hparams(hp, hp.PAGEdir) 22 | 23 | 24 | logging.info("# 许海明提醒你: 这里需要准备tfRecord") 25 | logging.info("# 许海明提醒你: 这里需要准备tfRecord") 26 | logging.info("# 许海明提醒你: 这里需要准备tfRecord") 27 | logging.info("# 许海明提醒你: 这里需要准备tfRecord") 28 | 29 | 30 | 31 | saveForTfRecord(hp.test, 32 | hp.maxlen_vae_Encoder, 33 | hp.maxlen_vae_Decoder_en, 34 | hp.maxlen_vae_Decoder_de, 35 | hp.vocab, 36 | output_file="./data/PAGE/test.tf_record", 37 | is_test=False) 38 | 39 | 40 | test_batches, num_test_batches, num_test_samples = get_batch_for_train_or_dev_or_test(hp.test, 41 | hp.maxlen_vae_Encoder, 42 | hp.maxlen_vae_Decoder_en, 43 | hp.maxlen_vae_Decoder_de, 44 | hp.test_batch_size, 45 | input_file="./data/PAGE/test.tf_record" , 46 | is_training=False, 47 | is_test= False) 48 | 49 | # create a iterator of the correct shape and type 50 | iter = tf.data.Iterator.from_structure(test_batches.output_types, test_batches.output_shapes) 51 | features=iter.get_next() 52 | 53 | xs = (features["input_ids_vae_encoder"], features["input_masks_vae_encoder"], features["segment_ids_vae_encoder"], features["sent1"], features["sent2"]) 54 | ys = (features["input_ids_vae_decoder_enc"], features["input_ids_vae_decoder_dec"], features["output_ids_vae_decoder_dec"], features["sent1"], features["sent2"]) 55 | 56 | 57 | test_init_op = iter.make_initializer(test_batches) 58 | 59 | 60 | logging.info("# Load model") 61 | m = VaeModel(hp) 62 | 63 | y_hat, _ = m.eval(xs, ys, mode="TPAGE") 64 | # y_hat = m.infer(xs, ys) 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | with tf.Session() as sess: 79 | ckpt = tf.train.latest_checkpoint(hp.PAGEdir) 80 | saver = tf.train.Saver(max_to_keep=hp.num_epochs_LM) 81 | saver.restore(sess, ckpt) 82 | sess.run(test_init_op) 83 | 84 | 85 | logging.info("# get hypotheses") 86 | hypotheses = get_hypotheses(num_test_batches, num_test_samples, sess, y_hat, m.idx2token) 87 | 88 | logging.info("# write results") 89 | model_output = "test_2.txt" 90 | translation = os.path.join(hp.PAGEdir,model_output) 91 | with open(translation, mode='w',encoding="utf-8") as fout: 92 | fout.write("\n".join(hypotheses)) 93 | 94 | 95 | 96 | 97 | logging.info("Done") 98 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | class Hparams: 4 | parser = argparse.ArgumentParser() 5 | 6 | 7 | 8 | 9 | ## files 10 | parser.add_argument('--train', help="训练数据") 11 | parser.add_argument('--eval', help="验证数据") 12 | parser.add_argument('--test', help="测试数据") 13 | 14 | 15 | ## vocabulary 16 | parser.add_argument('--vocab', default='results/bert/vocab.txt', 17 | help="vocabulary file path") 18 | 19 | parser.add_argument('--init_checkpoint_LM', help="语言模型的初始路径-fine_tune") 20 | parser.add_argument('--init_checkpoint_bert', help="bert模型的初始路径-fine_tune") 21 | parser.add_argument('--init_checkpoint_PAGE', help="复述模型的初始路径-fine_tune") 22 | 23 | 24 | # training scheme 25 | parser.add_argument('--batch_size', default=32, type=int) 26 | parser.add_argument('--eval_batch_size', default=32, type=int) 27 | parser.add_argument('--test_batch_size', default=128, type=int) 28 | 29 | parser.add_argument('--lr', default=5e-5, type=float, help="learning rate") 30 | parser.add_argument('--warmup_steps', default=100, type=int) 31 | 32 | parser.add_argument('--num_epochs_LM', default=20, type=int, help="语言模型训练的epoch") 33 | parser.add_argument('--num_epochs_PAGE', default=20, type=int, help="复述模型训练的epoch") 34 | 35 | parser.add_argument('--LMdir', default="results/LM", help="这是语言模型的的路径") 36 | parser.add_argument('--PAGEdir', default="results/PAGE", help="这是复述模型的的路径") 37 | 38 | # model 39 | parser.add_argument('--d_model', default=768, type=int, 40 | help="hidden dimension of encoder/decoder") 41 | parser.add_argument('--d_ff', default=2048, type=int, 42 | help="hidden dimension of feedforward layer") 43 | parser.add_argument('--num_blocks', default=4, type=int, 44 | help="number of encoder/decoder blocks") 45 | parser.add_argument('--num_heads', default=8, type=int, 46 | help="number of attention heads") 47 | 48 | parser.add_argument('--dropout_rate', default=0.1, type=float) 49 | parser.add_argument('--smoothing', default=0.1, type=float, 50 | help="label smoothing rate") 51 | 52 | 53 | 54 | 55 | parser.add_argument('--maxlen_vae_Encoder', default=80, type=int, 56 | help="VAE编码器的最大的长度 len(sen1)+len(sen2)") 57 | parser.add_argument('--maxlen_vae_Decoder_en', default=40, type=int, 58 | help="源句子的最大长度") 59 | parser.add_argument('--maxlen_vae_Decoder_de', default=40, type=int, 60 | help="复述句的最大长度") 61 | 62 | 63 | 64 | # vae Z 65 | parser.add_argument('--z_dim', default=768, type=int, 66 | help="VAE中那个Z的维度") 67 | 68 | 69 | 70 | # bert的参数 71 | parser.add_argument('--attention_probs_dropout_prob', default=0.1, type=float) 72 | parser.add_argument('--hidden_dropout_prob', default=0.1, type=float) 73 | parser.add_argument('--initializer_range', default=0.02, type=float) 74 | 75 | 76 | parser.add_argument('--hidden_size', default=768, type=int) 77 | parser.add_argument('--intermediate_size', default=3072, type=int) 78 | parser.add_argument('--max_position_embeddings', default=512, type=int) 79 | parser.add_argument('--num_attention_heads', default=12, type=int) 80 | parser.add_argument('--num_hidden_layers', default=12, type=int) 81 | parser.add_argument('--pooler_fc_size', default=768, type=int) 82 | parser.add_argument('--pooler_num_attention_heads', default=12, type=int) 83 | 84 | 85 | parser.add_argument('--pooler_num_fc_layers', default=3, type=int) 86 | parser.add_argument('--pooler_size_per_head', default=128, type=int) 87 | parser.add_argument('--type_vocab_size', default=2, type=int) 88 | parser.add_argument('--vocab_size', default=21128, type=int) 89 | 90 | parser.add_argument('--directionality', default="bidi") 91 | parser.add_argument('--hidden_act', default="gelu") 92 | parser.add_argument('--pooler_type', default="first_token_transform") 93 | 94 | -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python3 3 | ''' 4 | Feb. 2019 by kyubyong park. 5 | kbpark.linguist@gmail.com. 6 | https://www.github.com/kyubyong/transformer. 7 | 8 | Preprocess the iwslt 2016 datasets. 9 | ''' 10 | 11 | import os 12 | import errno 13 | import sentencepiece as spm 14 | import re 15 | from hparams import Hparams 16 | import logging 17 | 18 | logging.basicConfig(level=logging.INFO) 19 | 20 | def prepro(hp): 21 | """Load raw data -> Preprocessing -> Segmenting with sentencepice 22 | hp: hyperparams. argparse. 23 | """ 24 | logging.info("# Check if raw files exist") 25 | train1 = "iwslt2016/de-en/train.tags.de-en.de" 26 | train2 = "iwslt2016/de-en/train.tags.de-en.en" 27 | eval1 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.de.xml" 28 | eval2 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.en.xml" 29 | test1 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.de.xml" 30 | test2 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.en.xml" 31 | for f in (train1, train2, eval1, eval2, test1, test2): 32 | if not os.path.isfile(f): 33 | raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), f) 34 | 35 | logging.info("# Preprocessing") 36 | # train 37 | _prepro = lambda x: [line.strip() for line in open(x, 'r').read().split("\n") \ 38 | if not line.startswith("<")] 39 | prepro_train1, prepro_train2 = _prepro(train1), _prepro(train2) 40 | assert len(prepro_train1)==len(prepro_train2), "Check if train source and target files match." 41 | 42 | # eval 43 | _prepro = lambda x: [re.sub("<[^>]+>", "", line).strip() \ 44 | for line in open(x, 'r').read().split("\n") \ 45 | if line.startswith("ntk', dec, weights) # (N, T2, vocab_size) 117 | y_hat = tf.to_int32(tf.argmax(logits, axis=-1)) 118 | 119 | return logits, y_hat, y, sents2 120 | 121 | def train(self, xs, ys): 122 | ''' 123 | Returns 124 | loss: scalar. 125 | train_op: training operation 126 | global_step: scalar. 127 | summaries: training summary node 128 | ''' 129 | # forward 130 | memory, sents1 = self.encode(xs) 131 | logits, preds, y, sents2 = self.decode(ys, memory) 132 | 133 | # train scheme 134 | y_ = label_smoothing(tf.one_hot(y, depth=self.hp.vocab_size)) 135 | ce = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y_) 136 | nonpadding = tf.to_float(tf.not_equal(y, self.token2idx[""])) # 0: 137 | loss = tf.reduce_sum(ce * nonpadding) / (tf.reduce_sum(nonpadding) + 1e-7) 138 | 139 | global_step = tf.train.get_or_create_global_step() 140 | lr = noam_scheme(self.hp.lr, global_step, self.hp.warmup_steps) 141 | optimizer = tf.train.AdamOptimizer(lr) 142 | train_op = optimizer.minimize(loss, global_step=global_step) 143 | 144 | tf.summary.scalar('lr', lr) 145 | tf.summary.scalar("loss", loss) 146 | tf.summary.scalar("global_step", global_step) 147 | 148 | summaries = tf.summary.merge_all() 149 | 150 | return loss, train_op, global_step, summaries 151 | 152 | def eval(self, xs, ys): 153 | '''Predicts autoregressively 154 | At inference, input ys is ignored. 155 | Returns 156 | y_hat: (N, T2) 157 | ''' 158 | decoder_inputs, y, y_seqlen, sents2 = ys 159 | 160 | decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.token2idx[""] 161 | ys = (decoder_inputs, y, y_seqlen, sents2) 162 | 163 | memory, sents1 = self.encode(xs, False) 164 | 165 | logging.info("Inference graph is being built. Please be patient.") 166 | for _ in tqdm(range(self.hp.maxlen2)): 167 | logits, y_hat, y, sents2 = self.decode(ys, memory, False) 168 | if tf.reduce_sum(y_hat, 1) == self.token2idx[""]: break 169 | 170 | _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1) 171 | ys = (_decoder_inputs, y, y_seqlen, sents2) 172 | 173 | # monitor a random sample 174 | n = tf.random_uniform((), 0, tf.shape(y_hat)[0]-1, tf.int32) 175 | sent1 = sents1[n] 176 | pred = convert_idx_to_token_tensor(y_hat[n], self.idx2token) 177 | sent2 = sents2[n] 178 | 179 | tf.summary.text("sent1", sent1) 180 | tf.summary.text("pred", pred) 181 | tf.summary.text("sent2", sent2) 182 | summaries = tf.summary.merge_all() 183 | 184 | return y_hat, summaries 185 | 186 | -------------------------------------------------------------------------------- /bert_transformer_vae_for_PAGE.py: -------------------------------------------------------------------------------- 1 | from bert import bert_model_for_PAGE 2 | from bert.bert_model_for_PAGE import get_shape_list,layer_norm 3 | import tensorflow as tf 4 | from modules import ff, multihead_attention, label_smoothing, noam_scheme 5 | from utils import convert_idx_to_token_tensor 6 | from tqdm import tqdm 7 | import logging 8 | 9 | from data_load import load_vocab 10 | 11 | class VaeModel: 12 | 13 | ''' 14 | 这是模型的主题框架 15 | ''' 16 | def __init__(self,hp): 17 | self.hp = hp 18 | self.token2idx, self.idx2token = load_vocab(hp.vocab) 19 | 20 | 21 | # 变积分自动编码器的编码器 22 | def encoder_vae(self, xs, training, mode): 23 | 24 | ''' 25 | :param xs: 26 | :param training: 27 | :param mode: TPAGE表示训练复述模型 LM 表示训练transformer的解码器 PPAGE:代表预测的时候 28 | return: 29 | ''' 30 | 31 | input_ids_vae_encoder, input_masks_vae_encoder, segment_ids_vae_encoder, sents1, sents2 = xs 32 | 33 | if mode == "TPAGE" or mode == "PPAGE": 34 | # 这是整体训练 vae 35 | encoder_vae = bert_model_for_PAGE.BertModel( 36 | config=self.hp, 37 | is_training=training, 38 | input_ids=input_ids_vae_encoder, 39 | input_mask=input_masks_vae_encoder, 40 | token_type_ids=segment_ids_vae_encoder) 41 | 42 | self.embeddings = encoder_vae.get_embedding_table() 43 | self.full_position_embeddings = encoder_vae.get_full_position_embeddings() 44 | pattern_para = tf.squeeze(encoder_vae.get_sequence_output()[:, 0:1, :], axis=1) 45 | gaussian_params = tf.layers.dense(pattern_para, 2 * self.hp.z_dim) 46 | mean = gaussian_params[:, :self.hp.z_dim] 47 | stddev = 1e-6 + tf.nn.softplus(gaussian_params[:, self.hp.z_dim:]) 48 | else: 49 | raise("您好,你选择的mode 出现错误") 50 | 51 | return mean, stddev 52 | 53 | 54 | 55 | # 变积分自动编码器的解码器 56 | def decoder_vae(self, ys, z, training, mode): 57 | input_ids_vae_decoder_enc, input_ids_vae_decoder_dec,output_ids_vae_decoder_dec, sents1, sents2 = ys 58 | if mode == "TPAGE" or mode == "PPAGE": 59 | with tf.variable_scope("decoder_enc_vae", reuse=tf.AUTO_REUSE): 60 | enc = tf.nn.embedding_lookup(self.embeddings, input_ids_vae_decoder_enc) # (N, T1, d_model) 61 | input_shape = get_shape_list(enc, expected_rank=3) 62 | seq_length = input_shape[1] 63 | width = input_shape[2] 64 | output = enc 65 | assert_op = tf.assert_less_equal(seq_length, self.hp.max_position_embeddings) 66 | with tf.control_dependencies([assert_op]): 67 | position_embeddings = tf.slice(self.full_position_embeddings, [0, 0], 68 | [seq_length, -1]) 69 | num_dims = len(output.shape.as_list()) 70 | position_broadcast_shape = [] 71 | for _ in range(num_dims - 2): 72 | position_broadcast_shape.append(1) 73 | position_broadcast_shape.extend([seq_length, width]) 74 | position_embeddings = tf.reshape(position_embeddings, 75 | position_broadcast_shape) 76 | 77 | output += position_embeddings # 添加位置信息 78 | 79 | enc = output 80 | ## Blocks 81 | for i in range(self.hp.num_blocks): 82 | with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE): 83 | # self-attention 84 | enc = multihead_attention(seq_k=input_ids_vae_decoder_enc, 85 | seq_q=input_ids_vae_decoder_enc, 86 | queries=enc, 87 | keys=enc, 88 | values=enc, 89 | num_heads=self.hp.num_heads, 90 | dropout_rate=self.hp.dropout_rate, 91 | training=training, 92 | causality=False, 93 | scope="self_attention") 94 | # feed forward 95 | enc = ff(enc, num_units=[self.hp.d_ff, self.hp.d_model]) 96 | 97 | 98 | memory = enc 99 | 100 | 101 | # 下面是VAE的decoder中的 decoder 102 | with tf.variable_scope("decoder_dec_vae", reuse=tf.AUTO_REUSE): 103 | dec = tf.nn.embedding_lookup(self.embeddings, input_ids_vae_decoder_dec) # (N, T, d_model) 104 | input_shape = get_shape_list(dec, expected_rank=3) 105 | seq_length = input_shape[1] 106 | width = input_shape[2] 107 | output = dec 108 | assert_op = tf.assert_less_equal(seq_length, self.hp.max_position_embeddings) 109 | with tf.control_dependencies([assert_op]): 110 | position_embeddings = tf.slice(self.full_position_embeddings, [0, 0], 111 | [seq_length, -1]) 112 | num_dims = len(output.shape.as_list()) 113 | position_broadcast_shape = [] 114 | for _ in range(num_dims - 2): 115 | position_broadcast_shape.append(1) 116 | position_broadcast_shape.extend([seq_length, width]) 117 | position_embeddings = tf.reshape(position_embeddings, 118 | position_broadcast_shape) 119 | 120 | 121 | output += position_embeddings # 添加位置信息 122 | 123 | dec=output 124 | 125 | # 在这里加上 规则模式 采用的是concate的方式!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 126 | z = tf.expand_dims(z, axis=1) 127 | z = tf.tile(z, multiples=[1, seq_length, 1]) 128 | dec = tf.concat([dec, z], axis=-1) 129 | dec = tf.layers.dense(dec, self.hp.d_model) 130 | # Blocks 131 | for i in range(self.hp.num_blocks): 132 | with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE): 133 | # Masked self-attention (Note that causality is True at this time) 134 | dec = multihead_attention(seq_k=input_ids_vae_decoder_dec, 135 | seq_q=input_ids_vae_decoder_dec, 136 | queries=dec, 137 | keys=dec, 138 | values=dec, 139 | num_heads=self.hp.num_heads, 140 | dropout_rate=self.hp.dropout_rate, 141 | training=training, 142 | causality=True, 143 | scope="self_attention") 144 | 145 | # Vanilla attention 146 | dec = multihead_attention(seq_k=input_ids_vae_decoder_enc, 147 | seq_q=input_ids_vae_decoder_dec, 148 | queries=dec, 149 | keys=memory, 150 | values=memory, 151 | num_heads=self.hp.num_heads, 152 | dropout_rate=self.hp.dropout_rate, 153 | training=training, 154 | causality=False, 155 | scope="vanilla_attention") 156 | ### Feed Forward 157 | dec = ff(dec, num_units=[self.hp.d_ff, self.hp.d_model]) 158 | 159 | 160 | 161 | # Final linear projection (embedding weights are shared) 162 | weights = tf.transpose(self.embeddings) # (d_model, vocab_size) 163 | # 这里为了适应迎合 强制加上一个 FF 164 | logits = tf.einsum('ntd,dk->ntk', dec, weights) # (N, T2, vocab_size) 165 | y_hat = tf.to_int32(tf.argmax(logits, axis=-1)) 166 | return logits, y_hat, output_ids_vae_decoder_dec, sents2 167 | 168 | 169 | 170 | def train(self, xs, ys,mode): 171 | ''' 172 | Returns 173 | loss: scalar. 174 | train_op: training operation 175 | global_step: scalar. 176 | summaries: training summary node 177 | ''' 178 | # forward 179 | 180 | mu, sigma = self.encoder_vae(xs, training=True, mode=mode) 181 | if mode == "TPAGE" or mode == "PPAGE": 182 | # 表示 训练VAE 183 | # 这里提醒自己一下 将embeding 全部设为True 184 | z = mu + sigma * tf.random_normal(tf.shape(mu), 0, 1, dtype=tf.float32) 185 | else: 186 | raise ("许海明在这里提醒你:出现非法mode") 187 | 188 | 189 | 190 | logits, preds, y, sents2 = self.decoder_vae(ys, z,training=True,mode=mode) 191 | 192 | # train scheme 193 | 194 | ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y) 195 | nonpadding = tf.to_float(tf.not_equal(y, self.token2idx["[PAD]"])) # 0: 196 | loss_decoder = tf.reduce_sum(ce * nonpadding) / tf.to_float(get_shape_list(xs[0],expected_rank=2)[0]) 197 | 198 | # 这里加上KL loss 199 | if mode == "TPAGE": 200 | KL_loss = tf.reduce_mean(0.5 * tf.reduce_sum(tf.square(mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, [1])) 201 | else: 202 | KL_loss = 0.0 203 | 204 | loss = loss_decoder + KL_loss 205 | 206 | 207 | global_step = tf.train.get_or_create_global_step() 208 | lr = noam_scheme(self.hp.lr, global_step, self.hp.warmup_steps) 209 | optimizer = tf.train.AdamOptimizer(lr) 210 | train_op = optimizer.minimize(loss, global_step=global_step) 211 | 212 | # # monitor a random sample 213 | n = tf.random_uniform((), 0, tf.shape(preds)[0] - 1, tf.int32) 214 | print_demo=(xs[0][n], 215 | y[n], 216 | preds[n]) 217 | 218 | 219 | tf.summary.scalar('lr', lr) 220 | tf.summary.scalar("KL_loss", KL_loss) 221 | tf.summary.scalar("loss_decoder", loss_decoder) 222 | tf.summary.scalar("loss", loss) 223 | tf.summary.scalar("global_step", global_step) 224 | 225 | summaries = tf.summary.merge_all() 226 | 227 | return loss, train_op, global_step, summaries, print_demo 228 | 229 | 230 | 231 | def eval(self, xs,ys,mode): 232 | 233 | mu, sigma = self.encoder_vae(xs, training=False, mode=mode) #这里主要是为了获取 embeding 其他的都没用 234 | 235 | if mode == "TPAGE" or mode == "PPAGE": 236 | # 表示 训练VAE 237 | # 这里提醒自己一下 将embeding 全部设为True 238 | z = mu + sigma * tf.random_normal(tf.shape(mu), 0, 1, dtype=tf.float32) 239 | 240 | else: 241 | raise ("许海明在这里提醒你:出现非法mode") 242 | 243 | 244 | 245 | # z = tf.random_normal([get_shape_list(xs[0], expected_rank=2)[0], self.hp.z_dim]) #自动生成采样因子 246 | 247 | input_ids_vae_decoder_enc, input_ids_vae_decoder_dec, output_ids_vae_decoder_dec, sents1, sents2 = ys 248 | 249 | decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.token2idx[""] 250 | ys = (input_ids_vae_decoder_enc, decoder_inputs, output_ids_vae_decoder_dec, sents1, sents2) 251 | 252 | logging.info("Inference graph is being built. Please be patient.") 253 | for _ in tqdm(range(self.hp.maxlen_vae_Decoder_de)): 254 | logits, y_hat, y, sents2 = self.decoder_vae(ys, z, training=False, mode=mode) 255 | if tf.reduce_sum(y_hat, 1) == self.token2idx["[PAD]"]: 256 | break 257 | 258 | _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1) 259 | ys = (input_ids_vae_decoder_enc, _decoder_inputs, output_ids_vae_decoder_dec, sents1, sents2) 260 | 261 | # monitor a random sample 262 | n = tf.random_uniform((), 0, tf.shape(y_hat)[0]-1, tf.int32) 263 | sent1 = sents1[n] 264 | pred = convert_idx_to_token_tensor(y_hat[n], self.idx2token) 265 | sent2 = sents2[n] 266 | 267 | tf.summary.text("sent1", sent1) 268 | tf.summary.text("pred", pred) 269 | tf.summary.text("sent2", sent2) 270 | summaries = tf.summary.merge_all() 271 | 272 | return y_hat, summaries -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python3 3 | 4 | from bert.bert_model_for_PAGE import get_shape_list 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | def ln(inputs, epsilon = 1e-8, scope="ln"): 10 | '''Applies layer normalization. See https://arxiv.org/abs/1607.06450. 11 | inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`. 12 | epsilon: A floating number. A very small number for preventing ZeroDivision Error. 13 | scope: Optional scope for `variable_scope`. 14 | 15 | Returns: 16 | A tensor with the same shape and data dtype as `inputs`. 17 | ''' 18 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 19 | inputs_shape = inputs.get_shape() 20 | params_shape = inputs_shape[-1:] 21 | 22 | mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True) 23 | beta= tf.get_variable("beta", params_shape, initializer=tf.zeros_initializer()) 24 | gamma = tf.get_variable("gamma", params_shape, initializer=tf.ones_initializer()) 25 | normalized = (inputs - mean) / ( (variance + epsilon) ** (.5) ) 26 | outputs = gamma * normalized + beta 27 | 28 | return outputs 29 | 30 | def get_token_embeddings(vocab_size, num_units, zero_pad=True): 31 | '''Constructs token embedding matrix. 32 | Note that the column of index 0's are set to zeros. 33 | vocab_size: scalar. V. 34 | num_units: embedding dimensionalty. E. 35 | zero_pad: Boolean. If True, all the values of the first row (id = 0) should be constant zero 36 | To apply query/key masks easily, zero pad is turned on. 37 | 38 | Returns 39 | weight variable: (V, E) 40 | ''' 41 | with tf.variable_scope("shared_weight_matrix"): 42 | embeddings = tf.get_variable('weight_mat', 43 | dtype=tf.float32, 44 | shape=(vocab_size, num_units), 45 | initializer=tf.contrib.layers.xavier_initializer()) 46 | if zero_pad: 47 | embeddings = tf.concat((tf.zeros(shape=[1, num_units]), 48 | embeddings[1:, :]), 0) 49 | return embeddings 50 | 51 | def scaled_dot_product_attention(seq_k, seq_q, Q, K, V, 52 | causality=False, dropout_rate=0., 53 | training=True, 54 | num_heads=8, 55 | scope="scaled_dot_product_attention"): 56 | '''See 3.2.1. 57 | Q: Packed queries. 3d tensor. [N, T_q, d_k]. 58 | K: Packed keys. 3d tensor. [N, T_k, d_k]. 59 | V: Packed values. 3d tensor. [N, T_k, d_v]. 60 | causality: If True, applies masking for future blinding 61 | dropout_rate: A floating point number of [0, 1]. 62 | training: boolean for controlling droput 63 | scope: Optional scope for `variable_scope`. 64 | ''' 65 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 66 | d_k = Q.get_shape().as_list()[-1] 67 | 68 | # dot product 69 | outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1])) # (N, T_q, T_k) 70 | 71 | # scale 72 | outputs /= d_k ** 0.5 73 | 74 | # key masking 75 | outputs = mask(seq_k, seq_q, outputs, num_heads,type="key") 76 | 77 | # causality or future blinding masking 78 | if causality: 79 | outputs = mask(seq_k, seq_q, outputs, num_heads,type="future") 80 | 81 | # softmax 82 | outputs = tf.nn.softmax(outputs) 83 | 84 | attention = tf.transpose(outputs, [0, 2, 1]) 85 | tf.summary.image("attention", tf.expand_dims(attention[:1], -1)) 86 | 87 | # query masking 88 | outputs = mask(seq_k, seq_q, outputs,num_heads, type="query") 89 | 90 | # dropout 91 | outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=training) 92 | 93 | # weighted sum (context vectors) 94 | outputs = tf.matmul(outputs, V) # (N, T_q, d_v) 95 | 96 | return outputs 97 | 98 | def mask(seq_k, seq_q, inputs,num_heads, type=None): 99 | 100 | padding_num = -2 ** 32 + 1 101 | if type in ("k", "key", "keys"): 102 | # Generate masks 103 | masks = get_attn_key_pad_mask(seq_k=seq_k, seq_q=seq_q, PAD_ID=0) 104 | # Apply masks to inputs 105 | masks = tf.tile(masks,[num_heads, 1, 1]) 106 | paddings = tf.ones_like(inputs) * padding_num 107 | outputs = tf.where(tf.equal(masks, 0), paddings, inputs) # (N, T_q, T_k) 108 | elif type in ("q", "query", "queries"): 109 | # Generate masks 110 | masks = get_non_pad_mask(seq_k=seq_k, seq_q=seq_q, PAD_ID=0) 111 | masks = tf.tile(masks, [num_heads, 1, 1]) 112 | # Apply masks to inputs 113 | outputs = inputs*masks 114 | 115 | elif type in ("f", "future", "right"): 116 | masks = get_subsequent_mask(inputs) 117 | # masks = tf.tile(masks, [num_heads, 1, 1]) 118 | 119 | paddings = tf.ones_like(masks) * padding_num 120 | outputs = tf.where(tf.equal(masks, 0), paddings, inputs) 121 | else: 122 | print("Check if you entered type correctly!") 123 | 124 | return outputs 125 | 126 | 127 | 128 | 129 | 130 | 131 | def get_non_pad_mask(seq_k, seq_q, PAD_ID): 132 | ''' 133 | 134 | :param seq: [batch_size,len] 135 | :return: 136 | ''' 137 | ''' 138 | masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q) 139 | masks = tf.expand_dims(masks, -1) # (N, T_q, 1) 140 | masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k) 141 | ''' 142 | masks = tf.cast(tf.math.not_equal(seq_q, PAD_ID), tf.float32) 143 | masks = tf.expand_dims(masks, -1) # (N, T_q, 1) 144 | masks = tf.tile(masks, [1, 1, get_shape_list(seq_k)[1]]) # (N, T_q, T_k) 145 | return masks 146 | 147 | 148 | def get_attn_key_pad_mask(seq_k,seq_q,PAD_ID): 149 | ''' For masking out the padding part of key sequence. ''' 150 | # print(get_shape_list(seq_k,expected_rank=2)) 151 | 152 | padding_mask = tf.cast(tf.math.not_equal(seq_k,PAD_ID),tf.float32) 153 | masks = tf.expand_dims(padding_mask, 1) # (N, 1, T_k) 154 | masks = tf.tile(masks, [1, get_shape_list(seq_q,expected_rank=2)[1], 1]) # (N, T_q, T_k) 155 | # print(get_shape_list(masks,expected_rank=3)) 156 | return masks 157 | 158 | 159 | 160 | 161 | def get_subsequent_mask(inputs): 162 | ''' For masking out the subsequent info. ''' 163 | 164 | diag_vals = tf.ones_like(inputs[0, :, :]) # (T_q, T_k) 165 | tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k) 166 | masks = tf.tile(tf.expand_dims(tril, 0), [get_shape_list(inputs)[0], 1, 1]) # (N, T_q, T_k) 167 | 168 | return masks 169 | 170 | 171 | 172 | 173 | def multihead_attention(seq_k, 174 | seq_q, 175 | queries, keys, values, 176 | num_heads=8, 177 | dropout_rate=0, 178 | training=True, 179 | causality=False, 180 | scope="multihead_attention"): 181 | '''Applies multihead attention. See 3.2.2 182 | queries: A 3d tensor with shape of [N, T_q, d_model]. 183 | keys: A 3d tensor with shape of [N, T_k, d_model]. 184 | values: A 3d tensor with shape of [N, T_k, d_model]. 185 | num_heads: An int. Number of heads. 186 | dropout_rate: A floating point number. 187 | training: Boolean. Controller of mechanism for dropout. 188 | causality: Boolean. If true, units that reference the future are masked. 189 | scope: Optional scope for `variable_scope`. 190 | 191 | Returns 192 | A 3d tensor with shape of (N, T_q, C) 193 | ''' 194 | d_model = queries.get_shape().as_list()[-1] 195 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 196 | # Linear projections 197 | Q = tf.layers.dense(queries, d_model, use_bias=False) # (N, T_q, d_model) 198 | K = tf.layers.dense(keys, d_model, use_bias=False) # (N, T_k, d_model) 199 | V = tf.layers.dense(values, d_model, use_bias=False) # (N, T_k, d_model) 200 | 201 | # Split and concat 202 | Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, d_model/h) 203 | K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h) 204 | V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h) 205 | 206 | # Attention 207 | outputs = scaled_dot_product_attention(seq_k, seq_q, Q_, K_, V_, causality, dropout_rate, training,num_heads) 208 | 209 | # Restore shape 210 | outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2 ) # (N, T_q, d_model) 211 | 212 | # Residual connection 213 | outputs += queries 214 | 215 | # Normalize 216 | outputs = ln(outputs) 217 | 218 | return outputs 219 | 220 | def ff(inputs, num_units, scope="positionwise_feedforward"): 221 | '''position-wise feed forward net. See 3.3 222 | 223 | inputs: A 3d tensor with shape of [N, T, C]. 224 | num_units: A list of two integers. 225 | scope: Optional scope for `variable_scope`. 226 | 227 | Returns: 228 | A 3d tensor with the same shape and dtype as inputs 229 | ''' 230 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 231 | # Inner layer 232 | outputs = tf.layers.dense(inputs, num_units[0], activation=tf.nn.relu) 233 | 234 | # Outer layer 235 | outputs = tf.layers.dense(outputs, num_units[1]) 236 | 237 | # Residual connection 238 | outputs += inputs 239 | 240 | # Normalize 241 | outputs = ln(outputs) 242 | 243 | return outputs 244 | 245 | def label_smoothing(inputs, epsilon=0.1): 246 | '''Applies label smoothing. See 5.4 and https://arxiv.org/abs/1512.00567. 247 | inputs: 3d tensor. [N, T, V], where V is the number of vocabulary. 248 | epsilon: Smoothing rate. 249 | 250 | For example, 251 | 252 | ``` 253 | import tensorflow as tf 254 | inputs = tf.convert_to_tensor([[[0, 0, 1], 255 | [0, 1, 0], 256 | [1, 0, 0]], 257 | 258 | [[1, 0, 0], 259 | [1, 0, 0], 260 | [0, 1, 0]]], tf.float32) 261 | 262 | outputs = label_smoothing(inputs) 263 | 264 | with tf.Session() as sess: 265 | print(sess.run([outputs])) 266 | 267 | >> 268 | [array([[[ 0.03333334, 0.03333334, 0.93333334], 269 | [ 0.03333334, 0.93333334, 0.03333334], 270 | [ 0.93333334, 0.03333334, 0.03333334]], 271 | 272 | [[ 0.93333334, 0.03333334, 0.03333334], 273 | [ 0.93333334, 0.03333334, 0.03333334], 274 | [ 0.03333334, 0.93333334, 0.03333334]]], dtype=float32)] 275 | ``` 276 | ''' 277 | V = inputs.get_shape().as_list()[-1] # number of channels 278 | return ((1-epsilon) * inputs) + (epsilon / V) 279 | 280 | def positional_encoding(inputs, 281 | maxlen, 282 | masking=True, 283 | scope="positional_encoding"): 284 | '''Sinusoidal Positional_Encoding. See 3.5 285 | inputs: 3d tensor. (N, T, E) 286 | maxlen: scalar. Must be >= T 287 | masking: Boolean. If True, padding positions are set to zeros. 288 | scope: Optional scope for `variable_scope`. 289 | 290 | returns 291 | 3d tensor that has the same shape as inputs. 292 | ''' 293 | 294 | E = inputs.get_shape().as_list()[-1] # static 295 | N, T = tf.shape(inputs)[0], tf.shape(inputs)[1] # dynamic 296 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 297 | # position indices 298 | position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1]) # (N, T) 299 | 300 | # First part of the PE function: sin and cos argument 301 | position_enc = np.array([ 302 | [pos / np.power(10000, (i-i%2)/E) for i in range(E)] 303 | for pos in range(maxlen)]) 304 | 305 | # Second part, apply the cosine to even columns and sin to odds. 306 | position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i 307 | position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1 308 | position_enc = tf.convert_to_tensor(position_enc, tf.float32) # (maxlen, E) 309 | 310 | # lookup 311 | outputs = tf.nn.embedding_lookup(position_enc, position_ind) 312 | 313 | # masks 314 | if masking: 315 | outputs = tf.where(tf.equal(inputs, 0), inputs, outputs) 316 | 317 | return tf.to_float(outputs) 318 | 319 | def noam_scheme(init_lr, global_step, warmup_steps=4000.): 320 | '''Noam scheme learning rate decay 321 | init_lr: initial learning rate. scalar. 322 | global_step: scalar. 323 | warmup_steps: scalar. During warmup_steps, learning rate increases 324 | until it reaches init_lr. 325 | ''' 326 | step = tf.cast(global_step + 1, dtype=tf.float32) 327 | return init_lr * warmup_steps ** 0.5 * tf.minimum(step * warmup_steps ** -1.5, step ** -0.5) -------------------------------------------------------------------------------- /bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with open(vocab_file, mode="r",encoding="utf-8") as fr: 126 | for token in fr: 127 | if not token: 128 | break 129 | token = token.strip() 130 | vocab[token] = index 131 | index += 1 132 | return vocab 133 | 134 | 135 | def convert_by_vocab(vocab, items): 136 | """Converts a sequence of [tokens|ids] using the vocab.""" 137 | output = [] 138 | for item in items: 139 | output.append(vocab[item]) 140 | return output 141 | 142 | 143 | def convert_tokens_to_ids(vocab, tokens): 144 | return convert_by_vocab(vocab, tokens) 145 | 146 | 147 | def convert_ids_to_tokens(inv_vocab, ids): 148 | return convert_by_vocab(inv_vocab, ids) 149 | 150 | 151 | def whitespace_tokenize(text): 152 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 153 | text = text.strip() 154 | if not text: 155 | return [] 156 | tokens = text.split() 157 | return tokens 158 | 159 | 160 | class FullTokenizer(object): 161 | """Runs end-to-end tokenziation.""" 162 | 163 | def __init__(self, vocab_file, do_lower_case=True): 164 | self.vocab = load_vocab(vocab_file) 165 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 166 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 167 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 168 | 169 | def tokenize(self, text): 170 | split_tokens = [] 171 | for token in self.basic_tokenizer.tokenize(text): 172 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 173 | split_tokens.append(sub_token) 174 | 175 | return split_tokens 176 | 177 | def convert_tokens_to_ids(self, tokens): 178 | return convert_by_vocab(self.vocab, tokens) 179 | 180 | def convert_ids_to_tokens(self, ids): 181 | return convert_by_vocab(self.inv_vocab, ids) 182 | 183 | 184 | class BasicTokenizer(object): 185 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 186 | 187 | def __init__(self, do_lower_case=True): 188 | """Constructs a BasicTokenizer. 189 | 190 | Args: 191 | do_lower_case: Whether to lower case the input. 192 | """ 193 | self.do_lower_case = do_lower_case 194 | 195 | def tokenize(self, text): 196 | """Tokenizes a piece of text.""" 197 | text = convert_to_unicode(text) 198 | text = self._clean_text(text) 199 | 200 | # This was added on November 1st, 2018 for the multilingual and Chinese 201 | # models. This is also applied to the English models now, but it doesn't 202 | # matter since the English models were not trained on any Chinese data 203 | # and generally don't have any Chinese data in them (there are Chinese 204 | # characters in the vocabulary because Wikipedia does have some Chinese 205 | # words in the English Wikipedia.). 206 | text = self._tokenize_chinese_chars(text) 207 | 208 | orig_tokens = whitespace_tokenize(text) 209 | split_tokens = [] 210 | for token in orig_tokens: 211 | if self.do_lower_case: 212 | token = token.lower() 213 | token = self._run_strip_accents(token) 214 | split_tokens.extend(self._run_split_on_punc(token)) 215 | 216 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 217 | return output_tokens 218 | 219 | def _run_strip_accents(self, text): 220 | """Strips accents from a piece of text.""" 221 | text = unicodedata.normalize("NFD", text) 222 | output = [] 223 | for char in text: 224 | cat = unicodedata.category(char) 225 | if cat == "Mn": 226 | continue 227 | output.append(char) 228 | return "".join(output) 229 | 230 | def _run_split_on_punc(self, text): 231 | """Splits punctuation on a piece of text.""" 232 | chars = list(text) 233 | i = 0 234 | start_new_word = True 235 | output = [] 236 | while i < len(chars): 237 | char = chars[i] 238 | if _is_punctuation(char): 239 | output.append([char]) 240 | start_new_word = True 241 | else: 242 | if start_new_word: 243 | output.append([]) 244 | start_new_word = False 245 | output[-1].append(char) 246 | i += 1 247 | 248 | return ["".join(x) for x in output] 249 | 250 | def _tokenize_chinese_chars(self, text): 251 | """Adds whitespace around any CJK character.""" 252 | output = [] 253 | for char in text: 254 | cp = ord(char) 255 | if self._is_chinese_char(cp): 256 | output.append(" ") 257 | output.append(char) 258 | output.append(" ") 259 | else: 260 | output.append(char) 261 | return "".join(output) 262 | 263 | def _is_chinese_char(self, cp): 264 | """Checks whether CP is the codepoint of a CJK character.""" 265 | # This defines a "chinese character" as anything in the CJK Unicode block: 266 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 267 | # 268 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 269 | # despite its name. The modern Korean Hangul alphabet is a different block, 270 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 271 | # space-separated words, so they are not treated specially and handled 272 | # like the all of the other languages. 273 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 274 | (cp >= 0x3400 and cp <= 0x4DBF) or # 275 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 276 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 277 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 278 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 279 | (cp >= 0xF900 and cp <= 0xFAFF) or # 280 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 281 | return True 282 | 283 | return False 284 | 285 | def _clean_text(self, text): 286 | """Performs invalid character removal and whitespace cleanup on text.""" 287 | output = [] 288 | for char in text: 289 | cp = ord(char) 290 | if cp == 0 or cp == 0xfffd or _is_control(char): 291 | continue 292 | if _is_whitespace(char): 293 | output.append(" ") 294 | else: 295 | output.append(char) 296 | return "".join(output) 297 | 298 | 299 | class WordpieceTokenizer(object): 300 | """Runs WordPiece tokenziation.""" 301 | 302 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 303 | self.vocab = vocab 304 | self.unk_token = unk_token 305 | self.max_input_chars_per_word = max_input_chars_per_word 306 | 307 | def tokenize(self, text): 308 | """Tokenizes a piece of text into its word pieces. 309 | 310 | This uses a greedy longest-match-first algorithm to perform tokenization 311 | using the given vocabulary. 312 | 313 | For example: 314 | input = "unaffable" 315 | output = ["un", "##aff", "##able"] 316 | 317 | Args: 318 | text: A single token or whitespace separated tokens. This should have 319 | already been passed through `BasicTokenizer. 320 | 321 | Returns: 322 | A list of wordpiece tokens. 323 | """ 324 | 325 | text = convert_to_unicode(text) 326 | 327 | output_tokens = [] 328 | for token in whitespace_tokenize(text): 329 | chars = list(token) 330 | if len(chars) > self.max_input_chars_per_word: 331 | output_tokens.append(self.unk_token) 332 | continue 333 | 334 | is_bad = False 335 | start = 0 336 | sub_tokens = [] 337 | while start < len(chars): 338 | end = len(chars) 339 | cur_substr = None 340 | while start < end: 341 | substr = "".join(chars[start:end]) 342 | if start > 0: 343 | substr = "##" + substr 344 | if substr in self.vocab: 345 | cur_substr = substr 346 | break 347 | end -= 1 348 | if cur_substr is None: 349 | is_bad = True 350 | break 351 | sub_tokens.append(cur_substr) 352 | start = end 353 | 354 | if is_bad: 355 | output_tokens.append(self.unk_token) 356 | else: 357 | output_tokens.extend(sub_tokens) 358 | return output_tokens 359 | 360 | 361 | def _is_whitespace(char): 362 | """Checks whether `chars` is a whitespace character.""" 363 | # \t, \n, and \r are technically contorl characters but we treat them 364 | # as whitespace since they are generally considered as such. 365 | if char == " " or char == "\t" or char == "\n" or char == "\r": 366 | return True 367 | cat = unicodedata.category(char) 368 | if cat == "Zs": 369 | return True 370 | return False 371 | 372 | 373 | def _is_control(char): 374 | """Checks whether `chars` is a control character.""" 375 | # These are technically control characters but we count them as whitespace 376 | # characters. 377 | if char == "\t" or char == "\n" or char == "\r": 378 | return False 379 | cat = unicodedata.category(char) 380 | if cat in ("Cc", "Cf"): 381 | return True 382 | return False 383 | 384 | 385 | def _is_punctuation(char): 386 | """Checks whether `chars` is a punctuation character.""" 387 | cp = ord(char) 388 | # We treat all non-letter/number ASCII as punctuation. 389 | # Characters such as "^", "$", and "`" are not in the Unicode 390 | # Punctuation class but we treat them as punctuation anyways, for 391 | # consistency. 392 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 393 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 394 | return True 395 | cat = unicodedata.category(char) 396 | if cat.startswith("P"): 397 | return True 398 | return False 399 | -------------------------------------------------------------------------------- /data_load.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python3 3 | 4 | import tensorflow as tf 5 | from utils import calc_num_batches 6 | import copy 7 | import collections 8 | from bert import tokenization 9 | 10 | 11 | 12 | def load_vocab(vocab_fpath): 13 | '''Loads vocabulary file and returns idx<->token maps 14 | vocab_fpath: string. vocabulary file path. 15 | Note that these are reserved 16 | 0: , 1: , 2: , 3: 17 | 注意 这里 18 | 0:pad 19 | [CLS]的位置 我让他学习复述规则 20 | [UNK] 21 | [SEP] 两个句子分隔符号 同时也是句子结束符号 22 | [MASK] 23 | 一个句子的开始符号 24 | Returns 25 | two dictionaries. 26 | ''' 27 | vocab = collections.OrderedDict() 28 | index = 0 29 | with open(vocab_fpath, mode="r", encoding="utf-8") as fr: 30 | for token in fr: 31 | print(token) 32 | if not token: 33 | break 34 | token = token.strip() 35 | vocab[token] = index 36 | index += 1 37 | 38 | token2idx = {token: idx for idx, token in enumerate(vocab)} 39 | idx2token = {idx: token for idx, token in enumerate(vocab)} 40 | return token2idx, idx2token 41 | 42 | 43 | 44 | def load_data_for_test(fpath): 45 | ''' 46 | 47 | :param fpath:文件的路径 48 | :param sep 分隔符 49 | :return: 50 | ''' 51 | sents1, sents2 = [], [] 52 | with open(fpath, mode='r',encoding="utf-8") as fr: 53 | for line in fr: 54 | line = line.strip() 55 | if line is not None and line != "": 56 | assert len(line.strip())>3 57 | sents1.append(line.strip()) 58 | sents2=copy.deepcopy(sents1) 59 | return sents1, sents2 60 | 61 | 62 | 63 | 64 | def load_data_for_train_or_eval(fpath, sep="---xhm--"): 65 | ''' 66 | 67 | :param fpath:文件的路径 68 | :param sep 分隔符 69 | :return: 70 | ''' 71 | sents1, sents2 = [], [] 72 | num = 0 73 | with open(fpath, mode='r',encoding="utf-8") as fr: 74 | for line in fr: 75 | num+=1 76 | if num%1000000==0: 77 | print("数据正在预处理请你耐心等待。。。。",num) 78 | line = line.strip() 79 | if line is not None and line != "": 80 | sens = line.split(sep) 81 | if len(sens)!=3: 82 | print("语料格式错误的数据:",line) 83 | continue 84 | else: 85 | # assert len(sens[0].strip())>=3 86 | # assert len(sens[1].strip())>=3 87 | # 88 | if len(sens[1].strip())<=3 or len(sens[2].strip())<=3: 89 | print("数据长度不合格",line) 90 | continue 91 | sents1.append(sens[1].strip()) 92 | sents2.append(sens[2].strip()) 93 | return sents1, sents2 94 | 95 | 96 | 97 | 98 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 99 | """Truncates a sequence pair in place to the maximum length.""" 100 | 101 | # This is a simple heuristic which will always truncate the longer sequence 102 | # one token at a time. This makes more sense than truncating an equal percent 103 | # of tokens from each, since if one sequence is very short then each token 104 | # that's truncated likely contains more information than a longer sequence. 105 | while True: 106 | total_length = len(tokens_a) + len(tokens_b) 107 | if total_length <= max_length: 108 | break 109 | if len(tokens_a) > len(tokens_b): 110 | tokens_a.pop() 111 | else: 112 | tokens_b.pop() 113 | 114 | 115 | 116 | # 许海明 117 | class InputFeatures(object): 118 | """A single set of features of data.""" 119 | ''' 120 | yield (input_ids_vae_encoder, input_masks_vae_encoder, segment_ids_vae_encoder, sent1, sent2), \ 121 | (input_ids_vae_decoder_enc, input_ids_vae_decoder_dec, output_ids_vae_decoder_dec, sent1, sent2) 122 | ''' 123 | 124 | 125 | def __init__(self, 126 | input_ids_vae_encoder, 127 | input_masks_vae_encoder, 128 | segment_ids_vae_encoder, 129 | input_ids_vae_decoder_enc, 130 | input_ids_vae_decoder_dec, 131 | output_ids_vae_decoder_dec, 132 | sent1, 133 | sent2, 134 | is_real_example=True): 135 | 136 | self.input_ids_vae_encoder=input_ids_vae_encoder 137 | self.input_masks_vae_encoder=input_masks_vae_encoder 138 | self.segment_ids_vae_encoder=segment_ids_vae_encoder 139 | self.input_ids_vae_decoder_enc=input_ids_vae_decoder_enc 140 | self.input_ids_vae_decoder_dec=input_ids_vae_decoder_dec 141 | self.output_ids_vae_decoder_dec=output_ids_vae_decoder_dec 142 | self.sent1=sent1 143 | self.sent2=sent2 144 | self.is_real_example = is_real_example 145 | 146 | 147 | 148 | 149 | # 许海明 150 | def convert_single_example(ex_index,sent1, sent2, 151 | maxlen_vae_Encoder, 152 | maxlen_vae_Decoder_en, 153 | maxlen_vae_Decoder_de, tokenizer): 154 | 155 | 156 | 157 | sent1_c = copy.deepcopy(sent1) 158 | sent2_c = copy.deepcopy(sent2) 159 | 160 | ''' 161 | # VAE 的编码器的输入 简单就是一个bert的两句话的合在一起 162 | ''' 163 | tokens_sent1 = tokenizer.tokenize(sent1) 164 | tokens_sent2 = tokenizer.tokenize(sent2) 165 | _truncate_seq_pair(tokens_sent1, tokens_sent2, maxlen_vae_Encoder - 3) 166 | tokens_vae_encoder = [] 167 | segment_ids_vae_encoder = [] 168 | 169 | tokens_vae_encoder.append("[CLS]") 170 | segment_ids_vae_encoder.append(0) 171 | 172 | for token in tokens_sent1: 173 | tokens_vae_encoder.append(token) 174 | segment_ids_vae_encoder.append(0) 175 | 176 | tokens_vae_encoder.append("[SEP]") 177 | segment_ids_vae_encoder.append(0) 178 | 179 | for token in tokens_sent2: 180 | tokens_vae_encoder.append(token) 181 | segment_ids_vae_encoder.append(1) 182 | 183 | tokens_vae_encoder.append("[SEP]") 184 | segment_ids_vae_encoder.append(1) 185 | 186 | input_ids_vae_encoder = tokenizer.convert_tokens_to_ids(tokens_vae_encoder) 187 | input_masks_vae_encoder = [1] * len(input_ids_vae_encoder) 188 | # print("\n\ntokens_vae_encoder",tokens_vae_encoder) 189 | # print("\n\ninput_ids_vae_encoder",input_ids_vae_encoder) 190 | while len(input_ids_vae_encoder) < maxlen_vae_Encoder: 191 | input_ids_vae_encoder.append(0) 192 | input_masks_vae_encoder.append(0) 193 | segment_ids_vae_encoder.append(0) 194 | 195 | assert len(input_ids_vae_encoder) == maxlen_vae_Encoder 196 | assert len(input_masks_vae_encoder) == maxlen_vae_Encoder 197 | assert len(segment_ids_vae_encoder) == maxlen_vae_Encoder 198 | 199 | # vae的解码器的编码器输入 用符号 3 200 | tokens_sent3 = tokenizer.tokenize(sent1_c) 201 | 202 | if len(tokens_sent3) > maxlen_vae_Decoder_en - 1: 203 | tokens_sent3 = tokens_sent3[0:(maxlen_vae_Decoder_en - 1)] 204 | 205 | tokens_vae_decoder_enc = [] 206 | for token in tokens_sent3: 207 | tokens_vae_decoder_enc.append(token) 208 | # segment_ids_vae_decoder_enc.append(0) 209 | tokens_vae_decoder_enc.append("[SEP]") 210 | 211 | # segment_ids_vae_decoder_enc.append(0) 212 | input_ids_vae_decoder_enc = tokenizer.convert_tokens_to_ids(tokens_vae_decoder_enc) 213 | # print("\n\ntokens_vae_decoder_enc:",tokens_vae_decoder_enc) 214 | # print("input_ids_vae_decoder_enc:", input_ids_vae_decoder_enc) 215 | 216 | # input_mask_vae_decoder_enc = [1] * len(input_ids_vae_decoder_enc) 217 | while len(input_ids_vae_decoder_enc) < maxlen_vae_Decoder_en: 218 | input_ids_vae_decoder_enc.append(0) 219 | # input_mask_vae_decoder_enc.append(0) 220 | # segment_ids_vae_decoder_enc.append(0) 221 | 222 | # vae解码器的解码器的输入和输出 用符号 4 223 | # 这是训练 PAGE模型的代码 224 | tokens_sent4 = tokenizer.tokenize(sent2_c) 225 | 226 | if len(tokens_sent4) > maxlen_vae_Decoder_de - 1: 227 | tokens_sent4 = tokens_sent4[0:(maxlen_vae_Decoder_de - 1)] 228 | 229 | tokens_sent4_input = [""] 230 | tokens_sent4_input.extend(copy.copy(tokens_sent4)) 231 | tokens_sent4_output = copy.copy(tokens_sent4) 232 | tokens_sent4_output.append("[SEP]") 233 | 234 | input_ids_vae_decoder_dec = tokenizer.convert_tokens_to_ids(tokens_sent4_input) 235 | output_ids_vae_decoder_dec = tokenizer.convert_tokens_to_ids(tokens_sent4_output) 236 | 237 | 238 | while len(input_ids_vae_decoder_dec) < maxlen_vae_Decoder_de: 239 | input_ids_vae_decoder_dec.append(0) 240 | output_ids_vae_decoder_dec.append(0) 241 | # input_mask_vae_decoder_dec.append(0) 242 | 243 | assert len(input_ids_vae_decoder_dec) == maxlen_vae_Decoder_de 244 | assert len(output_ids_vae_decoder_dec) == maxlen_vae_Decoder_de 245 | # assert len(input_mask_vae_decoder_dec) == maxlen_vae_Decoder_de 246 | 247 | # x = encode(sent1, "x", token2idx) 248 | # y = encode(sent2, "y", token2idx) 249 | # decoder_input, y = y[:-1], y[1:] 250 | # 251 | # x_seqlen, y_seqlen = len(x), len(y) 252 | if ex_index<=3: 253 | print("*** Example ***") 254 | print("guid: %s" % (ex_index)) 255 | print("\n\ntokens_vae_encoder",tokens_vae_encoder) 256 | print("input_ids_vae_encoder",input_ids_vae_encoder) 257 | 258 | 259 | print("\n\ntokens_vae_decoder_enc:", tokens_vae_decoder_enc) 260 | print("input_ids_vae_decoder_enc:", input_ids_vae_decoder_enc) 261 | 262 | 263 | 264 | print("\n\n tokens_input_vae_decoder_dec", tokens_sent4_input) 265 | print("input_ids_vae_decoder_dec", input_ids_vae_decoder_dec) 266 | 267 | print("\n\n tokens_output_vae_decoder_dec", tokens_sent4_output) 268 | print("output_ids_vae_decoder_dec", output_ids_vae_decoder_dec) 269 | 270 | feature = InputFeatures( input_ids_vae_encoder=input_ids_vae_encoder, 271 | input_masks_vae_encoder=input_masks_vae_encoder, 272 | segment_ids_vae_encoder=segment_ids_vae_encoder, 273 | input_ids_vae_decoder_enc=input_ids_vae_decoder_enc, 274 | input_ids_vae_decoder_dec=input_ids_vae_decoder_dec, 275 | output_ids_vae_decoder_dec=output_ids_vae_decoder_dec, 276 | sent1=sent1, 277 | sent2=sent2, 278 | is_real_example=True) 279 | return feature 280 | 281 | 282 | 283 | 284 | # 许海明 285 | def file_based_convert_examples_to_features(sentS1, 286 | sentS2, 287 | maxlen_vae_Encoder, 288 | maxlen_vae_Decoder_en, 289 | maxlen_vae_Decoder_de, 290 | tokenizer, 291 | output_file): 292 | 293 | writer = tf.python_io.TFRecordWriter(output_file) 294 | 295 | for (ex_index, (sent1,sent2)) in enumerate(zip(sentS1,sentS2)): 296 | if ex_index % 500000 == 0: 297 | print("Writing example %d of %d" % (ex_index, len(sentS1))) 298 | 299 | feature = convert_single_example(ex_index,sent1, sent2, 300 | maxlen_vae_Encoder, 301 | maxlen_vae_Decoder_en, 302 | maxlen_vae_Decoder_de, tokenizer) 303 | 304 | def create_int_feature(values): 305 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 306 | return f 307 | 308 | def create_bytes_feature(value): 309 | """Returns a bytes_list from a string / byte.""" 310 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()])) 311 | 312 | features = collections.OrderedDict() 313 | 314 | features["input_ids_vae_encoder"] = create_int_feature(feature.input_ids_vae_encoder) 315 | features["input_masks_vae_encoder"] = create_int_feature(feature.input_masks_vae_encoder) 316 | features["segment_ids_vae_encoder"] = create_int_feature(feature.segment_ids_vae_encoder) 317 | features["input_ids_vae_decoder_enc"] = create_int_feature(feature.input_ids_vae_decoder_enc) 318 | features["input_ids_vae_decoder_dec"] = create_int_feature(feature.input_ids_vae_decoder_dec) 319 | features["output_ids_vae_decoder_dec"] = create_int_feature(feature.output_ids_vae_decoder_dec) 320 | 321 | features["sent1"] = create_bytes_feature(feature.sent1) 322 | features["sent2"] = create_bytes_feature(feature.sent2) 323 | 324 | 325 | features["is_real_example"] = create_int_feature( [int(feature.is_real_example)] ) 326 | 327 | 328 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 329 | writer.write(tf_example.SerializeToString()) 330 | writer.close() 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | def file_based_input_fn_builder(input_file, 341 | maxlen_vae_Encoder, 342 | maxlen_vae_Decoder_en, 343 | maxlen_vae_Decoder_de, 344 | is_training): 345 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 346 | 347 | name_to_features = { 348 | "input_ids_vae_encoder": tf.FixedLenFeature([maxlen_vae_Encoder], tf.int64), 349 | "input_masks_vae_encoder": tf.FixedLenFeature([maxlen_vae_Encoder], tf.int64), 350 | "segment_ids_vae_encoder": tf.FixedLenFeature([maxlen_vae_Encoder], tf.int64), 351 | "input_ids_vae_decoder_enc": tf.FixedLenFeature([maxlen_vae_Decoder_en], tf.int64), 352 | "input_ids_vae_decoder_dec": tf.FixedLenFeature([maxlen_vae_Decoder_de], tf.int64), 353 | "output_ids_vae_decoder_dec": tf.FixedLenFeature([maxlen_vae_Decoder_de], tf.int64), 354 | "sent1": tf.FixedLenFeature([], tf.string), 355 | "sent2": tf.FixedLenFeature([], tf.string), 356 | "is_real_example": tf.FixedLenFeature([], tf.int64), 357 | } 358 | 359 | def _decode_record(record, name_to_features): 360 | """Decodes a record to a TensorFlow example.""" 361 | example = tf.parse_single_example(record, name_to_features) 362 | 363 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 364 | # So cast all int64 to int32. 365 | for name in list(example.keys()): 366 | t = example[name] 367 | if t.dtype == tf.int64: 368 | t = tf.to_int32(t) 369 | example[name] = t 370 | 371 | return example 372 | 373 | def input_fn(batch_size): 374 | """The actual input function.""" 375 | d = tf.data.TFRecordDataset(input_file) 376 | if is_training: 377 | d = d.repeat() 378 | d = d.shuffle(buffer_size=100) 379 | 380 | d = d.apply( 381 | tf.contrib.data.map_and_batch( 382 | lambda record: _decode_record(record, name_to_features), 383 | batch_size=batch_size)) 384 | 385 | return d 386 | 387 | return input_fn 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | def saveForTfRecord(fpath, 399 | maxlen_vae_Encoder, 400 | maxlen_vae_Decoder_en, 401 | maxlen_vae_Decoder_de, 402 | vocab_fpath, 403 | output_file, 404 | is_test=False): 405 | ''' 406 | 407 | :param fpath: 408 | :param maxlen_vae_Encoder: 409 | :param maxlen_vae_Decoder_en: 410 | :param maxlen_vae_Decoder_de: 411 | :param vocab_fpath: 412 | :param output_dir: 这个是输出保存路径 413 | :param is_test: 是否是测试数据 414 | :return: 415 | ''' 416 | if is_test: 417 | # 是测试流程 418 | sents1, sents2 = load_data_for_test(fpath) 419 | else: 420 | sents1, sents2 = load_data_for_train_or_eval(fpath) 421 | # sents1, sents2 = load_data(fpath1, fpath2, maxlen1, maxlen2) 422 | 423 | print("读取完成") 424 | 425 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_fpath, do_lower_case=True) 426 | # output_file = os.path.join(output_dir, "train.tf_record") 427 | file_based_convert_examples_to_features(sents1, 428 | sents2, 429 | maxlen_vae_Encoder, 430 | maxlen_vae_Decoder_en, 431 | maxlen_vae_Decoder_de, 432 | tokenizer, 433 | output_file) 434 | 435 | 436 | 437 | 438 | 439 | 440 | def get_batch_for_train_or_dev_or_test( fpath, 441 | maxlen_vae_Encoder, 442 | maxlen_vae_Decoder_en, 443 | maxlen_vae_Decoder_de, 444 | batch_size, 445 | input_file, 446 | is_training, 447 | is_test): 448 | ''' 449 | 450 | Returns 451 | batches 452 | num_batches: number of mini-batches 453 | num_samples 454 | ''' 455 | 456 | if is_test: 457 | # 是测试流程 458 | sents1, sents2 = load_data_for_test(fpath) 459 | else: 460 | sents1, sents2 = load_data_for_train_or_eval(fpath) 461 | # sents1, sents2 = load_data(fpath1, fpath2, maxlen1, maxlen2) 462 | 463 | input_fn=file_based_input_fn_builder(input_file, 464 | maxlen_vae_Encoder, 465 | maxlen_vae_Decoder_en, 466 | maxlen_vae_Decoder_de, 467 | is_training) 468 | 469 | batches=input_fn(batch_size) 470 | num_batches = calc_num_batches(len(sents1), batch_size) 471 | return batches, num_batches, len(sents1) 472 | 473 | 474 | 475 | 476 | 477 | 478 | -------------------------------------------------------------------------------- /bert/bert_model_for_PAGE.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import numpy as np 27 | import six 28 | import tensorflow as tf 29 | 30 | 31 | class BertConfig(object): 32 | """Configuration for `BertModel`.""" 33 | 34 | def __init__(self, 35 | vocab_size, 36 | hidden_size=768, 37 | num_hidden_layers=12, 38 | num_attention_heads=12, 39 | intermediate_size=3072, 40 | hidden_act="gelu", 41 | hidden_dropout_prob=0.1, 42 | attention_probs_dropout_prob=0.1, 43 | max_position_embeddings=512, 44 | type_vocab_size=16, 45 | initializer_range=0.02): 46 | """Constructs BertConfig. 47 | 48 | Args: 49 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 50 | hidden_size: Size of the encoder layers and the pooler layer. 51 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 52 | num_attention_heads: Number of attention heads for each attention layer in 53 | the Transformer encoder. 54 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 55 | layer in the Transformer encoder. 56 | hidden_act: The non-linear activation function (function or string) in the 57 | encoder and pooler. 58 | hidden_dropout_prob: The dropout probability for all fully connected 59 | layers in the embeddings, encoder, and pooler. 60 | attention_probs_dropout_prob: The dropout ratio for the attention 61 | probabilities. 62 | max_position_embeddings: The maximum sequence length that this model might 63 | ever be used with. Typically set this to something large just in case 64 | (e.g., 512 or 1024 or 2048). 65 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 66 | `BertModel`. 67 | initializer_range: The stdev of the truncated_normal_initializer for 68 | initializing all weight matrices. 69 | """ 70 | self.vocab_size = vocab_size 71 | self.hidden_size = hidden_size 72 | self.num_hidden_layers = num_hidden_layers 73 | self.num_attention_heads = num_attention_heads 74 | self.hidden_act = hidden_act 75 | self.intermediate_size = intermediate_size 76 | self.hidden_dropout_prob = hidden_dropout_prob 77 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 78 | self.max_position_embeddings = max_position_embeddings 79 | self.type_vocab_size = type_vocab_size 80 | self.initializer_range = initializer_range 81 | 82 | @classmethod 83 | def from_dict(cls, json_object): 84 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 85 | config = BertConfig(vocab_size=None) 86 | for (key, value) in six.iteritems(json_object): 87 | config.__dict__[key] = value 88 | return config 89 | 90 | @classmethod 91 | def from_json_file(cls, json_file): 92 | """Constructs a `BertConfig` from a json file of parameters.""" 93 | with tf.gfile.GFile(json_file, "r") as reader: 94 | text = reader.read() 95 | return cls.from_dict(json.loads(text)) 96 | 97 | def to_dict(self): 98 | """Serializes this instance to a Python dictionary.""" 99 | output = copy.deepcopy(self.__dict__) 100 | return output 101 | 102 | def to_json_string(self): 103 | """Serializes this instance to a JSON string.""" 104 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 105 | 106 | 107 | class BertModel(object): 108 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 109 | 110 | Example usage: 111 | 112 | ```python 113 | # Already been converted into WordPiece token ids 114 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 115 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 116 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 117 | 118 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 119 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 120 | 121 | model = modeling.BertModel(config=config, is_training=True, 122 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 123 | 124 | label_embeddings = tf.get_variable(...) 125 | pooled_output = model.get_pooled_output() 126 | logits = tf.matmul(pooled_output, label_embeddings) 127 | ... 128 | ``` 129 | """ 130 | 131 | def __init__(self, 132 | config, 133 | is_training, 134 | input_ids, 135 | input_mask=None, 136 | token_type_ids=None, 137 | use_one_hot_embeddings=False, 138 | scope=None): 139 | """Constructor for BertModel. 140 | 141 | Args: 142 | config: `BertConfig` instance. 143 | is_training: bool. true for training model, false for eval model. Controls 144 | whether dropout will be applied. 145 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 146 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 147 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 148 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 149 | embeddings or tf.embedding_lookup() for the word embeddings. 150 | scope: (optional) variable scope. Defaults to "bert". 151 | 152 | Raises: 153 | ValueError: The config is invalid or one of the input tensor shapes 154 | is invalid. 155 | """ 156 | config = copy.deepcopy(config) 157 | if not is_training: 158 | config.hidden_dropout_prob = 0.0 159 | config.attention_probs_dropout_prob = 0.0 160 | 161 | input_shape = get_shape_list(input_ids, expected_rank=2) 162 | batch_size = input_shape[0] 163 | seq_length = input_shape[1] 164 | 165 | if input_mask is None: 166 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 167 | 168 | if token_type_ids is None: 169 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 170 | 171 | with tf.variable_scope(scope, default_name="bert"): 172 | with tf.variable_scope("embeddings"): 173 | # Perform embedding lookup on the word ids. 174 | (self.embedding_output, self.embedding_table) = embedding_lookup( 175 | input_ids=input_ids, 176 | vocab_size=config.vocab_size, 177 | embedding_size=config.hidden_size, 178 | initializer_range=config.initializer_range, 179 | word_embedding_name="word_embeddings", 180 | use_one_hot_embeddings=use_one_hot_embeddings) 181 | 182 | # Add positional embeddings and token type embeddings, then layer 183 | # normalize and perform dropout. 184 | self.embedding_output,self.full_position_embeddings = embedding_postprocessor( 185 | input_tensor=self.embedding_output, 186 | use_token_type=True, 187 | token_type_ids=token_type_ids, 188 | token_type_vocab_size=config.type_vocab_size, 189 | token_type_embedding_name="token_type_embeddings", 190 | use_position_embeddings=True, 191 | position_embedding_name="position_embeddings", 192 | initializer_range=config.initializer_range, 193 | max_position_embeddings=config.max_position_embeddings, 194 | dropout_prob=config.hidden_dropout_prob) 195 | 196 | with tf.variable_scope("encoder"): 197 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 198 | # mask of shape [batch_size, seq_length, seq_length] which is used 199 | # for the attention scores. 200 | attention_mask = create_attention_mask_from_input_mask( 201 | input_ids, input_mask) 202 | 203 | # Run the stacked transformer. 204 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 205 | self.all_encoder_layers = transformer_model( 206 | input_tensor=self.embedding_output, 207 | attention_mask=attention_mask, 208 | hidden_size=config.hidden_size, 209 | num_hidden_layers=config.num_hidden_layers, 210 | num_attention_heads=config.num_attention_heads, 211 | intermediate_size=config.intermediate_size, 212 | intermediate_act_fn=get_activation(config.hidden_act), 213 | hidden_dropout_prob=config.hidden_dropout_prob, 214 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 215 | initializer_range=config.initializer_range, 216 | do_return_all_layers=True) 217 | 218 | self.sequence_output = self.all_encoder_layers[-1] 219 | # The "pooler" converts the encoded sequence tensor of shape 220 | # [batch_size, seq_length, hidden_size] to a tensor of shape 221 | # [batch_size, hidden_size]. This is necessary for segment-level 222 | # (or segment-pair-level) classification tasks where we need a fixed 223 | # dimensional representation of the segment. 224 | with tf.variable_scope("pooler"): 225 | # We "pool" the model by simply taking the hidden state corresponding 226 | # to the first token. We assume that this has been pre-trained 227 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 228 | self.pooled_output = tf.layers.dense( 229 | first_token_tensor, 230 | config.hidden_size, 231 | activation=tf.tanh, 232 | kernel_initializer=create_initializer(config.initializer_range)) 233 | 234 | def get_pooled_output(self): 235 | return self.pooled_output 236 | 237 | def get_sequence_output(self): 238 | """Gets final hidden layer of encoder. 239 | 240 | Returns: 241 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 242 | to the final hidden of the transformer encoder. 243 | """ 244 | return self.sequence_output 245 | 246 | def get_all_encoder_layers(self): 247 | return self.all_encoder_layers 248 | 249 | def get_embedding_output(self): 250 | """Gets output of the embedding lookup (i.e., input to the transformer). 251 | 252 | Returns: 253 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 254 | to the output of the embedding layer, after summing the word 255 | embeddings with the positional embeddings and the token type embeddings, 256 | then performing layer normalization. This is the input to the transformer. 257 | """ 258 | return self.embedding_output 259 | 260 | def get_embedding_table(self): 261 | return self.embedding_table 262 | 263 | def get_full_position_embeddings(self): 264 | return self.full_position_embeddings 265 | 266 | def gelu(x): 267 | """Gaussian Error Linear Unit. 268 | 269 | This is a smoother version of the RELU. 270 | Original paper: https://arxiv.org/abs/1606.08415 271 | Args: 272 | x: float Tensor to perform activation. 273 | 274 | Returns: 275 | `x` with the GELU activation applied. 276 | """ 277 | cdf = 0.5 * (1.0 + tf.tanh( 278 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) 279 | return x * cdf 280 | 281 | 282 | def get_activation(activation_string): 283 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 284 | 285 | Args: 286 | activation_string: String name of the activation function. 287 | 288 | Returns: 289 | A Python function corresponding to the activation function. If 290 | `activation_string` is None, empty, or "linear", this will return None. 291 | If `activation_string` is not a string, it will return `activation_string`. 292 | 293 | Raises: 294 | ValueError: The `activation_string` does not correspond to a known 295 | activation. 296 | """ 297 | 298 | # We assume that anything that"s not a string is already an activation 299 | # function, so we just return it. 300 | if not isinstance(activation_string, six.string_types): 301 | return activation_string 302 | 303 | if not activation_string: 304 | return None 305 | 306 | act = activation_string.lower() 307 | if act == "linear": 308 | return None 309 | elif act == "relu": 310 | return tf.nn.relu 311 | elif act == "gelu": 312 | return gelu 313 | elif act == "tanh": 314 | return tf.tanh 315 | else: 316 | raise ValueError("Unsupported activation: %s" % act) 317 | 318 | 319 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 320 | """Compute the union of the current variables and checkpoint variables.""" 321 | assignment_map = {} 322 | initialized_variable_names = {} 323 | 324 | name_to_variable = collections.OrderedDict() 325 | for var in tvars: 326 | name = var.name 327 | m = re.match("^(.*):\\d+$", name) 328 | if m is not None: 329 | name = m.group(1) 330 | name_to_variable[name] = var 331 | 332 | init_vars = tf.train.list_variables(init_checkpoint) 333 | 334 | assignment_map = collections.OrderedDict() 335 | for x in init_vars: 336 | (name, var) = (x[0], x[1]) 337 | if name not in name_to_variable: 338 | continue 339 | assignment_map[name] = name 340 | initialized_variable_names[name] = 1 341 | initialized_variable_names[name + ":0"] = 1 342 | 343 | return (assignment_map, initialized_variable_names) 344 | 345 | 346 | def dropout(input_tensor, dropout_prob): 347 | """Perform dropout. 348 | 349 | Args: 350 | input_tensor: float Tensor. 351 | dropout_prob: Python float. The probability of dropping out a value (NOT of 352 | *keeping* a dimension as in `tf.nn.dropout`). 353 | 354 | Returns: 355 | A version of `input_tensor` with dropout applied. 356 | """ 357 | if dropout_prob is None or dropout_prob == 0.0: 358 | return input_tensor 359 | 360 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 361 | return output 362 | 363 | 364 | def layer_norm(input_tensor, name=None): 365 | """Run layer normalization on the last dimension of the tensor.""" 366 | return tf.contrib.layers.layer_norm( 367 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 368 | 369 | 370 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 371 | """Runs layer normalization followed by dropout.""" 372 | output_tensor = layer_norm(input_tensor, name) 373 | output_tensor = dropout(output_tensor, dropout_prob) 374 | return output_tensor 375 | 376 | 377 | def create_initializer(initializer_range=0.02): 378 | """Creates a `truncated_normal_initializer` with the given range.""" 379 | return tf.truncated_normal_initializer(stddev=initializer_range) 380 | 381 | 382 | def embedding_lookup(input_ids, 383 | vocab_size, 384 | embedding_size=128, 385 | initializer_range=0.02, 386 | word_embedding_name="word_embeddings", 387 | use_one_hot_embeddings=False): 388 | """Looks up words embeddings for id tensor. 389 | 390 | Args: 391 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 392 | ids. 393 | vocab_size: int. Size of the embedding vocabulary. 394 | embedding_size: int. Width of the word embeddings. 395 | initializer_range: float. Embedding initialization range. 396 | word_embedding_name: string. Name of the embedding table. 397 | use_one_hot_embeddings: bool. If True, use one-hot method for word 398 | embeddings. If False, use `tf.gather()`. 399 | 400 | Returns: 401 | float Tensor of shape [batch_size, seq_length, embedding_size]. 402 | """ 403 | # This function assumes that the input is of shape [batch_size, seq_length, 404 | # num_inputs]. 405 | # 406 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 407 | # reshape to [batch_size, seq_length, 1]. 408 | if input_ids.shape.ndims == 2: 409 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 410 | 411 | embedding_table = tf.get_variable( 412 | name=word_embedding_name, 413 | shape=[vocab_size, embedding_size], 414 | initializer=create_initializer(initializer_range)) 415 | 416 | 417 | flat_input_ids = tf.reshape(input_ids, [-1]) 418 | if use_one_hot_embeddings: 419 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 420 | output = tf.matmul(one_hot_input_ids, embedding_table) 421 | else: 422 | output = tf.gather(embedding_table, flat_input_ids) 423 | 424 | input_shape = get_shape_list(input_ids) 425 | 426 | output = tf.reshape(output, 427 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 428 | return (output, embedding_table) 429 | 430 | 431 | def embedding_postprocessor(input_tensor, 432 | use_token_type=False, 433 | token_type_ids=None, 434 | token_type_vocab_size=16, 435 | token_type_embedding_name="token_type_embeddings", 436 | use_position_embeddings=True, 437 | position_embedding_name="position_embeddings", 438 | initializer_range=0.02, 439 | max_position_embeddings=512, 440 | dropout_prob=0.1): 441 | """Performs various post-processing on a word embedding tensor. 442 | 443 | Args: 444 | input_tensor: float Tensor of shape [batch_size, seq_length, 445 | embedding_size]. 446 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 447 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 448 | Must be specified if `use_token_type` is True. 449 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 450 | token_type_embedding_name: string. The name of the embedding table variable 451 | for token type ids. 452 | use_position_embeddings: bool. Whether to add position embeddings for the 453 | position of each token in the sequence. 454 | position_embedding_name: string. The name of the embedding table variable 455 | for positional embeddings. 456 | initializer_range: float. Range of the weight initialization. 457 | max_position_embeddings: int. Maximum sequence length that might ever be 458 | used with this model. This can be longer than the sequence length of 459 | input_tensor, but cannot be shorter. 460 | dropout_prob: float. Dropout probability applied to the final output tensor. 461 | 462 | Returns: 463 | float tensor with same shape as `input_tensor`. 464 | 465 | Raises: 466 | ValueError: One of the tensor shapes or input values is invalid. 467 | """ 468 | input_shape = get_shape_list(input_tensor, expected_rank=3) 469 | batch_size = input_shape[0] 470 | seq_length = input_shape[1] 471 | width = input_shape[2] 472 | 473 | output = input_tensor 474 | 475 | if use_token_type: 476 | if token_type_ids is None: 477 | raise ValueError("`token_type_ids` must be specified if" 478 | "`use_token_type` is True.") 479 | token_type_table = tf.get_variable( 480 | name=token_type_embedding_name, 481 | shape=[token_type_vocab_size, width], 482 | initializer=create_initializer(initializer_range)) 483 | # This vocab will be small so we always do one-hot here, since it is always 484 | # faster for a small vocabulary. 485 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 486 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 487 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 488 | token_type_embeddings = tf.reshape(token_type_embeddings, 489 | [batch_size, seq_length, width]) 490 | output += token_type_embeddings 491 | 492 | if use_position_embeddings: 493 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 494 | with tf.control_dependencies([assert_op]): 495 | full_position_embeddings = tf.get_variable( 496 | name=position_embedding_name, 497 | shape=[max_position_embeddings, width], 498 | initializer=create_initializer(initializer_range)) 499 | 500 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 501 | [seq_length, -1]) 502 | num_dims = len(output.shape.as_list()) 503 | 504 | position_broadcast_shape = [] 505 | for _ in range(num_dims - 2): 506 | position_broadcast_shape.append(1) 507 | position_broadcast_shape.extend([seq_length, width]) 508 | position_embeddings = tf.reshape(position_embeddings, 509 | position_broadcast_shape) 510 | output += position_embeddings 511 | 512 | output = layer_norm_and_dropout(output, dropout_prob) 513 | return output,full_position_embeddings 514 | 515 | 516 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 517 | """Create 3D attention mask from a 2D tensor mask. 518 | 519 | Args: 520 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 521 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 522 | 523 | Returns: 524 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 525 | """ 526 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 527 | batch_size = from_shape[0] 528 | from_seq_length = from_shape[1] 529 | 530 | to_shape = get_shape_list(to_mask, expected_rank=2) 531 | to_seq_length = to_shape[1] 532 | 533 | to_mask = tf.cast( 534 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 535 | 536 | # We don't assume that `from_tensor` is a mask (although it could be). We 537 | # don't actually care if we attend *from* padding tokens (only *to* padding) 538 | # tokens so we create a tensor of all ones. 539 | # 540 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 541 | broadcast_ones = tf.ones( 542 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 543 | 544 | # Here we broadcast along two dimensions to create the mask. 545 | mask = broadcast_ones * to_mask 546 | 547 | return mask 548 | 549 | 550 | def attention_layer(from_tensor, 551 | to_tensor, 552 | attention_mask=None, 553 | num_attention_heads=1, 554 | size_per_head=512, 555 | query_act=None, 556 | key_act=None, 557 | value_act=None, 558 | attention_probs_dropout_prob=0.0, 559 | initializer_range=0.02, 560 | do_return_2d_tensor=False, 561 | batch_size=None, 562 | from_seq_length=None, 563 | to_seq_length=None): 564 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 565 | 566 | This is an implementation of multi-headed attention based on "Attention 567 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 568 | this is self-attention. Each timestep in `from_tensor` attends to the 569 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 570 | 571 | This function first projects `from_tensor` into a "query" tensor and 572 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 573 | of tensors of length `num_attention_heads`, where each tensor is of shape 574 | [batch_size, seq_length, size_per_head]. 575 | 576 | Then, the query and key tensors are dot-producted and scaled. These are 577 | softmaxed to obtain attention probabilities. The value tensors are then 578 | interpolated by these probabilities, then concatenated back to a single 579 | tensor and returned. 580 | 581 | In practice, the multi-headed attention are done with transposes and 582 | reshapes rather than actual separate tensors. 583 | 584 | Args: 585 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 586 | from_width]. 587 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 588 | attention_mask: (optional) int32 Tensor of shape [batch_size, 589 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 590 | attention scores will effectively be set to -infinity for any positions in 591 | the mask that are 0, and will be unchanged for positions that are 1. 592 | num_attention_heads: int. Number of attention heads. 593 | size_per_head: int. Size of each attention head. 594 | query_act: (optional) Activation function for the query transform. 595 | key_act: (optional) Activation function for the key transform. 596 | value_act: (optional) Activation function for the value transform. 597 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 598 | attention probabilities. 599 | initializer_range: float. Range of the weight initializer. 600 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 601 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 602 | output will be of shape [batch_size, from_seq_length, num_attention_heads 603 | * size_per_head]. 604 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 605 | of the 3D version of the `from_tensor` and `to_tensor`. 606 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 607 | of the 3D version of the `from_tensor`. 608 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 609 | of the 3D version of the `to_tensor`. 610 | 611 | Returns: 612 | float Tensor of shape [batch_size, from_seq_length, 613 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 614 | true, this will be of shape [batch_size * from_seq_length, 615 | num_attention_heads * size_per_head]). 616 | 617 | Raises: 618 | ValueError: Any of the arguments or tensor shapes are invalid. 619 | """ 620 | 621 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 622 | seq_length, width): 623 | output_tensor = tf.reshape( 624 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 625 | 626 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 627 | return output_tensor 628 | 629 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 630 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 631 | 632 | if len(from_shape) != len(to_shape): 633 | raise ValueError( 634 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 635 | 636 | if len(from_shape) == 3: 637 | batch_size = from_shape[0] 638 | from_seq_length = from_shape[1] 639 | to_seq_length = to_shape[1] 640 | elif len(from_shape) == 2: 641 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 642 | raise ValueError( 643 | "When passing in rank 2 tensors to attention_layer, the values " 644 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 645 | "must all be specified.") 646 | 647 | # Scalar dimensions referenced here: 648 | # B = batch size (number of sequences) 649 | # F = `from_tensor` sequence length 650 | # T = `to_tensor` sequence length 651 | # N = `num_attention_heads` 652 | # H = `size_per_head` 653 | 654 | from_tensor_2d = reshape_to_matrix(from_tensor) 655 | to_tensor_2d = reshape_to_matrix(to_tensor) 656 | 657 | # `query_layer` = [B*F, N*H] 658 | query_layer = tf.layers.dense( 659 | from_tensor_2d, 660 | num_attention_heads * size_per_head, 661 | activation=query_act, 662 | name="query", 663 | kernel_initializer=create_initializer(initializer_range)) 664 | 665 | # `key_layer` = [B*T, N*H] 666 | key_layer = tf.layers.dense( 667 | to_tensor_2d, 668 | num_attention_heads * size_per_head, 669 | activation=key_act, 670 | name="key", 671 | kernel_initializer=create_initializer(initializer_range)) 672 | 673 | # `value_layer` = [B*T, N*H] 674 | value_layer = tf.layers.dense( 675 | to_tensor_2d, 676 | num_attention_heads * size_per_head, 677 | activation=value_act, 678 | name="value", 679 | kernel_initializer=create_initializer(initializer_range)) 680 | 681 | # `query_layer` = [B, N, F, H] 682 | query_layer = transpose_for_scores(query_layer, batch_size, 683 | num_attention_heads, from_seq_length, 684 | size_per_head) 685 | 686 | # `key_layer` = [B, N, T, H] 687 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 688 | to_seq_length, size_per_head) 689 | 690 | # Take the dot product between "query" and "key" to get the raw 691 | # attention scores. 692 | # `attention_scores` = [B, N, F, T] 693 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 694 | attention_scores = tf.multiply(attention_scores, 695 | 1.0 / math.sqrt(float(size_per_head))) 696 | 697 | if attention_mask is not None: 698 | # `attention_mask` = [B, 1, F, T] 699 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 700 | 701 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 702 | # masked positions, this operation will create a tensor which is 0.0 for 703 | # positions we want to attend and -10000.0 for masked positions. 704 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 705 | 706 | # Since we are adding it to the raw scores before the softmax, this is 707 | # effectively the same as removing these entirely. 708 | attention_scores += adder 709 | 710 | # Normalize the attention scores to probabilities. 711 | # `attention_probs` = [B, N, F, T] 712 | attention_probs = tf.nn.softmax(attention_scores) 713 | 714 | # This is actually dropping out entire tokens to attend to, which might 715 | # seem a bit unusual, but is taken from the original Transformer paper. 716 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 717 | 718 | # `value_layer` = [B, T, N, H] 719 | value_layer = tf.reshape( 720 | value_layer, 721 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 722 | 723 | # `value_layer` = [B, N, T, H] 724 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 725 | 726 | # `context_layer` = [B, N, F, H] 727 | context_layer = tf.matmul(attention_probs, value_layer) 728 | 729 | # `context_layer` = [B, F, N, H] 730 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 731 | 732 | if do_return_2d_tensor: 733 | # `context_layer` = [B*F, N*H] 734 | context_layer = tf.reshape( 735 | context_layer, 736 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 737 | else: 738 | # `context_layer` = [B, F, N*H] 739 | context_layer = tf.reshape( 740 | context_layer, 741 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 742 | 743 | return context_layer 744 | 745 | 746 | def transformer_model(input_tensor, 747 | attention_mask=None, 748 | hidden_size=768, 749 | num_hidden_layers=12, 750 | num_attention_heads=12, 751 | intermediate_size=3072, 752 | intermediate_act_fn=gelu, 753 | hidden_dropout_prob=0.1, 754 | attention_probs_dropout_prob=0.1, 755 | initializer_range=0.02, 756 | do_return_all_layers=False): 757 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 758 | 759 | This is almost an exact implementation of the original Transformer encoder. 760 | 761 | See the original paper: 762 | https://arxiv.org/abs/1706.03762 763 | 764 | Also see: 765 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 766 | 767 | Args: 768 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 769 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 770 | seq_length], with 1 for positions that can be attended to and 0 in 771 | positions that should not be. 772 | hidden_size: int. Hidden size of the Transformer. 773 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 774 | num_attention_heads: int. Number of attention heads in the Transformer. 775 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 776 | forward) layer. 777 | intermediate_act_fn: function. The non-linear activation function to apply 778 | to the output of the intermediate/feed-forward layer. 779 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 780 | attention_probs_dropout_prob: float. Dropout probability of the attention 781 | probabilities. 782 | initializer_range: float. Range of the initializer (stddev of truncated 783 | normal). 784 | do_return_all_layers: Whether to also return all layers or just the final 785 | layer. 786 | 787 | Returns: 788 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 789 | hidden layer of the Transformer. 790 | 791 | Raises: 792 | ValueError: A Tensor shape or parameter is invalid. 793 | """ 794 | if hidden_size % num_attention_heads != 0: 795 | raise ValueError( 796 | "The hidden size (%d) is not a multiple of the number of attention " 797 | "heads (%d)" % (hidden_size, num_attention_heads)) 798 | 799 | attention_head_size = int(hidden_size / num_attention_heads) 800 | input_shape = get_shape_list(input_tensor, expected_rank=3) 801 | batch_size = input_shape[0] 802 | seq_length = input_shape[1] 803 | input_width = input_shape[2] 804 | 805 | # The Transformer performs sum residuals on all layers so the input needs 806 | # to be the same as the hidden size. 807 | if input_width != hidden_size: 808 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 809 | (input_width, hidden_size)) 810 | 811 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 812 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 813 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 814 | # help the optimizer. 815 | prev_output = reshape_to_matrix(input_tensor) 816 | 817 | all_layer_outputs = [] 818 | for layer_idx in range(num_hidden_layers): 819 | with tf.variable_scope("layer_%d" % layer_idx): 820 | layer_input = prev_output 821 | 822 | with tf.variable_scope("attention"): 823 | attention_heads = [] 824 | with tf.variable_scope("self"): 825 | attention_head = attention_layer( 826 | from_tensor=layer_input, 827 | to_tensor=layer_input, 828 | attention_mask=attention_mask, 829 | num_attention_heads=num_attention_heads, 830 | size_per_head=attention_head_size, 831 | attention_probs_dropout_prob=attention_probs_dropout_prob, 832 | initializer_range=initializer_range, 833 | do_return_2d_tensor=True, 834 | batch_size=batch_size, 835 | from_seq_length=seq_length, 836 | to_seq_length=seq_length) 837 | attention_heads.append(attention_head) 838 | 839 | attention_output = None 840 | if len(attention_heads) == 1: 841 | attention_output = attention_heads[0] 842 | else: 843 | # In the case where we have other sequences, we just concatenate 844 | # them to the self-attention head before the projection. 845 | attention_output = tf.concat(attention_heads, axis=-1) 846 | 847 | # Run a linear projection of `hidden_size` then add a residual 848 | # with `layer_input`. 849 | with tf.variable_scope("output"): 850 | attention_output = tf.layers.dense( 851 | attention_output, 852 | hidden_size, 853 | kernel_initializer=create_initializer(initializer_range)) 854 | attention_output = dropout(attention_output, hidden_dropout_prob) 855 | attention_output = layer_norm(attention_output + layer_input) 856 | 857 | # The activation is only applied to the "intermediate" hidden layer. 858 | with tf.variable_scope("intermediate"): 859 | intermediate_output = tf.layers.dense( 860 | attention_output, 861 | intermediate_size, 862 | activation=intermediate_act_fn, 863 | kernel_initializer=create_initializer(initializer_range)) 864 | 865 | # Down-project back to `hidden_size` then add the residual. 866 | with tf.variable_scope("output"): 867 | layer_output = tf.layers.dense( 868 | intermediate_output, 869 | hidden_size, 870 | kernel_initializer=create_initializer(initializer_range)) 871 | layer_output = dropout(layer_output, hidden_dropout_prob) 872 | layer_output = layer_norm(layer_output + attention_output) 873 | prev_output = layer_output 874 | all_layer_outputs.append(layer_output) 875 | 876 | if do_return_all_layers: 877 | final_outputs = [] 878 | for layer_output in all_layer_outputs: 879 | final_output = reshape_from_matrix(layer_output, input_shape) 880 | final_outputs.append(final_output) 881 | return final_outputs 882 | else: 883 | final_output = reshape_from_matrix(prev_output, input_shape) 884 | return final_output 885 | 886 | 887 | def get_shape_list(tensor, expected_rank=None, name=None): 888 | """Returns a list of the shape of tensor, preferring static dimensions. 889 | 890 | Args: 891 | tensor: A tf.Tensor object to find the shape of. 892 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 893 | specified and the `tensor` has a different rank, and exception will be 894 | thrown. 895 | name: Optional name of the tensor for the error message. 896 | 897 | Returns: 898 | A list of dimensions of the shape of tensor. All static dimensions will 899 | be returned as python integers, and dynamic dimensions will be returned 900 | as tf.Tensor scalars. 901 | """ 902 | if name is None: 903 | name = tensor.name 904 | 905 | if expected_rank is not None: 906 | assert_rank(tensor, expected_rank, name) 907 | 908 | shape = tensor.shape.as_list() 909 | 910 | non_static_indexes = [] 911 | for (index, dim) in enumerate(shape): 912 | if dim is None: 913 | non_static_indexes.append(index) 914 | 915 | if not non_static_indexes: 916 | return shape 917 | 918 | dyn_shape = tf.shape(tensor) 919 | for index in non_static_indexes: 920 | shape[index] = dyn_shape[index] 921 | return shape 922 | 923 | 924 | def reshape_to_matrix(input_tensor): 925 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 926 | ndims = input_tensor.shape.ndims 927 | if ndims < 2: 928 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 929 | (input_tensor.shape)) 930 | if ndims == 2: 931 | return input_tensor 932 | 933 | width = input_tensor.shape[-1] 934 | output_tensor = tf.reshape(input_tensor, [-1, width]) 935 | return output_tensor 936 | 937 | 938 | def reshape_from_matrix(output_tensor, orig_shape_list): 939 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 940 | if len(orig_shape_list) == 2: 941 | return output_tensor 942 | 943 | output_shape = get_shape_list(output_tensor) 944 | 945 | orig_dims = orig_shape_list[0:-1] 946 | width = output_shape[-1] 947 | 948 | return tf.reshape(output_tensor, orig_dims + [width]) 949 | 950 | 951 | def assert_rank(tensor, expected_rank, name=None): 952 | """Raises an exception if the tensor rank is not of the expected rank. 953 | 954 | Args: 955 | tensor: A tf.Tensor to check the rank of. 956 | expected_rank: Python integer or list of integers, expected rank. 957 | name: Optional name of the tensor for the error message. 958 | 959 | Raises: 960 | ValueError: If the expected shape doesn't match the actual shape. 961 | """ 962 | if name is None: 963 | name = tensor.name 964 | 965 | expected_rank_dict = {} 966 | if isinstance(expected_rank, six.integer_types): 967 | expected_rank_dict[expected_rank] = True 968 | else: 969 | for x in expected_rank: 970 | expected_rank_dict[x] = True 971 | 972 | actual_rank = tensor.shape.ndims 973 | if actual_rank not in expected_rank_dict: 974 | scope_name = tf.get_variable_scope().name 975 | raise ValueError( 976 | "For the tensor `%s` in scope `%s`, the actual rank " 977 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 978 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 979 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 125 | 126 | 127 | 128 | 129 | maxlen_vae_Encoder 130 | maxlen_vae_Decoder_de 131 | sents2 132 | sents1 133 | z 134 | conca 135 | ##um 136 | print 137 | InputExample 138 | output_ids_vae_decoder_dec 139 | _truncate_seq_pair 140 | load_data_for_train_or_eval 141 | tf.logging.info 142 | 143 | add_by_xhm_12341222e2 144 | trainable 145 | training 146 | init_checkpoint_PAGE 147 | PAGEdir 148 | test evaluation 149 | d_model 150 | [PAD] 151 | 768 152 | wa 153 | multi 154 | 6 155 | vocab 156 | 96 157 | em 158 | 159 | 160 | ---xhm--- 161 | print 162 | 163 | 164 | 165 | 167 | 168 | 209 | 210 | 211 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 |