├── LICENSE ├── README.md ├── main.py ├── model.py ├── my_attention_decoder_fn.py ├── my_loss.py ├── my_seq2seq.py ├── output_projection.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generating Informative Responses with Controlled Sentence Function 2 | 3 | ## Introduction 4 | 5 | Sentence function is a significant factor to achieve the purpose of the speaker. In this paper, we present a novel model to generate informative responses with controlled sentence function. Given a user post and a sentence function label, our model is to generate a response that is not only coherent with the specified function category, but also informative in content. 6 | 7 | This project is a tensorflow implementation of our work. 8 | 9 | ## Dependencies 10 | 11 | * Python 2.7 12 | * Numpy 13 | * Tensorflow 1.3.0 14 | 15 | ## Quick Start 16 | 17 | * Dataset 18 | 19 | Our dataset contains single-turn post-response pairs with corresponding sentence function labels. The sentence function labels of responses have been automatically annotated by a self-attentive classifier. 20 | 21 | Please download the [Chinese Dialogue Dataset with Sentence Function Labels](http://coai.cs.tsinghua.edu.cn/hml/dataset) to data directory. 22 | 23 | * Train 24 | 25 | ```python main.py ``` 26 | 27 | * Test 28 | 29 | ```python main.py --is_train=False --inference_path='xxx' --inference_version='yyy' ``` 30 | 31 | You can test the model using this command. You may set the directory of test set with inference_path and the checkpoint to be used with inference_version. The generation result will be output to the 'xxx.out' file. 32 | 33 | 34 | ## Details 35 | 36 | ### Training 37 | 38 | You can change the model parameters using: 39 | 40 | --symbols xxx size of full vocabulary 41 | --topic_symbols xxx size of topic vocabulary 42 | --full_kl_step xxx parameter of kl annealing 43 | --units xxx size of hidden units 44 | --embed_units xxx dimension of word embedding 45 | --batch_size xxx batch size in training process 46 | --per_checkpoint xxx steps to save and evaluate the model 47 | --data_dir xxx data directory 48 | --train_dir xxx training directory 49 | 50 | 51 | ## Paper 52 | 53 | Pei Ke, Jian Guan, Minlie Huang, Xiaoyan Zhu. 54 | [Generating Informative Responses with Controlled Sentence Function.](http://aclweb.org/anthology/P18-1139) 55 | ACL 2018, Melbourne, Australia. 56 | 57 | **Please kindly cite our paper if this paper and the code are helpful.** 58 | 59 | 60 | ## License 61 | 62 | Apache License 2.0 63 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import sys 4 | import time 5 | import random 6 | import pickle as pkl 7 | import math 8 | random.seed(time.time()) 9 | 10 | from model import Seq2SeqModel, _START_VOCAB 11 | 12 | # import tokenizer 13 | try: 14 | from wordseg_python import Global 15 | except: 16 | Global = None 17 | 18 | tf.app.flags.DEFINE_boolean("is_train", True, "Set to False to inference.") 19 | tf.app.flags.DEFINE_integer("symbols", 40000, "vocabulary size.") 20 | tf.app.flags.DEFINE_integer("topic_symbols", 10000, "topic vocabulary size.") 21 | tf.app.flags.DEFINE_integer("full_kl_step", 80000, "Total steps to finish annealing") 22 | tf.app.flags.DEFINE_integer("embed_units", 100, "Size of word embedding.") 23 | tf.app.flags.DEFINE_integer("units", 256, "Size of hidden units.") 24 | tf.app.flags.DEFINE_integer("batch_size", 128, "Batch size to use during training.") 25 | tf.app.flags.DEFINE_string("data_dir", "/home/kepei/seq2seq_rec_nostop/data", "Data directory") 26 | tf.app.flags.DEFINE_string("train_dir", "./train", "Training directory.") 27 | tf.app.flags.DEFINE_integer("per_checkpoint", 1000, "How many steps to do per checkpoint.") 28 | tf.app.flags.DEFINE_integer("inference_version", 0, "The version for inferencing.") 29 | tf.app.flags.DEFINE_boolean("log_parameters", True, "Set to True to show the parameters") 30 | tf.app.flags.DEFINE_string("inference_path", "", "Set filename of inference, default isscreen") 31 | tf.app.flags.DEFINE_string("num_keywords", 2, "Number of keywords extracted from responses") 32 | 33 | FLAGS = tf.app.flags.FLAGS 34 | 35 | def load_data(path, fname): 36 | # data sample: (post, response, keyword, label) 37 | # post: tokenized post sequence 38 | # response: tokenized response sequence 39 | # keyword: keywords extracted from repsonse (using PMI in this work) 40 | # label: one-hot sentence function label of corresponding response (annotated by self-attentive classifier) 41 | with open('%s/%s.post' % (path, fname)) as f: 42 | post = [line.strip().split() for line in f.readlines()] 43 | with open('%s/%s.response' % (path, fname)) as f: 44 | response = [line.strip().split() for line in f.readlines()] 45 | with open('%s/%s.keyword' % (path, fname)) as f: 46 | keyword = [line.strip().split() for line in f.readlines()] 47 | with open('%s/%s.label' % (path, fname)) as f: 48 | label = [line.strip().split('\t') for line in f.readlines()] 49 | data = [] 50 | for p, r, k, l in zip(post, response, keyword, label): 51 | data.append({'post': p, 'response': r, 'keyword': k, 'label':l}) 52 | return data 53 | 54 | def build_vocab(path, data, stop_list, func_list): 55 | print("Creating vocabulary...") 56 | vocab = {} 57 | vocab_topic = {} 58 | for i, pair in enumerate(data): 59 | if i % 100000 == 0: 60 | print(" processing line %d" % i) 61 | for token in pair['post']+pair['response']: 62 | if token in vocab: 63 | vocab[token] += 1 64 | else: 65 | vocab[token] = 1 66 | for token in pair['keyword']: 67 | if token not in stop_list: # remove stopwords from vocab_topic 68 | if token in vocab_topic: 69 | vocab_topic[token] += 1 70 | else: 71 | vocab_topic[token] = 1 72 | vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) 73 | vocab_topic_list = sorted(vocab_topic, key = vocab_topic.get, reverse = True) 74 | 75 | if len(vocab_list) > FLAGS.symbols: 76 | vocab_list = vocab_list[:FLAGS.symbols] # remove words with low frequency from vocab_list 77 | vocab_topic_list_new = [] 78 | for word in vocab_topic_list: 79 | if word in vocab_list: 80 | vocab_topic_list_new.append(word) # keep topic words in vocab_list 81 | if len(vocab_topic_list_new) > FLAGS.topic_symbols: 82 | vocab_topic_list_new = vocab_topic_list_new[:FLAGS.topic_symbols] 83 | 84 | topic_pos_list = [] # record the position of topic words 85 | topic_cnt = 0 86 | for ele in vocab_topic_list_new: 87 | if ele in vocab_list and ele not in func_list: 88 | topic_cnt += 1 89 | topic_pos_list.append(vocab_list.index(ele)) 90 | print 'topic_cnt = ', topic_cnt 91 | 92 | func_pos_list = [] # record the position of function-related words 93 | for ele in func_list.items(): 94 | if ele[0] in vocab_list: 95 | func_pos_list.append(vocab_list.index(ele[0])) 96 | 97 | # Load pre-trained word vectors from path/vector.txt 98 | # Format of word vectors (e.g. word "function"): function -0.1 0.2 ... 0.5 99 | print("Loading word vectors...") 100 | vectors = {} 101 | with open('%s/vector.txt' % path) as f: 102 | for i, line in enumerate(f): 103 | if i % 100000 == 0: 104 | print(" processing line %d" % i) 105 | s = line.strip() 106 | word = s[:s.find(' ')] 107 | vector = s[s.find(' ')+1:] 108 | vectors[word] = vector 109 | 110 | embed = [] 111 | for word in vocab_list: 112 | if word in vectors: 113 | vector = map(float, vectors[word].split()) 114 | else: 115 | vector = np.zeros((FLAGS.embed_units), dtype=np.float32) 116 | embed.append(vector) 117 | embed = np.array(embed, dtype=np.float32) 118 | 119 | return vocab_list, embed, vocab_topic_list_new, topic_pos_list, func_pos_list 120 | 121 | 122 | def gen_batched_data(data): 123 | encoder_len = max([len(item['post']) for item in data])+1 124 | decoder_len = max([len(item['response']) for item in data])+1 125 | 126 | posts, responses, posts_length, responses_length, labels = [], [], [], [], [] 127 | def padding(sent, l): 128 | return sent + ['_EOS'] + ['_PAD'] * (l-len(sent)-1) 129 | 130 | for item in data: 131 | posts.append(padding(item['post'], encoder_len)) 132 | responses.append(padding(item['response'], decoder_len)) 133 | posts_length.append(len(item['post'])+1) 134 | responses_length.append(len(item['response'])+1) 135 | labels.append(item['label']) 136 | 137 | batched_data = {'posts': np.array(posts), 138 | 'responses': np.array(responses), 139 | 'posts_length': posts_length, 140 | 'responses_length': responses_length, 141 | 'labels':np.array(labels)} 142 | return batched_data 143 | 144 | 145 | def train(model, sess, data_train, global_t): 146 | batched_data = gen_batched_data(data_train) 147 | outputs = model.step_decoder(sess, batched_data, global_t = global_t) 148 | return outputs 149 | 150 | 151 | def evaluate(model, sess, data_dev): 152 | # Evaluation on dev set 153 | loss = np.zeros((1, )) 154 | kl_loss, dec_loss, dis_loss = np.zeros((1, )), np.zeros((1, )), np.zeros((1, )) 155 | st, ed, times = 0, FLAGS.batch_size, 0 156 | while st < len(data_dev): 157 | selected_data = data_dev[st:ed] 158 | batched_data = gen_batched_data(selected_data) 159 | outputs = model.step_decoder(sess, batched_data, forward_only=True, global_t = FLAGS.full_kl_step) 160 | kl_loss += outputs[1] 161 | dec_loss += outputs[2] 162 | dis_loss += outputs[3] 163 | loss += outputs[0] 164 | st, ed = ed, ed+FLAGS.batch_size 165 | times += 1 166 | loss /= times 167 | kl_loss /= times 168 | dec_loss /= times 169 | dis_loss /= times 170 | show = lambda a: '[%s]' % (' '.join(['%.2f' % x for x in a])) 171 | print('perplexity on dev set: %s kl_loss: %s dec_loss: %s dis_loss: %s ' % (show(np.exp(dec_loss)), show(kl_loss), show(dec_loss), show(dis_loss))) 172 | 173 | 174 | def inference(model, sess, posts, label_no): 175 | length = [len(p)+1 for p in posts] 176 | def padding(sent, l): 177 | return sent + ['_EOS'] + ['_PAD'] * (l-len(sent)-1) 178 | 179 | batched_posts = [padding(p, max(length)) for p in posts] 180 | batched_data = {'posts': np.array(batched_posts), 181 | 'posts_length': np.array(length, dtype=np.int32)} 182 | 183 | results_inf = model.inference(sess, batched_data, label_no) 184 | responses = results_inf[0] 185 | results = [] 186 | res_cnt = 0 187 | for response in responses: 188 | result = [] 189 | token_cnt = 0 190 | for token in response: 191 | if token != '_EOS': 192 | result.append(token) 193 | token_cnt += 1 194 | else: 195 | break 196 | res_cnt += 1 197 | results.append(result) 198 | return results 199 | 200 | 201 | config = tf.ConfigProto() 202 | config.gpu_options.allow_growth = True 203 | with tf.Session(config=config) as sess: 204 | # load dataset 205 | data_train = load_data(FLAGS.data_dir, 'weibo_pair_train_pattern') 206 | data_dev = load_data(FLAGS.data_dir, 'weibo_pair_dev_pattern') 207 | 208 | # load stopword list 209 | stop_list = {} 210 | stop_file = open('stopword_utf8.txt', 'r') 211 | line_stop = stop_file.readline() 212 | while line_stop: 213 | temp = line_stop.strip() 214 | if temp not in stop_list: 215 | stop_list[temp] = 1 216 | else: 217 | stop_list[temp] += 1 218 | line_stop = stop_file.readline() 219 | stop_file.close() 220 | print 'stop_list=', len(stop_list) 221 | 222 | # load function-related word list 223 | func_list = {} 224 | func_file = open('functionword-utf8.txt', 'r') 225 | line_func = func_file.readline() 226 | while line_func: 227 | temp = line_func.strip() 228 | if temp not in func_list: 229 | func_list[temp] = 1 230 | else: 231 | func_list[temp] += 1 232 | line_func = func_file.readline() 233 | func_file.close() 234 | print 'func_list=', len(func_list) 235 | 236 | # build vocabularies 237 | vocab, embed, vocab_topic, topic_pos, func_pos = build_vocab(FLAGS.data_dir, data_train, stop_list, func_list) 238 | print 'num_topic_vocab=', len(vocab_topic) 239 | print 'num_func_vocab=', len(func_pos) 240 | 241 | # Training mode 242 | if FLAGS.is_train: 243 | model = Seq2SeqModel( 244 | FLAGS.symbols, 245 | FLAGS.embed_units, 246 | FLAGS.units, 247 | is_train=True, 248 | vocab=vocab, 249 | topic_pos=topic_pos, 250 | func_pos = func_pos, 251 | embed=embed, 252 | full_kl_step=FLAGS.full_kl_step) 253 | 254 | if FLAGS.log_parameters: 255 | model.print_parameters() 256 | 257 | if tf.train.get_checkpoint_state(FLAGS.train_dir): 258 | print("Reading model parameters from %s" % FLAGS.train_dir) 259 | model.saver.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir)) 260 | model.symbol2index.init.run() 261 | else: 262 | print("Created model with fresh parameters.") 263 | tf.global_variables_initializer().run() 264 | model.symbol2index.init.run() 265 | 266 | temp_total_losses, total_loss_step, kl_loss_step, decoder_loss_step, dis_loss_step, time_step = np.zeros((1, )), np.zeros((1, )), np.zeros((1, )), np.zeros((1, )), np.zeros((1, )), .0 267 | previous_losses = [1e18]*6 268 | 269 | num_batch = len(data_train) / FLAGS.batch_size 270 | random.shuffle(data_train) 271 | pre_train = [data_train[i:i+FLAGS.batch_size] for i in range(0, len(data_train), FLAGS.batch_size)] 272 | if len(data_train) % FLAGS.batch_size != 0: 273 | pre_train.pop() 274 | random.shuffle(pre_train) 275 | ptr = 0 276 | global_t = 0 277 | while True: 278 | if model.global_step.eval() % FLAGS.per_checkpoint == 0: 279 | show = lambda a: '[%s]' % (' '.join(['%.2f' % x for x in a])) 280 | print("global step %d learning rate %.4f step-time %.2f perplexity %s kl_loss %s dec_loss %s dis_loss %s" 281 | % (model.global_step.eval(), model.learning_rate.eval(), 282 | time_step, show(np.exp(decoder_loss_step)), show(kl_loss_step), show(decoder_loss_step), show(dis_loss_step))) 283 | model.saver.save(sess, '%s/checkpoint' % FLAGS.train_dir, global_step=model.global_step) 284 | evaluate(model, sess, data_dev) 285 | if np.sum(temp_total_losses) > max(previous_losses): 286 | sess.run(model.learning_rate_decay_op) 287 | previous_losses = previous_losses[1:]+[np.sum(temp_total_losses)] 288 | temp_total_losses, total_loss_step, kl_loss_step, decoder_loss_step, dis_loss_step, time_step = np.zeros((1, )), np.zeros((1, )), np.zeros((1, )), np.zeros((1, )), np.zeros((1, )), .0 289 | 290 | global_t = model.global_step.eval() 291 | start_time = time.time() 292 | temp_loss = train(model, sess, pre_train[ptr], global_t) 293 | total_loss_step += temp_loss[0] / FLAGS.per_checkpoint 294 | kl_loss_step += temp_loss[1] / FLAGS.per_checkpoint 295 | decoder_loss_step += temp_loss[2] / FLAGS.per_checkpoint 296 | dis_loss_step += temp_loss[3] / FLAGS.per_checkpoint 297 | if global_t>=1: 298 | temp_total_losses += decoder_loss_step + kl_loss_step*FLAGS.full_kl_step/global_t + dis_loss_step 299 | time_step += (time.time() - start_time) / FLAGS.per_checkpoint 300 | ptr += 1 301 | if ptr == num_batch: 302 | random.shuffle(pre_train) 303 | ptr = 0 304 | 305 | else: 306 | model = Seq2SeqModel( 307 | FLAGS.symbols, 308 | FLAGS.embed_units, 309 | FLAGS.units, 310 | is_train=False, 311 | topic_pos = topic_pos, 312 | func_pos = func_pos, 313 | vocab=None) 314 | 315 | if FLAGS.inference_version == 0: 316 | model_path = tf.train.latest_checkpoint(FLAGS.train_dir) 317 | else: 318 | model_path = '%s/checkpoint-%08d' % (FLAGS.train_dir, FLAGS.inference_version) 319 | print('restore from %s' % model_path) 320 | model.saver.restore(sess, model_path) 321 | model.symbol2index.init.run() 322 | 323 | # tokenizer 324 | def split(sent): 325 | if Global == None: 326 | return sent.split() 327 | 328 | sent = sent.decode('utf-8', 'ignore').encode('gbk', 'ignore') 329 | tuples = [(word.decode("gbk").encode("utf-8"), pos) 330 | for word, pos in Global.GetTokenPos(sent)] 331 | return [each[0] for each in tuples] 332 | 333 | 334 | posts = [] 335 | posts_ori = [] 336 | with open(FLAGS.inference_path) as f: 337 | for line in f: 338 | sent = line.strip() 339 | posts_ori.append(sent) 340 | cur_post = split(sent) 341 | posts.append(cur_post) 342 | 343 | responses = [[], [], []] 344 | st, ed = 0, FLAGS.batch_size 345 | while st < len(posts): 346 | for i in range(3): 347 | temp = inference(model, sess, posts[st: ed], i) 348 | responses[i] += temp 349 | st, ed = ed, ed+FLAGS.batch_size 350 | 351 | with open(FLAGS.inference_path+'.out', 'w') as f: 352 | for i in range(len(posts)): 353 | # Output interrogative, declarative and imperative responses in turn 354 | for k in range(3): 355 | f.writelines('%s\n' % (''.join(responses[k][i]))) 356 | f.writelines('\n') 357 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import my_attention_decoder_fn 4 | import my_loss 5 | import my_seq2seq 6 | 7 | from tensorflow.python.ops.nn import dynamic_rnn 8 | from tensorflow.python.ops.rnn_cell_impl import GRUCell, LSTMCell, MultiRNNCell 9 | from tensorflow.contrib.lookup.lookup_ops import HashTable, KeyValueTensorInitializer 10 | from tensorflow.contrib.layers.python.layers import layers 11 | from output_projection import output_projection_layer 12 | from tensorflow.python.ops import variable_scope 13 | 14 | from utils import sample_gaussian 15 | from utils import gaussian_kld 16 | 17 | PAD_ID = 0 18 | UNK_ID = 1 19 | GO_ID = 2 20 | EOS_ID = 3 21 | _START_VOCAB = ['_PAD', '_UNK', '_GO', '_EOS'] 22 | 23 | class Seq2SeqModel(object): 24 | def __init__(self, 25 | num_symbols, 26 | num_embed_units, 27 | num_units, 28 | is_train, 29 | vocab=None, 30 | topic_pos=None, 31 | func_pos = None, 32 | embed=None, 33 | learning_rate=0.1, 34 | learning_rate_decay_factor=0.9995, 35 | max_gradient_norm=5.0, 36 | max_length=30, 37 | latent_size=128, 38 | use_lstm=False, 39 | num_classes=3, 40 | full_kl_step=80000): 41 | 42 | self.posts = tf.placeholder(tf.string, shape=(None, None)) 43 | self.posts_length = tf.placeholder(tf.int32, shape=(None)) 44 | self.responses = tf.placeholder(tf.string, shape=(None, None)) 45 | self.responses_length = tf.placeholder(tf.int32, shape=(None)) 46 | self.labels = tf.placeholder(tf.float32, shape=(None, num_classes)) 47 | self.use_prior = tf.placeholder(tf.bool) 48 | self.global_t = tf.placeholder(tf.int32) 49 | self.topic_mask = tf.reduce_sum(tf.one_hot(topic_pos, num_symbols, 1.0, 0.0), axis = 0) 50 | self.func_mask = tf.reduce_sum(tf.one_hot(func_pos, num_symbols, 1.0, 0.0), axis = 0) 51 | self.ordinary_mask = tf.cast(tf.ones_like(self.topic_mask), tf.float32) - self.topic_mask - self.func_mask 52 | 53 | # build the vocab table (string to index) 54 | if is_train: 55 | self.symbols = tf.Variable(vocab, trainable=False, name="symbols") 56 | else: 57 | self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols") 58 | self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols, 59 | tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)), 60 | default_value=UNK_ID, name="symbol2index") 61 | 62 | self.posts_input = self.symbol2index.lookup(self.posts) 63 | self.responses_target = self.symbol2index.lookup(self.responses) 64 | batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(self.responses)[1] 65 | self.responses_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID, 66 | tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1) 67 | self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1, 68 | decoder_len), reverse=True, axis=1), [-1, decoder_len]) 69 | 70 | # build the embedding table (index to vector) 71 | if embed is None: 72 | # initialize the embedding randomly 73 | self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) 74 | else: 75 | # initialize the embedding by pre-trained word vectors 76 | self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) 77 | 78 | self.pattern_embed = tf.get_variable('pattern_embed', [num_classes, num_embed_units], tf.float32) 79 | 80 | self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts_input) 81 | self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) 82 | 83 | if use_lstm: 84 | cell_fw = LSTMCell(num_units) 85 | cell_bw = LSTMCell(num_units) 86 | cell_dec = LSTMCell(2*num_units) 87 | else: 88 | cell_fw = GRUCell(num_units) 89 | cell_bw = GRUCell(num_units) 90 | cell_dec = GRUCell(2*num_units) 91 | 92 | # post encoder 93 | with variable_scope.variable_scope("encoder"): 94 | encoder_output, encoder_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.encoder_input, 95 | self.posts_length, dtype=tf.float32) 96 | post_sum_state = tf.concat(encoder_state, 1) 97 | encoder_output = tf.concat(encoder_output, 2) 98 | 99 | # response encoder 100 | with variable_scope.variable_scope("encoder", reuse = True): 101 | decoder_state, decoder_last_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.decoder_input, 102 | self.responses_length, dtype=tf.float32) 103 | response_sum_state = tf.concat(decoder_last_state, 1) 104 | 105 | # recognition network 106 | with variable_scope.variable_scope("recog_net"): 107 | recog_input = tf.concat([post_sum_state, response_sum_state], 1) 108 | recog_mulogvar = tf.contrib.layers.fully_connected(recog_input, latent_size * 2, activation_fn=None, scope="muvar") 109 | recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1) 110 | 111 | # prior network 112 | with variable_scope.variable_scope("prior_net"): 113 | prior_fc1 = tf.contrib.layers.fully_connected(post_sum_state, latent_size * 2, activation_fn=tf.tanh, scope="fc1") 114 | prior_mulogvar = tf.contrib.layers.fully_connected(prior_fc1, latent_size * 2, activation_fn=None, scope="muvar") 115 | prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1) 116 | 117 | latent_sample = tf.cond(self.use_prior, 118 | lambda: sample_gaussian(prior_mu, prior_logvar), 119 | lambda: sample_gaussian(recog_mu, recog_logvar)) 120 | 121 | # Discriminator 122 | with variable_scope.variable_scope("discriminator"): 123 | dis_input = latent_sample 124 | pattern_fc1 = tf.contrib.layers.fully_connected(dis_input, latent_size, activation_fn=tf.tanh, scope="pattern_fc1") 125 | self.pattern_logits = tf.contrib.layers.fully_connected(pattern_fc1, num_classes, activation_fn=None, scope="pattern_logits") 126 | 127 | self.label_embedding = tf.matmul(self.labels, self.pattern_embed) 128 | 129 | output_fn, my_sequence_loss = output_projection_layer(2*num_units, num_symbols, latent_size, num_embed_units, self.topic_mask, self.ordinary_mask, self.func_mask) 130 | 131 | attention_keys, attention_values, attention_score_fn, attention_construct_fn = my_attention_decoder_fn.prepare_attention(encoder_output, 'luong', 2*num_units) 132 | 133 | with variable_scope.variable_scope("dec_start"): 134 | temp_start = tf.concat([post_sum_state, self.label_embedding, latent_sample], 1) 135 | dec_fc1 = tf.contrib.layers.fully_connected(temp_start, 2*num_units, activation_fn=tf.tanh, scope="dec_start_fc1") 136 | dec_fc2 = tf.contrib.layers.fully_connected(dec_fc1, 2*num_units, activation_fn=None, scope="dec_start_fc2") 137 | 138 | if is_train: 139 | # rnn decoder 140 | extra_info = tf.concat([self.label_embedding, latent_sample], 1) 141 | decoder_fn_train = my_attention_decoder_fn.attention_decoder_fn_train(dec_fc2, 142 | attention_keys, attention_values, attention_score_fn, attention_construct_fn, extra_info) 143 | self.decoder_output, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_train, 144 | self.decoder_input, self.responses_length, scope = "decoder") 145 | 146 | # calculate the loss 147 | self.decoder_loss = my_loss.sequence_loss(logits = self.decoder_output, 148 | targets = self.responses_target, weights = self.decoder_mask, 149 | extra_information = latent_sample, label_embedding = self.label_embedding, softmax_loss_function = my_sequence_loss) 150 | temp_klloss = tf.reduce_mean(gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar)) 151 | self.kl_weight = tf.minimum(tf.to_float(self.global_t)/full_kl_step, 1.0) 152 | self.klloss = self.kl_weight * temp_klloss 153 | temp_labels = tf.argmax(self.labels, 1) 154 | self.disloss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.pattern_logits, labels=temp_labels)) 155 | self.loss = self.decoder_loss + self.klloss + self.disloss # need to anneal the kl_weight 156 | 157 | # building graph finished and get all parameters 158 | self.params = tf.trainable_variables() 159 | 160 | # initialize the training process 161 | self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) 162 | self.learning_rate_decay_op = self.learning_rate.assign(self.learning_rate * learning_rate_decay_factor) 163 | self.global_step = tf.Variable(0, trainable=False) 164 | 165 | # calculate the gradient of parameters 166 | opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9) 167 | gradients = tf.gradients(self.loss, self.params) 168 | clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, 169 | max_gradient_norm) 170 | self.update = opt.apply_gradients(zip(clipped_gradients, self.params), 171 | global_step=self.global_step) 172 | 173 | else: 174 | # rnn decoder 175 | decoder_fn_inference = my_attention_decoder_fn.attention_decoder_fn_inference(output_fn, 176 | dec_fc2, attention_keys, attention_values, attention_score_fn, 177 | attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, latent_sample, self.label_embedding) 178 | self.decoder_distribution, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_inference, scope="decoder") 179 | self.generation_index = tf.argmax(tf.split(self.decoder_distribution, 180 | [2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK 181 | self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index) 182 | 183 | self.params = tf.trainable_variables() 184 | 185 | self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, 186 | max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) 187 | 188 | def print_parameters(self): 189 | for item in self.params: 190 | print('%s: %s' % (item.name, item.get_shape())) 191 | 192 | def step_decoder(self, session, data, forward_only=False, global_t=None): 193 | input_feed = {self.posts: data['posts'], 194 | self.posts_length: data['posts_length'], 195 | self.responses: data['responses'], 196 | self.responses_length: data['responses_length'], 197 | self.labels: data['labels'], 198 | self.use_prior: False} 199 | if global_t is not None: 200 | input_feed[self.global_t] = global_t 201 | if forward_only: #On the dev set 202 | output_feed = [self.loss, self.klloss, self.decoder_loss, self.disloss] 203 | else: 204 | output_feed = [self.loss, self.klloss, self.decoder_loss, self.disloss, self.gradient_norm, self.update] 205 | return session.run(output_feed, input_feed) 206 | 207 | def inference(self, session, data, label_no): 208 | if label_no == 0: 209 | temp_labels = np.tile(np.array([1, 0, 0]),(len(data['posts']),1)) 210 | else: 211 | if label_no == 1: 212 | temp_labels = np.tile(np.array([0, 1, 0]), (len(data['posts']), 1)) 213 | else: 214 | temp_labels = np.tile(np.array([0, 0, 1]), (len(data['posts']), 1)) 215 | input_feed = {self.posts: data['posts'], self.posts_length: data['posts_length'], 216 | self.responses: data['posts'], self.responses_length: data['posts_length'], 217 | self.labels: temp_labels, self.use_prior: True} 218 | output_feed = [self.generation] 219 | return session.run(output_feed, input_feed) 220 | -------------------------------------------------------------------------------- /my_attention_decoder_fn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tensorflow.contrib.layers.python.layers import layers 6 | #from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl 7 | from tensorflow.python.ops import rnn_cell_impl 8 | from tensorflow.python.framework import dtypes 9 | from tensorflow.python.framework import function 10 | from tensorflow.python.framework import ops 11 | from tensorflow.python.ops import array_ops 12 | from tensorflow.python.ops import control_flow_ops 13 | from tensorflow.python.ops import math_ops 14 | from tensorflow.python.ops import nn_ops 15 | from tensorflow.python.ops import variable_scope 16 | from tensorflow.python.util import nest 17 | import tensorflow as tf 18 | 19 | __all__ = [ 20 | "prepare_attention", "attention_decoder_fn_train", 21 | "attention_decoder_fn_inference" 22 | ] 23 | 24 | 25 | def attention_decoder_fn_train(encoder_state, 26 | attention_keys, 27 | attention_values, 28 | attention_score_fn, 29 | attention_construct_fn, 30 | extra_information, 31 | name=None): 32 | """Attentional decoder function for `dynamic_rnn_decoder` during training. 33 | 34 | The `attention_decoder_fn_train` is a training function for an 35 | attention-based sequence-to-sequence model. It should be used when 36 | `dynamic_rnn_decoder` is in the training mode. 37 | 38 | The `attention_decoder_fn_train` is called with a set of the user arguments 39 | and returns the `decoder_fn`, which can be passed to the 40 | `dynamic_rnn_decoder`, such that 41 | 42 | ``` 43 | dynamic_fn_train = attention_decoder_fn_train(encoder_state) 44 | outputs_train, state_train = dynamic_rnn_decoder( 45 | decoder_fn=dynamic_fn_train, ...) 46 | ``` 47 | 48 | Further usage can be found in the `kernel_tests/seq2seq_test.py`. 49 | 50 | Args: 51 | encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`. 52 | attention_keys: to be compared with target states. 53 | attention_values: to be used to construct context vectors. 54 | attention_score_fn: to compute similarity between key and target states. 55 | attention_construct_fn: to build attention states. 56 | extra_information: other embeddings like latent samples. 57 | name: (default: `None`) NameScope for the decoder function; 58 | defaults to "simple_decoder_fn_train" 59 | 60 | Returns: 61 | A decoder function with the required interface of `dynamic_rnn_decoder` 62 | intended for training. 63 | """ 64 | with ops.name_scope(name, "attention_decoder_fn_train", [ 65 | encoder_state, attention_keys, attention_values, attention_score_fn, 66 | attention_construct_fn 67 | ]): 68 | pass 69 | 70 | def decoder_fn(time, cell_state, cell_input, cell_output, context_state): 71 | """Decoder function used in the `dynamic_rnn_decoder` for training. 72 | 73 | Args: 74 | time: positive integer constant reflecting the current timestep. 75 | cell_state: state of RNNCell. 76 | cell_input: input provided by `dynamic_rnn_decoder`. 77 | cell_output: output of RNNCell. 78 | context_state: context state provided by `dynamic_rnn_decoder`. 79 | 80 | Returns: 81 | A tuple (done, next state, next input, emit output, next context state) 82 | where: 83 | 84 | done: `None`, which is used by the `dynamic_rnn_decoder` to indicate 85 | that `sequence_lengths` in `dynamic_rnn_decoder` should be used. 86 | 87 | next state: `cell_state`, this decoder function does not modify the 88 | given state. 89 | 90 | next input: `cell_input`, this decoder function does not modify the 91 | given input. The input could be modified when applying e.g. attention. 92 | 93 | emit output: `cell_output`, this decoder function does not modify the 94 | given output. 95 | 96 | next context state: `context_state`, this decoder function does not 97 | modify the given context state. The context state could be modified when 98 | applying e.g. beam search. 99 | """ 100 | with ops.name_scope( 101 | name, "attention_decoder_fn_train", 102 | [time, cell_state, cell_input, cell_output, context_state]): 103 | if cell_state is None: # first call, return encoder_state 104 | cell_state = encoder_state 105 | 106 | # init attention 107 | attention = _init_attention(encoder_state) 108 | else: 109 | # construct attention 110 | attention = attention_construct_fn(cell_output, attention_keys, 111 | attention_values) #batch*2units 112 | cell_output = attention 113 | 114 | # combine cell_input and attention 115 | next_input = array_ops.concat([cell_input, attention, extra_information], 1) #batch*(embed_units + 2num_units + pattern_embed) 116 | 117 | return (None, cell_state, next_input, cell_output, context_state) 118 | 119 | return decoder_fn 120 | 121 | 122 | def attention_decoder_fn_inference(output_fn, 123 | encoder_state, 124 | attention_keys, 125 | attention_values, 126 | attention_score_fn, 127 | attention_construct_fn, 128 | embeddings, 129 | start_of_sequence_id, 130 | end_of_sequence_id, 131 | maximum_length, 132 | num_decoder_symbols, 133 | latent_sample, 134 | label_embedding, 135 | dtype=dtypes.int32, 136 | name=None): 137 | """Attentional decoder function for `dynamic_rnn_decoder` during inference. 138 | 139 | The `attention_decoder_fn_inference` is a simple inference function for a 140 | sequence-to-sequence model. It should be used when `dynamic_rnn_decoder` is 141 | in the inference mode. 142 | 143 | The `attention_decoder_fn_inference` is called with user arguments 144 | and returns the `decoder_fn`, which can be passed to the 145 | `dynamic_rnn_decoder`, such that 146 | 147 | ``` 148 | dynamic_fn_inference = attention_decoder_fn_inference(...) 149 | outputs_inference, state_inference = dynamic_rnn_decoder( 150 | decoder_fn=dynamic_fn_inference, ...) 151 | ``` 152 | 153 | Further usage can be found in the `kernel_tests/seq2seq_test.py`. 154 | 155 | Args: 156 | output_fn: An output function to project your `cell_output` onto class 157 | logits. 158 | 159 | An example of an output function; 160 | 161 | ``` 162 | tf.variable_scope("decoder") as varscope 163 | output_fn = lambda x: layers.linear(x, num_decoder_symbols, 164 | scope=varscope) 165 | 166 | outputs_train, state_train = seq2seq.dynamic_rnn_decoder(...) 167 | logits_train = output_fn(outputs_train) 168 | 169 | varscope.reuse_variables() 170 | logits_inference, state_inference = seq2seq.dynamic_rnn_decoder( 171 | output_fn=output_fn, ...) 172 | ``` 173 | 174 | If `None` is supplied it will act as an identity function, which 175 | might be wanted when using the RNNCell `OutputProjectionWrapper`. 176 | 177 | encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`. 178 | attention_keys: to be compared with target states. 179 | attention_values: to be used to construct context vectors. 180 | attention_score_fn: to compute similarity between key and target states. 181 | attention_construct_fn: to build attention states. 182 | embeddings: The embeddings matrix used for the decoder sized 183 | `[num_decoder_symbols, embedding_size]`. 184 | start_of_sequence_id: The start of sequence ID in the decoder embeddings. 185 | end_of_sequence_id: The end of sequence ID in the decoder embeddings. 186 | maximum_length: The maximum allowed of time steps to decode. 187 | num_decoder_symbols: The number of classes to decode at each time step. 188 | dtype: (default: `dtypes.int32`) The default data type to use when 189 | handling integer objects. 190 | name: (default: `None`) NameScope for the decoder function; 191 | defaults to "attention_decoder_fn_inference" 192 | 193 | Returns: 194 | A decoder function with the required interface of `dynamic_rnn_decoder` 195 | intended for inference. 196 | """ 197 | with ops.name_scope(name, "attention_decoder_fn_inference", [ 198 | output_fn, encoder_state, attention_keys, attention_values, 199 | attention_score_fn, attention_construct_fn, embeddings, 200 | start_of_sequence_id, end_of_sequence_id, maximum_length, 201 | num_decoder_symbols, dtype 202 | ]): 203 | start_of_sequence_id = ops.convert_to_tensor(start_of_sequence_id, dtype) 204 | end_of_sequence_id = ops.convert_to_tensor(end_of_sequence_id, dtype) 205 | maximum_length = ops.convert_to_tensor(maximum_length, dtype) 206 | num_decoder_symbols = ops.convert_to_tensor(num_decoder_symbols, dtype) 207 | encoder_info = nest.flatten(encoder_state)[0] 208 | batch_size = encoder_info.get_shape()[0].value 209 | if output_fn is None: 210 | output_fn = lambda x: x 211 | if batch_size is None: 212 | batch_size = array_ops.shape(encoder_info)[0] 213 | 214 | def decoder_fn(time, cell_state, cell_input, cell_output, context_state): 215 | """Decoder function used in the `dynamic_rnn_decoder` for inference. 216 | 217 | The main difference between this decoder function and the `decoder_fn` in 218 | `attention_decoder_fn_train` is how `next_cell_input` is calculated. In 219 | decoder function we calculate the next input by applying an argmax across 220 | the feature dimension of the output from the decoder. This is a 221 | greedy-search approach. (Bahdanau et al., 2014) & (Sutskever et al., 2014) 222 | use beam-search instead. 223 | 224 | Args: 225 | time: positive integer constant reflecting the current timestep. 226 | cell_state: state of RNNCell. 227 | cell_input: input provided by `dynamic_rnn_decoder`. 228 | cell_output: output of RNNCell. 229 | context_state: context state provided by `dynamic_rnn_decoder`. 230 | 231 | Returns: 232 | A tuple (done, next state, next input, emit output, next context state) 233 | where: 234 | 235 | done: A boolean vector to indicate which sentences has reached a 236 | `end_of_sequence_id`. This is used for early stopping by the 237 | `dynamic_rnn_decoder`. When `time>=maximum_length` a boolean vector with 238 | all elements as `true` is returned. 239 | 240 | next state: `cell_state`, this decoder function does not modify the 241 | given state. 242 | 243 | next input: The embedding from argmax of the `cell_output` is used as 244 | `next_input`. 245 | 246 | emit output: If `output_fn is None` the supplied `cell_output` is 247 | returned, else the `output_fn` is used to update the `cell_output` 248 | before calculating `next_input` and returning `cell_output`. 249 | 250 | next context state: `context_state`, this decoder function does not 251 | modify the given context state. The context state could be modified when 252 | applying e.g. beam search. 253 | 254 | Raises: 255 | ValueError: if cell_input is not None. 256 | 257 | """ 258 | with ops.name_scope( 259 | name, "attention_decoder_fn_inference", 260 | [time, cell_state, cell_input, cell_output, context_state]): 261 | if cell_input is not None: 262 | raise ValueError("Expected cell_input to be None, but saw: %s" % 263 | cell_input) 264 | if cell_output is None: 265 | # invariant that this is time == 0 266 | next_input_id = array_ops.ones( 267 | [batch_size,], dtype=dtype) * (start_of_sequence_id) 268 | done = array_ops.zeros([batch_size,], dtype=dtypes.bool) 269 | cell_state = encoder_state 270 | cell_output = array_ops.zeros( 271 | [num_decoder_symbols], dtype=dtypes.float32) 272 | cell_input = array_ops.gather(embeddings, next_input_id) 273 | cell_type = array_ops.zeros( 274 | [3], dtype=dtypes.float32) 275 | 276 | # init attention 277 | attention = _init_attention(encoder_state) 278 | else: 279 | # construct attention 280 | attention = attention_construct_fn(cell_output, attention_keys, 281 | attention_values) 282 | cell_output = attention #batch*2num_units 283 | 284 | # argmax decoder 285 | cell_output, cell_type = output_fn(cell_output, latent_sample, label_embedding) # logits 286 | next_input_id = math_ops.cast( 287 | math_ops.argmax(cell_output, 1), dtype=dtype) 288 | done = math_ops.equal(next_input_id, end_of_sequence_id) 289 | cell_input = array_ops.gather(embeddings, next_input_id) 290 | 291 | # combine cell_input and attention 292 | next_input = array_ops.concat([cell_input, attention, label_embedding, latent_sample], 1) 293 | 294 | # if time > maxlen, return all true vector 295 | done = control_flow_ops.cond( 296 | math_ops.greater(time, maximum_length), 297 | lambda: array_ops.ones([batch_size,], dtype=dtypes.bool), 298 | lambda: done) 299 | return (done, cell_state, next_input, cell_output, context_state, cell_type) 300 | 301 | return decoder_fn 302 | 303 | 304 | ## Helper functions ## 305 | def prepare_attention(attention_states, 306 | attention_option, 307 | num_units, 308 | reuse=False): 309 | """Prepare keys/values/functions for attention. 310 | 311 | Args: 312 | attention_states: hidden states to attend over. 313 | attention_option: how to compute attention, either "luong" or "bahdanau". 314 | num_units: hidden state dimension. 315 | reuse: whether to reuse variable scope. 316 | 317 | Returns: 318 | attention_keys: to be compared with target states. 319 | attention_values: to be used to construct context vectors. 320 | attention_score_fn: to compute similarity between key and target states. 321 | attention_construct_fn: to build attention states. 322 | """ 323 | 324 | # Prepare attention keys / values from attention_states 325 | with variable_scope.variable_scope("attention_keys", reuse=reuse) as scope: 326 | attention_keys = layers.linear( 327 | attention_states, num_units, biases_initializer=None, scope=scope) 328 | attention_values = attention_states 329 | 330 | # Attention score function 331 | attention_score_fn = _create_attention_score_fn("attention_score", num_units, 332 | attention_option, reuse) 333 | 334 | # Attention construction function 335 | attention_construct_fn = _create_attention_construct_fn("attention_construct", 336 | num_units, 337 | attention_score_fn, 338 | reuse) 339 | 340 | return (attention_keys, attention_values, attention_score_fn, 341 | attention_construct_fn) 342 | 343 | 344 | def _init_attention(encoder_state): 345 | """Initialize attention. Handling both LSTM and GRU. 346 | 347 | Args: 348 | encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`. 349 | 350 | Returns: 351 | attn: initial zero attention vector. 352 | """ 353 | 354 | # Multi- vs single-layer 355 | # TODO(thangluong): is this the best way to check? 356 | if isinstance(encoder_state, tuple): 357 | top_state = encoder_state[-1] 358 | else: 359 | top_state = encoder_state 360 | 361 | # LSTM vs GRU 362 | if isinstance(top_state, rnn_cell_impl.LSTMStateTuple): 363 | attn = array_ops.zeros_like(top_state.h) 364 | else: 365 | attn = array_ops.zeros_like(top_state) 366 | 367 | return attn 368 | 369 | 370 | def _create_attention_construct_fn(name, num_units, attention_score_fn, reuse): 371 | """Function to compute attention vectors. 372 | 373 | Args: 374 | name: to label variables. 375 | num_units: hidden state dimension. 376 | attention_score_fn: to compute similarity between key and target states. 377 | reuse: whether to reuse variable scope. 378 | 379 | Returns: 380 | attention_construct_fn: to build attention states. 381 | """ 382 | with variable_scope.variable_scope(name, reuse=reuse) as scope: 383 | 384 | def construct_fn(attention_query, attention_keys, attention_values): 385 | context = attention_score_fn(attention_query, attention_keys, 386 | attention_values) 387 | concat_input = array_ops.concat([attention_query, context], 1) 388 | attention = layers.linear( 389 | concat_input, num_units, biases_initializer=None, scope=scope) 390 | return attention 391 | 392 | return construct_fn 393 | 394 | 395 | # keys: [batch_size, attention_length, attn_size] 396 | # query: [batch_size, 1, attn_size] 397 | # return weights [batch_size, attention_length] 398 | @function.Defun(func_name="attn_add_fun", noinline=True) 399 | def _attn_add_fun(v, keys, query): 400 | return math_ops.reduce_sum(v * math_ops.tanh(keys + query), [2]) 401 | 402 | 403 | @function.Defun(func_name="attn_mul_fun", noinline=True) 404 | def _attn_mul_fun(keys, query): 405 | return math_ops.reduce_sum(keys * query, [2]) 406 | 407 | 408 | def _create_attention_score_fn(name, 409 | num_units, 410 | attention_option, 411 | reuse, 412 | dtype=dtypes.float32): 413 | """Different ways to compute attention scores. 414 | 415 | Args: 416 | name: to label variables. 417 | num_units: hidden state dimension. 418 | attention_option: how to compute attention, either "luong" or "bahdanau". 419 | "bahdanau": additive (Bahdanau et al., ICLR'2015) 420 | "luong": multiplicative (Luong et al., EMNLP'2015) 421 | reuse: whether to reuse variable scope. 422 | dtype: (default: `dtypes.float32`) data type to use. 423 | 424 | Returns: 425 | attention_score_fn: to compute similarity between key and target states. 426 | """ 427 | with variable_scope.variable_scope(name, reuse=reuse): 428 | if attention_option == "bahdanau": 429 | query_w = variable_scope.get_variable( 430 | "attnW", [num_units, num_units], dtype=dtype) 431 | score_v = variable_scope.get_variable("attnV", [num_units], dtype=dtype) 432 | 433 | def attention_score_fn(query, keys, values): 434 | """Put attention masks on attention_values using attention_keys and query. 435 | 436 | Args: 437 | query: A Tensor of shape [batch_size, num_units]. 438 | keys: A Tensor of shape [batch_size, attention_length, num_units]. 439 | values: A Tensor of shape [batch_size, attention_length, num_units]. 440 | 441 | Returns: 442 | context_vector: A Tensor of shape [batch_size, num_units]. 443 | 444 | Raises: 445 | ValueError: if attention_option is neither "luong" or "bahdanau". 446 | 447 | 448 | """ 449 | if attention_option == "bahdanau": 450 | # transform query 451 | query = math_ops.matmul(query, query_w) 452 | 453 | # reshape query: [batch_size, 1, num_units] 454 | query = array_ops.reshape(query, [-1, 1, num_units]) 455 | 456 | # attn_fun 457 | scores = _attn_add_fun(score_v, keys, query) 458 | elif attention_option == "luong": 459 | # reshape query: [batch_size, 1, num_units] 460 | query = array_ops.reshape(query, [-1, 1, num_units]) 461 | 462 | # attn_fun 463 | scores = _attn_mul_fun(keys, query) 464 | else: 465 | raise ValueError("Unknown attention option %s!" % attention_option) 466 | 467 | # Compute alignment weights 468 | # scores: [batch_size, length] 469 | # alignments: [batch_size, length] 470 | # TODO(thangluong): not normalize over padding positions. 471 | alignments = nn_ops.softmax(scores) 472 | 473 | # Now calculate the attention-weighted vector. 474 | alignments = array_ops.expand_dims(alignments, 2) 475 | context_vector = math_ops.reduce_sum(alignments * values, [1]) 476 | context_vector.set_shape([None, num_units]) 477 | 478 | return context_vector 479 | 480 | return attention_score_fn 481 | -------------------------------------------------------------------------------- /my_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tensorflow.python.framework import ops 6 | from tensorflow.python.ops import array_ops 7 | from tensorflow.python.ops import nn_ops 8 | from tensorflow.python.ops import math_ops 9 | import tensorflow as tf 10 | 11 | __all__ = ["sequence_loss"] 12 | 13 | def sequence_loss(logits, targets, weights, extra_information, label_embedding, 14 | average_across_timesteps=True, average_across_batch=True, 15 | softmax_loss_function=None, name=None): 16 | """Weighted cross-entropy loss for a sequence of logits (per example). 17 | 18 | Args: 19 | logits: A 3D Tensor of shape 20 | [batch_size x sequence_length x num_decoder_symbols] and dtype float. 21 | The logits correspond to the prediction across all classes at each 22 | timestep. 23 | targets: A 2D Tensor of shape [batch_size x sequence_length] and dtype 24 | int. The target represents the true class at each timestep. 25 | weights: A 2D Tensor of shape [batch_size x sequence_length] and dtype 26 | float. Weights constitutes the weighting of each prediction in the 27 | sequence. When using weights as masking set all valid timesteps to 1 and 28 | all padded timesteps to 0. 29 | average_across_timesteps: If set, sum the cost across the sequence 30 | dimension and divide by the cost by the total label weight across 31 | timesteps. 32 | average_across_batch: If set, sum the cost across the batch dimension and 33 | divide the returned cost by the batch size. 34 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 35 | to be used instead of the standard softmax (the default if this is None). 36 | name: Optional name for this operation, defaults to "sequence_loss". 37 | 38 | Returns: 39 | A scalar float Tensor: The average log-perplexity per symbol (weighted). 40 | 41 | Raises: 42 | ValueError: logits does not have 3 dimensions or targets does not have 2 43 | dimensions or weights does not have 2 dimensions. 44 | """ 45 | if len(logits.get_shape()) != 3: 46 | raise ValueError("Logits must be a " 47 | "[batch_size x sequence_length x logits] tensor") 48 | if len(targets.get_shape()) != 2: 49 | raise ValueError("Targets must be a [batch_size x sequence_length] " 50 | "tensor") 51 | if len(weights.get_shape()) != 2: 52 | raise ValueError("Weights must be a [batch_size x sequence_length] " 53 | "tensor") 54 | with ops.name_scope(name, "sequence_loss", [logits, targets, weights]): 55 | num_classes = array_ops.shape(logits)[2] 56 | max_time = array_ops.shape(logits)[1] 57 | batch_size = array_ops.shape(logits)[0] 58 | latent_size = array_ops.shape(extra_information)[1] 59 | embed_size = array_ops.shape(label_embedding)[1] 60 | probs_flat = array_ops.reshape(logits, [-1, num_classes]) 61 | targets = array_ops.reshape(targets, [-1]) 62 | expand_extra_information = array_ops.reshape(tf.tile(extra_information, [1, max_time]), [batch_size*max_time, latent_size]) 63 | expand_label_embedding = array_ops.reshape(tf.tile(label_embedding, [1, max_time]), [batch_size*max_time, embed_size]) 64 | if softmax_loss_function is None: 65 | crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( 66 | labels=targets, logits=probs_flat) 67 | else: 68 | crossent = softmax_loss_function(probs_flat, targets, expand_extra_information, expand_label_embedding, max_time) 69 | #crossent = crossent * array_ops.reshape(weights, [-1]) 70 | crossent = array_ops.reshape(crossent, [-1, max_time]) 71 | crossent = crossent * weights 72 | if average_across_timesteps and average_across_batch: 73 | #crossent = math_ops.reduce_sum(crossent, 1) 74 | #total_size = math_ops.reduce_sum(weights, 1) 75 | crossent = math_ops.reduce_sum(crossent) 76 | total_size = math_ops.reduce_sum(weights) 77 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 78 | crossent /= total_size 79 | #crossent = math_ops.reduce_mean(crossent) 80 | else: 81 | batch_size = array_ops.shape(logits)[0] 82 | sequence_length = array_ops.shape(logits)[1] 83 | crossent = array_ops.reshape(crossent, [batch_size, sequence_length]) 84 | if average_across_timesteps and not average_across_batch: 85 | crossent = math_ops.reduce_sum(crossent, axis=[1]) 86 | total_size = math_ops.reduce_sum(weights, axis=[1]) 87 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 88 | crossent /= total_size 89 | if not average_across_timesteps and average_across_batch: 90 | crossent = math_ops.reduce_sum(crossent, axis=[0]) 91 | total_size = math_ops.reduce_sum(weights, axis=[0]) 92 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 93 | crossent /= total_size 94 | return crossent 95 | -------------------------------------------------------------------------------- /my_seq2seq.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tensorflow.contrib import layers 6 | from tensorflow.python.framework import ops 7 | from tensorflow.python.ops import array_ops 8 | from tensorflow.python.ops import control_flow_ops 9 | from tensorflow.python.ops import math_ops 10 | from tensorflow.python.ops import rnn 11 | #import rnn 12 | from tensorflow.python.ops import tensor_array_ops 13 | from tensorflow.python.ops import variable_scope as vs 14 | 15 | __all__ = ["dynamic_rnn_decoder"] 16 | 17 | def dynamic_rnn_decoder(cell, decoder_fn, inputs=None, sequence_length=None, 18 | parallel_iterations=None, swap_memory=False, 19 | time_major=False, scope=None, name=None): 20 | """ Dynamic RNN decoder for a sequence-to-sequence model specified by 21 | RNNCell and decoder function. 22 | 23 | The `dynamic_rnn_decoder` is similar to the `tf.python.ops.rnn.dynamic_rnn` 24 | as the decoder does not make any assumptions of sequence length and batch 25 | size of the input. 26 | 27 | The `dynamic_rnn_decoder` has two modes: training or inference and expects 28 | the user to create seperate functions for each. 29 | 30 | Under both training and inference, both `cell` and `decoder_fn` are expected, 31 | where `cell` performs computation at every timestep using `raw_rnn`, and 32 | `decoder_fn` allows modeling of early stopping, output, state, and next 33 | input and context. 34 | 35 | When training the user is expected to supply `inputs`. At every time step a 36 | slice of the supplied input is fed to the `decoder_fn`, which modifies and 37 | returns the input for the next time step. 38 | 39 | `sequence_length` is needed at training time, i.e., when `inputs` is not 40 | None, for dynamic unrolling. At test time, when `inputs` is None, 41 | `sequence_length` is not needed. 42 | 43 | Under inference `inputs` is expected to be `None` and the input is inferred 44 | solely from the `decoder_fn`. 45 | 46 | Args: 47 | cell: An instance of RNNCell. 48 | decoder_fn: A function that takes time, cell state, cell input, 49 | cell output and context state. It returns a early stopping vector, 50 | cell state, next input, cell output and context state. 51 | Examples of decoder_fn can be found in the decoder_fn.py folder. 52 | inputs: The inputs for decoding (embedded format). 53 | 54 | If `time_major == False` (default), this must be a `Tensor` of shape: 55 | `[batch_size, max_time, ...]`. 56 | 57 | If `time_major == True`, this must be a `Tensor` of shape: 58 | `[max_time, batch_size, ...]`. 59 | 60 | The input to `cell` at each time step will be a `Tensor` with dimensions 61 | `[batch_size, ...]`. 62 | 63 | sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. 64 | if `inputs` is not None and `sequence_length` is None it is inferred 65 | from the `inputs` as the maximal possible sequence length. 66 | parallel_iterations: (Default: 32). The number of iterations to run in 67 | parallel. Those operations which do not have any temporal dependency 68 | and can be run in parallel, will be. This parameter trades off 69 | time for space. Values >> 1 use more memory but take less time, 70 | while smaller values use less memory but computations take longer. 71 | swap_memory: Transparently swap the tensors produced in forward inference 72 | but needed for back prop from GPU to CPU. This allows training RNNs 73 | which would typically not fit on a single GPU, with very minimal (or no) 74 | performance penalty. 75 | time_major: The shape format of the `inputs` and `outputs` Tensors. 76 | If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. 77 | If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. 78 | Using `time_major = True` is a bit more efficient because it avoids 79 | transposes at the beginning and end of the RNN calculation. However, 80 | most TensorFlow data is batch-major, so by default this function 81 | accepts input and emits output in batch-major form. 82 | scope: VariableScope for the `raw_rnn`; 83 | defaults to None. 84 | name: NameScope for the decoder; 85 | defaults to "dynamic_rnn_decoder" 86 | 87 | Returns: 88 | A tuple (outputs, final_state, final_context_state) where: 89 | 90 | outputs: the RNN output 'Tensor'. 91 | 92 | If time_major == False (default), this will be a `Tensor` shaped: 93 | `[batch_size, max_time, cell.output_size]`. 94 | 95 | If time_major == True, this will be a `Tensor` shaped: 96 | `[max_time, batch_size, cell.output_size]`. 97 | 98 | final_state: The final state and will be shaped 99 | `[batch_size, cell.state_size]`. 100 | 101 | final_context_state: The context state returned by the final call 102 | to decoder_fn. This is useful if the context state maintains internal 103 | data which is required after the graph is run. 104 | For example, one way to diversify the inference output is to use 105 | a stochastic decoder_fn, in which case one would want to store the 106 | decoded outputs, not just the RNN outputs. This can be done by 107 | maintaining a TensorArray in context_state and storing the decoded 108 | output of each iteration therein. 109 | 110 | Raises: 111 | ValueError: if inputs is not None and has less than three dimensions. 112 | """ 113 | with ops.name_scope(name, "dynamic_rnn_decoder", 114 | [cell, decoder_fn, inputs, sequence_length, 115 | parallel_iterations, swap_memory, time_major, scope]): 116 | if inputs is not None: 117 | # Convert to tensor 118 | inputs = ops.convert_to_tensor(inputs) 119 | 120 | # Test input dimensions 121 | if inputs.get_shape().ndims is not None and ( 122 | inputs.get_shape().ndims < 2): 123 | raise ValueError("Inputs must have at least two dimensions") 124 | # Setup of RNN (dimensions, sizes, length, initial state, dtype) 125 | if not time_major: 126 | # [batch, seq, features] -> [seq, batch, features] 127 | inputs = array_ops.transpose(inputs, perm=[1, 0, 2]) 128 | 129 | dtype = inputs.dtype 130 | # Get data input information 131 | input_depth = int(inputs.get_shape()[2]) 132 | batch_depth = inputs.get_shape()[1].value 133 | max_time = inputs.get_shape()[0].value 134 | if max_time is None: 135 | max_time = array_ops.shape(inputs)[0] 136 | # Setup decoder inputs as TensorArray 137 | inputs_ta = tensor_array_ops.TensorArray(dtype, size=max_time) 138 | inputs_ta = inputs_ta.unstack(inputs) 139 | 140 | def loop_fn(time, cell_output, cell_state, loop_state): 141 | if cell_state is None: # first call, before while loop (in raw_rnn) 142 | if cell_output is not None: 143 | raise ValueError("Expected cell_output to be None when cell_state " 144 | "is None, but saw: %s" % cell_output) 145 | if loop_state is not None: 146 | raise ValueError("Expected loop_state to be None when cell_state " 147 | "is None, but saw: %s" % loop_state) 148 | context_state = None 149 | else: # subsequent calls, inside while loop, after cell excution 150 | if isinstance(loop_state, tuple): 151 | (done, context_state) = loop_state 152 | else: 153 | done = loop_state 154 | context_state = None 155 | 156 | # call decoder function 157 | if inputs is not None: # training 158 | # get next_cell_input 159 | if cell_state is None: 160 | next_cell_input = inputs_ta.read(0) 161 | else: 162 | if batch_depth is not None: 163 | batch_size = batch_depth 164 | else: 165 | batch_size = array_ops.shape(done)[0] 166 | next_cell_input = control_flow_ops.cond( 167 | math_ops.equal(time, max_time), 168 | lambda: array_ops.zeros([batch_size, input_depth], dtype=dtype), 169 | lambda: inputs_ta.read(time)) 170 | #(next_done, next_cell_state, next_cell_input, emit_output, 171 | # next_context_state, current_type) = decoder_fn(time, cell_state, next_cell_input, 172 | #cell_output, context_state) 173 | (next_done, next_cell_state, next_cell_input, emit_output, 174 | next_context_state) = decoder_fn(time, cell_state, next_cell_input, 175 | cell_output, context_state) 176 | else: # inference 177 | # next_cell_input is obtained through decoder_fn 178 | (next_done, next_cell_state, next_cell_input, emit_output, 179 | next_context_state, current_type) = decoder_fn(time, cell_state, None, cell_output, 180 | context_state) 181 | 182 | # check if we are done 183 | if next_done is None: # training 184 | next_done = time >= sequence_length 185 | 186 | # build next_loop_state 187 | if next_context_state is None: 188 | next_loop_state = next_done 189 | else: 190 | next_loop_state = (next_done, next_context_state) 191 | 192 | return (next_done, next_cell_input, next_cell_state, 193 | #emit_output, next_loop_state, current_type) 194 | emit_output, next_loop_state) 195 | 196 | # Run raw_rnn function 197 | #outputs_ta, final_state, final_loop_state, output_type = rnn.raw_rnn( 198 | outputs_ta, final_state, final_loop_state = rnn.raw_rnn( 199 | cell, loop_fn, parallel_iterations=parallel_iterations, 200 | swap_memory=swap_memory, scope=scope) 201 | outputs = outputs_ta.stack() 202 | #outputs_type = output_type.stack() 203 | 204 | # Get final context_state, if generated by user 205 | if isinstance(final_loop_state, tuple): 206 | final_context_state = final_loop_state[1] 207 | else: 208 | final_context_state = None 209 | 210 | if not time_major: 211 | # [seq, batch, features] -> [batch, seq, features] 212 | outputs = array_ops.transpose(outputs, perm=[1, 0, 2]) 213 | #outputs_type = array_ops.transpose(outputs_type, perm=[1, 0, 2]) 214 | #return outputs, final_state, final_context_state, outputs_type 215 | return outputs, final_state, final_context_state 216 | -------------------------------------------------------------------------------- /output_projection.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.layers.python.layers import layers 3 | from tensorflow.python.ops import variable_scope 4 | 5 | def output_projection_layer(num_units, num_symbols, latent_size, num_embed_units, topic_mask, ordinary_mask, func_mask, name="output_projection"): 6 | def output_fn(outputs, latent_z, label_embedding): 7 | with variable_scope.variable_scope(name): 8 | local_d = tf.reshape(outputs, [-1, num_units]) 9 | local_l = tf.reshape(tf.concat([outputs, latent_z], 1), [-1, num_units + latent_size]) 10 | local_d2 = tf.reshape(tf.concat([outputs, latent_z, label_embedding], 1), [-1, num_units + latent_size + num_embed_units]) 11 | 12 | # type controller 13 | l_fc1 = tf.contrib.layers.fully_connected(local_l, num_units + latent_size, activation_fn=tf.tanh, scope = 'l_fc1') 14 | l_fc2 = tf.contrib.layers.fully_connected(l_fc1, 3, activation_fn=None, scope = 'l_fc2') 15 | p_dis = tf.nn.softmax(l_fc2) 16 | p_dis_1, p_dis_2, p_dis_3 = tf.split(p_dis, 3, axis = 1) 17 | p_dis_1 = tf.reshape(tf.tile(p_dis_1, [1, num_symbols]), [-1, num_symbols]) 18 | p_dis_2 = tf.reshape(tf.tile(p_dis_2, [1, num_symbols]), [-1, num_symbols]) 19 | p_dis_3 = tf.reshape(tf.tile(p_dis_3, [1, num_symbols]), [-1, num_symbols]) 20 | type_index = p_dis 21 | 22 | # topic words 23 | w_fc2 = tf.contrib.layers.fully_connected(local_d, num_symbols, activation_fn=None, scope = 'w_fc2') 24 | p_w = tf.exp(w_fc2) 25 | p_w = p_w * tf.tile(tf.reshape(topic_mask, [1, num_symbols]), [tf.shape(local_d)[0], 1]) 26 | temp_normal = tf.tile(tf.reduce_sum(p_w, 1, keep_dims=True), [1, num_symbols]) 27 | y_prob_d = tf.div(p_w, temp_normal) 28 | 29 | # ordinary words 30 | d1_fc2 = tf.contrib.layers.fully_connected(local_d, num_symbols, activation_fn=None, scope = 'd1_fc2') 31 | temp_d1 = tf.exp(d1_fc2) 32 | temp_d1 = temp_d1 * tf.tile(tf.reshape(ordinary_mask, [1, num_symbols]), [tf.shape(local_d)[0], 1]) 33 | temp_normal = tf.tile(tf.reduce_sum(temp_d1, 1, keep_dims=True), [1, num_symbols]) 34 | y_prob_d1 = tf.div(temp_d1, temp_normal) 35 | 36 | # function-related words 37 | d2_fc2 = tf.contrib.layers.fully_connected(local_d2, num_symbols, activation_fn=None, scope = 'd2_fc2') 38 | temp_d2 = tf.exp(d2_fc2) 39 | temp_d2 = temp_d2 * tf.tile(tf.reshape(func_mask, [1, num_symbols]), [tf.shape(local_d)[0], 1]) 40 | temp_normal = tf.tile(tf.reduce_sum(temp_d2, 1, keep_dims=True), [1, num_symbols]) 41 | y_prob_d2 = tf.div(temp_d2, temp_normal) 42 | 43 | y_prob = p_dis_1 * y_prob_d + p_dis_2 * y_prob_d1 + p_dis_3 * y_prob_d2 44 | return y_prob, type_index 45 | 46 | def my_sequence_loss(outputs, targets, latent_z, label_embedding, max_time): 47 | with variable_scope.variable_scope("decoder/%s" % name): 48 | local_labels = tf.reshape(targets, [-1]) 49 | local_d = tf.reshape(outputs, [-1, num_units]) 50 | local_l = tf.reshape(tf.concat([outputs, latent_z], 1), [-1, num_units + latent_size]) 51 | local_d2 = tf.reshape(tf.concat([outputs, latent_z, label_embedding], 1), [-1, num_units + latent_size + num_embed_units]) 52 | 53 | # type controller 54 | l_fc1 = tf.contrib.layers.fully_connected(local_l, num_units + latent_size, activation_fn=tf.tanh, scope = 'l_fc1') 55 | l_fc2 = tf.contrib.layers.fully_connected(l_fc1, 3, activation_fn=None, scope = 'l_fc2') 56 | p_dis = tf.nn.softmax(l_fc2) 57 | p_dis_1, p_dis_2, p_dis_3 = tf.split(p_dis, 3, axis = 1) 58 | p_dis_1 = tf.reshape(tf.tile(p_dis_1, [1, num_symbols]), [-1, num_symbols]) 59 | p_dis_2 = tf.reshape(tf.tile(p_dis_2, [1, num_symbols]), [-1, num_symbols]) 60 | p_dis_3 = tf.reshape(tf.tile(p_dis_3, [1, num_symbols]), [-1, num_symbols]) 61 | 62 | # topic words 63 | w_fc2 = tf.contrib.layers.fully_connected(local_d, num_symbols, activation_fn=None, scope = 'w_fc2') 64 | p_w = tf.exp(w_fc2) 65 | p_w = p_w * tf.tile(tf.reshape(topic_mask, [1, num_symbols]), [tf.shape(local_d)[0], 1]) 66 | temp_normal = tf.tile(tf.reduce_sum(p_w, 1, keep_dims=True), [1, num_symbols]) 67 | y_prob_d = tf.div(p_w, temp_normal) 68 | 69 | # ordinary words 70 | d1_fc2 = tf.contrib.layers.fully_connected(local_d, num_symbols, activation_fn=None, scope = 'd1_fc2') 71 | temp_d1 = tf.exp(d1_fc2) 72 | temp_d1 = temp_d1 * tf.tile(tf.reshape(ordinary_mask, [1, num_symbols]), [tf.shape(local_d)[0], 1]) 73 | temp_normal = tf.tile(tf.reduce_sum(temp_d1, 1, keep_dims=True), [1, num_symbols]) 74 | y_prob_d1 = tf.div(temp_d1, temp_normal) 75 | 76 | # function-related words 77 | d2_fc2 = tf.contrib.layers.fully_connected(local_d2, num_symbols, activation_fn=None, scope = 'd2_fc2') 78 | temp_d2 = tf.exp(d2_fc2) 79 | temp_d2 = temp_d2 * tf.tile(tf.reshape(func_mask, [1, num_symbols]), [tf.shape(local_d)[0], 1]) 80 | temp_normal = tf.tile(tf.reduce_sum(temp_d2, 1, keep_dims=True), [1, num_symbols]) 81 | y_prob_d2 = tf.div(temp_d2, temp_normal) 82 | 83 | y_prob = p_dis_1 * y_prob_d + p_dis_2 * y_prob_d1 + p_dis_3 * y_prob_d2 84 | 85 | # cross entropy 86 | labels_onehot = tf.one_hot(local_labels, num_symbols) 87 | labels_onehot = tf.clip_by_value(labels_onehot, 0.0, 1.0) 88 | y_prob = tf.clip_by_value(y_prob, 1e-18, 1.0) 89 | cross_entropy = tf.reshape(tf.reduce_sum(-labels_onehot * tf.log(y_prob), 1), [-1, 1]) 90 | 91 | return cross_entropy 92 | 93 | return output_fn, my_sequence_loss 94 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar): 5 | kld = -0.5 * tf.reduce_sum(1 + (recog_logvar - prior_logvar) 6 | - tf.div(tf.pow(prior_mu - recog_mu, 2), tf.exp(prior_logvar)) 7 | - tf.div(tf.exp(recog_logvar), tf.exp(prior_logvar)), reduction_indices=1) 8 | return kld 9 | 10 | def sample_gaussian(mu, logvar): 11 | epsilon = tf.random_normal(tf.shape(logvar), name="epsilon") 12 | std = tf.exp(0.5 * logvar) 13 | z = mu + tf.multiply(std, epsilon) 14 | return z 15 | --------------------------------------------------------------------------------