├── data ├── tra │ └── tra_0.pkl ├── tst │ └── tst_0.pkl └── val │ └── val_0.pkl ├── train_w2v.py ├── prepro.py ├── predict_model.py ├── LICENSE ├── train_model.py ├── data_utils.py ├── README.md └── textsum_model.py /data/tra/tra_0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adrian9631/TextSumma/HEAD/data/tra/tra_0.pkl -------------------------------------------------------------------------------- /data/tst/tst_0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adrian9631/TextSumma/HEAD/data/tst/tst_0.pkl -------------------------------------------------------------------------------- /data/val/val_0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adrian9631/TextSumma/HEAD/data/val/val_0.pkl -------------------------------------------------------------------------------- /train_w2v.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import multiprocessing 4 | import logging 5 | from gensim.models import Word2Vec 6 | from gensim.models.word2vec import LineSentence 7 | 8 | if __name__ == '__main__': 9 | program = os.path.basename(sys.argv[0]) 10 | logger = logging.getLogger(program) 11 | 12 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s') 13 | logging.root.setLevel(level=logging.INFO) 14 | logger.info("running %s" % ' '.join(sys.argv)) 15 | 16 | if len(sys.argv) < 4: 17 | print("Using: python train_w2v.py one-billion-word-benchmark output_gensim_model output_word_vector") 18 | sys.exit(1) 19 | inp, outp1, outp2 = sys.argv[1:4] 20 | 21 | model = Word2Vec(LineSentence(inp), size=150, window=6, min_count=2, workers=(multiprocessing.cpu_count()-2), hs=1, sg=1, negative=10) 22 | 23 | model.save(outp1) 24 | model.wv.save_word2vec_format(outp2, binary=True) 25 | 26 | 27 | -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | import os 3 | import sys 4 | import codecs 5 | import pickle 6 | import logging 7 | 8 | 9 | def load(filename): 10 | with open(filename, 'rb') as output: 11 | data = pickle.load(output) 12 | return data 13 | 14 | def save(filename, data): 15 | with open(filename, 'wb') as output: 16 | pickle.dump(data, output) 17 | 18 | def compute(inp, oup, logger): 19 | cnt_file = 0 20 | for filename in os.listdir(inp): 21 | data_path1 = os.path.join(inp, filename) 22 | data_path2 = oup +'example_'+ str(cnt_file) + '.pkl' 23 | data = {} 24 | entity,abstract,article,label = [],[],[],[] 25 | cnt = 0 26 | with codecs.open(data_path1, 'r', encoding='utf-8', errors='ignore') as f: 27 | for line in f.readlines(): 28 | if line == '\n': 29 | cnt += 1 30 | continue 31 | if cnt == 0: 32 | pass 33 | if cnt == 1: 34 | article.append(line.replace('\t\t\t', '').replace('\n', '')) 35 | if cnt == 2: 36 | abstract.append(line.replace('\n', '').replace('*', '')) 37 | if cnt == 3: 38 | entity.append(line.replace('\n', '')) 39 | for idx, sent in enumerate(article): 40 | if sent[-1] == '1': 41 | label.append(idx) 42 | article = [sent[:len(sent)-1] for idx, sent in enumerate(article)] 43 | entity_dict = {} 44 | if len(entity) != 0: 45 | for pair in entity: 46 | key = pair.split(':')[0] 47 | value = pair.split(':')[1] 48 | entity_dict[key] = value 49 | data['entity'] = entity_dict 50 | data['abstract'] = abstract 51 | data['article'] = article 52 | data['label'] = label 53 | save(data_path2, data) 54 | cnt_file += 1 55 | if cnt_file % 500 == 0: 56 | logger.info("running the script, extract %d examples already..." % cnt_file) 57 | logger.info("extract %d examples totally this time, done." % (cnt_file+1)) 58 | 59 | if __name__ == "__main__": 60 | 61 | program = os.path.basename(sys.argv[0]) 62 | logger = logging.getLogger(program) 63 | 64 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s') 65 | logging.root.setLevel(level=logging.INFO) 66 | logger.info("running %s" % ' '.join(sys.argv)) 67 | 68 | if len(sys.argv) < 3: 69 | print("Using: python prepro.py ./source_dir/ ./target_dir/") 70 | sys.exit(1) 71 | inp, oup = sys.argv[1:3] 72 | 73 | compute(inp, oup, logger) 74 | 75 | -------------------------------------------------------------------------------- /predict_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import re 3 | import os 4 | import math 5 | import pickle 6 | import codecs 7 | import json 8 | import tensorflow as tf 9 | import numpy as np 10 | from data_utils import * 11 | from textsum_model import Neuralmodel 12 | from gensim.models import KeyedVectors 13 | from rouge import Rouge, FilesRouge 14 | 15 | #configuration 16 | FLAGS=tf.app.flags.FLAGS 17 | 18 | tf.app.flags.DEFINE_string("hyp_path","../res/hyp.txt","file of summary.") 19 | tf.app.flags.DEFINE_string("ref_path","../res/ref.txt","file of abstract.") 20 | tf.app.flags.DEFINE_string("result_path","../res/","path to store the predicted results.") 21 | tf.app.flags.DEFINE_string("tst_data_path","../src/neuralsum/dailymail/tst/","path of test data.") 22 | tf.app.flags.DEFINE_string("tst_file_path","../src/neuralsum/dailymail/tst/","file of test data.") 23 | tf.app.flags.DEFINE_boolean("use_tst_dataset", True,"using test dataset, set False to use the file as targets") 24 | tf.app.flags.DEFINE_string("entity_path","../cache/entity_dict.pkl", "path of entity data.") 25 | tf.app.flags.DEFINE_string("vocab_path","../cache/vocab","path of vocab frequency list") 26 | tf.app.flags.DEFINE_integer("vocab_size",199900,"maximum vocab size.") 27 | 28 | tf.app.flags.DEFINE_float("learning_rate",0.0001,"learning rate") 29 | 30 | tf.app.flags.DEFINE_integer("is_frozen_step", 0, "how many steps before fine-tuning the embedding.") 31 | tf.app.flags.DEFINE_integer("cur_learning_step", 0, "how many steps before using the predicted labels instead of true labels.") 32 | tf.app.flags.DEFINE_integer("decay_step", 5000, "how many steps before decay learning rate.") 33 | tf.app.flags.DEFINE_float("decay_rate", 0.1, "Rate of decay for learning rate.") 34 | tf.app.flags.DEFINE_string("ckpt_dir","../ckpt/","checkpoint location for the model") 35 | tf.app.flags.DEFINE_integer("batch_size", 1, "Batch size for training/evaluating.") 36 | tf.app.flags.DEFINE_integer("embed_size", 150,"embedding size") 37 | tf.app.flags.DEFINE_integer("input_y2_max_length", 40,"the max length of a sentence in abstracts") 38 | tf.app.flags.DEFINE_integer("max_num_sequence", 30,"the max number of sequence in documents") 39 | tf.app.flags.DEFINE_integer("max_num_abstract", 4,"the max number of abstract in documents") 40 | tf.app.flags.DEFINE_integer("sequence_length", 100,"the max length of a sentence in documents") 41 | tf.app.flags.DEFINE_integer("hidden_size", 300,"the hidden size of the encoder and decoder") 42 | tf.app.flags.DEFINE_boolean("use_highway_flag", True,"using highway network or not.") 43 | tf.app.flags.DEFINE_integer("highway_layers", 1,"How many layers in highway network.") 44 | tf.app.flags.DEFINE_integer("document_length", 1000,"the max vocabulary of documents") 45 | tf.app.flags.DEFINE_integer("beam_width", 4,"the beam search max width") 46 | tf.app.flags.DEFINE_integer("attention_size", 150,"the attention size of the decoder") 47 | tf.app.flags.DEFINE_boolean("extract_sentence_flag", True,"using sentence extractor") 48 | tf.app.flags.DEFINE_boolean("is_training", False,"is traning.true:tranining,false:testing/inference") 49 | tf.app.flags.DEFINE_boolean("use_embedding",True,"whether to use embedding or not.") 50 | tf.app.flags.DEFINE_string("word2vec_model_path","../w2v/benchmark_sg1_e150_b.vector","word2vec's vocabulary and vectors") 51 | filter_sizes = [1,2,3,4,5,6,7] 52 | feature_map = [20,20,30,40,50,70,70] 53 | cur_learning_steps = [0,0] 54 | 55 | def load(filename): 56 | with open(filename, 'rb') as output: 57 | data = pickle.load(output) 58 | return data 59 | 60 | def save(filename, data): 61 | with open(filename, 'wb') as output: 62 | pickle.dump(data, output) 63 | 64 | def dump(filename, data): 65 | with open(filename, 'w') as output: 66 | json.dump(data, output, cls=MyEncoder) 67 | 68 | def main(_): 69 | config=tf.ConfigProto() 70 | config.gpu_options.allow_growth = True 71 | results = [] 72 | with tf.Session(config=config) as sess: 73 | Model=Neuralmodel(FLAGS.extract_sentence_flag, FLAGS.is_training, FLAGS.vocab_size, FLAGS.batch_size, FLAGS.embed_size, FLAGS.learning_rate, cur_learning_steps, FLAGS.decay_step, FLAGS.decay_rate, FLAGS.max_num_sequence, FLAGS.sequence_length, 74 | filter_sizes, feature_map, FLAGS.use_highway_flag, FLAGS.highway_layers, FLAGS.hidden_size, FLAGS.document_length, FLAGS.max_num_abstract, FLAGS.beam_width, FLAGS.attention_size, FLAGS.input_y2_max_length) 75 | saver=tf.train.Saver() 76 | if os.path.exists(FLAGS.ckpt_dir+"checkpoint"): 77 | print("Restoring Variables from Checkpoint") 78 | saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir)) 79 | else: 80 | print("Can't find the checkpoint.going to stop") 81 | return 82 | if FLAGS.use_tst_dataset: 83 | predict_gen = Batch_P(FLAGS.tst_data_path, FLAGS.vocab_path, FLAGS) 84 | else: 85 | predict_gen = Batch_F(process_file(FLAGS.tst_file_path, FLAGS.entity_path), FLAGS.vocab_path, FLAGS) 86 | iteration = 0 87 | for batch in predict_gen: 88 | iteration += 1 89 | feed_dict={} 90 | feed_dict[Model.dropout_keep_prob] = 1.0 91 | feed_dict[Model.input_x] = batch['article_words'] 92 | feed_dict[Model.tst] = False 93 | feed_dict[Model.cur_learning] = False 94 | logits = sess.run(Model.logits, feed_dict=feed_dict) 95 | results.append(compute_score(logits, batch)) 96 | evaluate_file(logits, batch) 97 | if iteration % 500 == 0: 98 | print ('Dealing with %d examples already...' % iteration) 99 | 100 | print ('Waitting for storing the results...') 101 | for idx, data in enumerate(results): 102 | filename = os.path.join(FLAGS.result_path, 'tst_%d.json' % idx) 103 | dump(filename, data) 104 | print ('Counting for the rouge...') 105 | scores = evaluate_rouge(FLAGS.hyp_path, FLAGS.ref_path) 106 | print (scores) 107 | print ('Done.') 108 | 109 | def process_file(data_path, entity_path): # TODO 110 | examples = [] 111 | entitys = load(entity_path) 112 | with codecs.open(data_path, 'r', encoding='utf-8', errors='ignore') as f: 113 | for line in f.readlines(): 114 | if line == '\n': 115 | continue 116 | example = {} 117 | entity_dict = {} 118 | for idx, name in entitys.items(): 119 | if re.search(name, line): 120 | article = line.replace(name, idx) 121 | entity_dict[idx] = name 122 | example['article'] = article.splits('.') 123 | example['entity'] = entity_dict 124 | examples.append(example) 125 | return examples 126 | 127 | def evaluate_file(logits, batch): 128 | data = batch['original'] 129 | score_list = [] 130 | pos = 0 131 | for sent, score in zip(data['article'], logits[0][:len(data['article'])]): 132 | score_list.append((pos, score, sent)) 133 | pos += 1 134 | data['score'] = sorted(score_list, key=lambda x:x[1], reverse=True) 135 | summary = '. '.join([highest[2] for highest in sorted(score_list[:3], key=lambda x:x[0], reverse=False)]) 136 | abstract = '. '.join(data['abstract']) 137 | 138 | with open(FLAGS.hyp_path, 'a') as f: 139 | f.write(summary) 140 | f.write('\n') 141 | 142 | with open(FLAGS.ref_path, 'a') as f: 143 | f.write(abstract) 144 | f.write('\n') 145 | 146 | def evaluate_rouge(hyp_path, ref_path): 147 | files_rouge = FilesRouge(hyp_path, ref_path) 148 | rouge = files_rouge.get_scores(avg=True) 149 | return rouge 150 | 151 | def compute_score(logits, batch): 152 | data = batch['original'] 153 | score_list = [] 154 | pos = 0 155 | for sent, score in zip(data['article'], logits[0][:len(data['article'])]): 156 | score_list.append((pos, score, sent)) 157 | pos += 1 158 | data['score'] = sorted(score_list, key=lambda x:x[1], reverse=True) 159 | return data 160 | 161 | if __name__ == '__main__': 162 | tf.app.run() 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | from data_utils import * 5 | from textsum_model import Neuralmodel 6 | from gensim.models import KeyedVectors 7 | from rouge import Rouge 8 | import os 9 | import math 10 | import pickle 11 | from tqdm import tqdm 12 | 13 | #configuration 14 | FLAGS=tf.app.flags.FLAGS 15 | 16 | tf.app.flags.DEFINE_string("log_path","../log/","path of summary log.") 17 | tf.app.flags.DEFINE_string("tra_data_path","../src/neuralsum/dailymail/tra/","path of training data.") 18 | tf.app.flags.DEFINE_string("tst_data_path","../src/neuralsum/dailymail/tst/","path of test data.") 19 | tf.app.flags.DEFINE_string("val_data_path","../src/neuralsum/dailymail/val/","path of validation data.") 20 | tf.app.flags.DEFINE_string("vocab_path","../cache/vocab","path of vocab frequency list") 21 | tf.app.flags.DEFINE_integer("vocab_size",199900,"maximum vocab size.") 22 | 23 | tf.app.flags.DEFINE_float("learning_rate",0.0001,"learning rate") 24 | 25 | tf.app.flags.DEFINE_integer("is_frozen_step", 400, "how many steps before fine-tuning the embedding.") 26 | tf.app.flags.DEFINE_integer("decay_step", 5000, "how many steps before decay learning rate.") 27 | tf.app.flags.DEFINE_float("decay_rate", 0.1, "Rate of decay for learning rate.") 28 | tf.app.flags.DEFINE_string("ckpt_dir","../ckpt/","checkpoint location for the model") 29 | tf.app.flags.DEFINE_integer("batch_size", 20, "Batch size for training/evaluating.") 30 | tf.app.flags.DEFINE_integer("embed_size", 150,"embedding size") 31 | tf.app.flags.DEFINE_integer("input_y2_max_length", 40,"the max length of a sentence in abstracts") 32 | tf.app.flags.DEFINE_integer("max_num_sequence", 30,"the max number of sequence in documents") 33 | tf.app.flags.DEFINE_integer("max_num_abstract", 4,"the max number of abstract in documents") 34 | tf.app.flags.DEFINE_integer("sequence_length", 100,"the max length of a sentence in documents") 35 | tf.app.flags.DEFINE_integer("hidden_size", 300,"the hidden size of the encoder and decoder") 36 | tf.app.flags.DEFINE_boolean("use_highway_flag", True,"using highway network or not.") 37 | tf.app.flags.DEFINE_integer("highway_layers", 1,"How many layers in highway network.") 38 | tf.app.flags.DEFINE_integer("document_length", 1000,"the max vocabulary of documents") 39 | tf.app.flags.DEFINE_integer("beam_width", 4,"the beam search max width") 40 | tf.app.flags.DEFINE_integer("attention_size", 150,"the attention size of the decoder") 41 | tf.app.flags.DEFINE_boolean("extract_sentence_flag", True,"using sentence extractor") 42 | tf.app.flags.DEFINE_boolean("is_training", True,"is traning.true:tranining,false:testing/inference") 43 | tf.app.flags.DEFINE_integer("num_epochs",10,"number of epochs to run.") 44 | tf.app.flags.DEFINE_integer("validate_every", 1, "Validate every validate_every epochs.") 45 | tf.app.flags.DEFINE_boolean("use_embedding",True,"whether to use embedding or not.") 46 | tf.app.flags.DEFINE_string("word2vec_model_path","../w2v/benchmark_sg1_e150_b.vector","word2vec's vocabulary and vectors") 47 | filter_sizes = [1,2,3,4,5,6,7] 48 | feature_map = [20,20,30,40,50,70,70] 49 | cur_learning_steps = [500,2500] 50 | 51 | def main(_): 52 | config = tf.ConfigProto() 53 | config.gpu_options.allow_growth=True 54 | with tf.Session(config=config) as sess: 55 | # instantiate model 56 | Model = Neuralmodel(FLAGS.extract_sentence_flag, FLAGS.is_training, FLAGS.vocab_size, FLAGS.batch_size, FLAGS.embed_size, FLAGS.learning_rate, cur_learning_steps, FLAGS.decay_step, FLAGS.decay_rate, FLAGS.max_num_sequence, FLAGS.sequence_length, 57 | filter_sizes, feature_map, FLAGS.use_highway_flag, FLAGS.highway_layers, FLAGS.hidden_size, FLAGS.document_length, FLAGS.max_num_abstract, FLAGS.beam_width, FLAGS.attention_size, FLAGS.input_y2_max_length) 58 | # initialize saver 59 | saver = tf.train.Saver() 60 | if os.path.exists(FLAGS.ckpt_dir+"checkpoint"): 61 | print("Restoring Variables from Checkpoint.") 62 | saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir)) 63 | summary_writer = tf.summary.FileWriter(logdir=FLAGS.log_path, graph=sess.graph) 64 | else: 65 | print('Initializing Variables') 66 | sess.run(tf.global_variables_initializer()) 67 | summary_writer = tf.summary.FileWriter(logdir=FLAGS.log_path, graph=sess.graph) 68 | if FLAGS.use_embedding: #load pre-trained word embedding 69 | assign_pretrained_word_embedding(sess, FLAGS.vocab_path, FLAGS.vocab_size, Model,FLAGS.word2vec_model_path) 70 | curr_epoch=sess.run(Model.epoch_step) 71 | 72 | batch_size=FLAGS.batch_size 73 | iteration=0 74 | for epoch in range(curr_epoch,FLAGS.num_epochs): 75 | loss, counter = 0.0, 0 76 | train_gen = Batch(FLAGS.tra_data_path,FLAGS.vocab_path,FLAGS.batch_size,FLAGS) 77 | for batch in tqdm(train_gen): 78 | iteration=iteration+1 79 | if epoch==0 and counter==0: 80 | print("train_batch", batch['abstracts_len']) 81 | feed_dict={} 82 | if FLAGS.extract_sentence_flag: 83 | feed_dict[Model.dropout_keep_prob] = 0.5 84 | feed_dict[Model.input_x] = batch['article_words'] 85 | feed_dict[Model.input_y1] = batch['label_sentences'] 86 | feed_dict[Model.input_y1_length] = batch['article_len'] 87 | feed_dict[Model.tst] = FLAGS.is_training 88 | feed_dict[Model.cur_learning] = True if cur_learning_steps[1] > iteration and epoch == 0 else False 89 | else: 90 | feed_dict[Model.dropout_keep_prob] = 0.5 91 | feed_dict[Model.input_x] = batch['article_words'] 92 | feed_dict[Model.input_y2_length] = batch['abstracts_len'] 93 | feed_dict[Model.input_y2] = batch['abstracts_inputs'] 94 | feed_dict[Model.input_decoder_x] = batch['abstracts_targets'] 95 | feed_dict[Model.value_decoder_x] = batch['article_value'] 96 | feed_dict[Model.tst] = FLAGS.is_training 97 | train_op = Model.train_op_frozen if FLAGS.is_frozen_step > iteration and epoch == 0 else Model.train_op 98 | curr_loss,lr,_,_,summary,logits=sess.run([Model.loss_val,Model.learning_rate,train_op,Model.global_increment,Model.merge,Model.logits],feed_dict) 99 | summary_writer.add_summary(summary, global_step=iteration) 100 | loss,counter=loss+curr_loss,counter+1 101 | if counter %50==0: 102 | print("Epoch %d\tBatch %d\tTrain Loss:%.3f\tLearning rate:%.5f" %(epoch,counter,loss/float(counter),lr)) 103 | if iteration % 1000 == 0: 104 | eval_loss = do_eval(sess, Model) 105 | print("Epoch %d Validation Loss:%.3f\t " % (epoch, eval_loss)) 106 | # TODO eval_loss, acc_score = do_eval(sess, Model) 107 | # TODO print("Epoch %d Validation Loss:%.3f\t Acc:%.3f" % (epoch, eval_loss, acc_score)) 108 | # save model to checkpoint 109 | save_path = FLAGS.ckpt_dir + "model.ckpt" 110 | saver.save(sess, save_path, global_step=epoch) 111 | #epoch increment 112 | print("going to increment epoch counter....") 113 | sess.run(Model.epoch_increment) 114 | print(epoch,FLAGS.validate_every,(epoch % FLAGS.validate_every==0)) 115 | if epoch % FLAGS.validate_every==0: 116 | #save model to checkpoint 117 | save_path=FLAGS.ckpt_dir+"model.ckpt" 118 | saver.save(sess,save_path,global_step=epoch) 119 | summary_writer.close() 120 | 121 | def do_eval(sess, Model): 122 | eval_loss, eval_counter= 0.0, 0 123 | # eval_loss, eval_counter, acc_score= 0.0, 0, 0.0 124 | batch_size = 20 125 | valid_gen = Batch(FLAGS.tst_data_path,FLAGS.vocab_path,batch_size,FLAGS) 126 | for batch in valid_gen: 127 | feed_dict={} 128 | if FLAGS.extract_sentence_flag: 129 | feed_dict[Model.dropout_keep_prob] = 1.0 130 | feed_dict[Model.input_x] = batch['article_words'] 131 | feed_dict[Model.input_y1] = batch['label_sentences'] 132 | feed_dict[Model.input_y1_length] = batch['article_len'] 133 | feed_dict[Model.tst] = not FLAGS.is_training 134 | feed_dict[Model.cur_learning] = False 135 | else: 136 | feed_dict[Model.dropout_keep_prob] = 1.0 137 | feed_dict[Model.input_x] = batch['article_words'] 138 | feed_dict[Model.input_y2] = batch['abstracts_inputs'] 139 | feed_dict[Model.input_y2_length] = batch['abstracts_len'] 140 | feed_dict[Model.input_decoder_x] = batch['abstracts_targets'] 141 | feed_dict[Model.value_decoder_x] = batch['article_value'] 142 | feed_dict[Model.tst] = not FLAGS.is_training 143 | curr_eval_loss,logits=sess.run([Model.loss_val,Model.logits],feed_dict) 144 | # curr_acc_score = compute_label(logits, batch) 145 | # acc_score += curr_acc_score 146 | eval_loss += curr_eval_loss 147 | eval_counter += 1 148 | 149 | return eval_loss/float(eval_counter) # acc_score/float(eval_counter) 150 | 151 | def compute_label(logits, batch): # TODO 152 | imp_pos = np.argsort(logits) 153 | lab_num = [ len(res['label']) for res in batch['original']] 154 | lab_pos = [ res['label'] for res in batch['original']] 155 | abs_num = [ res['abstract'] for res in batch['original']] 156 | sen_pos = [ pos[:num] for pos, num in zip(imp_pos, lab_num)] 157 | 158 | # compute 159 | acc_list = [] 160 | for sen, lab, abst in zip(sen_pos, lab_pos, abs_num): 161 | sen = set(sen) 162 | lab = set(lab) 163 | if len(lab) == 0 or len(abst) == 0: 164 | continue 165 | score = float(len(sen&lab)) / len(abst) 166 | acc = 1.0 if score > 1.0 else score 167 | acc_list.append(acc) 168 | acc_score = np.mean(acc_list) 169 | 170 | return acc_score 171 | 172 | def assign_pretrained_word_embedding(sess,vocab_path,vocab_size,Model,word2vec_model_path): 173 | print("using pre-trained word emebedding.started.word2vec_model_path:",word2vec_model_path) 174 | vocab = Vocab(vocab_path, vocab_size) 175 | word2vec_model = KeyedVectors.load_word2vec_format(word2vec_model_path, binary=True) 176 | bound = np.sqrt(6.0) / np.sqrt(vocab_size) # bound for random variables. 177 | count_exist = 0; 178 | count_not_exist = 0 179 | word_embedding_2dlist = [[]] * vocab_size # create an empty word_embedding list. 180 | word_embedding_2dlist[0] = np.zeros(FLAGS.embed_size, dtype=np.float32) # assign empty for first word:'PAD' 181 | for i in range(1, vocab_size): # loop each word 182 | word = vocab.id2word(i) 183 | embedding = None 184 | try: 185 | embedding = word2vec_model[word] # try to get vector:it is an array. 186 | except Exception: 187 | embedding = None 188 | if embedding is not None: # the 'word' exist a embedding 189 | word_embedding_2dlist[i] = embedding; 190 | count_exist = count_exist + 1 # assign array to this word. 191 | else: # no embedding for this word 192 | word_embedding_2dlist[i] = np.random.uniform(-bound, bound, FLAGS.embed_size) 193 | count_not_exist = count_not_exist + 1 # init a random value for the word. 194 | word_embedding_final = np.array(word_embedding_2dlist) # covert to 2d array. 195 | word_embedding = tf.constant(word_embedding_final, dtype=tf.float32) # convert to tensor 196 | t_assign_embedding = tf.assign(Model.Embedding,word_embedding) # assign this value to our embedding variables of our model. 197 | sess.run(t_assign_embedding) 198 | 199 | word_embedding_2dlist_ = [[]] * 2 # create an empty word_embedding list for GO END. 200 | word_embedding_2dlist_[0] = np.random.uniform(-bound, bound, FLAGS.hidden_size) # GO 201 | word_embedding_2dlist_[1] = np.random.uniform(-bound, bound, FLAGS.hidden_size) # END 202 | word_embedding_final_ = np.array(word_embedding_2dlist_) # covert to 2d array. 203 | word_embedding_ = tf.constant(word_embedding_final_, dtype=tf.float32) # convert to tensor 204 | t_assign_embedding_ = tf.assign(Model.Embedding_,word_embedding_) # assign this value to our embedding variables of our model. 205 | sess.run(t_assign_embedding_) 206 | print("word. exists embedding:", count_exist, " ;word not exist embedding:", count_not_exist) 207 | print("using pre-trained word emebedding.ended...") 208 | 209 | if __name__ == "__main__": 210 | tf.app.run() 211 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import re 3 | import os 4 | import json 5 | import pickle 6 | import random 7 | import numpy as np 8 | from tflearn.data_utils import pad_sequences 9 | 10 | PAD_TOKEN = '[PAD]' 11 | UNKNOWN_TOKEN = '[UNK]' 12 | START_DECODING = '[START]' 13 | STOP_DECODING = '[STOP]' 14 | 15 | def load(filename): 16 | with open(filename, 'rb') as output: 17 | data = pickle.load(output) 18 | return data 19 | 20 | def save(filename, data): 21 | with open(filename, 'wb') as output: 22 | pickle.dump(data, output) 23 | 24 | def Batch(data_path, vocab_path, size, hps): 25 | 26 | res = {} 27 | filenames = os.listdir(data_path) 28 | random.shuffle(filenames) 29 | label_sentences, article_value, article_words, article_len, abstracts_targets, abstracts_inputs, abstracts_len, results = [], [], [], [], [], [], [], [] 30 | vocab = Vocab(vocab_path, hps.vocab_size) 31 | for cnt, filename in enumerate(filenames[:200000]): 32 | pickle_path = os.path.join(data_path, filename) 33 | res = load(pickle_path) 34 | label, value, words, len_a, targets, inputs, lens = Example(res['article'],res['abstract'],res['label'],res['entity'], vocab, hps) # TODO 35 | label_sentences.append(label) 36 | article_value.append(value) 37 | article_words.append(words) 38 | article_len.append(len_a) 39 | abstracts_targets.append(targets) 40 | abstracts_inputs.append(inputs) 41 | abstracts_len.append(lens) 42 | results.append(res) 43 | 44 | if (cnt+1) % size == 0: 45 | data_dict ={} 46 | data_dict['label_sentences'] = label_sentences 47 | data_dict['article_value'] = article_value 48 | data_dict['article_words'] = article_words 49 | data_dict['article_len'] = article_len 50 | data_dict['abstracts_targets'] = abstracts_targets 51 | data_dict['abstracts_inputs'] = abstracts_inputs 52 | data_dict['abstracts_len'] = abstracts_len 53 | data_dict['original'] = results 54 | label_sentences, article_value, article_words, article_len, abstracts_targets, abstracts_inputs, abstracts_len, results = [], [], [], [], [], [], [], [] 55 | yield data_dict 56 | 57 | def Example(article, abstracts, label, entity, vocab, hps): 58 | 59 | # get ids of special tokens 60 | start_decoding = vocab.word2id(START_DECODING) 61 | stop_decoding = vocab.word2id(STOP_DECODING) 62 | pad_id = vocab.word2id(PAD_TOKEN) 63 | 64 | """process the label""" 65 | # pos 2 multi one-hot 66 | label_sentences = label2ids(label, hps.max_num_sequence) 67 | 68 | """process the article""" 69 | # create vocab and word 2 id 70 | article_value = value2ids(article, vocab, hps.document_length) 71 | # word 2 id 72 | article_words = article2ids(article, vocab) 73 | # num sentence 74 | article_len = len(article) 75 | # word level padding 76 | article_words = pad_sequences(article_words, maxlen=hps.sequence_length, value=pad_id) 77 | # sentence level padding 78 | pad_article = np.expand_dims(np.zeros(hps.sequence_length, dtype=np.int32), axis = 0) 79 | if article_words.shape[0] > hps.max_num_sequence: 80 | article_words = article_words[:hps.max_num_sequence] 81 | while article_words.shape[0] < hps.max_num_sequence: 82 | article_words = np.concatenate((article_words, pad_article)) 83 | 84 | """process the abstract""" 85 | # word 2 id 86 | abstracts_words = abstract2ids(abstracts, vocab) 87 | # add tokens 88 | abstracts_inputs, abstracts_targets = token2add(abstracts_words, hps.input_y2_max_length, start_decoding, stop_decoding) 89 | # padding 90 | abstracts_inputs = pad_sequences(abstracts_inputs, maxlen=hps.input_y2_max_length, value=pad_id) 91 | abstracts_targets = pad_sequences(abstracts_targets, maxlen=hps.input_y2_max_length, value=pad_id) 92 | # search id in value position 93 | abstract_targets = value2pos(abstracts_targets, article_value, vocab) 94 | # sentence level padding 95 | pad_abstracts = np.expand_dims(np.zeros(hps.input_y2_max_length, dtype=np.int32), axis = 0) 96 | if abstracts_inputs.shape[0] > hps.max_num_abstract: 97 | abstracts_inputs = abstracts_inputs[:hps.max_num_abstract] 98 | while abstracts_inputs.shape[0] < hps.max_num_abstract: 99 | abstracts_inputs = np.concatenate((abstracts_inputs, pad_abstracts)) 100 | if abstracts_targets.shape[0] > hps.max_num_abstract: 101 | abstracts_targets = abstracts_targets[:hps.max_num_abstract] 102 | while abstracts_targets.shape[0] < hps.max_num_abstract: 103 | abstracts_targets = np.concatenate((abstracts_targets, pad_abstracts)) 104 | # mask 105 | abstracts_len = abstract2len(abstracts, hps.input_y2_max_length) 106 | if abstracts_len.shape[0] > hps.max_num_abstract: 107 | abstracts_len = abstracts_len[:hps.max_num_abstract] 108 | while abstracts_len.shape[0] < hps.max_num_abstract: 109 | abstracts_len = np.concatenate((abstracts_len, [1])) 110 | 111 | return label_sentences, article_value, article_words, article_len, abstracts_targets, abstracts_inputs, abstracts_len 112 | 113 | def Batch_F(file_data, vocab_path, hps): 114 | 115 | vocab = Vocab(vocab_path, hps.vocab_size) 116 | for res in file_data: 117 | article_value, article_words, article_len = Example_P(res['article'],res['entity'], vocab, hps) 118 | data_dict ={} 119 | data_dict['article_value'] = [article_value] 120 | data_dict['article_words'] = [article_words] 121 | data_dict['article_len'] = [article_len] 122 | data_dict['original'] = res 123 | yield data_dict 124 | 125 | def Batch_P(data_path, vocab_path, hps): 126 | 127 | filenames = os.listdir(data_path) 128 | vocab = Vocab(vocab_path, hps.vocab_size) 129 | for cnt, filename in enumerate(filenames): 130 | pickle_path = os.path.join(data_path, filename) 131 | res = load(pickle_path) 132 | article_value, article_words, article_len = Example_P(res['article'],res['entity'], vocab, hps) 133 | data_dict ={} 134 | data_dict['article_value'] = [article_value] 135 | data_dict['article_words'] = [article_words] 136 | data_dict['article_len'] = [article_len] 137 | data_dict['original'] = res 138 | yield data_dict 139 | 140 | def Example_P(article, entity, vocab, hps): 141 | 142 | # get ids of special tokens 143 | pad_id = vocab.word2id(PAD_TOKEN) 144 | 145 | """process the article""" 146 | # create vocab and word 2 id 147 | article_value = value2ids(article, vocab, hps.document_length) 148 | # word 2 id 149 | article_words = article2ids(article, vocab) 150 | # num sentence 151 | article_len = len(article) 152 | # word level padding 153 | article_words = pad_sequences(article_words, maxlen=hps.sequence_length, value=pad_id) 154 | # sentence level padding 155 | pad_article = np.expand_dims(np.zeros(hps.sequence_length, dtype=np.int32), axis = 0) 156 | if article_words.shape[0] > hps.max_num_sequence: 157 | article_words = article_words[:hps.max_num_sequence] 158 | while article_words.shape[0] < hps.max_num_sequence: 159 | article_words = np.concatenate((article_words, pad_article)) 160 | 161 | return article_value, article_words, article_len 162 | 163 | class Vocab(object): 164 | def __init__(self, vocab_file, max_size): 165 | self._word_to_id = {} 166 | self._id_to_word = {} 167 | self._count = 0 168 | 169 | for w in [PAD_TOKEN, UNKNOWN_TOKEN, START_DECODING, STOP_DECODING]: 170 | self._word_to_id[w] = self._count 171 | self._id_to_word[self._count] = w 172 | self._count += 1 173 | 174 | with open(vocab_file, 'r') as vocab_f: 175 | for line in vocab_f: 176 | pieces = line.split() 177 | if len(pieces) != 2: 178 | continue 179 | w = pieces[0] 180 | if w in [UNKNOWN_TOKEN, PAD_TOKEN,START_DECODING, STOP_DECODING]: 181 | raise Exception('[UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is'% w) 182 | if w in self._word_to_id: 183 | raise Exception('Duplicated word in vocabulary file: %s' % w) 184 | self._word_to_id[w] = self._count 185 | self._id_to_word[self._count] = w 186 | self._count += 1 187 | if max_size != 0 and self._count >= max_size: 188 | break 189 | 190 | def word2id(self, word): 191 | if word not in self._word_to_id: 192 | return self._word_to_id[UNKNOWN_TOKEN] 193 | return self._word_to_id[word] 194 | 195 | def id2word(self, word_id): 196 | return self._id_to_word[word_id] 197 | 198 | def size(self): 199 | return self._count 200 | 201 | def label2ids(labels, label_size): 202 | res = np.zeros(label_size, dtype=np.int32) 203 | label_list = [ pos for pos in labels if pos < label_size] 204 | res[label_list] = 1 205 | return res 206 | 207 | def value2ids(article, vocab, document_length): 208 | value = [] 209 | pad_id = vocab.word2id(PAD_TOKEN) 210 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 211 | stop_id = vocab.word2id(STOP_DECODING) 212 | value.append(unk_id) 213 | value.append(stop_id) 214 | for sent in article: 215 | article_words = sent.split() 216 | for w in article_words: 217 | i = vocab.word2id(w) 218 | if i == unk_id: 219 | pass 220 | if i not in value: 221 | value.append(i) 222 | cnt = 4 223 | while len(value) < document_length: 224 | if cnt not in value: 225 | value.append(cnt) 226 | cnt += 1 227 | return np.array(value) 228 | 229 | def value2pos(abstract, value, vocab): 230 | poss = [] 231 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 232 | for sent in abstract: 233 | pos=[] 234 | for i in sent: 235 | if i in value: 236 | pos.append(np.argwhere(value==i)[0]) 237 | else: 238 | pos.append(np.argwhere(value==unk_id)[0]) 239 | poss.append(np.array(pos)) 240 | return np.array(poss) 241 | 242 | def article2ids(article, vocab): 243 | idss = [] 244 | oovs = [] 245 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 246 | for sent in article: 247 | ids = [] 248 | article_words = sent.split() 249 | for w in article_words: 250 | i = vocab.word2id(w) 251 | if i == unk_id: 252 | if w not in oovs: 253 | oovs.append(w) 254 | ids.append(i) 255 | else: 256 | ids.append(i) 257 | idss.append(ids) 258 | return idss 259 | 260 | def abstract2ids(abstracts, vocab): 261 | idss= [] 262 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 263 | for sent in abstracts: 264 | ids = [] 265 | abstract_words = sent.split() 266 | for w in abstract_words: 267 | i = vocab.word2id(w) 268 | if i == unk_id: 269 | ids.append(i) 270 | else: 271 | ids.append(i) 272 | idss.append(ids) 273 | return idss 274 | 275 | def token2add(abstracts, max_len, start_id, stop_id): 276 | inps = [] 277 | targets = [] 278 | for sequence in abstracts: 279 | inp = [start_id] + sequence[:] 280 | target = sequence[:] 281 | if len(inp) > max_len: 282 | inp = inp[:max_len] 283 | target = target[:max_len] 284 | else: 285 | target.append(stop_id) 286 | assert len(inp) == len(target) 287 | inps.append(inp) 288 | targets.append(target) 289 | return inps, targets 290 | 291 | def abstract2len(abstracts, max_len): 292 | length = [] 293 | for sent in abstracts: 294 | abstract_words = sent.split() 295 | if len(abstract_words) + 1 > max_len: 296 | length.append(max_len) 297 | else: 298 | length.append(len(abstract_words)+1) 299 | return np.array(length) 300 | 301 | def outputids2words(id_list, vocab): 302 | words = [] 303 | for i in id_list: 304 | w = vocab.id2word(i) 305 | words.append(w) 306 | return words 307 | 308 | class MyEncoder(json.JSONEncoder): 309 | def default(self, obj): 310 | if isinstance(obj, np.integer): 311 | return int(obj) 312 | elif isinstance(obj, np.floating): 313 | return float(obj) 314 | elif isinstance(obj, np.ndarray): 315 | return obj.tolist() 316 | else: 317 | return super(MyEncoder, self).default(obj) 318 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | TextSumma 3 | = 4 | Just give it a shot for reproducing the ACL 2016 paper [*Neural Summarization by Extracting Sentences and Words*](https://arxiv.org/abs/1603.07252). The original code of author can be found [*here*](https://github.com/cheng6076/NeuralSum). 5 | 6 | ## Quick Start 7 | - **Step1 : Obtain datasets** 8 | Go [*here*](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) to download the corpus and get the scripts of ***one-billion-word-language-modeling-benchmark*** for training the word vectors. Run this and see more [*details*](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark/blob/master/README.corpus_generation): 9 | ```bash 10 | $ tar --extract -v --file ../statmt.org/tar_archives/training-monolingual.tgz --wildcards training-monolingual/news.20??.en.shuffled 11 | $ ./scripts/get-data.sh 12 | ``` 13 | The dataset ***cnn-dailymail*** with highlights in this paper offered by the authors is in [*here*](https://docs.google.com/uc?id=0B0Obe9L1qtsnSXZEd0JCenIyejg&export=download) and vocab in this repository. 14 | - **Step2 : Preprocess** 15 | Run this script to training the word vectors in the dataset ***one-billion-word-language-modeling-benchmark***: 16 | ```bash 17 | $ python train_w2v.py './one-billion-word-benchmark' 'output_gensim_model' 'output_word_vector' 18 | ``` 19 | Run this script to extract the sentences, labels and entitys in the dataset ***cnn-dailymail*** and get them pickled: 20 | ```bash 21 | $ python prepro.py './source_dir/' './target_dir/' 22 | ``` 23 | - **Step3 : Install nvidia-docker** 24 | Go for GPUs acceleration. See [*installation*](https://github.com/NVIDIA/nvidia-docker) to get more information for help. 25 | - **Step4 : Obtain *Deepo*** 26 | A series of Docker images (and their generator) that allows you to quickly set up your deep learning research environment. See [*Deepo*](https://github.com/ufoym/deepo) to get more details. Run this and turn on the **port 6006** for the tensorboard: 27 | 28 | ```bash 29 | $ nvidia-docker pull ufoym/deepo 30 | $ nvidia-docker run -p 0.0.0.0:6006:6006 -it -v /home/usrs/yourdir:/data ufoym/deepo env LANG=C.UTF-8 bash 31 | ``` 32 | enter the bash of *Deepo*, run pip to install the rest: 33 | ```bash 34 | $ pip install gensim rouge tflearn tqdm 35 | ``` 36 | * **Step5: Train the model and predict** 37 | 38 | Please add option **-h** to get more help in flag settings. 39 | ```bash 40 | $ python train_model.py 41 | $ python predict_model.py 42 | ``` 43 | * **Requirements**: 44 | Python3.6 Tensorflow 1.8.0 45 | 46 | ## Model details 47 | 48 | * **Structure NN-SE** 49 | 50 | sentence_model 51 | 52 | * **Sentence extractor** 53 | Here is the single step of the customzied LSTM with a score layer. 54 | ```python 55 | def lstm_single_step(self, St, At, h_t_minus_1, c_t_minus_1, p_t_minus_1): 56 | p_t_minus_1 = tf.reshape(p_t_minus_1, [-1, 1]) 57 | # Xt = p_t_minus_1 * St 58 | Xt = tf.multiply(p_t_minus_1, St) 59 | # dropout 60 | Xt = tf.nn.dropout(Xt, keep_prob=self.dropout_keep_prob) 61 | # compute the gate of input, forget, output 62 | i_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_i) + tf.matmul(h_t_minus_1, self.U_i) + self.b_i) 63 | f_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_f) + tf.matmul(h_t_minus_1, self.U_f) + self.b_f) 64 | c_t_candidate = tf.nn.tanh(tf.matmul(Xt, self.W_c) + tf.matmul(h_t_minus_1, self.U_c) + self.b_c) 65 | c_t = f_t * c_t_minus_1 + i_t * c_t_candidate 66 | o_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_o) + tf.matmul(h_t_minus_1, self.U_o) + self.b_o) 67 | h_t = o_t * tf.nn.tanh(c_t) 68 | # compute prob 69 | with tf.name_scope("Score_Layer"): 70 | concat_h = tf.concat([At, h_t], axis=1) 71 | concat_h_dropout = tf.nn.dropout(concat_h, keep_prob=self.dropout_keep_prob) 72 | score = tf.layers.dense(concat_h_dropout, 1, activation=tf.nn.tanh, name="score", reuse=tf.AUTO_REUSE) 73 | # activation and normalization 74 | p_t = self.sigmoid_norm(score) 75 | return h_t, c_t, p_t 76 | ``` 77 | * **Curriculum learning** 78 | Actually new to curriculum learning, just simply connect the weight of the true labels and those predicted with the rate of steps. 79 | ```python 80 | def weight_control(self, time_step, p_t): 81 | # curriculum learning control the weight between true labels and those predicted 82 | labels = tf.cast(self.input_y1[:,time_step:time_step+1], dtype=tf.float32) 83 | start = tf.cast(self.cur_step_start, dtype=tf.float32) 84 | end = tf.cast(self.cur_step_end, dtype=tf.float32) 85 | global_step = tf.cast(self.global_step, dtype=tf.float32) 86 | weight = tf.divide(tf.subtract(global_step, start), tf.subtract(end, start)) 87 | merge = (1. - weight) * labels + weight * p_t 88 | cond = tf.greater(start, global_step) 89 | p_t_curr = tf.cond(cond, lambda:labels, lambda:merge) 90 | return p_t_curr 91 | ``` 92 | * **Loss function** 93 | Coding the loss function manually instead of using the function *tf.losses.sigmoid_cross_entropy* cause the logits is between 0 and 1 with sigmoid activation and normalization already. 94 | ```python 95 | # loss:z*-log(x)+(1-z)*-log(1-x) 96 | # z=0 --> loss:-log(1-x) 97 | # z=1 --> loss:-log(x) 98 | with tf.name_scope("loss_sentence"): 99 | logits = tf.convert_to_tensor(self.logits) 100 | labels = tf.cast(self.input_y1, logits.dtype) 101 | zeros = tf.zeros_like(labels, dtype=labels.dtype) 102 | ones = tf.ones_like(logits, dtype=logits.dtype) 103 | cond = ( labels > zeros ) 104 | logits_ = tf.where(cond, logits, ones-logits) 105 | logits_log = tf.log(logits_) 106 | losses = -logits_log 107 | losses *= self.mask 108 | loss = tf.reduce_sum(losses, axis=1) 109 | loss = tf.reduce_mean(loss) 110 | ``` 111 | 112 | 113 | ## Performance 114 | 115 | * **Probability for the sentences in several timesteps** 116 | 117 | sentence_model 118 | 119 | * **Training loss** 120 | 121 | sentence_loss 122 | 123 | * **Figure** 124 | Some results seems to be nice. 125 | ```json 126 | { 127 | "entity": { 128 | "@entity31": "Jason Kernick", 129 | "@entity1": "Manchester", 130 | "@entity9": "Ashton Canal", 131 | "@entity46": "Environment Agency", 132 | "@entity44": "Etihad stadium", 133 | "@entity45": "Manchester City", 134 | "@entity115": "Easter Sunday", 135 | "@entity85": "Clayton", 136 | "@entity66": "Richard Kernick", 137 | "@entity109": "Etihad", 138 | "@entity137": "Greater Manchester Fire and Rescue Service", 139 | "@entity136": "Salford" 140 | }, 141 | "abstract": [ 142 | "the @entity9 became filled with heavy suds due to a 6ft wall of foam created by fire crews tackling a blaze", 143 | "the fire at a nearby chemical plant saw water from fire service mix with detergents that were being stored there", 144 | "the foam covered a 30 metre stretch of the canal near @entity45 's @entity44 in @entity85" 145 | ], 146 | "article": [ 147 | "a @entity1 canal was turned into a giant bubble bath after fire crews tackling a nearby chemical plant blaze saw their water mix with a detergent creating a six foot wall of foam", 148 | "the @entity9 was filled with heavy suds which appeared after a fire at an industrial unit occupied by a drug development company", 149 | "it is believed that the water used by firefighters to dampen down the flames mixed with the detergent being stored in the burning buildings", 150 | "now the @entity46 have launched an investigation to assess if the foam has impacted on wildlife after concerns were raised for the safety of fish in the affected waters", 151 | "a spokesman for the agency said : ' @entity46 is investigating after receiving reports of foam on a 30 metre stretch of the @entity9 , @entity1", 152 | "' initial investigations by @entity46 officers show that there appears to have been minimal impact on water quality , but our officers will continue to monitor and respond as necessary", 153 | "@entity66 takes a picture on his mobile phone of his boat trying to negotiate a lock and the foam , which ran into the @entity9 a cyclist takes a picture on his mobile phone as the foam comes up on to the cycle path", 154 | "the @entity46 are investigating to assess of the foam has harmed any wildlife the foam reached as high as six foot in some places and covered a 30 metre stretch along the water in the @entity85 area of @entity1 ' we are working with the fire service and taking samples of the foam to understand what it is made of , and what impact it may have on local wildlife in and around the canal", 155 | "' at the height of the blaze on sunday afternoon , which caused the foam , up to 50 firefighters were tackling the fire and police were also forced to wear face masks", 156 | "families in east @entity1 were urged to say indoors after a blast was reported at the industrial unit , which is just a few hundred yards from the @entity45 training ground on the @entity109 campus", 157 | "the fire at the chemical factory next to @entity45 's @entity44 send a huge plume of smoke across the city on @entity115 police wearing face masks went around neighbouring streets with loudspeakers urging people to stay inside while the fire raged police officers also told children on bikes and mothers pushing prams near the scene to go home and went around neighbouring streets with loudspeakers urging people to stay inside", 158 | "a huge plume of smoke also turned the sky black and could be seen right across the city and even into @entity136", 159 | "according to @entity137 , the fire was fueled by wooden pallets and unidentified chemicals but an investigation into the cause of the fire is still ongoing ." 160 | ], 161 | "label": [0, 1, 4, 7, 10], 162 | "score": [ 163 | [10, 0.6629698276519775, "the fire at the chemical factory next to @entity45 's @entity44 send a huge plume of smoke across the city on @entity115 police wearing face masks went around neighbouring streets with loudspeakers urging people to stay inside while the fire raged police officers also told children on bikes and mothers pushing prams near the scene to go home and went around neighbouring streets with loudspeakers urging people to stay inside"], 164 | [0, 0.6484572291374207, "a @entity1 canal was turned into a giant bubble bath after fire crews tackling a nearby chemical plant blaze saw their water mix with a detergent creating a six foot wall of foam"], 165 | [7, 0.5045493841171265, "the @entity46 are investigating to assess of the foam has harmed any wildlife the foam reached as high as six foot in some places and covered a 30 metre stretch along the water in the @entity85 area of @entity1 ' we are working with the fire service and taking samples of the foam to understand what it is made of , and what impact it may have on local wildlife in and around the canal"], 166 | [1, 0.45766133069992065, "the @entity9 was filled with heavy suds which appeared after a fire at an industrial unit occupied by a drug development company"], 167 | [4, 0.3478981852531433, "a spokesman for the agency said : ' @entity46 is investigating after receiving reports of foam on a 30 metre stretch of the @entity9 , @entity1"], 168 | [3, 0.3398599326610565, "now the @entity46 have launched an investigation to assess if the foam has impacted on wildlife after concerns were raised for the safety of fish in the affected waters"], 169 | [8, 0.3396754860877991, "' at the height of the blaze on sunday afternoon , which caused the foam , up to 50 firefighters were tackling the fire and police were also forced to wear face masks"], 170 | [6, 0.32800495624542236, "@entity66 takes a picture on his mobile phone of his boat trying to negotiate a lock and the foam , which ran into the @entity9 a cyclist takes a picture on his mobile phone as the foam comes up on to the cycle path"], 171 | [9, 0.29064181447029114, "families in east @entity1 were urged to say indoors after a blast was reported at the industrial unit , which is just a few hundred yards from the @entity45 training ground on the @entity109 campus"], 172 | [2, 0.25459226965904236, "it is believed that the water used by firefighters to dampen down the flames mixed with the detergent being stored in the burning buildings"], 173 | [5, 0.2020452618598938, "' initial investigations by @entity46 officers show that there appears to have been minimal impact on water quality , but our officers will continue to monitor and respond as necessary"], 174 | [12, 0.05926991254091263, "according to @entity137 , the fire was fueled by wooden pallets and unidentified chemicals but an investigation into the cause of the fire is still ongoing ."], 175 | [11, 0.05400165915489197, "a huge plume of smoke also turned the sky black and could be seen right across the city and even into @entity136"] 176 | ] 177 | } 178 | 179 | ``` 180 | This one remains a little complicated. 181 | ```json 182 | { 183 | "entity": { 184 | "@entity27": "Belichick", 185 | "@entity24": "Hiss", 186 | "@entity80": "Gisele Bündchen", 187 | "@entity97": "Sport Illustrated", 188 | "@entity115": "Julian Edelman", 189 | "@entity84": "Washington", 190 | "@entity86": "Seattle Seahawks", 191 | "@entity110": "Massachusetts US Senator", 192 | "@entity3": "Patriots", 193 | "@entity2": "Super Bowl", 194 | "@entity0": "Obama", 195 | "@entity4": "White House", 196 | "@entity8": "South Lawn", 197 | "@entity56": "Boomer Esiason", 198 | "@entity111": "Linda Holliday", 199 | "@entity75": "Donovan McNabb", 200 | "@entity96": "Las Vegas", 201 | "@entity30": "Chicago", 202 | "@entity33": "Boston", 203 | "@entity102": "Rob Gronkowski", 204 | "@entity99": "CBS", 205 | "@entity98": "Les Moonves", 206 | "@entity108": "Bellichick", 207 | "@entity109": "John Kerry", 208 | "@entity95": "Floyd Mayweather Jr.", 209 | "@entity94": "Manny Pacquiao", 210 | "@entity117": "Danny Amendola", 211 | "@entity62": "Bush Administration", 212 | "@entity44": "Bob Kraft", 213 | "@entity47": "Super Bowl MVP", 214 | "@entity68": "Showoffs", 215 | "@entity66": "US Senator", 216 | "@entity67": "White House Correspondents dinner", 217 | "@entity113": "George W. Bush", 218 | "@entity48": "Brady" 219 | }, 220 | "abstract": [ 221 | "@entity48 cited ' prior family commitments ' in bowing out of meeting with @entity0", 222 | "has been to the @entity4 to meet president @entity113 for previous @entity2 wins" 223 | ], 224 | "article": [ 225 | "president @entity0 invited the @entity2 champion @entity3 to the @entity4 on thursday - but could n't help but get one last deflategate joke in", 226 | "the president opened his speech on the @entity8 by remarking ' that whole ( deflategate ) story got blown out of proportion , ' referring to an investigation that 11 out of 12 footballs used in the afc championship game were under - inflated", 227 | "but then came the zinger : ' i usually tell a bunch of jokes at these events , but with the @entity3 in town i was worried that 11 out of 12 of them would fall flat", 228 | "coach @entity27 , who is notoriously humorless , responded by giving the president a thumbs down", 229 | "@entity0 was flanked by @entity27 and billionaire @entity3 owner @entity44", 230 | "missing from the occasion , though was the @entity47 and the team 's biggest star - @entity48", 231 | "a spokesman for the team cited ' prior family commitments ' as the reason @entity48 , 37 , did n't attend the ceremony", 232 | "sports commentators , including retired football great @entity56 , speculated that @entity48 snubbed @entity0 because he 's from the ' wrong political party", 233 | "' the superstar athlete has been to the @entity4 before", 234 | "he does have three other @entity2 rings , afterall", 235 | "but all the prior championships were under the @entity62", 236 | "february 's win was the first for the @entity3 since @entity0 took office", 237 | "@entity48 has also met @entity0 at least once before , as well", 238 | "he was pictured with the then - @entity66 at the 2005 @entity67", 239 | "@entity68 : the @entity3 gathered the team 's four @entity2 trophies won under coach @entity27 ( right , next to president @entity0 )", 240 | "@entity48 won his fourth @entity2 ring in february - and his first since president @entity0 took office @entity48 met president @entity0 at least once", 241 | "he is pictured here with the then - @entity66 and rival quarterback @entity75 in 2005 it 's not clear what @entity48 's prior commitment was", 242 | "his supermodel wife @entity80 , usually active on social media , gives no hint where the family is today if not in @entity84", 243 | "@entity48 led the @entity3 to his fourth @entity2 victory in february after defeating the @entity86 28 - 24", 244 | "despite his arm and movement being somewhat diminished by age , @entity48 's leadership and calm under pressure also won him @entity47 - his third", 245 | "whatever is taking up @entity48 's time this week , he made time next week to be ringside at the @entity95 - @entity94 fight in @entity96 next weekend", 246 | "according to @entity97 , @entity48 appealed directly to @entity99 president @entity98 for tickets to the much - touted matchup", 247 | "@entity3 tight end @entity102 could n't help but mug for the camera as the commander in chief gave a speech @entity0 walks with billionaire @entity3 owner @entity44 and coach @entity108 to the speech secretary of state @entity109 , a former @entity110 , greets @entity27 's girlfriend @entity111 at the ceremony @entity48 went to the @entity4 to meet president @entity113 after winning the @entity2 in 2005 and in 2004", 248 | "he 's not going to be there this year @entity3 players @entity115 and @entity117 snap pics in the @entity4 before meeting president @entity0 on thursday" 249 | ], 250 | "label": [0, 6, 7, 12, 14, 15, 18, 22, 23], 251 | "score": [ 252 | [1, 0.8683828115463257, "the president opened his speech on the @entity8 by remarking ' that whole ( deflategate ) story got blown out of proportion , ' referring to an investigation that 11 out of 12 footballs used in the afc championship game were under - inflated"], 253 | [0, 0.8339700102806091, "president @entity0 invited the @entity2 champion @entity3 to the @entity4 on thursday - but could n't help but get one last deflategate joke in"], 254 | [22, 0.7730730772018433, "@entity3 tight end @entity102 could n't help but mug for the camera as the commander in chief gave a speech @entity0 walks with billionaire @entity3 owner @entity44 and coach @entity108 to the speech secretary of state @entity109 , a former @entity110 , greets @entity27 's girlfriend @entity111 at the ceremony @entity48 went to the @entity4 to meet president @entity113 after winning the @entity2 in 2005 and in 2004"], 255 | [14, 0.7569227814674377, "@entity68 : the @entity3 gathered the team 's four @entity2 trophies won under coach @entity27 ( right , next to president @entity0 )"], 256 | [15, 0.6214166879653931, "@entity48 won his fourth @entity2 ring in february - and his first since president @entity0 took office @entity48 met president @entity0 at least once"], 257 | [18, 0.4963235855102539, "@entity48 led the @entity3 to his fourth @entity2 victory in february after defeating the @entity86 28 - 24"], 258 | [16, 0.45303720235824585, "he is pictured here with the then - @entity66 and rival quarterback @entity75 in 2005 it 's not clear what @entity48 's prior commitment was"], 259 | [5, 0.4204302430152893, "missing from the occasion , though was the @entity47 and the team 's biggest star - @entity48"], 260 | [7, 0.41678884625434875, "sports commentators , including retired football great @entity56 , speculated that @entity48 snubbed @entity0 because he 's from the ' wrong political party"], 261 | [20, 0.4135805070400238, "whatever is taking up @entity48 's time this week , he made time next week to be ringside at the @entity95 - @entity94 fight in @entity96 next weekend"], 262 | [6, 0.3958345353603363, "a spokesman for the team cited ' prior family commitments ' as the reason @entity48 , 37 , did n't attend the ceremony"], 263 | [4, 0.37495893239974976, "@entity0 was flanked by @entity27 and billionaire @entity3 owner @entity44"], 264 | [21, 0.3466879427433014, "according to @entity97 , @entity48 appealed directly to @entity99 president @entity98 for tickets to the much - touted matchup"], 265 | [19, 0.3316606283187866, "despite his arm and movement being somewhat diminished by age , @entity48 's leadership and calm under pressure also won him @entity47 - his third"], 266 | [2, 0.29267093539237976, "but then came the zinger : ' i usually tell a bunch of jokes at these events , but with the @entity3 in town i was worried that 11 out of 12 of them would fall flat"], 267 | [23, 0.27186375856399536, "he 's not going to be there this year @entity3 players @entity115 and @entity117 snap pics in the @entity4 before meeting president @entity0 on thursday"], 268 | [11, 0.26710671186447144, "february 's win was the first for the @entity3 since @entity0 took office"], 269 | [17, 0.17511016130447388, "his supermodel wife @entity80 , usually active on social media , gives no hint where the family is today if not in @entity84"], 270 | [3, 0.16352418065071106, "coach @entity27 , who is notoriously humorless , responded by giving the president a thumbs down"], 271 | [13, 0.14906153082847595, "he was pictured with the then - @entity66 at the 2005 @entity67"], 272 | [12, 0.1384015828371048, "@entity48 has also met @entity0 at least once before , as well"], 273 | [10, 0.07186555117368698, "but all the prior championships were under the @entity62"], 274 | [8, 0.07148505747318268, "' the superstar athlete has been to the @entity4 before"], 275 | [9, 0.035264041274785995, "he does have three other @entity2 rings , afterall"] 276 | ] 277 | } 278 | ``` 279 | Pls get more predicted results from [*here*](https://drive.google.com/open?id=1cXrR1kY-tlxArB-F9FSZba2T2RscAYVS) 280 | 281 | ## Discuss 282 | 283 | - Tuning the learning rate. 284 | - Freeze the weights of embedding for several steps or not. 285 | - Choose a proper step range for shift the value gradually to the probability predicted by the model. 286 | - The initialization of the weights and bias. 287 | - Find a proper way to evaluate while training.(just observe the loss in validation with early stop by my way) 288 | 289 | ## TODO list 290 | - NN-WE word extractor remain to be done. 291 | - Remain oov and ner problems while using raw data. 292 | - Using Threading and Queues in tensorflow to load the batch. 293 | 294 | ## Credits 295 | - Thanks for the authors of the paper. 296 | - Borrow some code from [*text_classification*](https://github.com/brightmart/text_classification) and learn a lot. 297 | - A great job [*pointer-generator*](https://github.com/abisee/pointer-generator) in text summarization that should be appreciated. 298 | -------------------------------------------------------------------------------- /textsum_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | from tensorflow.contrib.seq2seq.python.ops import * 4 | import numpy as py 5 | 6 | class Neuralmodel: 7 | def __init__(self,extract_sentence_flag,is_training,vocab_size,batch_size,embed_size,learning_rate,cur_step,decay_step,decay_rate,max_num_sequence, 8 | sequence_length,filter_sizes,feature_map,use_highway_flag,highway_layers,hidden_size,document_length,max_num_abstract,beam_width, 9 | attention_size,input_y2_max_length,clip_gradients=5.0, initializer=tf.random_normal_initializer(stddev=0.1)): 10 | """init all hyperparameter:""" 11 | self.initializer = tf.contrib.layers.xavier_initializer() 12 | self.initializer_uniform = tf.random_uniform_initializer(minval=-0.05,maxval=0.05) 13 | 14 | """Basic""" 15 | self.extract_sentence_flag = extract_sentence_flag 16 | self.vocab_size = vocab_size 17 | self.batch_size = batch_size 18 | self.embed_size = embed_size 19 | 20 | """learning_rate""" 21 | self.is_training = is_training 22 | self.tst = tf.placeholder(tf.bool, name='is_training_flag') 23 | self.learning_rate = tf.Variable(learning_rate, trainable=False, name='learning_rate') 24 | self.cur_step_start = tf.Variable(cur_step[0], trainable=False, name='start_for_cur_learning') 25 | self.cur_step_end = tf.Variable(cur_step[1], trainable=False, name='end_for_cur_learning') 26 | self.decay_step = decay_step 27 | self.decay_rate = decay_rate 28 | 29 | """Overfit""" 30 | self.dropout_keep_prob = tf.placeholder(tf.float32, name='dropout_keep_prob') 31 | self.clip_gradients = clip_gradients 32 | 33 | """CNN (word)""" 34 | self.max_num_sequence = max_num_sequence 35 | self.sequence_length = sequence_length 36 | self.filter_sizes = filter_sizes 37 | self.feature_map = feature_map 38 | 39 | """Highway Network""" 40 | self.use_highway_flag = use_highway_flag 41 | self.highway_layers = highway_layers 42 | 43 | """LSTM (sentence)""" 44 | self.hidden_size = hidden_size 45 | self.document_length = document_length 46 | 47 | """LSTM + Attention (generating)""" 48 | self.max_num_abstract = max_num_abstract 49 | self.beam_width = beam_width 50 | self.attention_size = attention_size 51 | self.input_y2_max_length = input_y2_max_length 52 | 53 | """Input""" 54 | self.input_x = tf.placeholder(tf.int32, [None, self.max_num_sequence, self.sequence_length], name="input_x") 55 | 56 | if extract_sentence_flag: 57 | self.input_y1 = tf.placeholder(tf.int32, [None, self.max_num_sequence], name="input_y_sentence") 58 | self.input_y1_length = tf.placeholder(tf.int32, [None], name="input_y_length") 59 | self.mask = tf.sequence_mask(self.input_y1_length, self.max_num_sequence, dtype=tf.float32, name='input_y_mask') 60 | self.cur_learning = tf.placeholder(tf.bool, name="use_cur_lr_strategy") 61 | else: 62 | self.input_y2_length = tf.placeholder(tf.int32, [None, self.max_num_abstract], name="input_y_word_length") 63 | self.input_y2 = tf.placeholder(tf.int32, [None, self.max_num_abstract, self.input_y2_max_length], name="input_y_word") 64 | self.input_decoder_x = tf.placeholder(tf.int32, [None, self.max_num_abstract, self.input_y2_max_length], name="input_decoder_x") 65 | self.value_decoder_x = tf.placeholder(tf.int32, [None, self.document_length], name="value_decoder_x") 66 | self.mask_list = [tf.sequence_mask(tf.squeeze(self.input_y2_length[idx:idx+1], axis=0), self.input_y2_max_length, dtype=tf.float32) for idx in range(self.batch_size)] 67 | self.targets = [tf.squeeze(self.input_y2[idx:idx+1], axis=0) for idx in range(self.batch_size)] 68 | 69 | """Count""" 70 | self.global_step = tf.Variable(0, trainable=False, name='Global_step') 71 | self.epoch_step = tf.Variable(0, trainable=False, name='Epoch_step') 72 | self.epoch_increment = tf.assign(self.epoch_step, tf.add(self.epoch_step, tf.constant(1))) 73 | self.global_increment = tf.assign(self.global_step, tf.add(self.global_step, tf.constant(1))) 74 | 75 | """Process""" 76 | self.instantiate_weights() 77 | 78 | """Logits""" 79 | if extract_sentence_flag: 80 | self.logits = self.inference() 81 | else: 82 | self.logits, self.final_sequence_lengths = self.inference() 83 | 84 | if not self.is_training: 85 | return 86 | 87 | if extract_sentence_flag: 88 | print('using sentence extractor...') 89 | self.loss_val = self.loss_sentence() 90 | else: 91 | print('using word extractor...') 92 | self.loss_val = self.loss_word() 93 | 94 | self.train_op = self.train() 95 | self.train_op_frozen = self.train_frozen() 96 | self.merge = tf.summary.merge_all() 97 | 98 | def instantiate_weights(self): 99 | with tf.name_scope("Embedding"): 100 | self.Embedding = tf.get_variable("embedding",shape=[self.vocab_size, self.embed_size],initializer=self.initializer) 101 | self.Embedding_ = tf.get_variable("embedding_", shape=[2, self.hidden_size], initializer=self.initializer) 102 | 103 | with tf.name_scope("Cell"): 104 | # input gate 105 | self.W_i = tf.get_variable("W_i", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform) 106 | self.U_i = tf.get_variable("U_i", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform) 107 | self.b_i = tf.get_variable("b_i", shape=[self.hidden_size],initializer=tf.zeros_initializer()) 108 | # forget gate 109 | self.W_f = tf.get_variable("W_f", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform) 110 | self.U_f = tf.get_variable("U_f", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform) 111 | self.b_f = tf.get_variable("b_f", shape=[self.hidden_size],initializer=tf.ones_initializer()) 112 | # cell gate 113 | self.W_c = tf.get_variable("W_c", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform) 114 | self.U_c = tf.get_variable("U_c", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform) 115 | self.b_c = tf.get_variable("b_c", shape=[self.hidden_size],initializer=tf.zeros_initializer()) 116 | # output gate 117 | self.W_o = tf.get_variable("W_o", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform) 118 | self.U_o = tf.get_variable("U_o", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform) 119 | self.b_o = tf.get_variable("b_o", shape=[self.hidden_size],initializer=tf.zeros_initializer()) 120 | 121 | def document_reader(self): 122 | """1.embedding""" 123 | # self.input_x : [batch_size, max_num_sequence, sentence_length] 124 | # self.embedded_words : [max_num_sequence, sentence_length, embed_size] 125 | # self.embedded_words_expanded : [batch_size, max_num_sequence, sentence_length, embed_size] 126 | embedded_words = [] 127 | for idx in range(self.batch_size): 128 | self.embedded_words = tf.nn.embedding_lookup(self.Embedding, self.input_x[idx:idx+1]) 129 | self.embedded_words_squeezed = tf.squeeze(self.embedded_words, axis=0) 130 | self.embedded_words_expanded = tf.expand_dims(self.embedded_words_squeezed, axis=-1) 131 | embedded_words.append(self.embedded_words_expanded) 132 | 133 | """2.CNN(word)""" 134 | # conv: [max_num_sequence, sequence_length-filter_size+1, 1, num_filters] 135 | # pooled: [max_num_sequence, 1, 1, num_filters] 136 | # pooled_temp: [max_num_sequence, num_filters * class_filters] 137 | # cnn_outputs: [batch_size, max_num_sequence, num_filters * class_filters] 138 | with tf.name_scope("CNN-Layer-Encoder"): 139 | pooled_outputs = [] 140 | for m, conv_s in enumerate(embedded_words): 141 | pooled_temp = [] 142 | for i, filter_size in enumerate(self.filter_sizes): 143 | with tf.variable_scope("convolution-pooling-%s" % filter_size, reuse=tf.AUTO_REUSE): 144 | filter=tf.get_variable("filter-%s"%filter_size,[filter_size,self.embed_size,1,self.feature_map[i]],initializer=self.initializer) 145 | conv=tf.nn.conv2d(conv_s, filter, strides=[1,1,1,1], padding="VALID",name="conv") 146 | conv=tf.contrib.layers.batch_norm(conv, is_training = self.tst, scope='cnn_bn_') 147 | b=tf.get_variable("b-%s"%filter_size,[self.feature_map[i]]) 148 | h=tf.nn.tanh(tf.nn.bias_add(conv,b),"tanh") 149 | pooled=tf.nn.max_pool(h, ksize=[1,self.sequence_length-filter_size+1,1,1], strides=[1,1,1,1], padding='VALID',name="pool") 150 | pooled_temp.append(pooled) 151 | pooled_temp = tf.concat(pooled_temp, axis=3) 152 | pooled_temp = tf.reshape(pooled_temp, [-1, self.hidden_size]) 153 | """3.Highway Network""" 154 | if self.use_highway_flag: 155 | pooled_temp = self.highway(pooled_temp, pooled_temp.get_shape()[1], m, self.highway_layers, 0) 156 | pooled_outputs.append(pooled_temp) 157 | cnn_outputs = tf.stack(pooled_outputs, axis=0) 158 | 159 | """4.LSTM(sentence)""" 160 | # lstm_outputs: [batch_size, max_time, hidden_size] 161 | # cell_state: [batch_size, hidden_size] 162 | with tf.variable_scope("LSTM-Layer-Encoder", initializer=self.initializer_uniform): 163 | lstm_cell = tf.nn.rnn_cell.LSTMCell(self.hidden_size) 164 | lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob = self.dropout_keep_prob) 165 | lstm_outputs, cell_state = tf.nn.dynamic_rnn(lstm_cell, cnn_outputs, dtype = tf.float32) 166 | return cnn_outputs, lstm_outputs, cell_state 167 | 168 | def highway(self, input_, size, mark, layer_size=1, bias=-2.0, f=tf.nn.relu): 169 | # t = sigmoid( W * y + b) 170 | # z = t * g(W * y + b) + (1 - t) * y 171 | # where g is nonlinearity, t is transform gate, and (1 - t) is carry gate. 172 | 173 | def linear(input_, output_size, mark, scope=None): 174 | shape = input_.get_shape().as_list() 175 | if len(shape) != 2: 176 | raise ValueError("Linear is expecting 2D arguments: %s" % str(shape)) 177 | if not shape[1]: 178 | raise ValueError("Linear expects shape[1] of arguments: %s" % str(shape)) 179 | input_size = shape[1] 180 | with tf.variable_scope(scope or "simplelinear"): 181 | W = tf.get_variable("W_%d" % mark, [output_size, input_size], initializer=self.initializer_uniform, dtype = input_.dtype) 182 | b = tf.get_variable("b_%d" % mark, [output_size], initializer=self.initializer_uniform, dtype = input_.dtype) 183 | return tf.matmul(input_, tf.transpose(W)) + b 184 | 185 | with tf.variable_scope("highway"): 186 | for idx in range(layer_size): 187 | g = f(linear(input_, size, mark, scope="highway_lin_%d" % idx)) 188 | t = tf.sigmoid(linear(input_, size, mark, scope="highway_gate_%d" % idx ) + bias) 189 | output = t * g + (1. - t) * input_ 190 | input_ = output 191 | return output 192 | 193 | def sigmoid_norm(self, score): 194 | # sigmoid(tanh) --> sigmoid([-1,1]) --> [0.26,0.73] --> [0,1] 195 | with tf.name_scope("sigmoid_norm"): 196 | Min = tf.sigmoid(tf.constant(-1, dtype=tf.float32)) 197 | Max = tf.sigmoid(tf.constant(1, dtype=tf.float32)) 198 | prob = tf.sigmoid(score) 199 | prob_norm = (prob - Min) / (Max - Min) 200 | return prob_norm 201 | 202 | def lstm_single_step(self, St, At, h_t_minus_1, c_t_minus_1, p_t_minus_1): 203 | 204 | p_t_minus_1 = tf.reshape(p_t_minus_1, [-1, 1]) 205 | # Xt = p_t_minus_1 * St 206 | Xt = tf.multiply(p_t_minus_1, St) 207 | # dropout 208 | Xt = tf.nn.dropout(Xt, self.dropout_keep_prob) 209 | # input forget output compute 210 | i_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_i) + tf.matmul(h_t_minus_1, self.U_i) + self.b_i) 211 | f_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_f) + tf.matmul(h_t_minus_1, self.U_f) + self.b_f) 212 | c_t_candidate = tf.nn.tanh(tf.matmul(Xt, self.W_c) + tf.matmul(h_t_minus_1, self.U_c) + self.b_c) 213 | c_t = f_t * c_t_minus_1 + i_t * c_t_candidate 214 | o_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_o) + tf.matmul(h_t_minus_1, self.U_o) + self.b_o) 215 | h_t = o_t * tf.nn.tanh(c_t) 216 | tf.summary.histogram("input:pt*st", Xt) 217 | tf.summary.histogram("attenton_z_value", At) 218 | tf.summary.histogram("hidden_z_value", h_t) 219 | # prob compute 220 | with tf.name_scope("Score_Layer"): 221 | concat_h = tf.concat([At, h_t], axis=1) 222 | tf.summary.histogram("concat", concat_h) 223 | concat_h_dropout = tf.nn.dropout(concat_h, keep_prob=self.dropout_keep_prob) 224 | score = tf.layers.dense(concat_h_dropout, 1, activation=tf.nn.tanh, name="score", reuse=tf.AUTO_REUSE) 225 | p_t = self.sigmoid_norm(score) 226 | 227 | return h_t, c_t, p_t 228 | 229 | def weight_control(self, time_step, p_t): 230 | # curriculum learning control the weight between true labels and those predicted 231 | labels = tf.cast(self.input_y1[:,time_step:time_step+1], dtype=tf.float32) 232 | start = tf.cast(self.cur_step_start, dtype=tf.float32) 233 | end = tf.cast(self.cur_step_end, dtype=tf.float32) 234 | global_step = tf.cast(self.global_step, dtype=tf.float32) 235 | weight = tf.divide(tf.subtract(global_step, start), tf.subtract(end, start)) 236 | merge = (1. - weight) * labels + weight * p_t 237 | cond = tf.greater(start, global_step) 238 | p_t_curr = tf.cond(cond, lambda:labels, lambda:merge) 239 | return p_t_curr 240 | 241 | def sentence_extractor(self): 242 | """4.1.1 LSTM(decoder)""" 243 | # decoder input each time: activation (MLP(h_t:At)) * St 244 | # h_t: decoder LSTM output 245 | # At: encoder LSTM output (document level) 246 | # St: encoder CNN output (sentence level) 247 | # probability value: [p_t = activation(MLP(h_t:At)) for h_t in h_t_steps ] 248 | with tf.name_scope("LSTM-Layer-Decoder"): 249 | # initialize 250 | h_t_lstm_list = [] 251 | p_t_lstm_list = [] 252 | lstm_tuple = self.initial_state 253 | c_t_0 = lstm_tuple[0] 254 | h_t_0 = lstm_tuple[1] 255 | p_t_0 = tf.ones((self.batch_size)) 256 | cnn_outputs = tf.split(self.cnn_outputs, self.max_num_sequence, axis=1) 257 | cnn_outputs = [tf.squeeze(i, axis=1) for i in cnn_outputs] 258 | attention_state = tf.split(self.attention_state, self.max_num_sequence, axis=1) 259 | attention_state = [tf.squeeze(i, axis=1) for i in attention_state] 260 | # first step 261 | start_tokens = tf.zeros([self.batch_size], tf.int32) # id for ['GO'] 262 | St_0 = tf.nn.embedding_lookup(self.Embedding_, start_tokens) 263 | At_0 = attention_state[0] 264 | h_t, c_t, p_t = self.lstm_single_step(St_0, At_0, h_t_0, c_t_0, p_t_0) 265 | p_t_lstm_list.append(p_t) 266 | tf.summary.histogram("prob_t", p_t) 267 | # next steps 268 | for time_step, merge in enumerate(zip(cnn_outputs[:-1], attention_state[1:])): 269 | St, At = merge[0], merge[1] 270 | if self.is_training: 271 | p_t = tf.cond(self.cur_learning, lambda: self.weight_control(time_step, p_t), lambda: p_t) 272 | h_t, c_t, p_t = self.lstm_single_step(St, At, h_t, c_t, p_t) 273 | p_t_lstm_list.append(p_t) 274 | tf.summary.histogram("sen_t", St) 275 | tf.summary.histogram("prob_t", p_t) 276 | # results 277 | logits = tf.concat(p_t_lstm_list, axis=1) 278 | 279 | return logits 280 | 281 | def word_extractor(self): # TODO 282 | # LSTM inputs: h_t = LSTM(wt-1,h_t-1) 283 | # Attention: h~t = Attention(h_t,h) 284 | logits_list = [] 285 | length_list = [] 286 | # values_decoder_embedded: [batch_size, document_length] 287 | # inputs_decoder_embedded: [batch_size, max_num_abstract, input_y2_max_length] 288 | attent_decoder_embedded = [] 289 | values_decoder_embedded = [] 290 | inputs_decoder_embedded = [] 291 | initial_state_embedded =[] 292 | encoder_inputs_lengths = [] 293 | embedded_values = tf.nn.embedding_lookup(self.Embedding, self.value_decoder_x) 294 | for idx in range(self.batch_size): 295 | c = tf.concat([self.initial_state[0][idx:idx+1] for _ in range(self.max_num_abstract)], axis=0) 296 | h = tf.concat([self.initial_state[1][idx:idx+1] for _ in range(self.max_num_abstract)], axis=0) 297 | embedded_initial_expand = tf.nn.rnn_cell.LSTMStateTuple(c, h) 298 | initial_state_embedded.append(embedded_initial_expand) 299 | embedded_attent_expand = tf.concat([self.attention_state[idx:idx+1] for _ in range(self.max_num_abstract)], axis=0) 300 | attent_decoder_embedded.append(embedded_attent_expand) 301 | embedded_abstracts = tf.nn.embedding_lookup(self.Embedding, self.input_decoder_x[idx:idx+1]) 302 | embedded_abstracts_squeezed = tf.squeeze(embedded_abstracts, axis=0) 303 | inputs_decoder_embedded.append(embedded_abstracts_squeezed) 304 | embedded_values_squeezed = embedded_values[idx:idx+1] 305 | #embedded_values_squeezed = tf.squeeze(embedded_values[idx:idx+1], axis=0) 306 | values_decoder_embedded.append(embedded_values_squeezed) 307 | encoder_inputs_length = tf.squeeze(self.input_y2_length[idx:idx+1], axis=0) 308 | encoder_inputs_lengths.append(encoder_inputs_length) 309 | 310 | for attent_embedded, inputs_embedded, values_embedded, initial_state, encoder_inputs_length in zip(attent_decoder_embedded, inputs_decoder_embedded, values_decoder_embedded, initial_state_embedded, encoder_inputs_lengths): 311 | 312 | with tf.variable_scope("attention-word-decoder", reuse=tf.AUTO_REUSE ): 313 | if self.is_training: 314 | attention_state = attent_embedded 315 | document_state = values_embedded 316 | document_length = self.document_length * tf.ones([1,], dtype=tf.int32) 317 | encoder_final_state = initial_state 318 | else: 319 | """4.2 beam search preparation""" 320 | attention_state = tf.contrib.seq2seq.tile_batch(attent_embedded, multiplier=self.beam_width) 321 | document_state = tf.contrib.seq2seq.tile_batch(values_embedded, multiplier=self.beam_width) 322 | encoder_inputs_length = tf.contrib.seq2seq.tile_batch(encoder_inputs_length, multiplier=self.beam_width) 323 | document_length = tf.contrib.seq2seq.tile_batch(self.document_length * tf.ones([1,], dtype=tf.int32), multiplier=self.beam_width) 324 | encoder_final_state = tf.contrib.framework.nest.map_structure(lambda s: tf.contrib.seq2seq.tile_batch(s, self.beam_width), initial_state) 325 | """4.2 Attention(Bahdanau)""" 326 | # building attention cell 327 | lstm_cell = tf.nn.rnn_cell.LSTMCell(self.hidden_size) 328 | lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=self.dropout_keep_prob) 329 | attention_mechanism1 = attention_wrapper.BahdanauAttention( 330 | num_units=self.hidden_size, memory=attention_state, memory_sequence_length=encoder_inputs_length 331 | ) 332 | attention_cell = attention_wrapper.AttentionWrapper( 333 | cell=lstm_cell, attention_mechanism=attention_mechanism1, attention_layer_size=self.attention_size, \ 334 | # cell_input_fn=(lambda inputs, attention: tf.layers.Dense(self.hidden_size, dtype=tf.float32, name="attention_inputs")(array.ops.concat([inputs, attention],-1))) TODO \ 335 | #cell_input_fn=(lambda inputs, attention: tf.squeeze(tf.layers.Dense(self.hidden_size, dtype=tf.float32, name="attention_inputs")(inputs), axis=0)), \ 336 | cell_input_fn=(lambda inputs, attention: tf.layers.Dense(self.hidden_size, dtype=tf.float32, name="attention_inputs")(inputs)), \ 337 | alignment_history=False, name='Attention_Wrapper' \ 338 | ) 339 | 340 | batch_size = self.max_num_abstract if self.is_training else self.max_num_abstract * self.beam_width 341 | decoder_initial_state = attention_cell.zero_state(batch_size=(batch_size), dtype=tf.float32).clone(cell_state=encoder_final_state) 342 | #tf.scalar_mul(inputs_embedded, inputs_embedded) 343 | if self.is_training: 344 | helper = tf.contrib.seq2seq.TrainingHelper(inputs=inputs_embedded, sequence_length=encoder_inputs_length, time_major=False, name="training_helper") 345 | training_decoder = tf.contrib.seq2seq.BasicDecoder(cell=attention_cell,helper=helper,initial_state=decoder_initial_state, output_layer=None) 346 | decoder_outputs, _, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder=training_decoder,output_time_major=False,impute_finished=True,maximum_iterations=self.input_y2_max_length) 347 | else: 348 | start_tokens=tf.ones([self.max_num_abstract,], tf.int32) * 2 349 | end_token= 3 350 | inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell=attention_cell,embedding=document_state,start_tokens=start_tokens,end_token=end_token,initial_state=decoder_initial_state,beam_width=self.beam_width,output_layer=None) 351 | decoder_outputs, _, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder=inference_decoder,output_time_major=False,impute_finished=True,maximum_iterations=self.input_y2_max_length) 352 | length_list.append(final_sequence_lengths) 353 | 354 | """4.2 attention * document mat""" 355 | # decoder_outputs: [batch_size, input_y2_max_length, attention_size] 356 | # final_sequence_lengths: [batch_size] 357 | # logits: [batch_size, input_y2_max_length, document_length] 358 | with tf.variable_scope("attention-vocab", reuse=tf.AUTO_REUSE): 359 | attention_mechanism2 =attention_wrapper.BahdanauAttention( 360 | num_units=self.attention_size, memory=document_state, memory_sequence_length=document_length 361 | ) 362 | state = tf.constant(True, dtype = tf.bool) # TODO trolling me ... 363 | decoder_outputs = decoder_outputs[0] 364 | list2 = [] 365 | for idx in range(self.max_num_abstract): 366 | list1=[] 367 | for step in range(self.input_y2_max_length): 368 | src = decoder_outputs[idx:idx+1,step:step+1,:] 369 | print (src.get_shape) 370 | #print (src.get_shape == (1,1,self.attention_size)) 371 | cond = tf.constant((src.get_shape == (1,1,self.attention_size)), tf.bool) 372 | query = tf.cond(cond, lambda:tf.squeeze(src, axis=1), lambda:tf.zeros([1,self.attention_size],tf.float32)) 373 | logits, state = attention_mechanism2(query=query, state=state) 374 | list1.append(logits) 375 | logits = tf.stack(list1, axis=1) 376 | list2.append(logits) 377 | logits = tf.concat(list2, axis=0) 378 | logits_list.append(logits) 379 | 380 | if self.is_training: 381 | return logits_list, [] 382 | else: 383 | return logits_list, length_list 384 | 385 | def inference(self): 386 | """ 387 | compute graph: 388 | 1.Embedding--> 2.CNN(word)-->3.LSTM(sentence) (Document Reader) 389 | 4.1 LSTM + MLP(labeling) (Sentence Extractor) 390 | 4.2 LSTM + Attention(generating) (Word Extractor) 391 | """ 392 | self.cnn_outputs, self.attention_state, self.initial_state = self.document_reader() 393 | if self.extract_sentence_flag: 394 | logits = self.sentence_extractor() 395 | return logits 396 | else: 397 | logits, final_sequence_lengths = self.word_extractor() 398 | return logits, final_sequence_lengths 399 | 400 | def loss_sentence(self, l2_lambda = 0.0001): 401 | # multi_class_labels: [batch_size, max_num_sequence] 402 | # logits: [batch_size, max_num_sequence] 403 | # losses: [batch_size, max_num_sequence] 404 | # origin:sigmoid log: max(x, 0) + x * z + log(1 + exp(-x)) 405 | # z*-log(x)+(1-z)*-log(1-x) 406 | # z=0 --> -log(1-x) 407 | # z=1 --> -log(x) 408 | with tf.name_scope("loss_sentence"): 409 | logits = tf.convert_to_tensor(self.logits) 410 | labels = tf.cast(self.input_y1, logits.dtype) 411 | zeros = tf.zeros_like(labels, dtype=labels.dtype) 412 | ones = tf.ones_like(logits, dtype=logits.dtype) 413 | cond = ( labels > zeros ) 414 | logits_ = tf.where(cond, logits, ones-logits) 415 | logits_log = tf.log(logits_) 416 | losses = -logits_log 417 | losses *= self.mask 418 | l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name]) * l2_lambda 419 | tf.summary.scalar("l2_loss", l2_loss) 420 | loss = tf.reduce_sum(losses, axis=1) 421 | loss = tf.reduce_mean(loss) 422 | tf.summary.scalar("loss", loss) 423 | 424 | return loss+l2_loss 425 | 426 | def loss_word(self, l2_lambda=0.001): 427 | # logits: [batch_size, sequence_length, document_length] 428 | # targets: [batch_size, sequence_length] 429 | # weights: [batch_size, sequence_length] 430 | # loss: scalar 431 | with tf.name_scope("loss_word"): 432 | loss = tf.Variable(0.0, trainable=False, dtype= tf.float32) 433 | for logits, targets, mask in zip(self.logits, self.targets, self.mask_list): 434 | loss += tf.contrib.seq2seq.sequence_loss(logits=logits,targets=targets,weights=mask,average_across_timesteps=True,average_across_batch=True) 435 | #l2_losses = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name]) * l2_lambda 436 | #loss = loss + l2_losses 437 | tf.summary.scalar("loss", loss) 438 | return loss 439 | 440 | def train_frozen(self): 441 | with tf.name_scope("train_op_frozen"): 442 | learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step, self.decay_step, self.decay_rate, staircase=True) 443 | self.learning_rate = learning_rate 444 | optimizer = tf.train.AdamOptimizer(learning_rate,beta1=0.99) 445 | tvars = [tvar for tvar in tf.trainable_variables() if 'embedding' not in tvar.name] 446 | gradients, variables = zip(*optimizer.compute_gradients(self.loss_val, tvars)) 447 | gradients, _ = tf.clip_by_global_norm(gradients, self.clip_gradients) 448 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 449 | with tf.control_dependencies(update_ops): 450 | train_op = optimizer.apply_gradients(zip(gradients, variables)) 451 | return train_op 452 | 453 | def train(self): 454 | with tf.name_scope("train_op"): 455 | learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step, self.decay_step, self.decay_rate, staircase=True) 456 | self.learning_rate = learning_rate 457 | optimizer = tf.train.AdamOptimizer(learning_rate,beta1=0.99) 458 | gradients, variables = zip(*optimizer.compute_gradients(self.loss_val)) 459 | gradients, _ = tf.clip_by_global_norm(gradients, self.clip_gradients) 460 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 461 | with tf.control_dependencies(update_ops): 462 | train_op = optimizer.apply_gradients(zip(gradients, variables)) 463 | return train_op 464 | 465 | --------------------------------------------------------------------------------