├── .gitignore ├── Camera_Ready-Harnessing_Pre_training_Models_with_Simple_Rules.pdf ├── LICENSE ├── README.md ├── evaluate ├── PINC │ ├── __init__.py │ └── pinc.py ├── bleu │ ├── __init__.py │ └── nltk_bleu.py ├── evaluate_em.py ├── evaluate_fr.py ├── formality │ ├── __init__.py │ ├── classifier_em.py │ └── classifier_fr.py ├── tokenizer │ ├── __init__.py │ └── tokenizer.py └── utils │ ├── __init__.py │ └── tools.py ├── gpt ├── config.py ├── download_model.py ├── main.py └── src │ ├── __init__.py │ ├── beamsearch.py │ ├── concat_fine_tuning.py │ ├── encoder.py │ ├── generate_unconditional_samples.py │ ├── gpt2.py │ ├── hierarchical_attention.py │ ├── interactive_conditional_samples.py │ ├── model.py │ ├── multi_gpu_training.py │ ├── sample.py │ ├── simple_finetune.py │ └── single_gpu_serving.py ├── gyafc_model_outputs ├── em_out │ ├── formal.gpt.cat.domain_cmb.ori_rule │ ├── formal.gpt.cat_no_share.ori_rule │ ├── formal.gpt.hie.ori_rule │ ├── formal.gpt.ori │ ├── formal.gpt.ori_rule.ens │ ├── formal.gpt.rule │ └── tmp.txt └── fr_out │ ├── formal.gpt.cat.domain_cmb.ori_rule │ ├── formal.gpt.cat_no_share.ori_rule │ ├── formal.gpt.hie.ori_rule │ ├── formal.gpt.ori │ ├── formal.gpt.ori_rule.ens │ ├── formal.gpt.rule │ └── tmp.txt ├── preprocess ├── __init__.py └── tokenize_corpus.py ├── requirements.txt ├── training_data └── ori │ └── Famliy_Relationships │ └── train │ └── tmp.txt └── utils ├── __init__.py ├── cat_files.py ├── common.py ├── embedding_api.py ├── file_api.py ├── layer.py └── multi_process_tokenizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /Camera_Ready-Harnessing_Pre_training_Models_with_Simple_Rules.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimth001/formality_emnlp19/bce1d08a5d4f0f5583dbe418d1c98092d053fd50/Camera_Ready-Harnessing_Pre_training_Models_with_Simple_Rules.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yunli Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Harnessing Pre-Trained Neural Networks with Rules for Formality Style Transfer 2 | 3 | ## 1. model outputs 4 | 5 | The outputs of our methods is under the "**gyafc_model_outputs**" directory. The "**em_out**" means the result for "**Entertainment&Music**". The "**fr_out**" means the result for "**Family&Relationships**". 6 | 7 | "**formal.gpt.ori**" is the result of "**GPT-Orig**" 8 | 9 | "**formal.gpt.rule**" is the result of "**GPT-Rule**" 10 | 11 | "**formal.gpt.ori_rules.ens**" is the result of "**GPT-Ensemble**" 12 | 13 | "**formal.gpt.cat_no_share.ori_rule**" is the result of "**GPT-CAT**" 14 | 15 | "**formal.gpt.hie.ori_rule**" is the result of "**GPT-HA**". 16 | 17 | "**formal.gpt.cat.domain_cmb.ori_rule**" is the result of "**GPT-CAT**" trained on domain combined data. 18 | 19 | ## 2. evaluation scripts 20 | 21 | We released our evaluation scripts for "**Formality**", "**BLEU**" and "**PINC**". Scripts for evluation are under the "**evaluate**" directory. Run "**evaluate_em.py**" or "**evaluate_fr.py**" can calculate the metrics for the model outputs("gyafc_model_output" should be under the "evaluate" directory). 22 | 23 | We didn't release our code for "**Meaning**" because we just use [BERT](https://github.com/google-research/bert) to fine-tune on STS. 24 | 25 | **References are not released directly because you should first get access to GYAFC dataset. See more in [Section 3.1](#contact).** 26 | 27 | ## 3. model scripts 28 | 29 | The code of our method is under "./gpt", "./utils" and "./preprocess". 30 | 31 | ### 3.1 training data
32 | 33 | The training data includes original GYAFC dataset and the outputs of a simple rule based system. To obtain our training data, you should first get the access to [GYAFC dataset](https://github.com/raosudha89/GYAFC-corpus). Once you have gained the access to GYAFC dataset, please forward the acknowledgment to rmwangyl@qq.com, then we will provide access to our training data and other materials for evaluation. 34 | 35 | ### 3.2 run 36 | 37 | Please download this repo directly, then put "training_data" under './' and "gyafc_model_outputs" under './evaluate/'. Run "main.py"(under './gpt/') to perform our methods. 38 | 39 | We suggest to use Pycharm to run this project. 40 | -------------------------------------------------------------------------------- /evaluate/PINC/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /evaluate/PINC/pinc.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | 3 | def get_n_gram_list(tokens,n_grams): 4 | a_list=[] 5 | len_tokens=len(tokens) 6 | for i in range(0, n_grams): 7 | for j in range(0,len_tokens-i): 8 | a_list.append(' '.join(tokens[j:j+i+1])) 9 | return a_list 10 | 11 | def cal_pinc_for_one_pair(src_tokens,gen_tokens,n_grams): 12 | src_n_gram_list=get_n_gram_list(src_tokens,n_grams=n_grams) 13 | gen_n_gram_list=get_n_gram_list(gen_tokens,n_grams=n_grams) 14 | counter=0 15 | for item in gen_n_gram_list: 16 | if item in src_n_gram_list: 17 | counter+=1 18 | if len(gen_n_gram_list)==0: 19 | return 0 20 | return 1-counter/len(gen_n_gram_list) 21 | 22 | def load_file_and_tokenize(file): 23 | sens=[] 24 | with open(file,'r',encoding='utf-8') as f: 25 | for line in f: 26 | sens.append(nltk.word_tokenize(line.strip())) 27 | return sens 28 | 29 | def cal_file_pinc(src_file,gen_file,n_grams): 30 | src_sens=load_file_and_tokenize(src_file) 31 | gen_sens=load_file_and_tokenize(gen_file) 32 | score=0 33 | assert len(src_sens)==len(gen_sens) 34 | ind=[i for i in range(0,len(src_sens))] 35 | for s,g,i in zip(src_sens,gen_sens,ind): 36 | score+=cal_pinc_for_one_pair(s,g,n_grams=n_grams) 37 | return score/len(src_sens) 38 | 39 | def evaluate_pinc(resources): 40 | def eval_factory(log_dict,re): 41 | src_file=re['input'] 42 | for key in re: 43 | log_dict[key]=cal_file_pinc(src_file,re[key],n_grams=4) 44 | eval_log={} 45 | for key in resources.keys(): 46 | eval_log[key]={} 47 | eval_factory(eval_log[key], resources[key]) 48 | return eval_log 49 | 50 | 51 | -------------------------------------------------------------------------------- /evaluate/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /evaluate/bleu/nltk_bleu.py: -------------------------------------------------------------------------------- 1 | from nltk.translate.bleu_score import corpus_bleu,sentence_bleu 2 | from nltk.translate.bleu_score import SmoothingFunction 3 | import nltk 4 | 5 | def bleu(reference_files_src_list,gen_file_src,ngrams=4,ignore_case=False): 6 | all_reference=[] 7 | for src in reference_files_src_list: 8 | with open(src,'r',encoding='utf-8') as f: 9 | one_reference=[] 10 | for line in f: 11 | if not ignore_case: 12 | one_reference.append(nltk.word_tokenize(line.strip())) 13 | else: 14 | one_reference.append(nltk.word_tokenize(line.strip().lower())) 15 | all_reference.append(one_reference) 16 | all_reference=[[all_reference[i][j] for i in range(0,len(all_reference))] for j in range(0,len(all_reference[0]))] 17 | gen=[] 18 | with open(gen_file_src,'r',encoding='utf-8') as f: 19 | for line in f: 20 | if not ignore_case: 21 | gen.append(nltk.word_tokenize(line.strip())) 22 | else: 23 | gen.append(nltk.word_tokenize(line.strip().lower())) 24 | weight=[1.0/ngrams]*ngrams 25 | print(len(gen)) 26 | b=corpus_bleu(all_reference,gen,weights=weight) 27 | return b 28 | 29 | 30 | def get_ref_src_list(path_prefix,ref_num=4): 31 | src_list=[] 32 | for i in range(0,ref_num): 33 | src_list.append(path_prefix+str(i)) 34 | return src_list 35 | 36 | 37 | def evaluate_bleu(resources): 38 | def eval_factory(log_dict,re): 39 | ref_list=[re['ref0'],re['ref1'],re['ref2'],re['ref3']] 40 | for key in re: 41 | print(key,len(re[key])) 42 | log_dict[key]=bleu(ref_list,re[key]) 43 | eval_log={} 44 | for key in resources.keys(): 45 | eval_log[key]={} 46 | eval_factory(eval_log[key], resources[key]) 47 | return eval_log -------------------------------------------------------------------------------- /evaluate/evaluate_em.py: -------------------------------------------------------------------------------- 1 | from evaluate.bleu.nltk_bleu import evaluate_bleu 2 | from evaluate.PINC.pinc import evaluate_pinc 3 | 4 | 5 | def get_default_resources(domain='fr',to_fm=True,to_inf=False): 6 | def factory(file_dict,in_dir,out_dir,ref_dir,in_flag='informal',target_flag='formal'): 7 | file_dict['input']=in_dir+in_flag 8 | file_dict['ref0']=ref_dir+target_flag+'.ref0' 9 | file_dict['ref1'] = ref_dir+target_flag+'.ref1' 10 | file_dict['ref2'] = ref_dir+target_flag+'.ref2' 11 | file_dict['ref3'] = ref_dir+target_flag+'.ref3' 12 | file_dict['gpt_rule'] = out_dir + target_flag + '.gpt.rule' 13 | file_dict['gpt_ori'] = out_dir + target_flag + '.gpt.ori' 14 | file_dict['gpt.hie.ori_rule'] = out_dir + target_flag + '.gpt.hieori_rule' 15 | file_dict['gpt.cat_no_share.ori_rule'] = out_dir + target_flag + '.gpt.cat_no_share.ori_rule' 16 | file_dict['gpt.ori_rule.ens'] = out_dir + target_flag + '.gpt.ori_rule.ens' 17 | #file_dict['FT'] = out_dir + target_flag + '.MultiTask-tag-style' 18 | file_dict['gpt.cat.domain_cmb.ori_rule'] = out_dir + target_flag + '.gpt.cat.domain_cmb.ori_rule' 19 | 20 | resources={} 21 | data_path='./gyafc_model_outputs/' 22 | if to_fm: 23 | resources['inf2fm']={} 24 | factory(resources['inf2fm'], data_path + domain + '_in/', 25 | data_path + domain + '_out/', data_path + domain + '_refs/', 26 | in_flag='informal', target_flag='formal') 27 | if to_inf: 28 | resources['fm2inf']={} 29 | factory(resources['fm2inf'], data_path + domain + '_in/', 30 | data_path + domain + '_out/', data_path + domain + '_refs/', 31 | in_flag='formal', target_flag='informal') 32 | return resources 33 | 34 | def print_dict(d): 35 | for key in d: 36 | print(key) 37 | nd=d[key] 38 | for k in nd: 39 | print(k,nd[k]) 40 | 41 | if __name__=='__main__': 42 | re=get_default_resources(domain='em') 43 | bleu_result=evaluate_bleu(re) 44 | #formality_result=evaluate_formality(re) 45 | #meaning_result=evaluate_meaning(re) 46 | #pinc_result = evaluate_pinc(re) 47 | print_dict(bleu_result) 48 | #print_dict(formality_result) 49 | #print_dict(meaning_result) 50 | #print_dict(pinc_result) -------------------------------------------------------------------------------- /evaluate/evaluate_fr.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | 6 | from evaluate.bleu.nltk_bleu import evaluate_bleu 7 | from evaluate.formality.classifier_fr import evaluate_formality 8 | from evaluate.PINC.pinc import evaluate_pinc 9 | def get_default_resources(domain='fr',to_fm=True,to_inf=False): 10 | def factory(file_dict,in_dir,out_dir,ref_dir,in_flag='informal',target_flag='formal'): 11 | '''file_dict['rule_based']=out_dir+target_flag+'.rule_based' 12 | file_dict['pbmt']=out_dir+target_flag+'.pbmt' 13 | file_dict['nmt_baseline']=out_dir+target_flag+'.nmt_baseline' 14 | file_dict['nmt_copy']=out_dir+target_flag+'.nmt_copy' 15 | file_dict['nmt_combined']=out_dir+target_flag+'.nmt_combined''' 16 | file_dict['input']=in_dir+in_flag 17 | file_dict['ref0']=ref_dir+target_flag+'.ref0' 18 | file_dict['ref1'] = ref_dir+target_flag+'.ref1' 19 | file_dict['ref2'] = ref_dir+target_flag+'.ref2' 20 | file_dict['ref3'] = ref_dir+target_flag+'.ref3' 21 | file_dict['gpt_rule'] = out_dir + target_flag + '.gpt.rule' 22 | file_dict['gpt_ori'] = out_dir + target_flag + '.gpt.ori' 23 | file_dict['gpt.hie.ori_rule'] = out_dir + target_flag + '.gpt.hieori_rule' 24 | file_dict['gpt.cat_no_share.ori_rule'] = out_dir + target_flag + '.gpt.cat_no_share.ori_rule' 25 | file_dict['gpt.ori_rule.ens'] = out_dir + target_flag + '.gpt.ori_rule.ens' 26 | #file_dict['FT']=out_dir+target_flag+'.MultiTask-tag-style' 27 | file_dict['gpt.cat.domain_cmb.ori_rule'] = out_dir + target_flag + '.gpt.cat.domain_cmb.ori_rule' 28 | resources={} 29 | data_path='./gyafc_model_outputs/' 30 | if to_fm: 31 | resources['inf2fm']={} 32 | factory(resources['inf2fm'], data_path + domain + '_in/', 33 | data_path + domain + '_out/', data_path + domain + '_refs/', 34 | in_flag='informal', target_flag='formal') 35 | if to_inf: 36 | resources['fm2inf']={} 37 | factory(resources['fm2inf'], data_path + domain + '_in/', 38 | data_path + domain + '_out/', data_path + domain + '_refs/', 39 | in_flag='formal', target_flag='informal') 40 | return resources 41 | 42 | def print_dict(d): 43 | for key in d: 44 | print(key) 45 | nd=d[key] 46 | for k in nd: 47 | print(k,nd[k]) 48 | 49 | if __name__=='__main__': 50 | re=get_default_resources(domain='fr') 51 | bleu_result=evaluate_bleu(re) 52 | pinc_result=evaluate_pinc(re) 53 | print_dict(bleu_result) 54 | print_dict(pinc_result) -------------------------------------------------------------------------------- /evaluate/formality/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /evaluate/formality/classifier_em.py: -------------------------------------------------------------------------------- 1 | import os 2 | import nltk 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.contrib import rnn 6 | from tensorflow.contrib.layers import xavier_initializer 7 | import time 8 | import pickle 9 | from datetime import timedelta 10 | from utils import embedding_api 11 | import random 12 | 13 | data_path='./new_exp_em/classifier/' 14 | class NNModel: 15 | def __init__(self, embedding, mode, model_path=None, vocab_hash=None): 16 | self.graph=tf.Graph() 17 | self.learning_rate=1e-3 18 | self.batch_size=256 19 | self.epoch_num=10 20 | self.dropout_keep_prob=1.0 21 | self.vocab_hash=vocab_hash 22 | self.embedding = embedding 23 | self.vocab_size = embedding.shape[0] 24 | with self.graph.as_default(): 25 | self.input_x = tf.placeholder(tf.int64, [None, None], name='input_x') 26 | self.x_sequence_len = tf.placeholder(tf.int64, [None], name='x_sequence_len') 27 | self.embedding_ph = tf.placeholder(tf.float32, [self.embedding.shape[0], self.embedding.shape[1]], 28 | name='embedding') 29 | self.input_y = tf.placeholder(tf.int64, [None], name='input_y') 30 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 31 | self.inputs_data=[self.input_x,self.x_sequence_len] 32 | self.save_dir='./new_exp_em/classifier/model/' 33 | self.print_per_batch=100 34 | self.require_improvement=6400 35 | if mode=='train': 36 | pass 37 | elif mode=='eval' or mode=='predict': 38 | pass 39 | else: 40 | assert False,'no this mode:'+str(mode) 41 | 42 | 43 | def __get_time_dif(self, start_time): 44 | end_time = time.time() 45 | time_dif = end_time - start_time 46 | return timedelta(seconds=int(round(time_dif))) 47 | 48 | def build_basic_rnn_model(self,rnn_unit_num=32,dense_layer_unit_num=8,class_num=2,reg_para=0.0): 49 | with self.graph.as_default(): 50 | with tf.device('/cpu:0'): 51 | word_embedding = tf.get_variable(name='embedding', shape=self.embedding.shape, dtype=tf.float32, 52 | trainable=True) 53 | self.embedding_init = word_embedding.assign(self.embedding_ph) 54 | x_embedding = tf.nn.embedding_lookup(word_embedding, self.input_x) 55 | with tf.name_scope("rnn"): 56 | fw_cell = rnn.LSTMCell(rnn_unit_num) 57 | bw_cell = rnn.LSTMCell(rnn_unit_num) 58 | out, state = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, x_embedding, self.x_sequence_len, 59 | dtype=tf.float32, 60 | initial_state_fw=None, initial_state_bw=None) 61 | # combined_out = tf.concat(out, axis=2) 62 | combined_state = tf.concat([state[0][1], state[1][1]], axis=1) 63 | with tf.name_scope("dense_layers"): 64 | fc = tf.layers.dense(combined_state, dense_layer_unit_num, name='fc1', activation=tf.nn.tanh, 65 | kernel_initializer=xavier_initializer(), 66 | kernel_regularizer=tf.contrib.layers.l2_regularizer(reg_para), 67 | bias_initializer=tf.zeros_initializer(), 68 | bias_regularizer=tf.contrib.layers.l2_regularizer(reg_para)) 69 | fc = tf.contrib.layers.dropout(fc, self.keep_prob) 70 | self.logits = tf.layers.dense(fc, class_num, name='fc2', activation=tf.nn.tanh, 71 | kernel_initializer=xavier_initializer(), 72 | kernel_regularizer=tf.contrib.layers.l2_regularizer(reg_para), 73 | bias_initializer=tf.zeros_initializer(), 74 | bias_regularizer=tf.contrib.layers.l2_regularizer(reg_para)) 75 | self.y_pred_value = tf.nn.softmax(self.logits) 76 | self.y_pred_class = tf.argmax(self.y_pred_value, 1) 77 | with tf.name_scope("optimize"): 78 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits + 1e-10, 79 | labels=self.input_y) 80 | self.loss = tf.reduce_mean(cross_entropy) 81 | self.optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss) 82 | with tf.name_scope("evaluate_metrics"): 83 | correct_pred = tf.equal(self.input_y, self.y_pred_class) 84 | self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 85 | 86 | def _evaluate_without_predict_result(self, input_data, target): 87 | batches = self.batch_iter([input_data], target, self.batch_size, shuffle=False) 88 | total_loss = 0.0 89 | total_acc = 0.0 90 | total_len = len(target) 91 | for batch_data, batch_target in batches: 92 | batch_len = len(batch_target) 93 | feed_dict = self.feed_data(inputs_data=batch_data, keep_prob=1.0, target=batch_target) 94 | loss, acc = self.session.run([self.loss, self.acc], feed_dict=feed_dict) 95 | total_loss += loss * batch_len 96 | total_acc += acc * batch_len 97 | return total_loss / total_len, total_acc / total_len 98 | 99 | 100 | def predict_prob(self, input_data, model_path): 101 | with self.graph.as_default(): 102 | saver = tf.train.Saver() 103 | config = tf.ConfigProto() 104 | config.gpu_options.allow_growth = True 105 | result = [] 106 | with tf.Session(config=config) as sess: 107 | self.session = sess 108 | saver.restore(sess, model_path) 109 | batches = self.batch_iter([input_data], target=None, batch_size=self.batch_size, shuffle=False) 110 | for batch_data in batches: 111 | feed_dict = self.feed_data(inputs_data=batch_data, keep_prob=1.0) 112 | pred = self.session.run([self.y_pred_value], feed_dict=feed_dict)[0] 113 | for d in pred: 114 | result.append(d) 115 | return result 116 | 117 | 118 | def feed_data(self, inputs_data, keep_prob, target=None): 119 | feed_dict = {} 120 | for i in range(len(self.inputs_data)): 121 | feed_dict[self.inputs_data[i]] = inputs_data[i] 122 | feed_dict[self.keep_prob] = keep_prob 123 | if not target is None: 124 | feed_dict[self.input_y] = target 125 | return feed_dict 126 | 127 | def evaluate(self,input_data,target,model_path): 128 | with self.graph.as_default(): 129 | saver = tf.train.Saver() 130 | config = tf.ConfigProto() 131 | config.gpu_options.allow_growth = True 132 | with tf.Session(config=config) as sess: 133 | self.session = sess 134 | saver.restore(sess, model_path) 135 | print(self._evaluate_without_predict_result(input_data, target)) 136 | 137 | def train_model(self,train_x,train_label,val_x,val_label,continue_train=False,previous_model_path=None): 138 | start_time = time.time() 139 | with self.graph.as_default(): 140 | saver = tf.train.Saver(max_to_keep=20) 141 | if not os.path.exists(self.save_dir): 142 | os.makedirs(self.save_dir) 143 | ############################################################## 144 | print(str(self.__get_time_dif(start_time)) + "trainning and evaluating...") 145 | total_batch = 0 146 | best_acc_val = 0.0 147 | last_improved = 0 148 | flag = False 149 | config = tf.ConfigProto() 150 | config.gpu_options.allow_growth = True 151 | with self.graph.as_default(): 152 | with tf.Session(config=config) as sess: 153 | self.session = sess 154 | if continue_train is False: 155 | sess.run(tf.global_variables_initializer()) 156 | sess.run(self.embedding_init, feed_dict={self.embedding_ph: self.embedding}) 157 | else: 158 | saver.restore(sess, previous_model_path) 159 | self.session.graph.finalize() 160 | for epoch in range(self.epoch_num): 161 | print("epoch:" + str(epoch + 1)) 162 | batch_train = self.batch_iter([train_x], train_label, batch_size=self.batch_size, shuffle=False) 163 | for batch_data, batch_target in batch_train: 164 | feed_dict = self.feed_data(inputs_data=batch_data, target=batch_target, 165 | keep_prob=self.dropout_keep_prob) 166 | s = self.session.run([self.optim], feed_dict=feed_dict) 167 | if total_batch == 0: 168 | saver.save(sess=self.session, save_path=self.save_dir + "model.ckpt") 169 | if total_batch > 0 and total_batch % self.print_per_batch == 0: 170 | feed_dict[self.keep_prob] = 1.0 171 | loss_train, acc_train = self.session.run([self.loss, self.acc], 172 | feed_dict=feed_dict) 173 | loss_val, acc_val = self._evaluate_without_predict_result(val_x, 174 | val_label) 175 | if acc_val > best_acc_val: 176 | best_acc_val = acc_val 177 | last_improved = total_batch 178 | saver.save(sess=self.session, 179 | save_path=self.save_dir + str(total_batch) + 'model.ckpt') 180 | improved_str = '*' 181 | else: 182 | improved_str = '' 183 | time_dif = self.__get_time_dif(start_time) 184 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 185 | + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}' 186 | print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, 187 | improved_str)) 188 | total_batch += 1 189 | if total_batch - last_improved > self.require_improvement: 190 | print("No optimization for a long time, auto-stopping...") 191 | flag = True 192 | break 193 | if flag: 194 | break 195 | self.session = None 196 | 197 | 198 | def batch_iter(self, input_data, target=None, batch_size=64,padding=0,shuffle=False): 199 | assert not input_data is None,"input_data is None" 200 | data_len = len(input_data[0]) 201 | num_batch = int((data_len - 1) / batch_size) + 1 202 | if shuffle: 203 | indices = np.random.permutation(np.arange(data_len)) 204 | else: 205 | indices = range(data_len) 206 | x_shuffle = [input_data[0][i] for i in indices] 207 | x_seq_len = [len(x_shuffle[i]) for i in range(len(x_shuffle))] 208 | if target is None: 209 | for i in range(num_batch): 210 | start_id = i * batch_size 211 | end_id = min((i + 1) * batch_size, data_len) 212 | batch_x = copy_list(x_shuffle[start_id:end_id]) 213 | batch_x_seq_len = x_seq_len[start_id:end_id] 214 | x_max_len = max(batch_x_seq_len) 215 | for list in batch_x: 216 | if len(list) < x_max_len: 217 | list += [padding] * (x_max_len - len(list)) 218 | yield [batch_x,batch_x_seq_len] 219 | else: 220 | y_shuffle = [target[i] for i in indices] 221 | for i in range(num_batch): 222 | start_id = i * batch_size 223 | end_id = min((i + 1) * batch_size, data_len) 224 | batch_y = y_shuffle[start_id:end_id] 225 | batch_x = copy_list(x_shuffle[start_id:end_id]) 226 | batch_x_seq_len = x_seq_len[start_id:end_id] 227 | x_max_len = max(batch_x_seq_len) 228 | for list in batch_x: 229 | if len(list) < x_max_len: 230 | list += [padding] * (x_max_len - len(list)) 231 | yield [batch_x,batch_x_seq_len],batch_y 232 | 233 | 234 | def get_file_src_list(parent_path, file_type='.txt'): 235 | files = os.listdir(parent_path) 236 | src_list = [] 237 | for file in files: 238 | absolute_path = os.path.join(parent_path, file) 239 | if os.path.isdir(absolute_path): 240 | src_list += get_file_src_list(absolute_path) 241 | elif file.endswith(file_type): 242 | src_list.append(absolute_path) 243 | return src_list 244 | 245 | 246 | def copy_list(list): 247 | new_list = [] 248 | for l in list: 249 | if type(l) == type([0]) or type(l) == type(np.array([0])): 250 | new_list.append(copy_list(l)) 251 | else: 252 | new_list.append(l) 253 | return new_list 254 | 255 | 256 | class Data: 257 | def __init__(self,x,y,ori_x=None): 258 | self.x=x 259 | self.y=y 260 | self.ori_x=ori_x 261 | def split(self): 262 | self.x=self.x.split(' ') 263 | def str2index(self,word_dict,with_unk=True): 264 | index=[] 265 | if with_unk: 266 | for s in self.x: 267 | if s in word_dict: 268 | index.append(word_dict[s]) 269 | else: 270 | index.append(len(word_dict)) 271 | else: 272 | for s in self.x: 273 | if s in word_dict: 274 | index.append(word_dict[s]) 275 | self.x=index 276 | 277 | 278 | def preprocess(informal_src_list,formal_src_list,embedding_path,output_path=None,shuffle=True): 279 | vectors,vocab_hash=embedding_api.load_word_embedding(embedding_path) 280 | all_data=[] 281 | for src in informal_src_list: 282 | with open(src,'r',encoding='utf-8') as f: 283 | for line in f: 284 | d=Data(nltk.word_tokenize(line.strip()), 0, line.strip()) 285 | d.str2index(vocab_hash,with_unk=False) 286 | all_data.append(d) 287 | for src in formal_src_list: 288 | with open(src,'r',encoding='utf-8') as f: 289 | for line in f: 290 | d=Data(nltk.word_tokenize(line.strip()), 1, line.strip()) 291 | d.str2index(vocab_hash,with_unk=False) 292 | all_data.append(d) 293 | if shuffle: 294 | random.shuffle(all_data) 295 | if output_path is not None: 296 | pickle.dump(all_data,open(output_path,'wb'),protocol=True) 297 | return all_data 298 | 299 | 300 | def all_prepro(): 301 | train_inf_src=['./new_exp_em/classifier/informal.train.tok.bpe.len_filtered'] 302 | train_fm_src = ['./new_exp_em/classifier/formal.train.tok.bpe.len_filtered'] 303 | val_inf_src = ['./new_exp_em/classifier/informal.val.tok.bpe'] 304 | val_fm_src = ['./new_exp_em/classifier/formal.val.tok.bpe'] 305 | test_inf_src = ['./new_exp_em/classifier/informal.test.tok.bpe'] 306 | test_fm_src = ['./new_exp_em/classifier/formal.test.tok.bpe'] 307 | embedding_path='./new_exp_em/embedding/embedding.bpe.big.txt' 308 | preprocess(train_inf_src,train_fm_src,embedding_path=embedding_path, 309 | output_path='./new_exp_em/classifier/train.pkl') 310 | preprocess(val_inf_src, val_fm_src, embedding_path=embedding_path, 311 | output_path='./new_exp_em/classifier/val.pkl') 312 | preprocess(test_inf_src, test_fm_src, embedding_path=embedding_path, 313 | output_path='./new_exp_em/classifier/test.pkl') 314 | 315 | def use_nn_model(): 316 | train = pickle.load(open('./new_exp_em/classifier/train.pkl', 'rb')) 317 | val = pickle.load(open('./new_exp_em/classifier/val.pkl', 'rb')) 318 | embedding_path = './new_exp_em/embedding/embedding.bpe.big.txt' 319 | embedding, vocab_hash = embedding_api.load_word_embedding(embedding_path) 320 | nn=NNModel(np.array(embedding),mode='train') 321 | nn.build_basic_rnn_model() 322 | nn.train_model([t.x for t in train],[t.y for t in train],[t.x for t in val],[t.y for t in val], 323 | continue_train=False, previous_model_path='./new_exp_em/classifier/model/990model.ckpt') 324 | 325 | def test(): 326 | test = pickle.load(open('./new_exp_em/classifier/test.pkl', 'rb')) 327 | embedding_path = './new_exp_em/embedding/corpus.fine_tune_embedding.epoch.10' 328 | embedding,vocab_hash = embedding_api.load_word_embedding(embedding_path) 329 | nn = NNModel(np.array(embedding),mode='eval') 330 | nn.build_basic_rnn_model() 331 | nn.evaluate([t.x for t in test],[t.y for t in test],model_path='') 332 | 333 | def predict(model_path,file_path='./new_exp_em/classifier/val.pkl',embedding_path='./new_exp_em/embedding/corpus.fine_tune_embedding.epoch.10'): 334 | test = pickle.load(open(file_path, 'rb')) 335 | embedding, vocab_hash = embedding_api.load_word_embedding(embedding_path) 336 | nn = NNModel(np.array(embedding),mode='predict') 337 | nn.batch_size=256 338 | nn.build_basic_rnn_model() 339 | result=nn.predict_prob([t.x for t in test], model_path=model_path) 340 | return test,result 341 | 342 | 343 | def evaluate_one_formality(input_file_path,is_inf): 344 | embedding_path = './new_exp_em/embedding/embedding.bpe.big.txt' 345 | embedding, vocab_hash = embedding_api.load_word_embedding(embedding_path) 346 | nn = NNModel(np.array(embedding), mode='eval') 347 | nn.batch_size = 128 348 | nn.build_basic_rnn_model() 349 | if is_inf: 350 | data = preprocess(informal_src_list=[input_file_path], formal_src_list=[], embedding_path=embedding_path, 351 | shuffle=False) 352 | else: 353 | data = preprocess(informal_src_list=[], formal_src_list=[input_file_path], embedding_path=embedding_path, 354 | shuffle=False) 355 | result = nn.predict_prob([t.x for t in data], model_path='./new_exp_em/classifier/model/600model.ckpt') 356 | score = 0 357 | if is_inf: 358 | for s in result: 359 | score += s[0] 360 | else: 361 | for s in result: 362 | score += s[1] 363 | print(score / len(data)) 364 | return score / len(data) 365 | 366 | def test_formality_score(files=None): 367 | if files is None: 368 | files = { 369 | 'informal': ['./new_exp_em/classifier/informal.test.tok.bpe'], 370 | 'formal': ['./new_exp_em/classifier/formal.test.tok.bpe'], 371 | 'rule_based': ['./data/Entertainment_Music/model_outputs/formal.rule_based.bpe'], 372 | 'pbmt': ['./data/Entertainment_Music/model_outputs/formal.pbmt.bpe'], 373 | 'nmt_baseline': ['./data/Entertainment_Music/model_outputs/formal.nmt_baseline.bpe'], 374 | 'nmt_copy': ['./data/Entertainment_Music/model_outputs/formal.nmt_copy.bpe'], 375 | 'nmt_combined': ['./data/Entertainment_Music/model_outputs/formal.nmt_combined.bpe'], 376 | } 377 | embedding_path = './new_exp_em/embedding/embedding.bpe.big.txt' 378 | embedding, vocab_hash = embedding_api.load_word_embedding(embedding_path) 379 | nn = NNModel(np.array(embedding),mode='eval') 380 | nn.batch_size = 128 381 | nn.build_basic_rnn_model() 382 | eval_log={} 383 | for key in files.keys(): 384 | if type(files[key])==type([]): 385 | fm_files=files[key]+'.bpe' 386 | else: 387 | fm_files=[files[key]+'.bpe'] 388 | data=preprocess(informal_src_list=[],formal_src_list=fm_files,embedding_path=embedding_path,shuffle=False) 389 | result = nn.predict_prob([t.x for t in data], model_path='./new_exp_em/classifier/model/600model.ckpt') 390 | score=0 391 | for s in result: 392 | score+=s[1] 393 | print(key,score/len(data)) 394 | eval_log[key]=score/len(data) 395 | return eval_log 396 | 397 | 398 | def cal_formality_score_for_each_sentence(output_dir,files=None): 399 | if files is None: 400 | files = { 401 | 'rule_based': ['./data/Family_Relationships/bpe_outputs/formal.rule_based.bpe'], 402 | 'pbmt': ['./data/Family_Relationships/bpe_outputs/formal.pbmt.bpe'], 403 | 'nmt_baseline': ['./data/Family_Relationships/bpe_outputs/formal.nmt_baseline.bpe'], 404 | 'nmt_copy': ['./data/Family_Relationships/bpe_outputs/formal.nmt_copy.bpe'], 405 | 'nmt_combined': ['./data/Family_Relationships/bpe_outputs/formal.nmt_combined.bpe'], 406 | } 407 | embedding_path = './new_exp_em/embedding/embedding.bpe.big.txt' 408 | embedding, vocab_hash = embedding_api.load_word_embedding(embedding_path) 409 | nn = NNModel(np.array(embedding),mode='eval') 410 | nn.batch_size = 128 411 | nn.build_basic_rnn_model() 412 | for key in files.keys(): 413 | data=preprocess(informal_src_list=[],formal_src_list=files[key],embedding_path=embedding_path,shuffle=False) 414 | result = nn.predict_prob([t.x for t in data], model_path='./new_exp_em/classifier/model/600model.ckpt') 415 | base_name=os.path.basename(files[key]) 416 | with open(os.path.join(output_dir,base_name+'.formality_score'),'w',encoding='utf-8') as fw: 417 | for r in result: 418 | fw.write(str(r[1])+'\n') 419 | 420 | 421 | def evaluate_formality(resources): 422 | eval_log={} 423 | for key in resources.keys(): 424 | eval_log[key] = test_formality_score(resources[key]) 425 | return eval_log -------------------------------------------------------------------------------- /evaluate/formality/classifier_fr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import nltk 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.contrib import rnn 6 | from tensorflow.contrib.layers import xavier_initializer 7 | import time 8 | import pickle 9 | from datetime import timedelta 10 | import utils.embedding_api as embedding_api 11 | import random 12 | 13 | data_path='./new_exp_fr/classifier/' 14 | class NNModel: 15 | def __init__(self, embedding, mode, model_path=None, vocab_hash=None): 16 | self.graph=tf.Graph() 17 | self.learning_rate=1e-3 18 | self.batch_size=256 19 | self.epoch_num=10 20 | self.dropout_keep_prob=1.0 21 | self.vocab_hash=vocab_hash 22 | self.embedding = embedding 23 | self.vocab_size = embedding.shape[0] 24 | with self.graph.as_default(): 25 | self.input_x = tf.placeholder(tf.int64, [None, None], name='input_x') 26 | self.x_sequence_len = tf.placeholder(tf.int64, [None], name='x_sequence_len') 27 | self.embedding_ph = tf.placeholder(tf.float32, [self.embedding.shape[0], self.embedding.shape[1]], 28 | name='embedding') 29 | self.input_y = tf.placeholder(tf.int64, [None], name='input_y') 30 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 31 | self.inputs_data=[self.input_x,self.x_sequence_len] 32 | self.save_dir='./new_exp_fr/classifier/model/' 33 | self.print_per_batch=100 34 | self.require_improvement=6400 35 | if mode=='train': 36 | pass 37 | elif mode=='eval' or mode=='predict': 38 | pass 39 | else: 40 | assert False,'no this mode:'+str(mode) 41 | 42 | 43 | def __get_time_dif(self, start_time): 44 | end_time = time.time() 45 | time_dif = end_time - start_time 46 | return timedelta(seconds=int(round(time_dif))) 47 | 48 | def build_basic_rnn_model(self,rnn_unit_num=32,dense_layer_unit_num=8,class_num=2,reg_para=0.0): 49 | with self.graph.as_default(): 50 | with tf.device('/cpu:0'): 51 | word_embedding = tf.get_variable(name='embedding', shape=self.embedding.shape, dtype=tf.float32, 52 | trainable=True) 53 | self.embedding_init = word_embedding.assign(self.embedding_ph) 54 | x_embedding = tf.nn.embedding_lookup(word_embedding, self.input_x) 55 | with tf.name_scope("rnn"): 56 | fw_cell = rnn.LSTMCell(rnn_unit_num) 57 | bw_cell = rnn.LSTMCell(rnn_unit_num) 58 | out, state = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, x_embedding, self.x_sequence_len, 59 | dtype=tf.float32, 60 | initial_state_fw=None, initial_state_bw=None) 61 | combined_state = tf.concat([state[0][1], state[1][1]], axis=1) 62 | with tf.name_scope("dense_layers"): 63 | fc = tf.layers.dense(combined_state, dense_layer_unit_num, name='fc1', activation=tf.nn.tanh, 64 | kernel_initializer=xavier_initializer(), 65 | kernel_regularizer=tf.contrib.layers.l2_regularizer(reg_para), 66 | bias_initializer=tf.zeros_initializer(), 67 | bias_regularizer=tf.contrib.layers.l2_regularizer(reg_para)) 68 | fc = tf.contrib.layers.dropout(fc, self.keep_prob) 69 | self.logits = tf.layers.dense(fc, class_num, name='fc2', activation=tf.nn.tanh, 70 | kernel_initializer=xavier_initializer(), 71 | kernel_regularizer=tf.contrib.layers.l2_regularizer(reg_para), 72 | bias_initializer=tf.zeros_initializer(), 73 | bias_regularizer=tf.contrib.layers.l2_regularizer(reg_para)) 74 | self.y_pred_value = tf.nn.softmax(self.logits) 75 | self.y_pred_class = tf.argmax(self.y_pred_value, 1) 76 | with tf.name_scope("optimize"): 77 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits + 1e-10, 78 | labels=self.input_y) 79 | self.loss = tf.reduce_mean(cross_entropy) 80 | self.optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss) 81 | with tf.name_scope("evaluate_metrics"): 82 | correct_pred = tf.equal(self.input_y, self.y_pred_class) 83 | self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 84 | 85 | def _evaluate_without_predict_result(self, input_data, target): 86 | batches = self.batch_iter([input_data], target, self.batch_size, shuffle=False) 87 | total_loss = 0.0 88 | total_acc = 0.0 89 | total_len = len(target) 90 | for batch_data, batch_target in batches: 91 | batch_len = len(batch_target) 92 | feed_dict = self.feed_data(inputs_data=batch_data, keep_prob=1.0, target=batch_target) 93 | loss, acc = self.session.run([self.loss, self.acc], feed_dict=feed_dict) 94 | total_loss += loss * batch_len 95 | total_acc += acc * batch_len 96 | return total_loss / total_len, total_acc / total_len 97 | 98 | 99 | def predict_prob(self, input_data, model_path): 100 | with self.graph.as_default(): 101 | saver = tf.train.Saver() 102 | config = tf.ConfigProto() 103 | config.gpu_options.allow_growth = True 104 | result = [] 105 | with tf.Session(config=config) as sess: 106 | self.session = sess 107 | saver.restore(sess, model_path) 108 | batches = self.batch_iter([input_data], target=None, batch_size=self.batch_size, shuffle=False) 109 | for batch_data in batches: 110 | feed_dict = self.feed_data(inputs_data=batch_data, keep_prob=1.0) 111 | pred = self.session.run([self.y_pred_value], feed_dict=feed_dict)[0] 112 | for d in pred: 113 | result.append(d) 114 | return result 115 | 116 | 117 | def feed_data(self, inputs_data, keep_prob, target=None): 118 | feed_dict = {} 119 | for i in range(len(self.inputs_data)): 120 | feed_dict[self.inputs_data[i]] = inputs_data[i] 121 | feed_dict[self.keep_prob] = keep_prob 122 | if not target is None: 123 | feed_dict[self.input_y] = target 124 | return feed_dict 125 | 126 | def evaluate(self,input_data,target,model_path): 127 | with self.graph.as_default(): 128 | saver = tf.train.Saver() 129 | config = tf.ConfigProto() 130 | config.gpu_options.allow_growth = True 131 | with tf.Session(config=config) as sess: 132 | self.session = sess 133 | saver.restore(sess, model_path) 134 | print(self._evaluate_without_predict_result(input_data, target)) 135 | 136 | def train_model(self,train_x,train_label,val_x,val_label,continue_train=False,previous_model_path=None): 137 | start_time = time.time() 138 | with self.graph.as_default(): 139 | saver = tf.train.Saver(max_to_keep=20) 140 | if not os.path.exists(self.save_dir): 141 | os.makedirs(self.save_dir) 142 | ############################################################## 143 | print(str(self.__get_time_dif(start_time)) + "trainning and evaluating...") 144 | total_batch = 0 145 | best_acc_val = 0.0 146 | last_improved = 0 147 | flag = False 148 | config = tf.ConfigProto() 149 | config.gpu_options.allow_growth = True 150 | with self.graph.as_default(): 151 | with tf.Session(config=config) as sess: 152 | self.session = sess 153 | if continue_train is False: 154 | sess.run(tf.global_variables_initializer()) 155 | sess.run(self.embedding_init, feed_dict={self.embedding_ph: self.embedding}) 156 | else: 157 | saver.restore(sess, previous_model_path) 158 | self.session.graph.finalize() 159 | for epoch in range(self.epoch_num): 160 | print("epoch:" + str(epoch + 1)) 161 | batch_train = self.batch_iter([train_x], train_label, batch_size=self.batch_size, shuffle=False) 162 | for batch_data, batch_target in batch_train: 163 | feed_dict = self.feed_data(inputs_data=batch_data, target=batch_target, 164 | keep_prob=self.dropout_keep_prob) 165 | s = self.session.run([self.optim], feed_dict=feed_dict) 166 | if total_batch == 0: 167 | saver.save(sess=self.session, save_path=self.save_dir + "model.ckpt") 168 | if total_batch > 0 and total_batch % self.print_per_batch == 0: 169 | feed_dict[self.keep_prob] = 1.0 170 | loss_train, acc_train = self.session.run([self.loss, self.acc], 171 | feed_dict=feed_dict) 172 | loss_val, acc_val = self._evaluate_without_predict_result(val_x, 173 | val_label) 174 | if acc_val > best_acc_val: 175 | best_acc_val = acc_val 176 | last_improved = total_batch 177 | saver.save(sess=self.session, 178 | save_path=self.save_dir + str(total_batch) + 'model.ckpt') 179 | improved_str = '*' 180 | else: 181 | improved_str = '' 182 | time_dif = self.__get_time_dif(start_time) 183 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 184 | + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}' 185 | print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, 186 | improved_str)) 187 | total_batch += 1 188 | if total_batch - last_improved > self.require_improvement: 189 | print("No optimization for a long time, auto-stopping...") 190 | flag = True 191 | break 192 | if flag: 193 | break 194 | self.session = None 195 | 196 | 197 | 198 | 199 | def batch_iter(self, input_data, target=None, batch_size=64,padding=0,shuffle=False): 200 | assert not input_data is None,"input_data is None" 201 | data_len = len(input_data[0]) 202 | num_batch = int((data_len - 1) / batch_size) + 1 203 | if shuffle:#target must be not none 204 | indices = np.random.permutation(np.arange(data_len)) 205 | else: 206 | indices = range(data_len) 207 | x_shuffle = [input_data[0][i] for i in indices] 208 | x_seq_len = [len(x_shuffle[i]) for i in range(len(x_shuffle))] 209 | if target is None: 210 | for i in range(num_batch): 211 | start_id = i * batch_size 212 | end_id = min((i + 1) * batch_size, data_len) 213 | batch_x = copy_list(x_shuffle[start_id:end_id]) 214 | batch_x_seq_len = x_seq_len[start_id:end_id] 215 | x_max_len = max(batch_x_seq_len) 216 | for list in batch_x: 217 | if len(list) < x_max_len: 218 | list += [padding] * (x_max_len - len(list)) 219 | yield [batch_x,batch_x_seq_len] 220 | else: 221 | y_shuffle = [target[i] for i in indices] 222 | for i in range(num_batch): 223 | start_id = i * batch_size 224 | end_id = min((i + 1) * batch_size, data_len) 225 | batch_y = y_shuffle[start_id:end_id] 226 | batch_x = copy_list(x_shuffle[start_id:end_id]) 227 | batch_x_seq_len = x_seq_len[start_id:end_id] 228 | x_max_len = max(batch_x_seq_len) 229 | for list in batch_x: 230 | if len(list) < x_max_len: 231 | list += [padding] * (x_max_len - len(list)) 232 | yield [batch_x,batch_x_seq_len],batch_y 233 | 234 | 235 | def get_file_src_list(parent_path, file_type='.txt'): 236 | files = os.listdir(parent_path) 237 | src_list = [] 238 | for file in files: 239 | absolute_path = os.path.join(parent_path, file) 240 | if os.path.isdir(absolute_path): 241 | src_list += get_file_src_list(absolute_path) 242 | elif file.endswith(file_type): 243 | src_list.append(absolute_path) 244 | return src_list 245 | 246 | def copy_list(list): 247 | new_list = [] 248 | for l in list: 249 | if type(l) == type([0]) or type(l) == type(np.array([0])): 250 | new_list.append(copy_list(l)) 251 | else: 252 | new_list.append(l) 253 | return new_list 254 | 255 | class Data: 256 | def __init__(self,x,y,ori_x=None): 257 | self.x=x 258 | self.y=y 259 | self.ori_x=ori_x 260 | def split(self): 261 | self.x=self.x.split(' ') 262 | def str2index(self,word_dict,with_unk=True): 263 | index=[] 264 | if with_unk: 265 | for s in self.x: 266 | if s in word_dict: 267 | index.append(word_dict[s]) 268 | else: 269 | index.append(len(word_dict)) 270 | else: 271 | for s in self.x: 272 | if s in word_dict: 273 | index.append(word_dict[s]) 274 | self.x=index 275 | 276 | def preprocess(informal_src_list,formal_src_list,embedding_path,output_path=None,shuffle=True): 277 | vectors,vocab_hash=embedding_api.load_word_embedding(embedding_path) 278 | all_data=[] 279 | for src in informal_src_list: 280 | with open(src,'r',encoding='utf-8') as f: 281 | for line in f: 282 | d=Data(line.strip().split(), 0, line.strip()) 283 | d.str2index(vocab_hash,with_unk=False) 284 | all_data.append(d) 285 | for src in formal_src_list: 286 | with open(src,'r',encoding='utf-8') as f: 287 | for line in f: 288 | d=Data(line.strip().split(), 1, line.strip()) 289 | d.str2index(vocab_hash,with_unk=False) 290 | all_data.append(d) 291 | if shuffle: 292 | random.shuffle(all_data) 293 | if output_path is not None: 294 | pickle.dump(all_data,open(output_path,'wb'),protocol=True) 295 | return all_data 296 | 297 | 298 | def all_prepro(): 299 | train_inf_src=['./new_exp_fr/classifier/informal.train.tok.bpe'] 300 | train_fm_src = ['./new_exp_fr/classifier/formal.train.tok.bpe'] 301 | val_inf_src = ['./new_exp_fr/classifier/informal.val.tok.bpe'] 302 | val_fm_src = ['./new_exp_fr/classifier/formal.val.tok.bpe'] 303 | test_inf_src = ['./new_exp_fr/classifier/informal.test.tok.bpe'] 304 | test_fm_src = ['./new_exp_fr/classifier/formal.test.tok.bpe'] 305 | embedding_path='./new_exp_fr/embedding/embedding.bpe.big.txt' 306 | preprocess(train_inf_src,train_fm_src,embedding_path=embedding_path, 307 | output_path='./new_exp_fr/classifier/train.pkl') 308 | preprocess(val_inf_src, val_fm_src, embedding_path=embedding_path, 309 | output_path='./new_exp_fr/classifier/val.pkl') 310 | preprocess(test_inf_src, test_fm_src, embedding_path=embedding_path, 311 | output_path='./new_exp_fr/classifier/test.pkl') 312 | 313 | def use_nn_model(): 314 | train = pickle.load(open('./new_exp_fr/classifier/train.pkl', 'rb')) 315 | val = pickle.load(open('./new_exp_fr/classifier/val.pkl', 'rb')) 316 | embedding_path = './new_exp_fr/embedding/embedding.bpe.big.txt' 317 | embedding, vocab_hash = embedding_api.load_word_embedding(embedding_path) 318 | nn=NNModel(np.array(embedding),mode='train') 319 | nn.build_basic_rnn_model() 320 | nn.train_model([t.x for t in train],[t.y for t in train],[t.x for t in val],[t.y for t in val], 321 | continue_train=False, previous_model_path='./new_exp_fr/classifier/model/990model.ckpt') 322 | 323 | def test(): 324 | test = pickle.load(open('./new_exp_fr/classifier/test.pkl', 'rb')) 325 | embedding_path = './new_exp_fr/embedding/corpus.fine_tune_embedding.epoch.10' 326 | embedding,vocab_hash = embedding_api.load_word_embedding(embedding_path) 327 | nn = NNModel(np.array(embedding),mode='eval') 328 | nn.build_basic_rnn_model() 329 | nn.evaluate([t.x for t in test],[t.y for t in test],model_path='') 330 | 331 | def predict(model_path,file_path='./new_exp_fr/classifier/val.pkl',embedding_path='./new_exp_fr/embedding/corpus.fine_tune_embedding.epoch.10'): 332 | test = pickle.load(open(file_path, 'rb')) 333 | embedding, vocab_hash = embedding_api.load_word_embedding(embedding_path) 334 | nn = NNModel(np.array(embedding),mode='predict') 335 | nn.batch_size=10000 336 | nn.build_basic_rnn_model() 337 | result=nn.predict_prob([t.x for t in test], model_path=model_path) 338 | return test,result 339 | 340 | 341 | def evaluate_one_formality(input_file_path,is_inf): 342 | embedding_path = './new_exp_fr/embedding/embedding.bpe.big.txt' 343 | embedding, vocab_hash = embedding_api.load_word_embedding(embedding_path) 344 | nn = NNModel(np.array(embedding), mode='eval') 345 | nn.batch_size = 128 346 | nn.build_basic_rnn_model() 347 | if is_inf: 348 | data = preprocess(informal_src_list=[input_file_path], formal_src_list=[], embedding_path=embedding_path, 349 | shuffle=False) 350 | else: 351 | data = preprocess(informal_src_list=[], formal_src_list=[input_file_path], embedding_path=embedding_path, 352 | shuffle=False) 353 | result = nn.predict_prob([t.x for t in data], model_path='./new_exp_fr/classifier/model/1700model.ckpt') 354 | score = 0 355 | if is_inf: 356 | for s in result: 357 | score += s[0] 358 | else: 359 | for s in result: 360 | score += s[1] 361 | print(score / len(data)) 362 | return score/len(data) 363 | 364 | def test_formality_score(files=None): 365 | if files is None: 366 | files = { 367 | 'rule_based': ['./data/Family_Relationships/bpe_outputs/formal.rule_based.bpe'], 368 | 'pbmt': ['./data/Family_Relationships/bpe_outputs/formal.pbmt.bpe'], 369 | 'nmt_baseline': ['./data/Family_Relationships/bpe_outputs/formal.nmt_baseline.bpe'], 370 | 'nmt_copy': ['./data/Family_Relationships/bpe_outputs/formal.nmt_copy.bpe'], 371 | 'nmt_combined': ['./data/Family_Relationships/bpe_outputs/formal.nmt_combined.bpe'], 372 | } 373 | embedding_path = './new_exp_fr/embedding/embedding.bpe.big.txt' 374 | embedding, vocab_hash = embedding_api.load_word_embedding(embedding_path) 375 | nn = NNModel(np.array(embedding),mode='eval') 376 | nn.batch_size = 128 377 | nn.build_basic_rnn_model() 378 | eval_log={} 379 | for key in files.keys(): 380 | if type(files[key])==type([]): 381 | fm_files=files[key]+'.bpe' 382 | else: 383 | fm_files=[files[key]+'.bpe'] 384 | data=preprocess(informal_src_list=[],formal_src_list=fm_files,embedding_path=embedding_path,shuffle=False) 385 | result = nn.predict_prob([t.x for t in data], model_path='./new_exp_fr/classifier/model/1700model.ckpt') 386 | score=0 387 | for s in result: 388 | score+=s[1] 389 | print(key,score/len(data)) 390 | eval_log[key]=score/len(data) 391 | return eval_log 392 | 393 | 394 | def cal_formality_score_for_each_sentence(output_dir,files=None): 395 | if files is None: 396 | files = { 397 | 'rule_based': ['./data/Family_Relationships/bpe_outputs/formal.rule_based.bpe'], 398 | 'pbmt': ['./data/Family_Relationships/bpe_outputs/formal.pbmt.bpe'], 399 | 'nmt_baseline': ['./data/Family_Relationships/bpe_outputs/formal.nmt_baseline.bpe'], 400 | 'nmt_copy': ['./data/Family_Relationships/bpe_outputs/formal.nmt_copy.bpe'], 401 | 'nmt_combined': ['./data/Family_Relationships/bpe_outputs/formal.nmt_combined.bpe'], 402 | } 403 | embedding_path = './new_exp_fr/embedding/embedding.bpe.big.txt' 404 | embedding, vocab_hash = embedding_api.load_word_embedding(embedding_path) 405 | nn = NNModel(np.array(embedding),mode='eval') 406 | nn.batch_size = 128 407 | nn.build_basic_rnn_model() 408 | for key in files.keys(): 409 | data=preprocess(informal_src_list=[],formal_src_list=files[key],embedding_path=embedding_path,shuffle=False) 410 | result = nn.predict_prob([t.x for t in data], model_path='./new_exp_fr/classifier/model/1700model.ckpt') 411 | base_name=os.path.basename(files[key]) 412 | with open(os.path.join(output_dir,base_name+'.formality_score'),'w',encoding='utf-8') as fw: 413 | for r in result: 414 | fw.write(str(r[1])+'\n') 415 | 416 | 417 | def evaluate_formality(resources): 418 | eval_log={} 419 | for key in resources.keys(): 420 | eval_log[key] = test_formality_score(resources[key]) 421 | return eval_log -------------------------------------------------------------------------------- /evaluate/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /evaluate/tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | 3 | def file_tokenize(input,output): 4 | with open(input,'r',encoding='utf-8') as f: 5 | with open(output,'w',encoding='utf-8') as fw: 6 | for line in f: 7 | fw.write(' '.join(nltk.word_tokenize(line.strip()))+'\n') 8 | 9 | -------------------------------------------------------------------------------- /evaluate/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /evaluate/utils/tools.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | 3 | def load_fasttext_word_embedding(path): 4 | vectors = [] 5 | vocab_hash = {} 6 | with open(path, 'r', encoding='utf-8') as f: 7 | first_line = True 8 | for line in f: 9 | if first_line: 10 | first_line = False 11 | continue 12 | strs = line.strip().split(' ') 13 | vocab_hash[strs[0]] = len(vectors) 14 | vectors.append([float(s) for s in strs[1:]]) 15 | return vectors, vocab_hash 16 | 17 | def load_corpus_and_stat_vocab(path): 18 | freq_vocab={} 19 | corpus=[] 20 | with open(path, 'r', encoding='utf-8') as f: 21 | for line in f: 22 | line=line.strip() 23 | sens=break_sentence(line,skip=True) 24 | new_sens=[] 25 | for s in sens: 26 | words=tokenizer(s,join=False,only_split=False) 27 | for w in words: 28 | if w in freq_vocab: 29 | freq_vocab[w]+=1 30 | else: 31 | freq_vocab[w]=1 32 | new_sens.append(' '.join(words)) 33 | corpus.append(' '.join(new_sens)) 34 | return corpus,freq_vocab 35 | 36 | 37 | 38 | 39 | def break_sentence(paragraph,skip=False,punctuations=None): 40 | if skip: 41 | return [paragraph] 42 | if punctuations is None: 43 | punctuations = ['?', '?', '...', '......', '!','.',',',','] 44 | sens=[] 45 | sens.append(paragraph) 46 | new_sens = [] 47 | one_sen = '' 48 | for char in paragraph: 49 | if char in punctuations: 50 | if one_sen!='': 51 | one_sen += (' ' + char) 52 | else: 53 | one_sen+=char 54 | new_sens.append(one_sen) 55 | one_sen = '' 56 | else: 57 | one_sen+=char 58 | if one_sen!='': 59 | new_sens.append(one_sen) 60 | return new_sens 61 | 62 | def tokenizer(sentence,join=False,only_split=True): 63 | if only_split: 64 | if join: 65 | return sentence 66 | else: 67 | return sentence.split() 68 | else: 69 | if join: 70 | return ' '.join(nltk.word_tokenize(sentence)) 71 | else: 72 | return nltk.word_tokenize(sentence) 73 | 74 | def break_sen_and_tokernize(para,break_sen=False): 75 | if break_sen: 76 | return nltk.word_tokenize(' '.join(break_sentence(para))) 77 | else: 78 | return nltk.word_tokenize(para) 79 | 80 | if __name__=='__main__': 81 | #print(' a a '.strip()) 82 | #print('Do not belive in it at all.'.split('?')) 83 | a=break_sentence('This question is for girls,"Have you ever gone out with a guy that you liked, but you did not really know?".') 84 | print(a) 85 | print(' '.join(a)) 86 | #print(nltk.word_tokenize(' '.join(a))) -------------------------------------------------------------------------------- /gpt/config.py: -------------------------------------------------------------------------------- 1 | from gpt.src import encoder 2 | text_enc = encoder.get_encoder('./models/117M') 3 | config_path='./models/117M' 4 | -------------------------------------------------------------------------------- /gpt/download_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import requests 4 | from tqdm import tqdm 5 | 6 | if len(sys.argv) != 2: 7 | print('You must enter the model name as a parameter, e.g.: download_model.py 117M') 8 | sys.exit(1) 9 | 10 | model = sys.argv[1] 11 | 12 | subdir = os.path.join('models', model) 13 | if not os.path.exists(subdir): 14 | os.makedirs(subdir) 15 | subdir = subdir.replace('\\','/') # needed for Windows 16 | 17 | for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']: 18 | 19 | r = requests.get("https://storage.googleapis.com/gpt-2/" + subdir + "/" + filename, stream=True) 20 | 21 | with open(os.path.join(subdir, filename), 'wb') as f: 22 | file_size = int(r.headers["content-length"]) 23 | chunk_size = 1000 24 | with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: 25 | # 1k for chunk_size, since Ethernet packet size is around 1500 bytes 26 | for chunk in r.iter_content(chunk_size=chunk_size): 27 | f.write(chunk) 28 | pbar.update(chunk_size) -------------------------------------------------------------------------------- /gpt/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['CUDA_VISIBLE_DEVICES']='7' 3 | from gpt.src.concat_fine_tuning import concat_finetuning, domain_combined 4 | from gpt.src.hierarchical_attention import HA 5 | from gpt.src.simple_finetune import simple_finetune, ensemble_test 6 | 7 | 8 | def run_simple_finetune_and_emsemble_decoding(domain): 9 | simple_finetune(domain=domain, methods='ori') 10 | simple_finetune(domain=domain, methods='rule') 11 | ensemble_test(domain=domain) 12 | 13 | 14 | def run_concat_finetune(domain): 15 | concat_finetuning(domain=domain) 16 | 17 | def run_ha(domain): 18 | HA(domain=domain) 19 | 20 | def run_all(): 21 | run_simple_finetune_and_emsemble_decoding('fr') 22 | run_simple_finetune_and_emsemble_decoding('em') 23 | run_concat_finetune('fr') 24 | run_concat_finetune('em') 25 | run_ha('fr') 26 | run_ha('em') 27 | domain_combined('fr',only_test=False) 28 | domain_combined('em',only_test=True) 29 | 30 | if __name__=='__main__': 31 | run_all() 32 | 33 | -------------------------------------------------------------------------------- /gpt/src/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /gpt/src/beamsearch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Beam Search 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import copy 10 | import tensorflow as tf 11 | 12 | from collections import namedtuple 13 | from tensorflow.python.util import nest 14 | from utils.common import * 15 | 16 | class BeamSearchState(namedtuple("BeamSearchState", 17 | ("inputs", "state", "finish"))): 18 | pass 19 | 20 | def _get_inference_fn(model_fns, features): 21 | def inference_fn(inputs, state): 22 | local_features = { 23 | "source1": features["source1"], 24 | "source1_length": features["source1_length"], 25 | "source2": features["source2"], 26 | "source2_length": features["source2_length"], 27 | # [bos_id, ...] => [..., 0] 28 | "target": tf.pad(inputs[:, 1:], [[0, 0], [0, 1]]), 29 | "target_length": tf.fill([tf.shape(inputs)[0]], 30 | tf.shape(inputs)[1]) 31 | } 32 | 33 | outputs = [] 34 | next_state = [] 35 | 36 | for (model_fn, model_state) in zip(model_fns, state): 37 | if model_state: 38 | output, new_state = model_fn(local_features, model_state) 39 | outputs.append(output) 40 | next_state.append(new_state) 41 | else: 42 | output = model_fn(local_features) 43 | outputs.append(output) 44 | next_state.append({}) 45 | 46 | # Ensemble 47 | log_prob = tf.add_n(outputs) / float(len(outputs)) 48 | 49 | return log_prob, next_state 50 | 51 | return inference_fn 52 | 53 | def ensemble_model_fn_wrapper_gpt(model_fn,input,states,hparams,scopes_for_ensemble): 54 | new_state_list=[] 55 | step_log_probs_list=[] 56 | for i in range(0, len(scopes_for_ensemble)): 57 | next_outputs = model_fn(hparams, input, past=states[i], scope=scopes_for_ensemble[i]) 58 | next_state = tf.concat([states[i], next_outputs['presents']], axis=-2) 59 | new_state_list.append(next_state) 60 | step_log_probs = next_outputs['logits'][:, 0, :] 61 | step_log_probs_list.append(step_log_probs) 62 | step_log_probs_ensemble=tf.reduce_mean(tf.stack(step_log_probs_list,axis=1),axis=1) 63 | step_log_probs_ensemble = tf.log(tf.nn.softmax(step_log_probs_ensemble)) 64 | return new_state_list, step_log_probs_ensemble 65 | 66 | 67 | def _beam_search_step(time, func, state, batch_size, beam_size, alpha, eos_id, hparams, scopes_for_ensemble, ensemble=False, concat_state_dim=None): 68 | # Compute log probabilities 69 | seqs, log_probs = state.inputs[:2] 70 | flat_seqs = merge_first_two_dims(seqs) 71 | if ensemble: 72 | flat_state = nest.map_structure(lambda x: merge_first_two_dims(x), 73 | state.state) 74 | next_state,step_log_probs=ensemble_model_fn_wrapper_gpt(func,tf.expand_dims(flat_seqs[:, -1], axis=1),flat_state,hparams,scopes_for_ensemble=scopes_for_ensemble) 75 | else: 76 | flat_state = nest.map_structure(lambda x: merge_first_two_dims(x), 77 | state.state) 78 | next_outputs = func(hparams, tf.expand_dims(flat_seqs[:, -1], axis=1), flat_state) 79 | if concat_state_dim is not None:#none or -2 80 | next_state = nest.map_structure(lambda x:tf.concat([x,next_outputs['presents']],axis=concat_state_dim),flat_state) 81 | else: 82 | next_state = next_outputs['presents'] 83 | #next_state = tf.concat([flat_state, next_outputs['presents']], axis=-2) 84 | step_log_probs = next_outputs['logits'][:, 0, :] 85 | step_log_probs = tf.log(tf.nn.softmax(step_log_probs)) 86 | step_log_probs = split_first_two_dims(step_log_probs, batch_size, 87 | beam_size) 88 | next_state = nest.map_structure( 89 | lambda x: split_first_two_dims(x, batch_size, beam_size), 90 | next_state) 91 | curr_log_probs = tf.expand_dims(log_probs, 2) + step_log_probs 92 | 93 | # Apply length penalty 94 | length_penalty = tf.pow((5.0 + tf.to_float(time + 1)) / 6.0, alpha) 95 | curr_scores = curr_log_probs / length_penalty 96 | vocab_size = curr_scores.shape[-1].value or tf.shape(curr_scores)[-1] 97 | 98 | # Select top-k candidates 99 | # [batch_size, beam_size * vocab_size] 100 | curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size]) 101 | # [batch_size, 2 * beam_size] 102 | top_scores, top_indices = tf.nn.top_k(curr_scores, k=2 * beam_size) 103 | # Shape: [batch_size, 2 * beam_size] 104 | beam_indices = top_indices // vocab_size 105 | symbol_indices = top_indices % vocab_size 106 | # Expand sequences 107 | # [batch_size, 2 * beam_size, time] 108 | candidate_seqs = gather_2d(seqs, beam_indices) 109 | candidate_seqs = tf.concat([candidate_seqs, 110 | tf.expand_dims(symbol_indices, 2)], 2) 111 | 112 | # Expand sequences 113 | # Suppress finished sequences 114 | flags = tf.equal(symbol_indices, eos_id) 115 | # [batch, 2 * beam_size] 116 | alive_scores = top_scores + tf.to_float(flags) * tf.float32.min 117 | # [batch, beam_size] 118 | alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size) 119 | alive_symbols = gather_2d(symbol_indices, alive_indices) 120 | alive_indices = gather_2d(beam_indices, alive_indices) 121 | alive_seqs = gather_2d(seqs, alive_indices) 122 | # [batch_size, beam_size, time + 1] 123 | alive_seqs = tf.concat([alive_seqs, tf.expand_dims(alive_symbols, 2)], 2) 124 | alive_state = nest.map_structure( 125 | lambda x: gather_2d(x, alive_indices), 126 | next_state) 127 | alive_log_probs = alive_scores * length_penalty 128 | 129 | # Select finished sequences 130 | prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish 131 | # [batch, 2 * beam_size] 132 | step_fin_scores = top_scores + (1.0 - tf.to_float(flags)) * tf.float32.min 133 | # [batch, 3 * beam_size] 134 | fin_flags = tf.concat([prev_fin_flags, flags], axis=1) 135 | fin_scores = tf.concat([prev_fin_scores, step_fin_scores], axis=1) 136 | # [batch, beam_size] 137 | fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size) 138 | fin_flags = gather_2d(fin_flags, fin_indices) 139 | pad_seqs = tf.fill([batch_size, beam_size, 1], 140 | tf.constant(eos_id, tf.int32)) 141 | prev_fin_seqs = tf.concat([prev_fin_seqs, pad_seqs], axis=2) 142 | fin_seqs = tf.concat([prev_fin_seqs, candidate_seqs], axis=1) 143 | fin_seqs = gather_2d(fin_seqs, fin_indices) 144 | 145 | new_state = BeamSearchState( 146 | inputs=(alive_seqs, alive_log_probs, alive_scores), 147 | state=alive_state, 148 | finish=(fin_flags, fin_seqs, fin_scores) 149 | ) 150 | 151 | return time + 1, new_state 152 | 153 | 154 | def beam_search(func, state, batch_size, beam_size, max_length, alpha, 155 | init_seqs, eos_id, hparams, scopes_for_ensemble, ensemble=False, concat_state_dim=None): 156 | init_log_probs = tf.constant([[0.] + [tf.float32.min] * (beam_size - 1)]) 157 | init_log_probs = tf.tile(init_log_probs, [batch_size, 1]) 158 | init_scores = tf.zeros_like(init_log_probs) 159 | fin_seqs = tf.zeros([batch_size, beam_size, 1], tf.int32) 160 | fin_scores = tf.fill([batch_size, beam_size], tf.float32.min) 161 | fin_flags = tf.zeros([batch_size, beam_size], tf.bool) 162 | 163 | state = BeamSearchState( 164 | inputs=(init_seqs, init_log_probs, init_scores), 165 | state=state, 166 | finish=(fin_flags, fin_seqs, fin_scores) 167 | ) 168 | 169 | max_step = tf.reduce_max(max_length) 170 | 171 | def _is_finished(t, s): 172 | log_probs = s.inputs[1] 173 | finished_flags = s.finish[0] 174 | finished_scores = s.finish[2] 175 | max_lp = tf.pow(((5.0 + tf.to_float(max_step)) / 6.0), alpha) 176 | best_alive_score = log_probs[:, 0] / max_lp 177 | worst_finished_score = tf.reduce_min( 178 | finished_scores * tf.to_float(finished_flags), axis=1) 179 | add_mask = 1.0 - tf.to_float(tf.reduce_any(finished_flags, 1)) 180 | worst_finished_score += tf.float32.min * add_mask 181 | bound_is_met = tf.reduce_all(tf.greater(worst_finished_score, 182 | best_alive_score)) 183 | 184 | cond = tf.logical_and(tf.less(t, max_step), 185 | tf.logical_not(bound_is_met)) 186 | 187 | return cond 188 | 189 | def _loop_fn(t, s): 190 | outs = _beam_search_step(t, func, s, batch_size, beam_size, alpha, eos_id, hparams=hparams, scopes_for_ensemble=scopes_for_ensemble, 191 | ensemble=ensemble, concat_state_dim=concat_state_dim) 192 | return outs 193 | if type(state.state)==list: 194 | tmp=state.state 195 | state_shape_invariants=[] 196 | for item in tmp: 197 | state_shape_invariants.append(item.shape) 198 | else: 199 | state_shape_invariants=state.state.shape 200 | time = tf.constant(0, name="time") 201 | shape_invariants = BeamSearchState( 202 | inputs=(tf.TensorShape([None, None, None]), 203 | tf.TensorShape([None, None]), 204 | tf.TensorShape([None, None])), 205 | state=state_shape_invariants, 206 | finish=(tf.TensorShape([None, None]), 207 | tf.TensorShape([None, None, None]), 208 | tf.TensorShape([None, None])) 209 | ) 210 | outputs = tf.while_loop(_is_finished, _loop_fn, [time, state], 211 | shape_invariants=[tf.TensorShape([]), 212 | shape_invariants], 213 | parallel_iterations=1, 214 | back_prop=False) 215 | 216 | final_state = outputs[1] 217 | alive_seqs = final_state.inputs[0] 218 | alive_scores = final_state.inputs[2] 219 | final_flags = final_state.finish[0] 220 | final_seqs = final_state.finish[1] 221 | final_scores = final_state.finish[2] 222 | 223 | alive_seqs.set_shape([None, beam_size, None]) 224 | final_seqs.set_shape([None, beam_size, None]) 225 | 226 | final_seqs = tf.where(tf.reduce_any(final_flags, 1), final_seqs, 227 | alive_seqs) 228 | final_scores = tf.where(tf.reduce_any(final_flags, 1), final_scores, 229 | alive_scores) 230 | 231 | return final_seqs, final_scores 232 | 233 | 234 | def create_inference_graph(init_seqs, state, step_fn, hparams, decode_length, batch_size, beam_size, decode_alpha, eos_id, 235 | ensemble, concat_state_dim, scopes_for_ensemble=None): 236 | tiled_context_state = nest.map_structure( 237 | lambda x:tile_to_beam_size(x,beam_size), 238 | state 239 | ) 240 | tiled_init_seq=nest.map_structure( 241 | lambda x:tile_to_beam_size(x,beam_size), 242 | init_seqs 243 | ) 244 | seqs, scores = beam_search(step_fn, tiled_context_state, batch_size, beam_size, 245 | decode_length, decode_alpha, tiled_init_seq, eos_id, hparams=hparams, scopes_for_ensemble=scopes_for_ensemble, 246 | ensemble=ensemble, concat_state_dim=concat_state_dim) 247 | 248 | # return seqs[:, :top_beams, 1:], scores[:, :top_beams] 249 | return seqs, scores 250 | -------------------------------------------------------------------------------- /gpt/src/concat_fine_tuning.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.contrib.seq2seq import sequence_loss 6 | from gpt.src import model 7 | from gpt.src import beamsearch 8 | import tensorflow.contrib.slim as slim 9 | from datetime import timedelta 10 | import time 11 | from gpt.src.single_gpu_serving import beam_search_generator 12 | from utils.file_api import read_file_lines,write_file_lines 13 | import random 14 | from gpt.config import * 15 | from utils.cat_files import cat_files 16 | import shutil 17 | 18 | class Encoder___(): 19 | def __init__(self,scope,hparam): 20 | if scope is None: 21 | self.scope='encoder' 22 | else: 23 | self.scope=scope 24 | self.hparam=hparam 25 | 26 | def encode(self,input,input_len,past=None): 27 | with tf.variable_scope(self.scope,reuse=tf.AUTO_REUSE): 28 | lm_output = model.model(hparams=self.hparam, X=input, past=past, reuse=tf.AUTO_REUSE) 29 | presents = lm_output['present'] 30 | presents.set_shape(model.past_shape(hparams=self.hparam, batch_size=None)) 31 | target_mask = tf.sequence_mask(input_len, maxlen=tf.shape(input)[1], dtype=tf.float32) 32 | target_mask=tf.expand_dims(target_mask,2) 33 | print(presents) 34 | encode_out=tf.transpose(presents,perm=(0,4,2,3,1,5)) 35 | ori_enc_shape=tf.shape(encode_out) 36 | encode_out=tf.reshape(encode_out,shape=(tf.shape(presents)[0],tf.shape(presents)[4],-1)) 37 | encode_out=tf.multiply(encode_out,target_mask) 38 | encode_out=tf.reshape(encode_out,shape=ori_enc_shape) 39 | encode_out=tf.transpose(encode_out,perm=(0,4,2,3,1,5)) 40 | return encode_out 41 | 42 | class Decoder___(): 43 | def __init__(self,scope,hparam): 44 | if scope is None: 45 | self.scope='decoder' 46 | else: 47 | self.scope=scope 48 | self.hparam=hparam 49 | 50 | def decode_one_step(self,hparams,tokens,past): 51 | with tf.variable_scope(self.scope): 52 | lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) 53 | logits = lm_output['logits'] 54 | presents = lm_output['present'] 55 | presents.set_shape(model.past_shape(hparams=hparams, batch_size=None)) 56 | return { 57 | 'logits': logits, 58 | 'presents': tf.concat([past,presents],axis=-2) 59 | } 60 | 61 | def decode_all(self,hparams,tokens,past): 62 | with tf.variable_scope(self.scope): 63 | lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) 64 | logits = lm_output['logits'] 65 | presents = lm_output['present'] 66 | presents.set_shape(model.past_shape(hparams=hparams, batch_size=None)) 67 | return { 68 | 'logits': logits, 69 | 'presents': tf.concat([past,presents],axis=-2) 70 | } 71 | 72 | 73 | class NMT_GPT(): 74 | def __init__(self,input_num,config_path): 75 | self.hparams = model.default_hparams() 76 | self.config_path = config_path 77 | with open(os.path.join(self.config_path, 'hparams.json')) as f: 78 | self.hparams.override_from_dict(json.load(f)) 79 | self.input_num=input_num 80 | self.text_enc = encoder.get_encoder(self.config_path) 81 | self.sos_id=self.text_enc.encode('\t')[0] 82 | self.eos_id=self.text_enc.encode('\n')[0] 83 | 84 | def def_placeholder_and_components(self): 85 | self.encoder = Encoder___('encoder', self.hparams) 86 | self.decoder = Decoder___('decoder', self.hparams) 87 | self.inputs = [tf.placeholder(tf.int32, [None, None], name='input_%d' % i) for i in range(0, self.input_num)] 88 | self.input_lens = [tf.placeholder(tf.int32, [None, ], name='input_len_%d' % i) for i in 89 | range(0, self.input_num)] 90 | self.target_in = tf.placeholder(tf.int32, [None, None], name='target_in') 91 | self.target_out = tf.placeholder(tf.int32, [None, None], name='target_out') 92 | self.target_len = tf.placeholder(tf.int32, [None], name='target_len') 93 | 94 | 95 | 96 | def build_training_model(self): 97 | self.def_placeholder_and_components() 98 | past_for_decoder=None 99 | for i in range(0,self.input_num): 100 | presents=self.encoder.encode(self.inputs[i],self.input_lens[i],past_for_decoder) 101 | if past_for_decoder is None: 102 | past_for_decoder=presents 103 | else: 104 | past_for_decoder=tf.concat([past_for_decoder,presents],axis=-2) 105 | all_logits=self.decoder.decode_all(self.hparams,tokens=self.target_in,past=past_for_decoder)['logits'] 106 | with tf.name_scope('loss'): 107 | batch_max_seq_len = tf.shape(self.target_in)[1] 108 | target_mask = tf.sequence_mask(self.target_len, maxlen=batch_max_seq_len, dtype=tf.float32) 109 | cost = sequence_loss(logits=all_logits, targets=self.target_out, 110 | weights=target_mask) 111 | return cost 112 | 113 | 114 | def build_beam_search_graph(self, beam_size, batch_size, max_decode_length, decode_alpha=0.6): 115 | self.def_placeholder_and_components() 116 | past_for_decoder = None 117 | for i in range(0, self.input_num): 118 | presents = self.encoder.encode(self.inputs[i], self.input_lens[i], past_for_decoder) 119 | if past_for_decoder is None: 120 | past_for_decoder = presents 121 | else: 122 | past_for_decoder = tf.concat([past_for_decoder, presents], axis=-2) 123 | with tf.name_scope('beam_search'): 124 | init_seq = tf.fill(dims=(batch_size, 1), value=self.sos_id) 125 | seqs, scores = beamsearch.create_inference_graph(init_seqs=init_seq, state=past_for_decoder, 126 | step_fn=self.decoder.decode_one_step, hparams=self.hparams, 127 | decode_length=max_decode_length, 128 | batch_size=batch_size, beam_size=beam_size, 129 | decode_alpha=decode_alpha, eos_id=self.eos_id, 130 | ensemble=False,concat_state_dim=None) 131 | return seqs, scores 132 | 133 | 134 | class NMT_GPT_Trainer(): 135 | def __init__(self,model_fn:NMT_GPT): 136 | self.model_fn=model_fn 137 | self.learning_rate=1e-4 138 | self.sep_flag='\t' 139 | self.graph=tf.Graph() 140 | self.vars_for_infer = [] 141 | self.vars_for_train = [] 142 | self.losses=[] 143 | self.only_predict_target=True 144 | tf.logging.set_verbosity(tf.logging.INFO) 145 | self.is_append_sep=True 146 | self.hier_enc_end_token=self.model_fn.text_enc.encode('\t') 147 | 148 | 149 | def average_gradients(self,tower_grads): 150 | average_grads = [] 151 | for grad_and_vars in zip(*tower_grads): 152 | grads = [tf.expand_dims(g, 0) for g, _ in grad_and_vars] 153 | grads = tf.concat(grads, 0) 154 | grad = tf.reduce_mean(grads, 0) 155 | grad_and_var = (grad, grad_and_vars[0][1]) 156 | # [(grad0, var0),(grad1, var1),...] 157 | average_grads.append(grad_and_var) 158 | return average_grads 159 | 160 | 161 | def build_graph(self): 162 | with self.graph.as_default(): 163 | self.opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate) 164 | self.tower_grads = [] 165 | loss = self.model_fn.build_training_model() 166 | self.losses.append(loss) 167 | grads = self.opt.compute_gradients(loss) 168 | tvs = tf.trainable_variables() 169 | self.accum_vars = [tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) 170 | for tv in 171 | tvs] 172 | self.zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in self.accum_vars] 173 | self.accum_grad_ops = [self.accum_vars[j].assign_add(gv[0]) for j, gv in 174 | enumerate(grads) if gv[0] is not None] 175 | self.tower_grads.append([(self.accum_vars[j], gv[1]) for j, gv in enumerate(grads) ]) 176 | grads = self.average_gradients(self.tower_grads) 177 | with tf.device('/gpu:0'): 178 | self.accum_steps=tf.placeholder(tf.float32, [], name='accum_stpes') 179 | self.train_step = self.opt.apply_gradients([(g/self.accum_steps, v) for g,v in grads]) 180 | self.avg_loss=tf.stack(self.losses,axis=0) 181 | self.avg_loss=tf.reduce_mean(self.avg_loss) 182 | 183 | 184 | def create_session_init_and_print_all_trainable_vars(self, max_to_save, ori_gpt_model_path=None): 185 | # Print parameters 186 | with self.graph.as_default(): 187 | all_weights = {v.name: v for v in tf.trainable_variables()} 188 | total_size = 0 189 | for v_name in sorted(list(all_weights)): 190 | v = all_weights[v_name] 191 | tf.logging.info("%s\tshape %s", v.name[:-2].ljust(80), 192 | str(v.shape).ljust(20)) 193 | v_size = np.prod(np.array(v.shape.as_list())).tolist() 194 | total_size += v_size 195 | tf.logging.info("Total trainable variables size: %d", total_size) 196 | all_var_list = slim.get_variables_to_restore() 197 | for v in all_var_list: 198 | if 'Adam' in v.name: 199 | self.vars_for_train.append(v) 200 | elif v.name.startswith('beta'): 201 | self.vars_for_train.append(v) 202 | elif v.name.startswith('parallel'): 203 | pass 204 | elif v.name.startswith('Variable'): 205 | pass 206 | else: 207 | self.vars_for_infer.append(v) 208 | if len(self.vars_for_infer) > 0: 209 | self.saver_infer = tf.train.Saver(self.vars_for_infer, max_to_keep=max_to_save) 210 | if len(self.vars_for_train) > 0: 211 | self.saver_train = tf.train.Saver(self.vars_for_train, max_to_keep=max_to_save) 212 | config = tf.ConfigProto() 213 | config.gpu_options.allow_growth = True 214 | sess = tf.Session(graph=self.graph, config=config) 215 | init_op = tf.global_variables_initializer() 216 | sess.run(init_op) 217 | restore_ops=[] 218 | if ori_gpt_model_path is not None: 219 | ckpt = tf.train.latest_checkpoint(ori_gpt_model_path) 220 | tf.logging.info("Loading %s" % ckpt) 221 | var_list = tf.train.list_variables(ckpt) 222 | values = {} 223 | reader = tf.train.load_checkpoint(ckpt) 224 | for (name, shape) in var_list: 225 | if not name.startswith('model/'): # ignore global_step 226 | continue 227 | tensor = reader.get_tensor(name) 228 | values[name] = tensor 229 | for v in self.vars_for_infer: 230 | #print(v.name) 231 | tmp = '/'.join(v.name.split('/')[1:]) 232 | v_name = tmp.split(':')[0] 233 | if v_name!='model/sen_attn_w': 234 | op = tf.assign(v, values[v_name]) 235 | restore_ops.append(op) 236 | sess.run(restore_ops) 237 | return sess 238 | 239 | 240 | def padding_batch(self, input_list): 241 | in_len = [len(i) for i in input_list] 242 | new_in = pad_sequences(input_list, padding='post') 243 | return new_in, in_len 244 | 245 | 246 | def train_or_eval_batch_with_raw_text(self, sess, input_text, mini_batch, is_train=True, 247 | run_options=None): 248 | batch_size = len(input_text) 249 | batch_input = {} 250 | batch_target_in = [] 251 | batch_target_out =[] 252 | batch_target_len =[] 253 | batch_input_len = {} 254 | for text in input_text: 255 | strs=text.split(self.sep_flag) 256 | inputs=strs[:-1] 257 | target=strs[-1] 258 | if self.is_append_sep: 259 | inputs_tokens = [self.model_fn.text_enc.encode(item)+self.hier_enc_end_token for item in inputs] 260 | else: 261 | inputs_tokens = [self.model_fn.text_enc.encode(item) for item in inputs] 262 | target_tokens=self.model_fn.text_enc.encode(target) 263 | for i in range(0,len(inputs_tokens)): 264 | if i not in batch_input: 265 | batch_input[i]=[] 266 | batch_input[i].append(inputs_tokens[i]) 267 | if i not in batch_input_len: 268 | batch_input_len[i]=[len(inputs_tokens[i])] 269 | else: 270 | batch_input_len[i].append(len(inputs_tokens[i])) 271 | tar_in=[self.model_fn.sos_id]+target_tokens 272 | tar_out=target_tokens+[self.model_fn.eos_id] 273 | batch_target_len.append(len(tar_out)) 274 | batch_target_in.append(tar_in) 275 | batch_target_out.append(tar_out) 276 | # gradient accum and update 277 | #assert batch_size%mini_batch==0 278 | with self.graph.as_default(): 279 | data_num = batch_size 280 | losses = [] 281 | low = 0 282 | if is_train: 283 | sess.run(self.zero_ops) 284 | while low < data_num: 285 | n_samples = min([mini_batch, data_num - low]) 286 | mini_batch_input = [batch_input[i][low:low + n_samples] for i in range(0,len(batch_input))] 287 | mini_batch_input_len = [batch_input_len[i][low:low + n_samples] for i in range(0, len(batch_input))] 288 | mini_batch_target_in = batch_target_in[low:low + n_samples] 289 | mini_batch_target_out = batch_target_out[low:low + n_samples] 290 | mini_batch_target_len = batch_target_len[low:low + n_samples] 291 | mini_batch_target_in_padded, _ = self.padding_batch(mini_batch_target_in) 292 | mini_batch_target_out_padded, _ = self.padding_batch(mini_batch_target_out) 293 | feed_dict={} 294 | for i in range(0,self.model_fn.input_num): 295 | p,_ = self.padding_batch(mini_batch_input[i]) 296 | feed_dict[self.model_fn.inputs[i]]=p 297 | feed_dict[self.model_fn.input_lens[i]]=mini_batch_input_len[i] 298 | feed_dict[self.model_fn.target_in] = mini_batch_target_in_padded 299 | feed_dict[self.model_fn.target_out] = mini_batch_target_out_padded 300 | feed_dict[self.model_fn.target_len] = mini_batch_target_len 301 | if is_train: 302 | result = sess.run([self.accum_grad_ops, self.avg_loss], feed_dict=feed_dict, options=run_options) 303 | loss=result[-1] 304 | else: 305 | loss = sess.run(self.avg_loss, feed_dict=feed_dict) 306 | low += n_samples 307 | losses.append(loss*n_samples) 308 | if is_train: 309 | sess.run(self.train_step,feed_dict={self.accum_steps:batch_size/mini_batch}) 310 | return sum(losses) / batch_size 311 | 312 | 313 | def training(self, eos_id=None, train_corpus='./story/story.train', dev_corpus='./story/story.dev', 314 | init_step_num=1, learning_rate=1e-4, batch_size=64, mini_batch=16, total_steps=100000, 315 | train_ckpt_path='./models/117M/model_train_1/', infer_ckpt_path='./models/117M/', 316 | eval_per_n_steps=1, max_to_save=3, early_stop_steps=6000,append_eos=True,ori_gpt_model_path=None): 317 | self.learning_rate=learning_rate 318 | sess=self.create_session_init_and_print_all_trainable_vars(max_to_save,ori_gpt_model_path=ori_gpt_model_path) 319 | if ori_gpt_model_path is None: 320 | self.restore_model_and_init(sess, infer_ckpt_path, train_ckpt_path) 321 | train = load_corpus(train_corpus) 322 | #random.shuffle(train) 323 | # train=[' '.join(['you' for j in range(0,512)]) for i in range(0,512)] 324 | dev = load_corpus(dev_corpus) 325 | step = init_step_num 326 | low = 0 327 | epoch_num = 1 328 | train_data_num = len(train) 329 | eval_data_num = len(dev) 330 | last_improvement_step = init_step_num 331 | best_loss = 100000 332 | saved_steps = [] 333 | tf.logging.info('start training...') 334 | self.graph.finalize() 335 | start_time = time.time() 336 | while step < total_steps: 337 | run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True) 338 | n_samples = min([batch_size, train_data_num - low]) 339 | train_loss = self.train_or_eval_batch_with_raw_text(sess, train[low:low + n_samples], 340 | run_options=run_options, 341 | mini_batch=mini_batch, 342 | ) 343 | ###eval: 344 | if step % eval_per_n_steps == 0: 345 | eval_low = 0 346 | eval_losses = [] 347 | while eval_low < eval_data_num: 348 | eval_n_samples = min([batch_size, eval_data_num - eval_low]) 349 | eval_losses.append(self.train_or_eval_batch_with_raw_text( 350 | sess, dev[eval_low:eval_low + eval_n_samples], is_train=False, mini_batch=mini_batch)) 351 | eval_low += eval_n_samples 352 | eval_avg_loss = sum(eval_losses) / len(eval_losses) 353 | time_dif = get_time_dif(start_time) 354 | if eval_avg_loss < best_loss: 355 | best_loss = eval_avg_loss 356 | last_improvement_step = step 357 | tf.logging.info('save step %d', last_improvement_step) 358 | self.save_model(sess, infer_ckpt_path, train_ckpt_path, step=step) 359 | saved_steps.append(last_improvement_step) 360 | tf.logging.info("%s: step %d: train loss %f; eval loss %f *", time_dif, step, train_loss, 361 | eval_avg_loss) 362 | if len(saved_steps) > max_to_save: 363 | saved_steps = saved_steps[1:] 364 | else: 365 | tf.logging.info("%s: step %d: train loss %f; eval loss %f", time_dif, step, train_loss, 366 | eval_avg_loss) 367 | if step - last_improvement_step > early_stop_steps: 368 | tf.logging.info("early stopping...") 369 | break 370 | ### 371 | step += 1 372 | low += n_samples 373 | if low == train_data_num: 374 | low = 0 375 | epoch_num += 1 376 | #random.shuffle(train) 377 | sess.close() 378 | print('all work has finished') 379 | 380 | 381 | def restore_model_and_init(self, sess, ckpt_for_infer, ckpt_for_train): 382 | with self.graph.as_default(): 383 | if ckpt_for_infer is not None: 384 | ckpt = tf.train.latest_checkpoint(ckpt_for_infer) 385 | if ckpt is not None: 386 | self.saver_infer.restore(sess, ckpt) 387 | tf.logging.info('restored inferring params from %s',ckpt) 388 | if ckpt_for_train is not None: 389 | ckpt = tf.train.latest_checkpoint(ckpt_for_train) 390 | if ckpt is not None: 391 | self.saver_train.restore(sess, ckpt) 392 | tf.logging.info('restored training params from %s', ckpt) 393 | 394 | 395 | def save_model(self, sess, infer_ckpt_path, train_ckpt_path, step): 396 | with self.graph.as_default(): 397 | if infer_ckpt_path is not None and len(self.vars_for_infer) > 0: 398 | self.saver_infer.save(sess, os.path.join(infer_ckpt_path,'model'), global_step=step) 399 | if train_ckpt_path is not None and len(self.vars_for_train) > 0: 400 | self.saver_train.save(sess, os.path.join(train_ckpt_path,'model'), global_step=step) 401 | 402 | 403 | def padding_for_target_mask(self,mask_list,input_len): 404 | batch_size= len(mask_list) 405 | assert batch_size==len(input_len) 406 | max_len=max(input_len) 407 | for i in range(0,batch_size): 408 | l=input_len[i] 409 | mask_list[i]=mask_list[i]+[0.0]*(max_len-l) 410 | 411 | 412 | def load_corpus(path): 413 | lines = [] 414 | with open(path, 'r', encoding='utf-8') as f: 415 | for line in f: 416 | lines.append(line.strip()) 417 | return lines 418 | 419 | def get_time_dif(start_time): 420 | end_time = time.time() 421 | time_dif = end_time - start_time 422 | return timedelta(seconds=int(round(time_dif))) 423 | 424 | 425 | def test(config_path,input_num,model_dir='./models/ori_rule/formality_infer/',input_path='../training_data/dif_models/eval.ori_rule', 426 | output_path='../evaluate/gyafc_model_outputs/fr_out/formal.gpt.cat_ori_rule.old',beam_size=4,max_dec_len=60,dec_alpha=0.6): 427 | gpt2 = NMT_GPT(config_path=config_path,input_num=input_num) 428 | generator = beam_search_generator(gpt2, beam_size=beam_size, 429 | model_directory=model_dir, max_dec_len=max_dec_len, 430 | dec_alpha=dec_alpha) 431 | sess=generator.build_graph_and_restore(eos_id=gpt2.text_enc.encode('\n')[0]) 432 | lines=read_file_lines(input_path) 433 | result=[] 434 | for line in lines: 435 | result.append(generator.generate(sess,line,multi_pls=True)) 436 | print(line+' ||| '+result[-1].strip()) 437 | sess.close() 438 | write_file_lines(output_path, result) 439 | 440 | 441 | def train(config_path,input_num,ori_gpt_model=None,sep_flag='\t', 442 | train_corpus='../training_data/preprocessed/Family_Relationships/train.ori.txt', 443 | dev_corpus='../training_data/preprocessed/Family_Relationships/val.ori.txt', 444 | infer_ckpt_path='./models/ori_data_fr/formality_infer/', 445 | train_ckpt_path='./models/ori_data_fr/formality_train/'): 446 | gpt2 = NMT_GPT(input_num,config_path) 447 | trainer = NMT_GPT_Trainer(gpt2) 448 | trainer.build_graph() 449 | trainer.sep_flag=sep_flag 450 | trainer.training(train_corpus=train_corpus, 451 | dev_corpus=dev_corpus, 452 | infer_ckpt_path=infer_ckpt_path, train_ckpt_path=train_ckpt_path, 453 | learning_rate=1e-4, init_step_num=1, 454 | batch_size=128, mini_batch=16, 455 | eval_per_n_steps=100, 456 | total_steps=3000, 457 | early_stop_steps=200, 458 | max_to_save=2, 459 | append_eos=True, 460 | eos_id=gpt2.text_enc.encode('\n')[0],ori_gpt_model_path=ori_gpt_model) 461 | 462 | 463 | def concat_finetuning(domain='fr',max_len_limit=220,only_test=False): 464 | methods = ['ori', 'rule'] 465 | model_path='./models_cat_'+domain+'/'+'_'.join(methods) 466 | init_model_path = './models/formality_infer' 467 | if not os.path.exists('./models_cat_'+domain): 468 | os.mkdir('./models_cat_'+domain) 469 | if not os.path.exists(model_path): 470 | os.mkdir(model_path) 471 | os.mkdir(model_path+'/formality_train') 472 | shutil.copytree(init_model_path, model_path+'/formality_infer') 473 | data_path = '../training_data/dif_models_'+domain+'/' 474 | cat_files([data_path + 'informal.train.'+m for m in methods]+ [ data_path + 'formal.train.rule', ], 475 | data_path + 'train.'+'_'.join(methods), 476 | tokenizer=text_enc, max_len=max_len_limit) 477 | cat_files([data_path + 'informal.val.' + m for m in methods] + [data_path + 'formal.val.rule', ], 478 | data_path + 'val.' + '_'.join(methods), 479 | tokenizer=text_enc, max_len=max_len_limit) 480 | lp = cat_files([data_path + 'informal.test.' + m for m in methods], 481 | data_path + 'eval.' + '_'.join(methods), 482 | tokenizer=text_enc, max_len=max_len_limit) 483 | if lp: 484 | print('_'.join(methods)+' data droped') 485 | if not only_test: 486 | train(config_path=config_path,input_num=len(methods),sep_flag='\t', ori_gpt_model=init_model_path, 487 | train_corpus=data_path + 'train.'+'_'.join(methods), 488 | dev_corpus=data_path + 'val.'+'_'.join(methods), 489 | infer_ckpt_path=model_path+'/formality_infer', 490 | train_ckpt_path=model_path+'/formality_train') 491 | test(config_path=config_path,input_num=len(methods), 492 | model_dir=model_path+'/formality_infer', 493 | input_path=data_path + 'eval.'+'_'.join(methods), 494 | output_path='../evaluate/gyafc_model_outputs/' + domain + '_out/formal.gpt.cat_no_share.'+'_'.join(methods)) 495 | 496 | 497 | def domain_combined(test_domain,only_test=False): 498 | methods = ['ori', 'rule'] 499 | model_path = './models_domain_combined/' + '_'.join(methods) 500 | init_model_path = './models/formality_infer' 501 | if not os.path.exists('./models_domain_combined'): 502 | os.mkdir('./models_domain_combined') 503 | if not os.path.exists(model_path): 504 | os.mkdir(model_path) 505 | os.mkdir(model_path + '/formality_train') 506 | shutil.copytree(init_model_path, model_path + '/formality_infer') 507 | data_path = '../training_data/domain_combined/' 508 | if not only_test: 509 | train(config_path=config_path, input_num=len(methods), sep_flag='\t', ori_gpt_model=init_model_path, 510 | train_corpus=data_path + 'train.' + '_'.join(methods), 511 | dev_corpus=data_path + 'val.' + '_'.join(methods), 512 | infer_ckpt_path=model_path + '/formality_infer', 513 | train_ckpt_path=model_path + '/formality_train') 514 | test(config_path=config_path, input_num=len(methods), 515 | model_dir=model_path + '/formality_infer', 516 | input_path='../training_data/dif_models_'+test_domain+'/eval.' + '_'.join(methods), 517 | output_path='../evaluate/gyafc_model_outputs/' + test_domain + '_out/formal.gpt.cat.domain_cmb.' + '_'.join(methods)) -------------------------------------------------------------------------------- /gpt/src/encoder.py: -------------------------------------------------------------------------------- 1 | """Byte pair encoding utilities""" 2 | 3 | import os 4 | import json 5 | import regex as re 6 | from functools import lru_cache 7 | 8 | @lru_cache() 9 | def bytes_to_unicode(): 10 | """ 11 | Returns list of utf-8 byte and a corresponding list of unicode strings. 12 | The reversible bpe codes work on unicode strings. 13 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 14 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 15 | This is a signficant percentage of your normal, say, 32K bpe vocab. 16 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 17 | And avoids mapping to whitespace/control characters the bpe code barfs on. 18 | """ 19 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 20 | cs = bs[:] 21 | n = 0 22 | for b in range(2**8): 23 | if b not in bs: 24 | bs.append(b) 25 | cs.append(2**8+n) 26 | n += 1 27 | cs = [chr(n) for n in cs] 28 | return dict(zip(bs, cs)) 29 | 30 | def get_pairs(word): 31 | """Return set of symbol pairs in a word. 32 | 33 | Word is represented as tuple of symbols (symbols being variable-length strings). 34 | """ 35 | pairs = set() 36 | prev_char = word[0] 37 | for char in word[1:]: 38 | pairs.add((prev_char, char)) 39 | prev_char = char 40 | return pairs 41 | 42 | class Encoder: 43 | def __init__(self, encoder, bpe_merges, errors='replace'): 44 | self.encoder = encoder 45 | self.decoder = {v:k for k,v in self.encoder.items()} 46 | self.errors = errors # how to handle errors in decoding 47 | self.byte_encoder = bytes_to_unicode() 48 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 49 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 50 | self.cache = {} 51 | 52 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 53 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 54 | 55 | def bpe(self, token): 56 | if token in self.cache: 57 | return self.cache[token] 58 | word = tuple(token) 59 | pairs = get_pairs(word) 60 | 61 | if not pairs: 62 | return token 63 | 64 | while True: 65 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 66 | if bigram not in self.bpe_ranks: 67 | break 68 | first, second = bigram 69 | new_word = [] 70 | i = 0 71 | while i < len(word): 72 | try: 73 | j = word.index(first, i) 74 | new_word.extend(word[i:j]) 75 | i = j 76 | except: 77 | new_word.extend(word[i:]) 78 | break 79 | 80 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 81 | new_word.append(first+second) 82 | i += 2 83 | else: 84 | new_word.append(word[i]) 85 | i += 1 86 | new_word = tuple(new_word) 87 | word = new_word 88 | if len(word) == 1: 89 | break 90 | else: 91 | pairs = get_pairs(word) 92 | word = ' '.join(word) 93 | self.cache[token] = word 94 | return word 95 | 96 | def encode(self, text): 97 | bpe_tokens = [] 98 | for token in re.findall(self.pat, text): 99 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 100 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 101 | return bpe_tokens 102 | 103 | def decode(self, tokens): 104 | sub_words=[self.decoder[token] for token in tokens] 105 | text = ''.join(sub_words) 106 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 107 | return text 108 | 109 | def get_encoder(model_name): 110 | with open(os.path.join(model_name, 'encoder.json'), 'r') as f: 111 | encoder = json.load(f) 112 | with open(os.path.join(model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f: 113 | bpe_data = f.read() 114 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 115 | return Encoder( 116 | encoder=encoder, 117 | bpe_merges=bpe_merges, 118 | ) 119 | -------------------------------------------------------------------------------- /gpt/src/generate_unconditional_samples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import fire 4 | import json 5 | import os 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from src import model, sample, encoder 10 | 11 | def sample_model( 12 | model_name='117M', 13 | seed=None, 14 | nsamples=0, 15 | batch_size=1, 16 | length=None, 17 | temperature=1, 18 | top_k=0, 19 | ): 20 | """ 21 | Run the sample_model 22 | :model_name=117M : String, which model to use 23 | :seed=None : Integer seed for random number generators, fix seed to 24 | reproduce results 25 | :nsamples=0 : Number of samples to return, if 0, continues to 26 | generate samples indefinately. 27 | :batch_size=1 : Number of batches (only affects speed/memory). 28 | :length=None : Number of tokens in generated text, if None (default), is 29 | determined by model hyperparameters 30 | :temperature=1 : Float value controlling randomness in boltzmann 31 | distribution. Lower temperature results in less random completions. As the 32 | temperature approaches zero, the model will become deterministic and 33 | repetitive. Higher temperature results in more random completions. 34 | :top_k=0 : Integer value controlling diversity. 1 means only 1 word is 35 | considered for each step (token), resulting in deterministic completions, 36 | while 40 means 40 words are considered at each step. 0 (default) is a 37 | special setting meaning no restrictions. 40 generally is a good value. 38 | """ 39 | enc = encoder.get_encoder(model_name) 40 | hparams = model.default_hparams() 41 | with open(os.path.join('models', model_name, 'hparams.json')) as f: 42 | hparams.override_from_dict(json.load(f)) 43 | 44 | if length is None: 45 | length = hparams.n_ctx 46 | elif length > hparams.n_ctx: 47 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) 48 | 49 | with tf.Session(graph=tf.Graph()) as sess: 50 | np.random.seed(seed) 51 | tf.set_random_seed(seed) 52 | 53 | output = sample.sample_sequence( 54 | hparams=hparams, length=length, 55 | start_token=enc.encoder['<|endoftext|>'], 56 | batch_size=batch_size, 57 | temperature=temperature, top_k=top_k 58 | )[:, 1:] 59 | 60 | saver = tf.train.Saver() 61 | ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) 62 | saver.restore(sess, ckpt) 63 | 64 | generated = 0 65 | while nsamples == 0 or generated < nsamples: 66 | out = sess.run(output) 67 | for i in range(batch_size): 68 | generated += batch_size 69 | text = enc.decode(out[i]) 70 | print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) 71 | print(text) 72 | 73 | if __name__ == '__main__': 74 | fire.Fire(sample_model) 75 | 76 | -------------------------------------------------------------------------------- /gpt/src/gpt2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import json 3 | import os 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.contrib.seq2seq import sequence_loss 7 | from gpt.src import model, encoder, sample 8 | from gpt.src import beamsearch 9 | 10 | 11 | class GPT2: 12 | def __init__(self,config_path='./models/117M'): 13 | self.config_path = config_path 14 | self.text_enc = encoder.get_encoder(self.config_path) 15 | self.hparams = model.default_hparams() 16 | with open(os.path.join(self.config_path, 'hparams.json')) as f: 17 | self.hparams.override_from_dict(json.load(f)) 18 | self.eos_id = self.text_enc.encode('\n')[0] 19 | 20 | 21 | 22 | def ensemble_decoding_beam_search_graph(self,context_list,beam_size,batch_size,max_decode_length,eos_id,model_num,decode_alpha=0.6): 23 | def step(hparams, tokens, past=None, scope=None): 24 | if scope is not None: 25 | with tf.variable_scope(scope): 26 | lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) 27 | logits = lm_output['logits'] 28 | presents = lm_output['present'] 29 | presents.set_shape(model.past_shape(hparams=hparams, batch_size=None)) 30 | return { 31 | 'logits': logits, 32 | 'presents': presents, 33 | } 34 | context_output_list=[] 35 | context_state_list=[] 36 | all_scopes=[] 37 | for i in range(0,model_num): 38 | with tf.variable_scope('model_'+str(i)) as sc: 39 | with tf.name_scope('sample_sequence'): 40 | context_output_list.append(step(self.hparams, context_list[i][:, :-1],scope=sc)) 41 | context_state_list.append(context_output_list[-1]['presents']) 42 | all_scopes.append('model_'+str(i)) 43 | with tf.name_scope('beam_search'): 44 | init_seq = tf.expand_dims(context_list[0][:, -1], axis=1) 45 | seqs, scores = beamsearch.create_inference_graph(init_seqs=init_seq, state=context_state_list, 46 | step_fn=step, hparams=self.hparams, 47 | decode_length=max_decode_length, 48 | batch_size=batch_size, beam_size=beam_size, 49 | decode_alpha=decode_alpha, eos_id=eos_id, scopes_for_ensemble=all_scopes, 50 | ensemble=True, concat_state_dim=None) 51 | return seqs, scores 52 | 53 | 54 | 55 | def build_beam_search_graph(self,beam_size,batch_size,max_decode_length,decode_alpha=0.6): 56 | self.inputs = tf.placeholder(tf.int32, [1, None]) 57 | def step(hparams, tokens, past=None): 58 | lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) 59 | logits = lm_output['logits'] 60 | presents = lm_output['present'] 61 | presents.set_shape(model.past_shape(hparams=hparams, batch_size=None)) 62 | return { 63 | 'logits': logits, 64 | 'presents': presents, 65 | } 66 | with tf.name_scope('sample_sequence'): 67 | context_output = step(self.hparams, self.inputs[:, :-1]) 68 | context_state=context_output['presents'] 69 | with tf.name_scope('beam_search'): 70 | init_seq = tf.expand_dims(self.inputs[:, -1], axis=1) 71 | seqs, scores=beamsearch.create_inference_graph(init_seqs=init_seq,state=context_state, 72 | step_fn=step,hparams=self.hparams, 73 | decode_length=max_decode_length, 74 | batch_size=batch_size,beam_size=beam_size, 75 | decode_alpha=decode_alpha,eos_id=self.eos_id, 76 | ensemble=False, concat_state_dim=-2) 77 | 78 | return seqs,scores 79 | 80 | def build_inferring_graph(self, context, seed=None, nsamples=1, 81 | batch_size=1, length=None, temperature=1, top_k=40): 82 | self.generate_batch_size = batch_size 83 | self.generate_n_samples = nsamples 84 | if batch_size is None: 85 | batch_size = 1 86 | 87 | if length is None: 88 | length = self.hparams.n_ctx // 2 89 | elif length > self.hparams.n_ctx: 90 | raise ValueError("Can't get samples longer than window size: %s" % self.hparams.n_ctx) 91 | 92 | np.random.seed(seed) 93 | tf.set_random_seed(seed) 94 | output = sample.sample_sequence( 95 | hparams=self.hparams, length=length, 96 | context=context, 97 | batch_size=batch_size, 98 | temperature=temperature, top_k=top_k 99 | ) 100 | return output 101 | 102 | 103 | def build_training_graph(self,input,input_len,target,target_mask=None): 104 | batch_max_seq_len = tf.shape(input)[1] 105 | def step(hparams, tokens, past=None): 106 | lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) 107 | logits = lm_output['logits'] 108 | presents = lm_output['present'] 109 | presents.set_shape(model.past_shape(hparams=hparams, batch_size=None)) 110 | return { 111 | 'logits': logits, 112 | 'presents': presents, 113 | } 114 | with tf.name_scope('sample_sequence'): 115 | all_logits = step(hparams=self.hparams, tokens=input)['logits'] 116 | with tf.name_scope('loss'): 117 | if target_mask is None: 118 | target_mask = tf.sequence_mask(input_len, maxlen=batch_max_seq_len, dtype=tf.float32) 119 | cost = sequence_loss(logits=all_logits, targets=target, 120 | weights=target_mask) 121 | return cost 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /gpt/src/hierarchical_attention.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.contrib.seq2seq import sequence_loss 7 | from gpt.src import model 8 | from gpt.src import beamsearch 9 | import tensorflow.contrib.slim as slim 10 | from datetime import timedelta 11 | import time 12 | from tensorflow.python.keras.preprocessing.sequence import pad_sequences 13 | from gpt.src.single_gpu_serving import beam_search_generator 14 | from utils.file_api import read_file_lines,write_file_lines 15 | from gpt.src.model import positions_for,Encoder,Decoder 16 | from utils.cat_files import cat_files 17 | from gpt.config import * 18 | 19 | 20 | class NMT_GPT(): 21 | def __init__(self,input_num,config_path): 22 | self.hparams = model.default_hparams() 23 | self.config_path = config_path 24 | with open(os.path.join(self.config_path, 'hparams.json')) as f: 25 | self.hparams.override_from_dict(json.load(f)) 26 | self.input_num=input_num 27 | self.text_enc = encoder.get_encoder(self.config_path) 28 | self.sos_id=self.text_enc.encode('\t')[0] 29 | self.eos_id=self.text_enc.encode('\n')[0] 30 | 31 | def def_placeholder_and_components(self): 32 | # embeddings: 33 | with tf.variable_scope('encoder'): 34 | with tf.variable_scope('model'): 35 | self.wpe = tf.get_variable('wpe', [self.hparams.n_ctx, self.hparams.n_embd], 36 | initializer=tf.random_normal_initializer(stddev=0.01)) 37 | self.wte = tf.get_variable('wte', [self.hparams.n_vocab, self.hparams.n_embd], 38 | initializer=tf.random_normal_initializer(stddev=0.02)) 39 | self.encoder = Encoder('encoder', self.hparams) 40 | self.decoder = Decoder('encoder', self.hparams) 41 | self.inputs = [tf.placeholder(tf.int32, [None, None], name='input_%d' % i) for i in range(0, self.input_num)] 42 | self.input_lens = [tf.placeholder(tf.int32, [None, ], name='input_len_%d' % i) for i in 43 | range(0, self.input_num)] 44 | self.target_in = tf.placeholder(tf.int32, [None, None], name='target_in') 45 | self.target_out = tf.placeholder(tf.int32, [None, None], name='target_out') 46 | self.target_len = tf.placeholder(tf.int32, [None], name='target_len') 47 | 48 | 49 | 50 | def build_training_model(self): 51 | self.def_placeholder_and_components() 52 | emb_out=[] 53 | enc_h_out=[] 54 | past_for_decoder=[] 55 | for i in range(0,self.input_num): 56 | past_length=0 57 | h = tf.gather(self.wte, self.inputs[i]) + tf.gather(self.wpe, positions_for(self.inputs[i], past_length)) 58 | emb_out.append(h) 59 | presents, h_enc=self.encoder.encode(h,self.input_lens[i]) 60 | enc_h_out.append(h_enc) 61 | past_for_decoder.append(presents) 62 | all_logits=self.decoder.decode_all(tokens=self.target_in,past_list=past_for_decoder,enc_h_list=enc_h_out)['logits'] 63 | with tf.name_scope('loss'): 64 | batch_max_seq_len = tf.shape(self.target_in)[1] 65 | target_mask = tf.sequence_mask(self.target_len, maxlen=batch_max_seq_len, dtype=tf.float32) 66 | cost = sequence_loss(logits=all_logits, targets=self.target_out, 67 | weights=target_mask) 68 | return cost 69 | 70 | 71 | def build_beam_search_graph(self, beam_size, batch_size, max_decode_length, decode_alpha=0.6): 72 | self.def_placeholder_and_components() 73 | emb_out = [] 74 | enc_h_out = [] 75 | past_for_decoder = [] 76 | for i in range(0, self.input_num): 77 | past_length = 0 78 | h = tf.gather(self.wte, self.inputs[i]) + tf.gather(self.wpe, positions_for(self.inputs[i], past_length)) 79 | emb_out.append(h) 80 | presents, h_enc = self.encoder.encode(h, self.input_lens[i]) 81 | enc_h_out.append(h_enc) 82 | past_for_decoder.append(presents) 83 | past_length = 0 if enc_h_out[0] is None else tf.shape(enc_h_out[0])[-2] 84 | self.decoder.sef_var_for_beam_search(past_length,enc_h_out,beam_size=beam_size) 85 | with tf.name_scope('beam_search'): 86 | init_seq = tf.fill(dims=(batch_size, 1), value=self.sos_id) 87 | seqs, scores = beamsearch.create_inference_graph(init_seqs=init_seq, state=past_for_decoder, 88 | step_fn=self.decoder.decode_one_step, hparams=self.hparams, 89 | decode_length=max_decode_length, 90 | batch_size=batch_size, beam_size=beam_size, 91 | decode_alpha=decode_alpha, eos_id=self.eos_id, 92 | ensemble=False, concat_state_dim=None) 93 | return seqs, scores 94 | 95 | 96 | class NMT_GPT_Trainer(): 97 | def __init__(self,model_fn:NMT_GPT): 98 | self.model_fn=model_fn 99 | self.learning_rate=1e-4 100 | self.sep_flag='\t' 101 | self.graph=tf.Graph() 102 | self.vars_for_infer = [] 103 | self.vars_for_train = [] 104 | self.losses=[] 105 | self.only_predict_target=True 106 | tf.logging.set_verbosity(tf.logging.INFO) 107 | self.is_hierarchical=True 108 | self.hier_enc_end_token=self.model_fn.text_enc.encode('\t') 109 | 110 | 111 | def average_gradients(self,tower_grads): 112 | average_grads = [] 113 | for grad_and_vars in zip(*tower_grads): 114 | grads = [tf.expand_dims(g, 0) for g, _ in grad_and_vars] 115 | grads = tf.concat(grads, 0) 116 | grad = tf.reduce_mean(grads, 0) 117 | grad_and_var = (grad, grad_and_vars[0][1]) 118 | # [(grad0, var0),(grad1, var1),...] 119 | average_grads.append(grad_and_var) 120 | return average_grads 121 | 122 | 123 | def build_graph(self): 124 | with self.graph.as_default(): 125 | self.opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate) 126 | self.tower_grads = [] 127 | loss = self.model_fn.build_training_model() 128 | self.losses.append(loss) 129 | grads = self.opt.compute_gradients(loss) 130 | tvs = tf.trainable_variables() 131 | self.accum_vars = [tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) 132 | for tv in 133 | tvs] 134 | self.zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in self.accum_vars] 135 | self.accum_grad_ops = [self.accum_vars[j].assign_add(gv[0]) for j, gv in 136 | enumerate(grads) if gv[0] is not None] 137 | self.tower_grads.append([(self.accum_vars[j], gv[1]) for j, gv in enumerate(grads) ]) 138 | grads = self.average_gradients(self.tower_grads) 139 | with tf.device('/gpu:0'): 140 | self.accum_steps=tf.placeholder(tf.float32, [], name='accum_stpes') 141 | self.train_step = self.opt.apply_gradients([(g/self.accum_steps, v) for g,v in grads]) 142 | self.avg_loss=tf.stack(self.losses,axis=0) 143 | self.avg_loss=tf.reduce_mean(self.avg_loss) 144 | 145 | 146 | def create_session_init_and_print_all_trainable_vars(self, max_to_save, ori_gpt_model_path=None): 147 | # Print parameters 148 | with self.graph.as_default(): 149 | all_weights = {v.name: v for v in tf.trainable_variables()} 150 | total_size = 0 151 | for v_name in sorted(list(all_weights)): 152 | v = all_weights[v_name] 153 | tf.logging.info("%s\tshape %s", v.name[:-2].ljust(80), 154 | str(v.shape).ljust(20)) 155 | v_size = np.prod(np.array(v.shape.as_list())).tolist() 156 | total_size += v_size 157 | tf.logging.info("Total trainable variables size: %d", total_size) 158 | all_var_list = slim.get_variables_to_restore() 159 | for v in all_var_list: 160 | if 'Adam' in v.name: 161 | self.vars_for_train.append(v) 162 | elif v.name.startswith('beta'): 163 | self.vars_for_train.append(v) 164 | elif v.name.startswith('parallel'): 165 | pass 166 | elif v.name.startswith('Variable'): 167 | pass 168 | else: 169 | self.vars_for_infer.append(v) 170 | if len(self.vars_for_infer) > 0: 171 | self.saver_infer = tf.train.Saver(self.vars_for_infer, max_to_keep=max_to_save) 172 | if len(self.vars_for_train) > 0: 173 | self.saver_train = tf.train.Saver(self.vars_for_train, max_to_keep=max_to_save) 174 | config = tf.ConfigProto() 175 | config.gpu_options.allow_growth = True 176 | sess = tf.Session(graph=self.graph, config=config) 177 | init_op = tf.global_variables_initializer() 178 | sess.run(init_op) 179 | restore_ops=[] 180 | if ori_gpt_model_path is not None: 181 | ckpt = tf.train.latest_checkpoint(ori_gpt_model_path) 182 | tf.logging.info("Loading %s" % ckpt) 183 | var_list = tf.train.list_variables(ckpt) 184 | values = {} 185 | reader = tf.train.load_checkpoint(ckpt) 186 | for (name, shape) in var_list: 187 | if not name.startswith('model/'): # ignore global_step 188 | continue 189 | tensor = reader.get_tensor(name) 190 | values[name] = tensor 191 | for v in self.vars_for_infer: 192 | #print(v.name) 193 | tmp = '/'.join(v.name.split('/')[1:]) 194 | v_name = tmp.split(':')[0] 195 | if v_name!='model/sen_attn_w': 196 | op = tf.assign(v, values[v_name]) 197 | restore_ops.append(op) 198 | sess.run(restore_ops) 199 | return sess 200 | 201 | 202 | def padding_batch(self, input_list): 203 | in_len = [len(i) for i in input_list] 204 | new_in = pad_sequences(input_list, padding='post') 205 | return new_in, in_len 206 | 207 | 208 | def train_or_eval_batch_with_raw_text(self, sess, input_text, mini_batch, is_train=True, 209 | run_options=None): 210 | batch_size = len(input_text) 211 | batch_input = {} 212 | batch_target_in = [] 213 | batch_target_out =[] 214 | batch_target_len =[] 215 | batch_input_len = {} 216 | for text in input_text: 217 | strs=text.split(self.sep_flag) 218 | inputs=strs[:-1] 219 | target=strs[-1] 220 | if self.is_hierarchical: 221 | inputs_tokens = [self.model_fn.text_enc.encode(item)+self.hier_enc_end_token for item in inputs] 222 | else: 223 | inputs_tokens = [self.model_fn.text_enc.encode(item) for item in inputs] 224 | target_tokens=self.model_fn.text_enc.encode(target) 225 | for i in range(0,len(inputs_tokens)): 226 | if i not in batch_input: 227 | batch_input[i]=[] 228 | batch_input[i].append(inputs_tokens[i]) 229 | if i not in batch_input_len: 230 | batch_input_len[i]=[len(inputs_tokens[i])] 231 | else: 232 | batch_input_len[i].append(len(inputs_tokens[i])) 233 | tar_in=[self.model_fn.sos_id]+target_tokens 234 | tar_out=target_tokens+[self.model_fn.eos_id] 235 | batch_target_len.append(len(tar_out)) 236 | batch_target_in.append(tar_in) 237 | batch_target_out.append(tar_out) 238 | # gradient accum and update 239 | #assert batch_size%mini_batch==0 240 | with self.graph.as_default(): 241 | data_num = batch_size 242 | losses = [] 243 | low = 0 244 | if is_train: 245 | sess.run(self.zero_ops) 246 | while low < data_num: 247 | n_samples = min([mini_batch, data_num - low]) 248 | mini_batch_input = [batch_input[i][low:low + n_samples] for i in range(0,len(batch_input))] 249 | mini_batch_input_len = [batch_input_len[i][low:low + n_samples] for i in range(0, len(batch_input))] 250 | mini_batch_target_in = batch_target_in[low:low + n_samples] 251 | mini_batch_target_out = batch_target_out[low:low + n_samples] 252 | mini_batch_target_len = batch_target_len[low:low + n_samples] 253 | mini_batch_target_in_padded, _ = self.padding_batch(mini_batch_target_in) 254 | mini_batch_target_out_padded, _ = self.padding_batch(mini_batch_target_out) 255 | feed_dict={} 256 | for i in range(0,self.model_fn.input_num): 257 | p,_ = self.padding_batch(mini_batch_input[i]) 258 | feed_dict[self.model_fn.inputs[i]]=p 259 | feed_dict[self.model_fn.input_lens[i]]=mini_batch_input_len[i] 260 | feed_dict[self.model_fn.target_in] = mini_batch_target_in_padded 261 | feed_dict[self.model_fn.target_out] = mini_batch_target_out_padded 262 | feed_dict[self.model_fn.target_len] = mini_batch_target_len 263 | if is_train: 264 | result = sess.run([self.accum_grad_ops, self.avg_loss], feed_dict=feed_dict, options=run_options) 265 | loss=result[-1] 266 | else: 267 | loss = sess.run(self.avg_loss, feed_dict=feed_dict) 268 | low += n_samples 269 | losses.append(loss*n_samples) 270 | if is_train: 271 | sess.run(self.train_step,feed_dict={self.accum_steps:batch_size/mini_batch}) 272 | return sum(losses) / batch_size 273 | 274 | 275 | def training(self, eos_id=None, train_corpus='./story/story.train', dev_corpus='./story/story.dev', 276 | init_step_num=1, learning_rate=1e-4, batch_size=64, mini_batch=16, total_steps=100000, 277 | train_ckpt_path='./models/117M/model_train_1/', infer_ckpt_path='./models/117M/', 278 | eval_per_n_steps=1, max_to_save=3, early_stop_steps=6000,append_eos=True,ori_gpt_model_path=None): 279 | self.learning_rate=learning_rate 280 | sess=self.create_session_init_and_print_all_trainable_vars(max_to_save,ori_gpt_model_path=ori_gpt_model_path) 281 | if ori_gpt_model_path is None: 282 | self.restore_model_and_init(sess, infer_ckpt_path, train_ckpt_path) 283 | train = load_corpus(train_corpus) 284 | # train=[' '.join(['you' for j in range(0,512)]) for i in range(0,512)] 285 | dev = load_corpus(dev_corpus) 286 | step = init_step_num 287 | low = 0 288 | epoch_num = 1 289 | train_data_num = len(train) 290 | eval_data_num = len(dev) 291 | last_improvement_step = init_step_num 292 | best_loss = 100000 293 | saved_steps = [] 294 | tf.logging.info('start training...') 295 | self.graph.finalize() 296 | start_time = time.time() 297 | while step < total_steps: 298 | run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True) 299 | n_samples = min([batch_size, train_data_num - low]) 300 | train_loss = self.train_or_eval_batch_with_raw_text(sess, train[low:low + n_samples], 301 | run_options=run_options, 302 | mini_batch=mini_batch, 303 | ) 304 | ###eval: 305 | if step % eval_per_n_steps == 0: 306 | eval_low = 0 307 | eval_losses = [] 308 | while eval_low < eval_data_num: 309 | eval_n_samples = min([batch_size, eval_data_num - eval_low]) 310 | eval_losses.append(self.train_or_eval_batch_with_raw_text( 311 | sess, dev[eval_low:eval_low + eval_n_samples], is_train=False, mini_batch=mini_batch)) 312 | eval_low += eval_n_samples 313 | eval_avg_loss = sum(eval_losses) / len(eval_losses) 314 | time_dif = get_time_dif(start_time) 315 | if eval_avg_loss < best_loss: 316 | best_loss = eval_avg_loss 317 | last_improvement_step = step 318 | tf.logging.info('save step %d', last_improvement_step) 319 | self.save_model(sess, infer_ckpt_path, train_ckpt_path, step=step) 320 | saved_steps.append(last_improvement_step) 321 | tf.logging.info("%s: step %d: train loss %f; eval loss %f *", time_dif, step, train_loss, 322 | eval_avg_loss) 323 | if len(saved_steps) > max_to_save: 324 | saved_steps = saved_steps[1:] 325 | else: 326 | tf.logging.info("%s: step %d: train loss %f; eval loss %f", time_dif, step, train_loss, 327 | eval_avg_loss) 328 | if step - last_improvement_step > early_stop_steps: 329 | tf.logging.info("early stopping...") 330 | break 331 | ### 332 | step += 1 333 | low += n_samples 334 | if low == train_data_num: 335 | low = 0 336 | epoch_num += 1 337 | sess.close() 338 | print('all work has finished') 339 | 340 | 341 | def restore_model_and_init(self, sess, ckpt_for_infer, ckpt_for_train): 342 | with self.graph.as_default(): 343 | if ckpt_for_infer is not None: 344 | ckpt = tf.train.latest_checkpoint(ckpt_for_infer) 345 | if ckpt is not None: 346 | self.saver_infer.restore(sess, ckpt) 347 | tf.logging.info('restored inferring params from %s',ckpt) 348 | if ckpt_for_train is not None: 349 | ckpt = tf.train.latest_checkpoint(ckpt_for_train) 350 | if ckpt is not None: 351 | self.saver_train.restore(sess, ckpt) 352 | tf.logging.info('restored training params from %s', ckpt) 353 | 354 | 355 | def save_model(self, sess, infer_ckpt_path, train_ckpt_path, step): 356 | with self.graph.as_default(): 357 | if infer_ckpt_path is not None and len(self.vars_for_infer) > 0: 358 | self.saver_infer.save(sess, os.path.join(infer_ckpt_path,'model'), global_step=step) 359 | if train_ckpt_path is not None and len(self.vars_for_train) > 0: 360 | self.saver_train.save(sess, os.path.join(train_ckpt_path,'model'), global_step=step) 361 | 362 | 363 | def padding_for_target_mask(self,mask_list,input_len): 364 | batch_size= len(mask_list) 365 | assert batch_size==len(input_len) 366 | max_len=max(input_len) 367 | for i in range(0,batch_size): 368 | l=input_len[i] 369 | mask_list[i]=mask_list[i]+[0.0]*(max_len-l) 370 | 371 | 372 | 373 | def load_corpus(path): 374 | lines = [] 375 | with open(path, 'r', encoding='utf-8') as f: 376 | for line in f: 377 | lines.append(line.strip()) 378 | return lines 379 | 380 | def get_time_dif(start_time): 381 | end_time = time.time() 382 | time_dif = end_time - start_time 383 | return timedelta(seconds=int(round(time_dif))) 384 | 385 | 386 | def test(config_path,input_num,model_dir='./models/ori_rule/formality_infer/',input_path='../training_data/dif_models/eval.ori_rule', 387 | output_path='../evaluate/gyafc_model_outputs/fr_out/formal.gpt.cat_ori_rule.old',beam_size=4,max_dec_len=60,dec_alpha=0.6): 388 | gpt2 = NMT_GPT(config_path=config_path,input_num=input_num) 389 | generator = beam_search_generator(gpt2, beam_size=beam_size, 390 | model_directory=model_dir, max_dec_len=max_dec_len, 391 | dec_alpha=dec_alpha) 392 | sess=generator.build_graph_and_restore(eos_id=gpt2.text_enc.encode('\n')[0]) 393 | lines=read_file_lines(input_path) 394 | result=[] 395 | for line in lines: 396 | result.append(generator.generate(sess,line,multi_pls=True)) 397 | print(line+' ||| '+result[-1].strip()) 398 | sess.close() 399 | write_file_lines(output_path, result) 400 | 401 | 402 | def train(config_path,input_num,ori_gpt_model=None,sep_flag='\t', 403 | train_corpus='../training_data/preprocessed/Family_Relationships/train.ori.txt', 404 | dev_corpus='../training_data/preprocessed/Family_Relationships/val.ori.txt', 405 | infer_ckpt_path='./models/ori_data_fr/formality_infer/', 406 | train_ckpt_path='./models/ori_data_fr/formality_train/'): 407 | gpt2 = NMT_GPT(input_num,config_path) 408 | trainer = NMT_GPT_Trainer(gpt2) 409 | trainer.build_graph() 410 | trainer.sep_flag=sep_flag 411 | trainer.training(train_corpus=train_corpus, 412 | dev_corpus=dev_corpus, 413 | infer_ckpt_path=infer_ckpt_path, train_ckpt_path=train_ckpt_path, 414 | learning_rate=1e-4, init_step_num=1, 415 | batch_size=128, mini_batch=16, 416 | eval_per_n_steps=100, 417 | total_steps=3000, 418 | early_stop_steps=200, 419 | max_to_save=2, 420 | append_eos=True, 421 | eos_id=gpt2.text_enc.encode('\n')[0],ori_gpt_model_path=ori_gpt_model) 422 | 423 | 424 | def HA(domain='fr',max_len_limit=220,only_test=False): 425 | methods = ['ori', 'rule'] 426 | model_path='./models_hie_'+domain+'/'+'_'.join(methods) 427 | init_model_path = './models/formality_infer' 428 | if not os.path.exists('./models_hie_'+domain): 429 | os.mkdir('./models_hie_'+domain) 430 | if not os.path.exists(model_path): 431 | os.mkdir(model_path) 432 | os.mkdir(model_path+'/formality_train') 433 | shutil.copytree(init_model_path, model_path+'/formality_infer') 434 | data_path = '../training_data/dif_models_'+domain+'/' 435 | cat_files([data_path + 'informal.train.'+m for m in methods]+ [ data_path + 'formal.train.rule', ], 436 | data_path + 'train.'+'_'.join(methods), 437 | tokenizer=text_enc, max_len=max_len_limit) 438 | cat_files([data_path + 'informal.val.' + m for m in methods] + [data_path + 'formal.val.rule', ], 439 | data_path + 'val.' + '_'.join(methods), 440 | tokenizer=text_enc, max_len=max_len_limit) 441 | lp = cat_files([data_path + 'informal.test.' + m for m in methods], 442 | data_path + 'eval.' + '_'.join(methods), 443 | tokenizer=text_enc, max_len=max_len_limit) 444 | if lp: 445 | print('_'.join(methods)+' data droped') 446 | if not only_test: 447 | train(config_path=config_path,input_num=len(methods),sep_flag='\t', ori_gpt_model=init_model_path, 448 | train_corpus=data_path + 'train.'+'_'.join(methods), 449 | dev_corpus=data_path + 'val.'+'_'.join(methods), 450 | infer_ckpt_path=model_path+'/formality_infer', 451 | train_ckpt_path=model_path+'/formality_train') 452 | test(config_path=config_path,input_num=len(methods), 453 | model_dir=model_path+'/formality_infer', 454 | input_path=data_path + 'eval.'+'_'.join(methods), 455 | output_path='../evaluate/gyafc_model_outputs/' + domain + '_out/formal.gpt.hie'+'_'.join(methods)) 456 | -------------------------------------------------------------------------------- /gpt/src/interactive_conditional_samples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import fire 4 | import json 5 | import os 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from src import model, sample, encoder 10 | 11 | def interact_model( 12 | model_name='117M', 13 | seed=None, 14 | nsamples=1, 15 | batch_size=1, 16 | length=None, 17 | temperature=1, 18 | top_k=40, 19 | ): 20 | """ 21 | Interactively run the model 22 | :model_name=117M : String, which model to use 23 | :seed=None : Integer seed for random number generators, fix seed to reproduce 24 | results 25 | :nsamples=1 : Number of samples to return total 26 | :batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples. 27 | :length=None : Number of tokens in generated text, if None (default), is 28 | determined by model hyperparameters 29 | :temperature=1 : Float value controlling randomness in boltzmann 30 | distribution. Lower temperature results in less random completions. As the 31 | temperature approaches zero, the model will become deterministic and 32 | repetitive. Higher temperature results in more random completions. 33 | :top_k=0 : Integer value controlling diversity. 1 means only 1 word is 34 | considered for each step (token), resulting in deterministic completions, 35 | while 40 means 40 words are considered at each step. 0 (default) is a 36 | special setting meaning no restrictions. 40 generally is a good value. 37 | """ 38 | if batch_size is None: 39 | batch_size = 1 40 | assert nsamples % batch_size == 0 41 | 42 | enc = encoder.get_encoder(model_name) 43 | hparams = model.default_hparams() 44 | with open(os.path.join('models', model_name, 'hparams.json')) as f: 45 | hparams.override_from_dict(json.load(f)) 46 | 47 | if length is None: 48 | length = hparams.n_ctx // 2 49 | elif length > hparams.n_ctx: 50 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) 51 | 52 | config = tf.ConfigProto() 53 | config.gpu_options.allow_growth = True 54 | with tf.Session(graph=tf.Graph(),config=config) as sess: 55 | context = tf.placeholder(tf.int32, [batch_size, None]) 56 | np.random.seed(seed) 57 | tf.set_random_seed(seed) 58 | output = sample.sample_sequence( 59 | hparams=hparams, length=length, 60 | context=context, 61 | batch_size=batch_size, 62 | temperature=temperature, top_k=top_k 63 | ) 64 | 65 | saver = tf.train.Saver() 66 | ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) 67 | saver.restore(sess, ckpt) 68 | 69 | while True: 70 | raw_text = input("Model prompt >>> ") 71 | while not raw_text: 72 | print('Prompt should not be empty!') 73 | raw_text = input("Model prompt >>> ") 74 | context_tokens = enc.encode(raw_text) 75 | generated = 0 76 | for _ in range(nsamples // batch_size): 77 | out = sess.run(output, feed_dict={ 78 | context: [context_tokens for _ in range(batch_size)] 79 | })[:, len(context_tokens):] 80 | for i in range(batch_size): 81 | generated += 1 82 | text = enc.decode(out[i]) 83 | print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) 84 | print(text) 85 | print("=" * 80) 86 | 87 | if __name__ == '__main__': 88 | fire.Fire(interact_model) 89 | 90 | -------------------------------------------------------------------------------- /gpt/src/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.contrib.training import HParams 4 | from utils.common import gather_2d,tile_to_beam_size,merge_first_two_dims 5 | from tensorflow.python.util import nest 6 | 7 | def default_hparams(): 8 | return HParams( 9 | n_vocab=0, 10 | n_ctx=1024, 11 | n_embd=768, 12 | n_head=12, 13 | n_layer=12, 14 | ) 15 | 16 | def shape_list(x): 17 | """Deal with dynamic shape in tensorflow cleanly.""" 18 | static = x.shape.as_list() 19 | dynamic = tf.shape(x) 20 | return [dynamic[i] if s is None else s for i, s in enumerate(static)] 21 | 22 | def softmax(x, axis=-1): 23 | x = x - tf.reduce_max(x, axis=axis, keepdims=True) 24 | ex = tf.exp(x) 25 | return ex / tf.reduce_sum(ex, axis=axis, keepdims=True) 26 | 27 | def gelu(x): 28 | return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3)))) 29 | 30 | def norm(x, scope, *, axis=-1, epsilon=1e-5): 31 | """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" 32 | with tf.variable_scope(scope): 33 | n_state = x.shape[-1].value 34 | g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1)) 35 | b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0)) 36 | u = tf.reduce_mean(x, axis=axis, keepdims=True) 37 | s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True) 38 | x = (x - u) * tf.rsqrt(s + epsilon) 39 | x = x*g + b 40 | return x 41 | 42 | def split_states(x, n): 43 | """Reshape the last dimension of x into [n, x.shape[-1]/n].""" 44 | *start, m = shape_list(x) 45 | return tf.reshape(x, start + [n, m//n]) 46 | 47 | def merge_states(x): 48 | """Smash the last two dimensions of x into a single dimension.""" 49 | *start, a, b = shape_list(x) 50 | return tf.reshape(x, start + [a*b]) 51 | 52 | def conv1d(x, scope, nf, *, w_init_stdev=0.02): 53 | with tf.variable_scope(scope): 54 | *start, nx = shape_list(x) 55 | w = tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)) 56 | b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0)) 57 | c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf]) 58 | return c 59 | 60 | def attention_mask(nd, ns, *, dtype): 61 | """1's in the lower triangle, counting from the lower right corner. 62 | 63 | Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. 64 | """ 65 | i = tf.range(nd)[:,None] 66 | j = tf.range(ns) 67 | m = i >= j - ns + nd 68 | return tf.cast(m, dtype) 69 | 70 | 71 | def attn(x, scope, n_state, *, past, hparams): 72 | assert x.shape.ndims == 3 # Should be [batch, sequence, features] 73 | assert n_state % hparams.n_head == 0 74 | if past is not None: 75 | assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v] 76 | 77 | def split_heads(x): 78 | # From [batch, sequence, features] to [batch, heads, sequence, features] 79 | return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3]) 80 | 81 | def merge_heads(x): 82 | # Reverse of split_heads 83 | return merge_states(tf.transpose(x, [0, 2, 1, 3])) 84 | 85 | def mask_attn_weights(w): 86 | # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. 87 | _, _, nd, ns = shape_list(w) 88 | b = attention_mask(nd, ns, dtype=w.dtype) 89 | b = tf.reshape(b, [1, 1, nd, ns]) 90 | w = w*b - tf.cast(1e10, w.dtype)*(1-b) 91 | return w 92 | 93 | def multihead_attn(q, k, v): 94 | # q, k, v have shape [batch, heads, sequence, features] 95 | w = tf.matmul(q, k, transpose_b=True) 96 | w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype)) 97 | 98 | w = mask_attn_weights(w) 99 | w = softmax(w) 100 | a = tf.matmul(w, v) 101 | return a 102 | 103 | with tf.variable_scope(scope): 104 | c = conv1d(x, 'c_attn', n_state*3) 105 | q, k, v = map(split_heads, tf.split(c, 3, axis=2)) 106 | present = tf.stack([k, v], axis=1) 107 | if past is not None: 108 | pk, pv = tf.unstack(past, axis=1) 109 | k = tf.concat([pk, k], axis=-2) 110 | v = tf.concat([pv, v], axis=-2) 111 | a = multihead_attn(q, k, v) 112 | a = merge_heads(a) 113 | a = conv1d(a, 'c_proj', n_state) 114 | return a, present 115 | 116 | 117 | def mlp(x, scope, n_state, *, hparams): 118 | with tf.variable_scope(scope): 119 | nx = x.shape[-1].value 120 | h = gelu(conv1d(x, 'c_fc', n_state)) 121 | h2 = conv1d(h, 'c_proj', nx) 122 | return h2 123 | 124 | 125 | def block(x, scope, *, past, hparams): 126 | with tf.variable_scope(scope): 127 | nx = x.shape[-1].value 128 | a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams) 129 | x = x + a 130 | m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams) 131 | x = x + m 132 | return x, present 133 | 134 | def past_shape(*, hparams, batch_size=None, sequence=None): 135 | return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head] 136 | 137 | def expand_tile(value, size): 138 | """Add a new axis of given size.""" 139 | value = tf.convert_to_tensor(value, name='value') 140 | ndims = value.shape.ndims 141 | return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims) 142 | 143 | def positions_for(tokens, past_length): 144 | batch_size = tf.shape(tokens)[0] 145 | nsteps = tf.shape(tokens)[1] 146 | return expand_tile(past_length + tf.range(nsteps), batch_size) 147 | 148 | 149 | class Encoder(): 150 | def __init__(self,scope,hparam): 151 | if scope is None: 152 | self.scope='encoder' 153 | else: 154 | self.scope=scope 155 | self.hparams=hparam 156 | 157 | def encode(self, h, h_len, past=None, scope='encoder', reuse=tf.AUTO_REUSE): 158 | with tf.variable_scope(scope, reuse=reuse): 159 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE): 160 | # Transformer 161 | presents = [] 162 | pasts = tf.unstack(past, axis=1) if past is not None else [None] * self.hparams.n_layer 163 | assert len(pasts) == self.hparams.n_layer 164 | for layer, past_one in enumerate(pasts): 165 | h, present = block(h, 'h%d' % layer, past=past_one, hparams=self.hparams) 166 | presents.append(present) 167 | presents = tf.stack(presents, axis=1) 168 | h = norm(h, 'ln_f') 169 | final_id = h_len - 1 170 | h = gather_2d(h, tf.expand_dims(final_id, axis=1)) 171 | target_mask = tf.sequence_mask(h_len-1, maxlen=tf.shape(h)[1], dtype=tf.float32)#h_len-1把sentence token给mask掉 172 | target_mask = tf.expand_dims(target_mask, 2) 173 | encode_out = tf.transpose(presents, perm=(0, 4, 2, 3, 1, 5)) 174 | ori_enc_shape = tf.shape(encode_out) 175 | encode_out = tf.reshape(encode_out, shape=(tf.shape(presents)[0], tf.shape(presents)[4], -1)) 176 | encode_out = tf.multiply(encode_out, target_mask) 177 | encode_out = tf.reshape(encode_out, shape=ori_enc_shape) 178 | encode_out = tf.transpose(encode_out, perm=(0, 4, 2, 3, 1, 5)) 179 | encode_out.set_shape(past_shape(hparams=self.hparams, batch_size=None)) 180 | return encode_out, h 181 | 182 | 183 | 184 | class Decoder(): 185 | def __init__(self,scope,hparams): 186 | self.scope = scope 187 | self.hparams = hparams 188 | with tf.variable_scope(scope): 189 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE): 190 | self.wpe=tf.get_variable('wpe', [self.hparams.n_ctx, self.hparams.n_embd], 191 | initializer=tf.random_normal_initializer(stddev=0.01)) 192 | self.wte = tf.get_variable('wte', [self.hparams.n_vocab, self.hparams.n_embd], 193 | initializer=tf.random_normal_initializer(stddev=0.02)) 194 | self.attn_w = tf.get_variable(shape=(self.hparams.n_embd, self.hparams.n_embd), name='sen_attn_w') 195 | 196 | 197 | #def decode_all 198 | def decode_all(self,tokens,past_list,enc_h_list): 199 | with tf.variable_scope(self.scope,reuse=tf.AUTO_REUSE): 200 | with tf.variable_scope('model',reuse=tf.AUTO_REUSE): 201 | results = {} 202 | if type(past_list)!=list: 203 | past_list=[past_list] 204 | batch, sequence = shape_list(tokens) 205 | #past_length = 0 206 | all_past_length=[0 if past_list[0] is None else tf.shape(past_list[0])[-2]] 207 | past_length = tf.reduce_max(tf.stack(all_past_length,axis=0),axis=0) 208 | h = tf.gather(self.wte, tokens) + tf.gather(self.wpe, positions_for(tokens, past_length)) 209 | values_present = {} 210 | for i in range(0, self.hparams.n_layer): 211 | querys = h 212 | values_h = [] 213 | for j in range(0, len(past_list)): 214 | past = past_list[j] 215 | pasts = tf.unstack(past, axis=1) if past is not None else [None] * self.hparams.n_layer 216 | assert len(pasts) == self.hparams.n_layer 217 | h, present = block(querys, 'h%d' % i, past=pasts[i], hparams=self.hparams) 218 | values_h.append(h) 219 | if j in values_present: 220 | values_present[j].append(present) 221 | else: 222 | values_present[j]=[present] 223 | enc_h_all = tf.concat(enc_h_list, axis=1) 224 | attn_score = tf.tensordot(querys, self.attn_w, axes=(2, 0)) 225 | attn_score = tf.matmul(attn_score, tf.transpose(enc_h_all, perm=(0, 2, 1))) # batch*seq*context_num 226 | attn_score = tf.nn.softmax(attn_score,axis=2) 227 | val_h_cat = tf.stack(values_h, axis=2) 228 | val_h_cat = tf.expand_dims(attn_score, axis=3) * val_h_cat 229 | val_h_cat = tf.reduce_sum(val_h_cat, axis=2) 230 | h = val_h_cat 231 | for j in range(0,len(past_list)): 232 | values_present[j]=tf.stack(values_present[j],axis=1) 233 | past_list[j]=tf.concat([past_list[j],values_present[j]],axis=-2) 234 | h = norm(h, 'ln_f') 235 | # Language model loss. Do tokens