├── data
├── tra
│ └── tra_0.pkl
├── tst
│ └── tst_0.pkl
└── val
│ └── val_0.pkl
├── train_w2v.py
├── prepro.py
├── predict_model.py
├── LICENSE
├── train_model.py
├── data_utils.py
├── README.md
└── textsum_model.py
/data/tra/tra_0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adrian9631/TextSumma/HEAD/data/tra/tra_0.pkl
--------------------------------------------------------------------------------
/data/tst/tst_0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adrian9631/TextSumma/HEAD/data/tst/tst_0.pkl
--------------------------------------------------------------------------------
/data/val/val_0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adrian9631/TextSumma/HEAD/data/val/val_0.pkl
--------------------------------------------------------------------------------
/train_w2v.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import multiprocessing
4 | import logging
5 | from gensim.models import Word2Vec
6 | from gensim.models.word2vec import LineSentence
7 |
8 | if __name__ == '__main__':
9 | program = os.path.basename(sys.argv[0])
10 | logger = logging.getLogger(program)
11 |
12 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s')
13 | logging.root.setLevel(level=logging.INFO)
14 | logger.info("running %s" % ' '.join(sys.argv))
15 |
16 | if len(sys.argv) < 4:
17 | print("Using: python train_w2v.py one-billion-word-benchmark output_gensim_model output_word_vector")
18 | sys.exit(1)
19 | inp, outp1, outp2 = sys.argv[1:4]
20 |
21 | model = Word2Vec(LineSentence(inp), size=150, window=6, min_count=2, workers=(multiprocessing.cpu_count()-2), hs=1, sg=1, negative=10)
22 |
23 | model.save(outp1)
24 | model.wv.save_word2vec_format(outp2, binary=True)
25 |
26 |
27 |
--------------------------------------------------------------------------------
/prepro.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 | import os
3 | import sys
4 | import codecs
5 | import pickle
6 | import logging
7 |
8 |
9 | def load(filename):
10 | with open(filename, 'rb') as output:
11 | data = pickle.load(output)
12 | return data
13 |
14 | def save(filename, data):
15 | with open(filename, 'wb') as output:
16 | pickle.dump(data, output)
17 |
18 | def compute(inp, oup, logger):
19 | cnt_file = 0
20 | for filename in os.listdir(inp):
21 | data_path1 = os.path.join(inp, filename)
22 | data_path2 = oup +'example_'+ str(cnt_file) + '.pkl'
23 | data = {}
24 | entity,abstract,article,label = [],[],[],[]
25 | cnt = 0
26 | with codecs.open(data_path1, 'r', encoding='utf-8', errors='ignore') as f:
27 | for line in f.readlines():
28 | if line == '\n':
29 | cnt += 1
30 | continue
31 | if cnt == 0:
32 | pass
33 | if cnt == 1:
34 | article.append(line.replace('\t\t\t', '').replace('\n', ''))
35 | if cnt == 2:
36 | abstract.append(line.replace('\n', '').replace('*', ''))
37 | if cnt == 3:
38 | entity.append(line.replace('\n', ''))
39 | for idx, sent in enumerate(article):
40 | if sent[-1] == '1':
41 | label.append(idx)
42 | article = [sent[:len(sent)-1] for idx, sent in enumerate(article)]
43 | entity_dict = {}
44 | if len(entity) != 0:
45 | for pair in entity:
46 | key = pair.split(':')[0]
47 | value = pair.split(':')[1]
48 | entity_dict[key] = value
49 | data['entity'] = entity_dict
50 | data['abstract'] = abstract
51 | data['article'] = article
52 | data['label'] = label
53 | save(data_path2, data)
54 | cnt_file += 1
55 | if cnt_file % 500 == 0:
56 | logger.info("running the script, extract %d examples already..." % cnt_file)
57 | logger.info("extract %d examples totally this time, done." % (cnt_file+1))
58 |
59 | if __name__ == "__main__":
60 |
61 | program = os.path.basename(sys.argv[0])
62 | logger = logging.getLogger(program)
63 |
64 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s')
65 | logging.root.setLevel(level=logging.INFO)
66 | logger.info("running %s" % ' '.join(sys.argv))
67 |
68 | if len(sys.argv) < 3:
69 | print("Using: python prepro.py ./source_dir/ ./target_dir/")
70 | sys.exit(1)
71 | inp, oup = sys.argv[1:3]
72 |
73 | compute(inp, oup, logger)
74 |
75 |
--------------------------------------------------------------------------------
/predict_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import re
3 | import os
4 | import math
5 | import pickle
6 | import codecs
7 | import json
8 | import tensorflow as tf
9 | import numpy as np
10 | from data_utils import *
11 | from textsum_model import Neuralmodel
12 | from gensim.models import KeyedVectors
13 | from rouge import Rouge, FilesRouge
14 |
15 | #configuration
16 | FLAGS=tf.app.flags.FLAGS
17 |
18 | tf.app.flags.DEFINE_string("hyp_path","../res/hyp.txt","file of summary.")
19 | tf.app.flags.DEFINE_string("ref_path","../res/ref.txt","file of abstract.")
20 | tf.app.flags.DEFINE_string("result_path","../res/","path to store the predicted results.")
21 | tf.app.flags.DEFINE_string("tst_data_path","../src/neuralsum/dailymail/tst/","path of test data.")
22 | tf.app.flags.DEFINE_string("tst_file_path","../src/neuralsum/dailymail/tst/","file of test data.")
23 | tf.app.flags.DEFINE_boolean("use_tst_dataset", True,"using test dataset, set False to use the file as targets")
24 | tf.app.flags.DEFINE_string("entity_path","../cache/entity_dict.pkl", "path of entity data.")
25 | tf.app.flags.DEFINE_string("vocab_path","../cache/vocab","path of vocab frequency list")
26 | tf.app.flags.DEFINE_integer("vocab_size",199900,"maximum vocab size.")
27 |
28 | tf.app.flags.DEFINE_float("learning_rate",0.0001,"learning rate")
29 |
30 | tf.app.flags.DEFINE_integer("is_frozen_step", 0, "how many steps before fine-tuning the embedding.")
31 | tf.app.flags.DEFINE_integer("cur_learning_step", 0, "how many steps before using the predicted labels instead of true labels.")
32 | tf.app.flags.DEFINE_integer("decay_step", 5000, "how many steps before decay learning rate.")
33 | tf.app.flags.DEFINE_float("decay_rate", 0.1, "Rate of decay for learning rate.")
34 | tf.app.flags.DEFINE_string("ckpt_dir","../ckpt/","checkpoint location for the model")
35 | tf.app.flags.DEFINE_integer("batch_size", 1, "Batch size for training/evaluating.")
36 | tf.app.flags.DEFINE_integer("embed_size", 150,"embedding size")
37 | tf.app.flags.DEFINE_integer("input_y2_max_length", 40,"the max length of a sentence in abstracts")
38 | tf.app.flags.DEFINE_integer("max_num_sequence", 30,"the max number of sequence in documents")
39 | tf.app.flags.DEFINE_integer("max_num_abstract", 4,"the max number of abstract in documents")
40 | tf.app.flags.DEFINE_integer("sequence_length", 100,"the max length of a sentence in documents")
41 | tf.app.flags.DEFINE_integer("hidden_size", 300,"the hidden size of the encoder and decoder")
42 | tf.app.flags.DEFINE_boolean("use_highway_flag", True,"using highway network or not.")
43 | tf.app.flags.DEFINE_integer("highway_layers", 1,"How many layers in highway network.")
44 | tf.app.flags.DEFINE_integer("document_length", 1000,"the max vocabulary of documents")
45 | tf.app.flags.DEFINE_integer("beam_width", 4,"the beam search max width")
46 | tf.app.flags.DEFINE_integer("attention_size", 150,"the attention size of the decoder")
47 | tf.app.flags.DEFINE_boolean("extract_sentence_flag", True,"using sentence extractor")
48 | tf.app.flags.DEFINE_boolean("is_training", False,"is traning.true:tranining,false:testing/inference")
49 | tf.app.flags.DEFINE_boolean("use_embedding",True,"whether to use embedding or not.")
50 | tf.app.flags.DEFINE_string("word2vec_model_path","../w2v/benchmark_sg1_e150_b.vector","word2vec's vocabulary and vectors")
51 | filter_sizes = [1,2,3,4,5,6,7]
52 | feature_map = [20,20,30,40,50,70,70]
53 | cur_learning_steps = [0,0]
54 |
55 | def load(filename):
56 | with open(filename, 'rb') as output:
57 | data = pickle.load(output)
58 | return data
59 |
60 | def save(filename, data):
61 | with open(filename, 'wb') as output:
62 | pickle.dump(data, output)
63 |
64 | def dump(filename, data):
65 | with open(filename, 'w') as output:
66 | json.dump(data, output, cls=MyEncoder)
67 |
68 | def main(_):
69 | config=tf.ConfigProto()
70 | config.gpu_options.allow_growth = True
71 | results = []
72 | with tf.Session(config=config) as sess:
73 | Model=Neuralmodel(FLAGS.extract_sentence_flag, FLAGS.is_training, FLAGS.vocab_size, FLAGS.batch_size, FLAGS.embed_size, FLAGS.learning_rate, cur_learning_steps, FLAGS.decay_step, FLAGS.decay_rate, FLAGS.max_num_sequence, FLAGS.sequence_length,
74 | filter_sizes, feature_map, FLAGS.use_highway_flag, FLAGS.highway_layers, FLAGS.hidden_size, FLAGS.document_length, FLAGS.max_num_abstract, FLAGS.beam_width, FLAGS.attention_size, FLAGS.input_y2_max_length)
75 | saver=tf.train.Saver()
76 | if os.path.exists(FLAGS.ckpt_dir+"checkpoint"):
77 | print("Restoring Variables from Checkpoint")
78 | saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir))
79 | else:
80 | print("Can't find the checkpoint.going to stop")
81 | return
82 | if FLAGS.use_tst_dataset:
83 | predict_gen = Batch_P(FLAGS.tst_data_path, FLAGS.vocab_path, FLAGS)
84 | else:
85 | predict_gen = Batch_F(process_file(FLAGS.tst_file_path, FLAGS.entity_path), FLAGS.vocab_path, FLAGS)
86 | iteration = 0
87 | for batch in predict_gen:
88 | iteration += 1
89 | feed_dict={}
90 | feed_dict[Model.dropout_keep_prob] = 1.0
91 | feed_dict[Model.input_x] = batch['article_words']
92 | feed_dict[Model.tst] = False
93 | feed_dict[Model.cur_learning] = False
94 | logits = sess.run(Model.logits, feed_dict=feed_dict)
95 | results.append(compute_score(logits, batch))
96 | evaluate_file(logits, batch)
97 | if iteration % 500 == 0:
98 | print ('Dealing with %d examples already...' % iteration)
99 |
100 | print ('Waitting for storing the results...')
101 | for idx, data in enumerate(results):
102 | filename = os.path.join(FLAGS.result_path, 'tst_%d.json' % idx)
103 | dump(filename, data)
104 | print ('Counting for the rouge...')
105 | scores = evaluate_rouge(FLAGS.hyp_path, FLAGS.ref_path)
106 | print (scores)
107 | print ('Done.')
108 |
109 | def process_file(data_path, entity_path): # TODO
110 | examples = []
111 | entitys = load(entity_path)
112 | with codecs.open(data_path, 'r', encoding='utf-8', errors='ignore') as f:
113 | for line in f.readlines():
114 | if line == '\n':
115 | continue
116 | example = {}
117 | entity_dict = {}
118 | for idx, name in entitys.items():
119 | if re.search(name, line):
120 | article = line.replace(name, idx)
121 | entity_dict[idx] = name
122 | example['article'] = article.splits('.')
123 | example['entity'] = entity_dict
124 | examples.append(example)
125 | return examples
126 |
127 | def evaluate_file(logits, batch):
128 | data = batch['original']
129 | score_list = []
130 | pos = 0
131 | for sent, score in zip(data['article'], logits[0][:len(data['article'])]):
132 | score_list.append((pos, score, sent))
133 | pos += 1
134 | data['score'] = sorted(score_list, key=lambda x:x[1], reverse=True)
135 | summary = '. '.join([highest[2] for highest in sorted(score_list[:3], key=lambda x:x[0], reverse=False)])
136 | abstract = '. '.join(data['abstract'])
137 |
138 | with open(FLAGS.hyp_path, 'a') as f:
139 | f.write(summary)
140 | f.write('\n')
141 |
142 | with open(FLAGS.ref_path, 'a') as f:
143 | f.write(abstract)
144 | f.write('\n')
145 |
146 | def evaluate_rouge(hyp_path, ref_path):
147 | files_rouge = FilesRouge(hyp_path, ref_path)
148 | rouge = files_rouge.get_scores(avg=True)
149 | return rouge
150 |
151 | def compute_score(logits, batch):
152 | data = batch['original']
153 | score_list = []
154 | pos = 0
155 | for sent, score in zip(data['article'], logits[0][:len(data['article'])]):
156 | score_list.append((pos, score, sent))
157 | pos += 1
158 | data['score'] = sorted(score_list, key=lambda x:x[1], reverse=True)
159 | return data
160 |
161 | if __name__ == '__main__':
162 | tf.app.run()
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/train_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import tensorflow as tf
3 | import numpy as np
4 | from data_utils import *
5 | from textsum_model import Neuralmodel
6 | from gensim.models import KeyedVectors
7 | from rouge import Rouge
8 | import os
9 | import math
10 | import pickle
11 | from tqdm import tqdm
12 |
13 | #configuration
14 | FLAGS=tf.app.flags.FLAGS
15 |
16 | tf.app.flags.DEFINE_string("log_path","../log/","path of summary log.")
17 | tf.app.flags.DEFINE_string("tra_data_path","../src/neuralsum/dailymail/tra/","path of training data.")
18 | tf.app.flags.DEFINE_string("tst_data_path","../src/neuralsum/dailymail/tst/","path of test data.")
19 | tf.app.flags.DEFINE_string("val_data_path","../src/neuralsum/dailymail/val/","path of validation data.")
20 | tf.app.flags.DEFINE_string("vocab_path","../cache/vocab","path of vocab frequency list")
21 | tf.app.flags.DEFINE_integer("vocab_size",199900,"maximum vocab size.")
22 |
23 | tf.app.flags.DEFINE_float("learning_rate",0.0001,"learning rate")
24 |
25 | tf.app.flags.DEFINE_integer("is_frozen_step", 400, "how many steps before fine-tuning the embedding.")
26 | tf.app.flags.DEFINE_integer("decay_step", 5000, "how many steps before decay learning rate.")
27 | tf.app.flags.DEFINE_float("decay_rate", 0.1, "Rate of decay for learning rate.")
28 | tf.app.flags.DEFINE_string("ckpt_dir","../ckpt/","checkpoint location for the model")
29 | tf.app.flags.DEFINE_integer("batch_size", 20, "Batch size for training/evaluating.")
30 | tf.app.flags.DEFINE_integer("embed_size", 150,"embedding size")
31 | tf.app.flags.DEFINE_integer("input_y2_max_length", 40,"the max length of a sentence in abstracts")
32 | tf.app.flags.DEFINE_integer("max_num_sequence", 30,"the max number of sequence in documents")
33 | tf.app.flags.DEFINE_integer("max_num_abstract", 4,"the max number of abstract in documents")
34 | tf.app.flags.DEFINE_integer("sequence_length", 100,"the max length of a sentence in documents")
35 | tf.app.flags.DEFINE_integer("hidden_size", 300,"the hidden size of the encoder and decoder")
36 | tf.app.flags.DEFINE_boolean("use_highway_flag", True,"using highway network or not.")
37 | tf.app.flags.DEFINE_integer("highway_layers", 1,"How many layers in highway network.")
38 | tf.app.flags.DEFINE_integer("document_length", 1000,"the max vocabulary of documents")
39 | tf.app.flags.DEFINE_integer("beam_width", 4,"the beam search max width")
40 | tf.app.flags.DEFINE_integer("attention_size", 150,"the attention size of the decoder")
41 | tf.app.flags.DEFINE_boolean("extract_sentence_flag", True,"using sentence extractor")
42 | tf.app.flags.DEFINE_boolean("is_training", True,"is traning.true:tranining,false:testing/inference")
43 | tf.app.flags.DEFINE_integer("num_epochs",10,"number of epochs to run.")
44 | tf.app.flags.DEFINE_integer("validate_every", 1, "Validate every validate_every epochs.")
45 | tf.app.flags.DEFINE_boolean("use_embedding",True,"whether to use embedding or not.")
46 | tf.app.flags.DEFINE_string("word2vec_model_path","../w2v/benchmark_sg1_e150_b.vector","word2vec's vocabulary and vectors")
47 | filter_sizes = [1,2,3,4,5,6,7]
48 | feature_map = [20,20,30,40,50,70,70]
49 | cur_learning_steps = [500,2500]
50 |
51 | def main(_):
52 | config = tf.ConfigProto()
53 | config.gpu_options.allow_growth=True
54 | with tf.Session(config=config) as sess:
55 | # instantiate model
56 | Model = Neuralmodel(FLAGS.extract_sentence_flag, FLAGS.is_training, FLAGS.vocab_size, FLAGS.batch_size, FLAGS.embed_size, FLAGS.learning_rate, cur_learning_steps, FLAGS.decay_step, FLAGS.decay_rate, FLAGS.max_num_sequence, FLAGS.sequence_length,
57 | filter_sizes, feature_map, FLAGS.use_highway_flag, FLAGS.highway_layers, FLAGS.hidden_size, FLAGS.document_length, FLAGS.max_num_abstract, FLAGS.beam_width, FLAGS.attention_size, FLAGS.input_y2_max_length)
58 | # initialize saver
59 | saver = tf.train.Saver()
60 | if os.path.exists(FLAGS.ckpt_dir+"checkpoint"):
61 | print("Restoring Variables from Checkpoint.")
62 | saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir))
63 | summary_writer = tf.summary.FileWriter(logdir=FLAGS.log_path, graph=sess.graph)
64 | else:
65 | print('Initializing Variables')
66 | sess.run(tf.global_variables_initializer())
67 | summary_writer = tf.summary.FileWriter(logdir=FLAGS.log_path, graph=sess.graph)
68 | if FLAGS.use_embedding: #load pre-trained word embedding
69 | assign_pretrained_word_embedding(sess, FLAGS.vocab_path, FLAGS.vocab_size, Model,FLAGS.word2vec_model_path)
70 | curr_epoch=sess.run(Model.epoch_step)
71 |
72 | batch_size=FLAGS.batch_size
73 | iteration=0
74 | for epoch in range(curr_epoch,FLAGS.num_epochs):
75 | loss, counter = 0.0, 0
76 | train_gen = Batch(FLAGS.tra_data_path,FLAGS.vocab_path,FLAGS.batch_size,FLAGS)
77 | for batch in tqdm(train_gen):
78 | iteration=iteration+1
79 | if epoch==0 and counter==0:
80 | print("train_batch", batch['abstracts_len'])
81 | feed_dict={}
82 | if FLAGS.extract_sentence_flag:
83 | feed_dict[Model.dropout_keep_prob] = 0.5
84 | feed_dict[Model.input_x] = batch['article_words']
85 | feed_dict[Model.input_y1] = batch['label_sentences']
86 | feed_dict[Model.input_y1_length] = batch['article_len']
87 | feed_dict[Model.tst] = FLAGS.is_training
88 | feed_dict[Model.cur_learning] = True if cur_learning_steps[1] > iteration and epoch == 0 else False
89 | else:
90 | feed_dict[Model.dropout_keep_prob] = 0.5
91 | feed_dict[Model.input_x] = batch['article_words']
92 | feed_dict[Model.input_y2_length] = batch['abstracts_len']
93 | feed_dict[Model.input_y2] = batch['abstracts_inputs']
94 | feed_dict[Model.input_decoder_x] = batch['abstracts_targets']
95 | feed_dict[Model.value_decoder_x] = batch['article_value']
96 | feed_dict[Model.tst] = FLAGS.is_training
97 | train_op = Model.train_op_frozen if FLAGS.is_frozen_step > iteration and epoch == 0 else Model.train_op
98 | curr_loss,lr,_,_,summary,logits=sess.run([Model.loss_val,Model.learning_rate,train_op,Model.global_increment,Model.merge,Model.logits],feed_dict)
99 | summary_writer.add_summary(summary, global_step=iteration)
100 | loss,counter=loss+curr_loss,counter+1
101 | if counter %50==0:
102 | print("Epoch %d\tBatch %d\tTrain Loss:%.3f\tLearning rate:%.5f" %(epoch,counter,loss/float(counter),lr))
103 | if iteration % 1000 == 0:
104 | eval_loss = do_eval(sess, Model)
105 | print("Epoch %d Validation Loss:%.3f\t " % (epoch, eval_loss))
106 | # TODO eval_loss, acc_score = do_eval(sess, Model)
107 | # TODO print("Epoch %d Validation Loss:%.3f\t Acc:%.3f" % (epoch, eval_loss, acc_score))
108 | # save model to checkpoint
109 | save_path = FLAGS.ckpt_dir + "model.ckpt"
110 | saver.save(sess, save_path, global_step=epoch)
111 | #epoch increment
112 | print("going to increment epoch counter....")
113 | sess.run(Model.epoch_increment)
114 | print(epoch,FLAGS.validate_every,(epoch % FLAGS.validate_every==0))
115 | if epoch % FLAGS.validate_every==0:
116 | #save model to checkpoint
117 | save_path=FLAGS.ckpt_dir+"model.ckpt"
118 | saver.save(sess,save_path,global_step=epoch)
119 | summary_writer.close()
120 |
121 | def do_eval(sess, Model):
122 | eval_loss, eval_counter= 0.0, 0
123 | # eval_loss, eval_counter, acc_score= 0.0, 0, 0.0
124 | batch_size = 20
125 | valid_gen = Batch(FLAGS.tst_data_path,FLAGS.vocab_path,batch_size,FLAGS)
126 | for batch in valid_gen:
127 | feed_dict={}
128 | if FLAGS.extract_sentence_flag:
129 | feed_dict[Model.dropout_keep_prob] = 1.0
130 | feed_dict[Model.input_x] = batch['article_words']
131 | feed_dict[Model.input_y1] = batch['label_sentences']
132 | feed_dict[Model.input_y1_length] = batch['article_len']
133 | feed_dict[Model.tst] = not FLAGS.is_training
134 | feed_dict[Model.cur_learning] = False
135 | else:
136 | feed_dict[Model.dropout_keep_prob] = 1.0
137 | feed_dict[Model.input_x] = batch['article_words']
138 | feed_dict[Model.input_y2] = batch['abstracts_inputs']
139 | feed_dict[Model.input_y2_length] = batch['abstracts_len']
140 | feed_dict[Model.input_decoder_x] = batch['abstracts_targets']
141 | feed_dict[Model.value_decoder_x] = batch['article_value']
142 | feed_dict[Model.tst] = not FLAGS.is_training
143 | curr_eval_loss,logits=sess.run([Model.loss_val,Model.logits],feed_dict)
144 | # curr_acc_score = compute_label(logits, batch)
145 | # acc_score += curr_acc_score
146 | eval_loss += curr_eval_loss
147 | eval_counter += 1
148 |
149 | return eval_loss/float(eval_counter) # acc_score/float(eval_counter)
150 |
151 | def compute_label(logits, batch): # TODO
152 | imp_pos = np.argsort(logits)
153 | lab_num = [ len(res['label']) for res in batch['original']]
154 | lab_pos = [ res['label'] for res in batch['original']]
155 | abs_num = [ res['abstract'] for res in batch['original']]
156 | sen_pos = [ pos[:num] for pos, num in zip(imp_pos, lab_num)]
157 |
158 | # compute
159 | acc_list = []
160 | for sen, lab, abst in zip(sen_pos, lab_pos, abs_num):
161 | sen = set(sen)
162 | lab = set(lab)
163 | if len(lab) == 0 or len(abst) == 0:
164 | continue
165 | score = float(len(sen&lab)) / len(abst)
166 | acc = 1.0 if score > 1.0 else score
167 | acc_list.append(acc)
168 | acc_score = np.mean(acc_list)
169 |
170 | return acc_score
171 |
172 | def assign_pretrained_word_embedding(sess,vocab_path,vocab_size,Model,word2vec_model_path):
173 | print("using pre-trained word emebedding.started.word2vec_model_path:",word2vec_model_path)
174 | vocab = Vocab(vocab_path, vocab_size)
175 | word2vec_model = KeyedVectors.load_word2vec_format(word2vec_model_path, binary=True)
176 | bound = np.sqrt(6.0) / np.sqrt(vocab_size) # bound for random variables.
177 | count_exist = 0;
178 | count_not_exist = 0
179 | word_embedding_2dlist = [[]] * vocab_size # create an empty word_embedding list.
180 | word_embedding_2dlist[0] = np.zeros(FLAGS.embed_size, dtype=np.float32) # assign empty for first word:'PAD'
181 | for i in range(1, vocab_size): # loop each word
182 | word = vocab.id2word(i)
183 | embedding = None
184 | try:
185 | embedding = word2vec_model[word] # try to get vector:it is an array.
186 | except Exception:
187 | embedding = None
188 | if embedding is not None: # the 'word' exist a embedding
189 | word_embedding_2dlist[i] = embedding;
190 | count_exist = count_exist + 1 # assign array to this word.
191 | else: # no embedding for this word
192 | word_embedding_2dlist[i] = np.random.uniform(-bound, bound, FLAGS.embed_size)
193 | count_not_exist = count_not_exist + 1 # init a random value for the word.
194 | word_embedding_final = np.array(word_embedding_2dlist) # covert to 2d array.
195 | word_embedding = tf.constant(word_embedding_final, dtype=tf.float32) # convert to tensor
196 | t_assign_embedding = tf.assign(Model.Embedding,word_embedding) # assign this value to our embedding variables of our model.
197 | sess.run(t_assign_embedding)
198 |
199 | word_embedding_2dlist_ = [[]] * 2 # create an empty word_embedding list for GO END.
200 | word_embedding_2dlist_[0] = np.random.uniform(-bound, bound, FLAGS.hidden_size) # GO
201 | word_embedding_2dlist_[1] = np.random.uniform(-bound, bound, FLAGS.hidden_size) # END
202 | word_embedding_final_ = np.array(word_embedding_2dlist_) # covert to 2d array.
203 | word_embedding_ = tf.constant(word_embedding_final_, dtype=tf.float32) # convert to tensor
204 | t_assign_embedding_ = tf.assign(Model.Embedding_,word_embedding_) # assign this value to our embedding variables of our model.
205 | sess.run(t_assign_embedding_)
206 | print("word. exists embedding:", count_exist, " ;word not exist embedding:", count_not_exist)
207 | print("using pre-trained word emebedding.ended...")
208 |
209 | if __name__ == "__main__":
210 | tf.app.run()
211 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import re
3 | import os
4 | import json
5 | import pickle
6 | import random
7 | import numpy as np
8 | from tflearn.data_utils import pad_sequences
9 |
10 | PAD_TOKEN = '[PAD]'
11 | UNKNOWN_TOKEN = '[UNK]'
12 | START_DECODING = '[START]'
13 | STOP_DECODING = '[STOP]'
14 |
15 | def load(filename):
16 | with open(filename, 'rb') as output:
17 | data = pickle.load(output)
18 | return data
19 |
20 | def save(filename, data):
21 | with open(filename, 'wb') as output:
22 | pickle.dump(data, output)
23 |
24 | def Batch(data_path, vocab_path, size, hps):
25 |
26 | res = {}
27 | filenames = os.listdir(data_path)
28 | random.shuffle(filenames)
29 | label_sentences, article_value, article_words, article_len, abstracts_targets, abstracts_inputs, abstracts_len, results = [], [], [], [], [], [], [], []
30 | vocab = Vocab(vocab_path, hps.vocab_size)
31 | for cnt, filename in enumerate(filenames[:200000]):
32 | pickle_path = os.path.join(data_path, filename)
33 | res = load(pickle_path)
34 | label, value, words, len_a, targets, inputs, lens = Example(res['article'],res['abstract'],res['label'],res['entity'], vocab, hps) # TODO
35 | label_sentences.append(label)
36 | article_value.append(value)
37 | article_words.append(words)
38 | article_len.append(len_a)
39 | abstracts_targets.append(targets)
40 | abstracts_inputs.append(inputs)
41 | abstracts_len.append(lens)
42 | results.append(res)
43 |
44 | if (cnt+1) % size == 0:
45 | data_dict ={}
46 | data_dict['label_sentences'] = label_sentences
47 | data_dict['article_value'] = article_value
48 | data_dict['article_words'] = article_words
49 | data_dict['article_len'] = article_len
50 | data_dict['abstracts_targets'] = abstracts_targets
51 | data_dict['abstracts_inputs'] = abstracts_inputs
52 | data_dict['abstracts_len'] = abstracts_len
53 | data_dict['original'] = results
54 | label_sentences, article_value, article_words, article_len, abstracts_targets, abstracts_inputs, abstracts_len, results = [], [], [], [], [], [], [], []
55 | yield data_dict
56 |
57 | def Example(article, abstracts, label, entity, vocab, hps):
58 |
59 | # get ids of special tokens
60 | start_decoding = vocab.word2id(START_DECODING)
61 | stop_decoding = vocab.word2id(STOP_DECODING)
62 | pad_id = vocab.word2id(PAD_TOKEN)
63 |
64 | """process the label"""
65 | # pos 2 multi one-hot
66 | label_sentences = label2ids(label, hps.max_num_sequence)
67 |
68 | """process the article"""
69 | # create vocab and word 2 id
70 | article_value = value2ids(article, vocab, hps.document_length)
71 | # word 2 id
72 | article_words = article2ids(article, vocab)
73 | # num sentence
74 | article_len = len(article)
75 | # word level padding
76 | article_words = pad_sequences(article_words, maxlen=hps.sequence_length, value=pad_id)
77 | # sentence level padding
78 | pad_article = np.expand_dims(np.zeros(hps.sequence_length, dtype=np.int32), axis = 0)
79 | if article_words.shape[0] > hps.max_num_sequence:
80 | article_words = article_words[:hps.max_num_sequence]
81 | while article_words.shape[0] < hps.max_num_sequence:
82 | article_words = np.concatenate((article_words, pad_article))
83 |
84 | """process the abstract"""
85 | # word 2 id
86 | abstracts_words = abstract2ids(abstracts, vocab)
87 | # add tokens
88 | abstracts_inputs, abstracts_targets = token2add(abstracts_words, hps.input_y2_max_length, start_decoding, stop_decoding)
89 | # padding
90 | abstracts_inputs = pad_sequences(abstracts_inputs, maxlen=hps.input_y2_max_length, value=pad_id)
91 | abstracts_targets = pad_sequences(abstracts_targets, maxlen=hps.input_y2_max_length, value=pad_id)
92 | # search id in value position
93 | abstract_targets = value2pos(abstracts_targets, article_value, vocab)
94 | # sentence level padding
95 | pad_abstracts = np.expand_dims(np.zeros(hps.input_y2_max_length, dtype=np.int32), axis = 0)
96 | if abstracts_inputs.shape[0] > hps.max_num_abstract:
97 | abstracts_inputs = abstracts_inputs[:hps.max_num_abstract]
98 | while abstracts_inputs.shape[0] < hps.max_num_abstract:
99 | abstracts_inputs = np.concatenate((abstracts_inputs, pad_abstracts))
100 | if abstracts_targets.shape[0] > hps.max_num_abstract:
101 | abstracts_targets = abstracts_targets[:hps.max_num_abstract]
102 | while abstracts_targets.shape[0] < hps.max_num_abstract:
103 | abstracts_targets = np.concatenate((abstracts_targets, pad_abstracts))
104 | # mask
105 | abstracts_len = abstract2len(abstracts, hps.input_y2_max_length)
106 | if abstracts_len.shape[0] > hps.max_num_abstract:
107 | abstracts_len = abstracts_len[:hps.max_num_abstract]
108 | while abstracts_len.shape[0] < hps.max_num_abstract:
109 | abstracts_len = np.concatenate((abstracts_len, [1]))
110 |
111 | return label_sentences, article_value, article_words, article_len, abstracts_targets, abstracts_inputs, abstracts_len
112 |
113 | def Batch_F(file_data, vocab_path, hps):
114 |
115 | vocab = Vocab(vocab_path, hps.vocab_size)
116 | for res in file_data:
117 | article_value, article_words, article_len = Example_P(res['article'],res['entity'], vocab, hps)
118 | data_dict ={}
119 | data_dict['article_value'] = [article_value]
120 | data_dict['article_words'] = [article_words]
121 | data_dict['article_len'] = [article_len]
122 | data_dict['original'] = res
123 | yield data_dict
124 |
125 | def Batch_P(data_path, vocab_path, hps):
126 |
127 | filenames = os.listdir(data_path)
128 | vocab = Vocab(vocab_path, hps.vocab_size)
129 | for cnt, filename in enumerate(filenames):
130 | pickle_path = os.path.join(data_path, filename)
131 | res = load(pickle_path)
132 | article_value, article_words, article_len = Example_P(res['article'],res['entity'], vocab, hps)
133 | data_dict ={}
134 | data_dict['article_value'] = [article_value]
135 | data_dict['article_words'] = [article_words]
136 | data_dict['article_len'] = [article_len]
137 | data_dict['original'] = res
138 | yield data_dict
139 |
140 | def Example_P(article, entity, vocab, hps):
141 |
142 | # get ids of special tokens
143 | pad_id = vocab.word2id(PAD_TOKEN)
144 |
145 | """process the article"""
146 | # create vocab and word 2 id
147 | article_value = value2ids(article, vocab, hps.document_length)
148 | # word 2 id
149 | article_words = article2ids(article, vocab)
150 | # num sentence
151 | article_len = len(article)
152 | # word level padding
153 | article_words = pad_sequences(article_words, maxlen=hps.sequence_length, value=pad_id)
154 | # sentence level padding
155 | pad_article = np.expand_dims(np.zeros(hps.sequence_length, dtype=np.int32), axis = 0)
156 | if article_words.shape[0] > hps.max_num_sequence:
157 | article_words = article_words[:hps.max_num_sequence]
158 | while article_words.shape[0] < hps.max_num_sequence:
159 | article_words = np.concatenate((article_words, pad_article))
160 |
161 | return article_value, article_words, article_len
162 |
163 | class Vocab(object):
164 | def __init__(self, vocab_file, max_size):
165 | self._word_to_id = {}
166 | self._id_to_word = {}
167 | self._count = 0
168 |
169 | for w in [PAD_TOKEN, UNKNOWN_TOKEN, START_DECODING, STOP_DECODING]:
170 | self._word_to_id[w] = self._count
171 | self._id_to_word[self._count] = w
172 | self._count += 1
173 |
174 | with open(vocab_file, 'r') as vocab_f:
175 | for line in vocab_f:
176 | pieces = line.split()
177 | if len(pieces) != 2:
178 | continue
179 | w = pieces[0]
180 | if w in [UNKNOWN_TOKEN, PAD_TOKEN,START_DECODING, STOP_DECODING]:
181 | raise Exception('[UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is'% w)
182 | if w in self._word_to_id:
183 | raise Exception('Duplicated word in vocabulary file: %s' % w)
184 | self._word_to_id[w] = self._count
185 | self._id_to_word[self._count] = w
186 | self._count += 1
187 | if max_size != 0 and self._count >= max_size:
188 | break
189 |
190 | def word2id(self, word):
191 | if word not in self._word_to_id:
192 | return self._word_to_id[UNKNOWN_TOKEN]
193 | return self._word_to_id[word]
194 |
195 | def id2word(self, word_id):
196 | return self._id_to_word[word_id]
197 |
198 | def size(self):
199 | return self._count
200 |
201 | def label2ids(labels, label_size):
202 | res = np.zeros(label_size, dtype=np.int32)
203 | label_list = [ pos for pos in labels if pos < label_size]
204 | res[label_list] = 1
205 | return res
206 |
207 | def value2ids(article, vocab, document_length):
208 | value = []
209 | pad_id = vocab.word2id(PAD_TOKEN)
210 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
211 | stop_id = vocab.word2id(STOP_DECODING)
212 | value.append(unk_id)
213 | value.append(stop_id)
214 | for sent in article:
215 | article_words = sent.split()
216 | for w in article_words:
217 | i = vocab.word2id(w)
218 | if i == unk_id:
219 | pass
220 | if i not in value:
221 | value.append(i)
222 | cnt = 4
223 | while len(value) < document_length:
224 | if cnt not in value:
225 | value.append(cnt)
226 | cnt += 1
227 | return np.array(value)
228 |
229 | def value2pos(abstract, value, vocab):
230 | poss = []
231 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
232 | for sent in abstract:
233 | pos=[]
234 | for i in sent:
235 | if i in value:
236 | pos.append(np.argwhere(value==i)[0])
237 | else:
238 | pos.append(np.argwhere(value==unk_id)[0])
239 | poss.append(np.array(pos))
240 | return np.array(poss)
241 |
242 | def article2ids(article, vocab):
243 | idss = []
244 | oovs = []
245 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
246 | for sent in article:
247 | ids = []
248 | article_words = sent.split()
249 | for w in article_words:
250 | i = vocab.word2id(w)
251 | if i == unk_id:
252 | if w not in oovs:
253 | oovs.append(w)
254 | ids.append(i)
255 | else:
256 | ids.append(i)
257 | idss.append(ids)
258 | return idss
259 |
260 | def abstract2ids(abstracts, vocab):
261 | idss= []
262 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
263 | for sent in abstracts:
264 | ids = []
265 | abstract_words = sent.split()
266 | for w in abstract_words:
267 | i = vocab.word2id(w)
268 | if i == unk_id:
269 | ids.append(i)
270 | else:
271 | ids.append(i)
272 | idss.append(ids)
273 | return idss
274 |
275 | def token2add(abstracts, max_len, start_id, stop_id):
276 | inps = []
277 | targets = []
278 | for sequence in abstracts:
279 | inp = [start_id] + sequence[:]
280 | target = sequence[:]
281 | if len(inp) > max_len:
282 | inp = inp[:max_len]
283 | target = target[:max_len]
284 | else:
285 | target.append(stop_id)
286 | assert len(inp) == len(target)
287 | inps.append(inp)
288 | targets.append(target)
289 | return inps, targets
290 |
291 | def abstract2len(abstracts, max_len):
292 | length = []
293 | for sent in abstracts:
294 | abstract_words = sent.split()
295 | if len(abstract_words) + 1 > max_len:
296 | length.append(max_len)
297 | else:
298 | length.append(len(abstract_words)+1)
299 | return np.array(length)
300 |
301 | def outputids2words(id_list, vocab):
302 | words = []
303 | for i in id_list:
304 | w = vocab.id2word(i)
305 | words.append(w)
306 | return words
307 |
308 | class MyEncoder(json.JSONEncoder):
309 | def default(self, obj):
310 | if isinstance(obj, np.integer):
311 | return int(obj)
312 | elif isinstance(obj, np.floating):
313 | return float(obj)
314 | elif isinstance(obj, np.ndarray):
315 | return obj.tolist()
316 | else:
317 | return super(MyEncoder, self).default(obj)
318 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | TextSumma
3 | =
4 | Just give it a shot for reproducing the ACL 2016 paper [*Neural Summarization by Extracting Sentences and Words*](https://arxiv.org/abs/1603.07252). The original code of author can be found [*here*](https://github.com/cheng6076/NeuralSum).
5 |
6 | ## Quick Start
7 | - **Step1 : Obtain datasets**
8 | Go [*here*](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) to download the corpus and get the scripts of ***one-billion-word-language-modeling-benchmark*** for training the word vectors. Run this and see more [*details*](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark/blob/master/README.corpus_generation):
9 | ```bash
10 | $ tar --extract -v --file ../statmt.org/tar_archives/training-monolingual.tgz --wildcards training-monolingual/news.20??.en.shuffled
11 | $ ./scripts/get-data.sh
12 | ```
13 | The dataset ***cnn-dailymail*** with highlights in this paper offered by the authors is in [*here*](https://docs.google.com/uc?id=0B0Obe9L1qtsnSXZEd0JCenIyejg&export=download) and vocab in this repository.
14 | - **Step2 : Preprocess**
15 | Run this script to training the word vectors in the dataset ***one-billion-word-language-modeling-benchmark***:
16 | ```bash
17 | $ python train_w2v.py './one-billion-word-benchmark' 'output_gensim_model' 'output_word_vector'
18 | ```
19 | Run this script to extract the sentences, labels and entitys in the dataset ***cnn-dailymail*** and get them pickled:
20 | ```bash
21 | $ python prepro.py './source_dir/' './target_dir/'
22 | ```
23 | - **Step3 : Install nvidia-docker**
24 | Go for GPUs acceleration. See [*installation*](https://github.com/NVIDIA/nvidia-docker) to get more information for help.
25 | - **Step4 : Obtain *Deepo***
26 | A series of Docker images (and their generator) that allows you to quickly set up your deep learning research environment. See [*Deepo*](https://github.com/ufoym/deepo) to get more details. Run this and turn on the **port 6006** for the tensorboard:
27 |
28 | ```bash
29 | $ nvidia-docker pull ufoym/deepo
30 | $ nvidia-docker run -p 0.0.0.0:6006:6006 -it -v /home/usrs/yourdir:/data ufoym/deepo env LANG=C.UTF-8 bash
31 | ```
32 | enter the bash of *Deepo*, run pip to install the rest:
33 | ```bash
34 | $ pip install gensim rouge tflearn tqdm
35 | ```
36 | * **Step5: Train the model and predict**
37 |
38 | Please add option **-h** to get more help in flag settings.
39 | ```bash
40 | $ python train_model.py
41 | $ python predict_model.py
42 | ```
43 | * **Requirements**:
44 | Python3.6 Tensorflow 1.8.0
45 |
46 | ## Model details
47 |
48 | * **Structure NN-SE**
49 |
50 |
51 |
52 | * **Sentence extractor**
53 | Here is the single step of the customzied LSTM with a score layer.
54 | ```python
55 | def lstm_single_step(self, St, At, h_t_minus_1, c_t_minus_1, p_t_minus_1):
56 | p_t_minus_1 = tf.reshape(p_t_minus_1, [-1, 1])
57 | # Xt = p_t_minus_1 * St
58 | Xt = tf.multiply(p_t_minus_1, St)
59 | # dropout
60 | Xt = tf.nn.dropout(Xt, keep_prob=self.dropout_keep_prob)
61 | # compute the gate of input, forget, output
62 | i_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_i) + tf.matmul(h_t_minus_1, self.U_i) + self.b_i)
63 | f_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_f) + tf.matmul(h_t_minus_1, self.U_f) + self.b_f)
64 | c_t_candidate = tf.nn.tanh(tf.matmul(Xt, self.W_c) + tf.matmul(h_t_minus_1, self.U_c) + self.b_c)
65 | c_t = f_t * c_t_minus_1 + i_t * c_t_candidate
66 | o_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_o) + tf.matmul(h_t_minus_1, self.U_o) + self.b_o)
67 | h_t = o_t * tf.nn.tanh(c_t)
68 | # compute prob
69 | with tf.name_scope("Score_Layer"):
70 | concat_h = tf.concat([At, h_t], axis=1)
71 | concat_h_dropout = tf.nn.dropout(concat_h, keep_prob=self.dropout_keep_prob)
72 | score = tf.layers.dense(concat_h_dropout, 1, activation=tf.nn.tanh, name="score", reuse=tf.AUTO_REUSE)
73 | # activation and normalization
74 | p_t = self.sigmoid_norm(score)
75 | return h_t, c_t, p_t
76 | ```
77 | * **Curriculum learning**
78 | Actually new to curriculum learning, just simply connect the weight of the true labels and those predicted with the rate of steps.
79 | ```python
80 | def weight_control(self, time_step, p_t):
81 | # curriculum learning control the weight between true labels and those predicted
82 | labels = tf.cast(self.input_y1[:,time_step:time_step+1], dtype=tf.float32)
83 | start = tf.cast(self.cur_step_start, dtype=tf.float32)
84 | end = tf.cast(self.cur_step_end, dtype=tf.float32)
85 | global_step = tf.cast(self.global_step, dtype=tf.float32)
86 | weight = tf.divide(tf.subtract(global_step, start), tf.subtract(end, start))
87 | merge = (1. - weight) * labels + weight * p_t
88 | cond = tf.greater(start, global_step)
89 | p_t_curr = tf.cond(cond, lambda:labels, lambda:merge)
90 | return p_t_curr
91 | ```
92 | * **Loss function**
93 | Coding the loss function manually instead of using the function *tf.losses.sigmoid_cross_entropy* cause the logits is between 0 and 1 with sigmoid activation and normalization already.
94 | ```python
95 | # loss:z*-log(x)+(1-z)*-log(1-x)
96 | # z=0 --> loss:-log(1-x)
97 | # z=1 --> loss:-log(x)
98 | with tf.name_scope("loss_sentence"):
99 | logits = tf.convert_to_tensor(self.logits)
100 | labels = tf.cast(self.input_y1, logits.dtype)
101 | zeros = tf.zeros_like(labels, dtype=labels.dtype)
102 | ones = tf.ones_like(logits, dtype=logits.dtype)
103 | cond = ( labels > zeros )
104 | logits_ = tf.where(cond, logits, ones-logits)
105 | logits_log = tf.log(logits_)
106 | losses = -logits_log
107 | losses *= self.mask
108 | loss = tf.reduce_sum(losses, axis=1)
109 | loss = tf.reduce_mean(loss)
110 | ```
111 |
112 |
113 | ## Performance
114 |
115 | * **Probability for the sentences in several timesteps**
116 |
117 |
118 |
119 | * **Training loss**
120 |
121 |
122 |
123 | * **Figure**
124 | Some results seems to be nice.
125 | ```json
126 | {
127 | "entity": {
128 | "@entity31": "Jason Kernick",
129 | "@entity1": "Manchester",
130 | "@entity9": "Ashton Canal",
131 | "@entity46": "Environment Agency",
132 | "@entity44": "Etihad stadium",
133 | "@entity45": "Manchester City",
134 | "@entity115": "Easter Sunday",
135 | "@entity85": "Clayton",
136 | "@entity66": "Richard Kernick",
137 | "@entity109": "Etihad",
138 | "@entity137": "Greater Manchester Fire and Rescue Service",
139 | "@entity136": "Salford"
140 | },
141 | "abstract": [
142 | "the @entity9 became filled with heavy suds due to a 6ft wall of foam created by fire crews tackling a blaze",
143 | "the fire at a nearby chemical plant saw water from fire service mix with detergents that were being stored there",
144 | "the foam covered a 30 metre stretch of the canal near @entity45 's @entity44 in @entity85"
145 | ],
146 | "article": [
147 | "a @entity1 canal was turned into a giant bubble bath after fire crews tackling a nearby chemical plant blaze saw their water mix with a detergent creating a six foot wall of foam",
148 | "the @entity9 was filled with heavy suds which appeared after a fire at an industrial unit occupied by a drug development company",
149 | "it is believed that the water used by firefighters to dampen down the flames mixed with the detergent being stored in the burning buildings",
150 | "now the @entity46 have launched an investigation to assess if the foam has impacted on wildlife after concerns were raised for the safety of fish in the affected waters",
151 | "a spokesman for the agency said : ' @entity46 is investigating after receiving reports of foam on a 30 metre stretch of the @entity9 , @entity1",
152 | "' initial investigations by @entity46 officers show that there appears to have been minimal impact on water quality , but our officers will continue to monitor and respond as necessary",
153 | "@entity66 takes a picture on his mobile phone of his boat trying to negotiate a lock and the foam , which ran into the @entity9 a cyclist takes a picture on his mobile phone as the foam comes up on to the cycle path",
154 | "the @entity46 are investigating to assess of the foam has harmed any wildlife the foam reached as high as six foot in some places and covered a 30 metre stretch along the water in the @entity85 area of @entity1 ' we are working with the fire service and taking samples of the foam to understand what it is made of , and what impact it may have on local wildlife in and around the canal",
155 | "' at the height of the blaze on sunday afternoon , which caused the foam , up to 50 firefighters were tackling the fire and police were also forced to wear face masks",
156 | "families in east @entity1 were urged to say indoors after a blast was reported at the industrial unit , which is just a few hundred yards from the @entity45 training ground on the @entity109 campus",
157 | "the fire at the chemical factory next to @entity45 's @entity44 send a huge plume of smoke across the city on @entity115 police wearing face masks went around neighbouring streets with loudspeakers urging people to stay inside while the fire raged police officers also told children on bikes and mothers pushing prams near the scene to go home and went around neighbouring streets with loudspeakers urging people to stay inside",
158 | "a huge plume of smoke also turned the sky black and could be seen right across the city and even into @entity136",
159 | "according to @entity137 , the fire was fueled by wooden pallets and unidentified chemicals but an investigation into the cause of the fire is still ongoing ."
160 | ],
161 | "label": [0, 1, 4, 7, 10],
162 | "score": [
163 | [10, 0.6629698276519775, "the fire at the chemical factory next to @entity45 's @entity44 send a huge plume of smoke across the city on @entity115 police wearing face masks went around neighbouring streets with loudspeakers urging people to stay inside while the fire raged police officers also told children on bikes and mothers pushing prams near the scene to go home and went around neighbouring streets with loudspeakers urging people to stay inside"],
164 | [0, 0.6484572291374207, "a @entity1 canal was turned into a giant bubble bath after fire crews tackling a nearby chemical plant blaze saw their water mix with a detergent creating a six foot wall of foam"],
165 | [7, 0.5045493841171265, "the @entity46 are investigating to assess of the foam has harmed any wildlife the foam reached as high as six foot in some places and covered a 30 metre stretch along the water in the @entity85 area of @entity1 ' we are working with the fire service and taking samples of the foam to understand what it is made of , and what impact it may have on local wildlife in and around the canal"],
166 | [1, 0.45766133069992065, "the @entity9 was filled with heavy suds which appeared after a fire at an industrial unit occupied by a drug development company"],
167 | [4, 0.3478981852531433, "a spokesman for the agency said : ' @entity46 is investigating after receiving reports of foam on a 30 metre stretch of the @entity9 , @entity1"],
168 | [3, 0.3398599326610565, "now the @entity46 have launched an investigation to assess if the foam has impacted on wildlife after concerns were raised for the safety of fish in the affected waters"],
169 | [8, 0.3396754860877991, "' at the height of the blaze on sunday afternoon , which caused the foam , up to 50 firefighters were tackling the fire and police were also forced to wear face masks"],
170 | [6, 0.32800495624542236, "@entity66 takes a picture on his mobile phone of his boat trying to negotiate a lock and the foam , which ran into the @entity9 a cyclist takes a picture on his mobile phone as the foam comes up on to the cycle path"],
171 | [9, 0.29064181447029114, "families in east @entity1 were urged to say indoors after a blast was reported at the industrial unit , which is just a few hundred yards from the @entity45 training ground on the @entity109 campus"],
172 | [2, 0.25459226965904236, "it is believed that the water used by firefighters to dampen down the flames mixed with the detergent being stored in the burning buildings"],
173 | [5, 0.2020452618598938, "' initial investigations by @entity46 officers show that there appears to have been minimal impact on water quality , but our officers will continue to monitor and respond as necessary"],
174 | [12, 0.05926991254091263, "according to @entity137 , the fire was fueled by wooden pallets and unidentified chemicals but an investigation into the cause of the fire is still ongoing ."],
175 | [11, 0.05400165915489197, "a huge plume of smoke also turned the sky black and could be seen right across the city and even into @entity136"]
176 | ]
177 | }
178 |
179 | ```
180 | This one remains a little complicated.
181 | ```json
182 | {
183 | "entity": {
184 | "@entity27": "Belichick",
185 | "@entity24": "Hiss",
186 | "@entity80": "Gisele Bündchen",
187 | "@entity97": "Sport Illustrated",
188 | "@entity115": "Julian Edelman",
189 | "@entity84": "Washington",
190 | "@entity86": "Seattle Seahawks",
191 | "@entity110": "Massachusetts US Senator",
192 | "@entity3": "Patriots",
193 | "@entity2": "Super Bowl",
194 | "@entity0": "Obama",
195 | "@entity4": "White House",
196 | "@entity8": "South Lawn",
197 | "@entity56": "Boomer Esiason",
198 | "@entity111": "Linda Holliday",
199 | "@entity75": "Donovan McNabb",
200 | "@entity96": "Las Vegas",
201 | "@entity30": "Chicago",
202 | "@entity33": "Boston",
203 | "@entity102": "Rob Gronkowski",
204 | "@entity99": "CBS",
205 | "@entity98": "Les Moonves",
206 | "@entity108": "Bellichick",
207 | "@entity109": "John Kerry",
208 | "@entity95": "Floyd Mayweather Jr.",
209 | "@entity94": "Manny Pacquiao",
210 | "@entity117": "Danny Amendola",
211 | "@entity62": "Bush Administration",
212 | "@entity44": "Bob Kraft",
213 | "@entity47": "Super Bowl MVP",
214 | "@entity68": "Showoffs",
215 | "@entity66": "US Senator",
216 | "@entity67": "White House Correspondents dinner",
217 | "@entity113": "George W. Bush",
218 | "@entity48": "Brady"
219 | },
220 | "abstract": [
221 | "@entity48 cited ' prior family commitments ' in bowing out of meeting with @entity0",
222 | "has been to the @entity4 to meet president @entity113 for previous @entity2 wins"
223 | ],
224 | "article": [
225 | "president @entity0 invited the @entity2 champion @entity3 to the @entity4 on thursday - but could n't help but get one last deflategate joke in",
226 | "the president opened his speech on the @entity8 by remarking ' that whole ( deflategate ) story got blown out of proportion , ' referring to an investigation that 11 out of 12 footballs used in the afc championship game were under - inflated",
227 | "but then came the zinger : ' i usually tell a bunch of jokes at these events , but with the @entity3 in town i was worried that 11 out of 12 of them would fall flat",
228 | "coach @entity27 , who is notoriously humorless , responded by giving the president a thumbs down",
229 | "@entity0 was flanked by @entity27 and billionaire @entity3 owner @entity44",
230 | "missing from the occasion , though was the @entity47 and the team 's biggest star - @entity48",
231 | "a spokesman for the team cited ' prior family commitments ' as the reason @entity48 , 37 , did n't attend the ceremony",
232 | "sports commentators , including retired football great @entity56 , speculated that @entity48 snubbed @entity0 because he 's from the ' wrong political party",
233 | "' the superstar athlete has been to the @entity4 before",
234 | "he does have three other @entity2 rings , afterall",
235 | "but all the prior championships were under the @entity62",
236 | "february 's win was the first for the @entity3 since @entity0 took office",
237 | "@entity48 has also met @entity0 at least once before , as well",
238 | "he was pictured with the then - @entity66 at the 2005 @entity67",
239 | "@entity68 : the @entity3 gathered the team 's four @entity2 trophies won under coach @entity27 ( right , next to president @entity0 )",
240 | "@entity48 won his fourth @entity2 ring in february - and his first since president @entity0 took office @entity48 met president @entity0 at least once",
241 | "he is pictured here with the then - @entity66 and rival quarterback @entity75 in 2005 it 's not clear what @entity48 's prior commitment was",
242 | "his supermodel wife @entity80 , usually active on social media , gives no hint where the family is today if not in @entity84",
243 | "@entity48 led the @entity3 to his fourth @entity2 victory in february after defeating the @entity86 28 - 24",
244 | "despite his arm and movement being somewhat diminished by age , @entity48 's leadership and calm under pressure also won him @entity47 - his third",
245 | "whatever is taking up @entity48 's time this week , he made time next week to be ringside at the @entity95 - @entity94 fight in @entity96 next weekend",
246 | "according to @entity97 , @entity48 appealed directly to @entity99 president @entity98 for tickets to the much - touted matchup",
247 | "@entity3 tight end @entity102 could n't help but mug for the camera as the commander in chief gave a speech @entity0 walks with billionaire @entity3 owner @entity44 and coach @entity108 to the speech secretary of state @entity109 , a former @entity110 , greets @entity27 's girlfriend @entity111 at the ceremony @entity48 went to the @entity4 to meet president @entity113 after winning the @entity2 in 2005 and in 2004",
248 | "he 's not going to be there this year @entity3 players @entity115 and @entity117 snap pics in the @entity4 before meeting president @entity0 on thursday"
249 | ],
250 | "label": [0, 6, 7, 12, 14, 15, 18, 22, 23],
251 | "score": [
252 | [1, 0.8683828115463257, "the president opened his speech on the @entity8 by remarking ' that whole ( deflategate ) story got blown out of proportion , ' referring to an investigation that 11 out of 12 footballs used in the afc championship game were under - inflated"],
253 | [0, 0.8339700102806091, "president @entity0 invited the @entity2 champion @entity3 to the @entity4 on thursday - but could n't help but get one last deflategate joke in"],
254 | [22, 0.7730730772018433, "@entity3 tight end @entity102 could n't help but mug for the camera as the commander in chief gave a speech @entity0 walks with billionaire @entity3 owner @entity44 and coach @entity108 to the speech secretary of state @entity109 , a former @entity110 , greets @entity27 's girlfriend @entity111 at the ceremony @entity48 went to the @entity4 to meet president @entity113 after winning the @entity2 in 2005 and in 2004"],
255 | [14, 0.7569227814674377, "@entity68 : the @entity3 gathered the team 's four @entity2 trophies won under coach @entity27 ( right , next to president @entity0 )"],
256 | [15, 0.6214166879653931, "@entity48 won his fourth @entity2 ring in february - and his first since president @entity0 took office @entity48 met president @entity0 at least once"],
257 | [18, 0.4963235855102539, "@entity48 led the @entity3 to his fourth @entity2 victory in february after defeating the @entity86 28 - 24"],
258 | [16, 0.45303720235824585, "he is pictured here with the then - @entity66 and rival quarterback @entity75 in 2005 it 's not clear what @entity48 's prior commitment was"],
259 | [5, 0.4204302430152893, "missing from the occasion , though was the @entity47 and the team 's biggest star - @entity48"],
260 | [7, 0.41678884625434875, "sports commentators , including retired football great @entity56 , speculated that @entity48 snubbed @entity0 because he 's from the ' wrong political party"],
261 | [20, 0.4135805070400238, "whatever is taking up @entity48 's time this week , he made time next week to be ringside at the @entity95 - @entity94 fight in @entity96 next weekend"],
262 | [6, 0.3958345353603363, "a spokesman for the team cited ' prior family commitments ' as the reason @entity48 , 37 , did n't attend the ceremony"],
263 | [4, 0.37495893239974976, "@entity0 was flanked by @entity27 and billionaire @entity3 owner @entity44"],
264 | [21, 0.3466879427433014, "according to @entity97 , @entity48 appealed directly to @entity99 president @entity98 for tickets to the much - touted matchup"],
265 | [19, 0.3316606283187866, "despite his arm and movement being somewhat diminished by age , @entity48 's leadership and calm under pressure also won him @entity47 - his third"],
266 | [2, 0.29267093539237976, "but then came the zinger : ' i usually tell a bunch of jokes at these events , but with the @entity3 in town i was worried that 11 out of 12 of them would fall flat"],
267 | [23, 0.27186375856399536, "he 's not going to be there this year @entity3 players @entity115 and @entity117 snap pics in the @entity4 before meeting president @entity0 on thursday"],
268 | [11, 0.26710671186447144, "february 's win was the first for the @entity3 since @entity0 took office"],
269 | [17, 0.17511016130447388, "his supermodel wife @entity80 , usually active on social media , gives no hint where the family is today if not in @entity84"],
270 | [3, 0.16352418065071106, "coach @entity27 , who is notoriously humorless , responded by giving the president a thumbs down"],
271 | [13, 0.14906153082847595, "he was pictured with the then - @entity66 at the 2005 @entity67"],
272 | [12, 0.1384015828371048, "@entity48 has also met @entity0 at least once before , as well"],
273 | [10, 0.07186555117368698, "but all the prior championships were under the @entity62"],
274 | [8, 0.07148505747318268, "' the superstar athlete has been to the @entity4 before"],
275 | [9, 0.035264041274785995, "he does have three other @entity2 rings , afterall"]
276 | ]
277 | }
278 | ```
279 | Pls get more predicted results from [*here*](https://drive.google.com/open?id=1cXrR1kY-tlxArB-F9FSZba2T2RscAYVS)
280 |
281 | ## Discuss
282 |
283 | - Tuning the learning rate.
284 | - Freeze the weights of embedding for several steps or not.
285 | - Choose a proper step range for shift the value gradually to the probability predicted by the model.
286 | - The initialization of the weights and bias.
287 | - Find a proper way to evaluate while training.(just observe the loss in validation with early stop by my way)
288 |
289 | ## TODO list
290 | - NN-WE word extractor remain to be done.
291 | - Remain oov and ner problems while using raw data.
292 | - Using Threading and Queues in tensorflow to load the batch.
293 |
294 | ## Credits
295 | - Thanks for the authors of the paper.
296 | - Borrow some code from [*text_classification*](https://github.com/brightmart/text_classification) and learn a lot.
297 | - A great job [*pointer-generator*](https://github.com/abisee/pointer-generator) in text summarization that should be appreciated.
298 |
--------------------------------------------------------------------------------
/textsum_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import tensorflow as tf
3 | from tensorflow.contrib.seq2seq.python.ops import *
4 | import numpy as py
5 |
6 | class Neuralmodel:
7 | def __init__(self,extract_sentence_flag,is_training,vocab_size,batch_size,embed_size,learning_rate,cur_step,decay_step,decay_rate,max_num_sequence,
8 | sequence_length,filter_sizes,feature_map,use_highway_flag,highway_layers,hidden_size,document_length,max_num_abstract,beam_width,
9 | attention_size,input_y2_max_length,clip_gradients=5.0, initializer=tf.random_normal_initializer(stddev=0.1)):
10 | """init all hyperparameter:"""
11 | self.initializer = tf.contrib.layers.xavier_initializer()
12 | self.initializer_uniform = tf.random_uniform_initializer(minval=-0.05,maxval=0.05)
13 |
14 | """Basic"""
15 | self.extract_sentence_flag = extract_sentence_flag
16 | self.vocab_size = vocab_size
17 | self.batch_size = batch_size
18 | self.embed_size = embed_size
19 |
20 | """learning_rate"""
21 | self.is_training = is_training
22 | self.tst = tf.placeholder(tf.bool, name='is_training_flag')
23 | self.learning_rate = tf.Variable(learning_rate, trainable=False, name='learning_rate')
24 | self.cur_step_start = tf.Variable(cur_step[0], trainable=False, name='start_for_cur_learning')
25 | self.cur_step_end = tf.Variable(cur_step[1], trainable=False, name='end_for_cur_learning')
26 | self.decay_step = decay_step
27 | self.decay_rate = decay_rate
28 |
29 | """Overfit"""
30 | self.dropout_keep_prob = tf.placeholder(tf.float32, name='dropout_keep_prob')
31 | self.clip_gradients = clip_gradients
32 |
33 | """CNN (word)"""
34 | self.max_num_sequence = max_num_sequence
35 | self.sequence_length = sequence_length
36 | self.filter_sizes = filter_sizes
37 | self.feature_map = feature_map
38 |
39 | """Highway Network"""
40 | self.use_highway_flag = use_highway_flag
41 | self.highway_layers = highway_layers
42 |
43 | """LSTM (sentence)"""
44 | self.hidden_size = hidden_size
45 | self.document_length = document_length
46 |
47 | """LSTM + Attention (generating)"""
48 | self.max_num_abstract = max_num_abstract
49 | self.beam_width = beam_width
50 | self.attention_size = attention_size
51 | self.input_y2_max_length = input_y2_max_length
52 |
53 | """Input"""
54 | self.input_x = tf.placeholder(tf.int32, [None, self.max_num_sequence, self.sequence_length], name="input_x")
55 |
56 | if extract_sentence_flag:
57 | self.input_y1 = tf.placeholder(tf.int32, [None, self.max_num_sequence], name="input_y_sentence")
58 | self.input_y1_length = tf.placeholder(tf.int32, [None], name="input_y_length")
59 | self.mask = tf.sequence_mask(self.input_y1_length, self.max_num_sequence, dtype=tf.float32, name='input_y_mask')
60 | self.cur_learning = tf.placeholder(tf.bool, name="use_cur_lr_strategy")
61 | else:
62 | self.input_y2_length = tf.placeholder(tf.int32, [None, self.max_num_abstract], name="input_y_word_length")
63 | self.input_y2 = tf.placeholder(tf.int32, [None, self.max_num_abstract, self.input_y2_max_length], name="input_y_word")
64 | self.input_decoder_x = tf.placeholder(tf.int32, [None, self.max_num_abstract, self.input_y2_max_length], name="input_decoder_x")
65 | self.value_decoder_x = tf.placeholder(tf.int32, [None, self.document_length], name="value_decoder_x")
66 | self.mask_list = [tf.sequence_mask(tf.squeeze(self.input_y2_length[idx:idx+1], axis=0), self.input_y2_max_length, dtype=tf.float32) for idx in range(self.batch_size)]
67 | self.targets = [tf.squeeze(self.input_y2[idx:idx+1], axis=0) for idx in range(self.batch_size)]
68 |
69 | """Count"""
70 | self.global_step = tf.Variable(0, trainable=False, name='Global_step')
71 | self.epoch_step = tf.Variable(0, trainable=False, name='Epoch_step')
72 | self.epoch_increment = tf.assign(self.epoch_step, tf.add(self.epoch_step, tf.constant(1)))
73 | self.global_increment = tf.assign(self.global_step, tf.add(self.global_step, tf.constant(1)))
74 |
75 | """Process"""
76 | self.instantiate_weights()
77 |
78 | """Logits"""
79 | if extract_sentence_flag:
80 | self.logits = self.inference()
81 | else:
82 | self.logits, self.final_sequence_lengths = self.inference()
83 |
84 | if not self.is_training:
85 | return
86 |
87 | if extract_sentence_flag:
88 | print('using sentence extractor...')
89 | self.loss_val = self.loss_sentence()
90 | else:
91 | print('using word extractor...')
92 | self.loss_val = self.loss_word()
93 |
94 | self.train_op = self.train()
95 | self.train_op_frozen = self.train_frozen()
96 | self.merge = tf.summary.merge_all()
97 |
98 | def instantiate_weights(self):
99 | with tf.name_scope("Embedding"):
100 | self.Embedding = tf.get_variable("embedding",shape=[self.vocab_size, self.embed_size],initializer=self.initializer)
101 | self.Embedding_ = tf.get_variable("embedding_", shape=[2, self.hidden_size], initializer=self.initializer)
102 |
103 | with tf.name_scope("Cell"):
104 | # input gate
105 | self.W_i = tf.get_variable("W_i", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform)
106 | self.U_i = tf.get_variable("U_i", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform)
107 | self.b_i = tf.get_variable("b_i", shape=[self.hidden_size],initializer=tf.zeros_initializer())
108 | # forget gate
109 | self.W_f = tf.get_variable("W_f", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform)
110 | self.U_f = tf.get_variable("U_f", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform)
111 | self.b_f = tf.get_variable("b_f", shape=[self.hidden_size],initializer=tf.ones_initializer())
112 | # cell gate
113 | self.W_c = tf.get_variable("W_c", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform)
114 | self.U_c = tf.get_variable("U_c", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform)
115 | self.b_c = tf.get_variable("b_c", shape=[self.hidden_size],initializer=tf.zeros_initializer())
116 | # output gate
117 | self.W_o = tf.get_variable("W_o", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform)
118 | self.U_o = tf.get_variable("U_o", shape=[self.hidden_size,self.hidden_size], initializer=self.initializer_uniform)
119 | self.b_o = tf.get_variable("b_o", shape=[self.hidden_size],initializer=tf.zeros_initializer())
120 |
121 | def document_reader(self):
122 | """1.embedding"""
123 | # self.input_x : [batch_size, max_num_sequence, sentence_length]
124 | # self.embedded_words : [max_num_sequence, sentence_length, embed_size]
125 | # self.embedded_words_expanded : [batch_size, max_num_sequence, sentence_length, embed_size]
126 | embedded_words = []
127 | for idx in range(self.batch_size):
128 | self.embedded_words = tf.nn.embedding_lookup(self.Embedding, self.input_x[idx:idx+1])
129 | self.embedded_words_squeezed = tf.squeeze(self.embedded_words, axis=0)
130 | self.embedded_words_expanded = tf.expand_dims(self.embedded_words_squeezed, axis=-1)
131 | embedded_words.append(self.embedded_words_expanded)
132 |
133 | """2.CNN(word)"""
134 | # conv: [max_num_sequence, sequence_length-filter_size+1, 1, num_filters]
135 | # pooled: [max_num_sequence, 1, 1, num_filters]
136 | # pooled_temp: [max_num_sequence, num_filters * class_filters]
137 | # cnn_outputs: [batch_size, max_num_sequence, num_filters * class_filters]
138 | with tf.name_scope("CNN-Layer-Encoder"):
139 | pooled_outputs = []
140 | for m, conv_s in enumerate(embedded_words):
141 | pooled_temp = []
142 | for i, filter_size in enumerate(self.filter_sizes):
143 | with tf.variable_scope("convolution-pooling-%s" % filter_size, reuse=tf.AUTO_REUSE):
144 | filter=tf.get_variable("filter-%s"%filter_size,[filter_size,self.embed_size,1,self.feature_map[i]],initializer=self.initializer)
145 | conv=tf.nn.conv2d(conv_s, filter, strides=[1,1,1,1], padding="VALID",name="conv")
146 | conv=tf.contrib.layers.batch_norm(conv, is_training = self.tst, scope='cnn_bn_')
147 | b=tf.get_variable("b-%s"%filter_size,[self.feature_map[i]])
148 | h=tf.nn.tanh(tf.nn.bias_add(conv,b),"tanh")
149 | pooled=tf.nn.max_pool(h, ksize=[1,self.sequence_length-filter_size+1,1,1], strides=[1,1,1,1], padding='VALID',name="pool")
150 | pooled_temp.append(pooled)
151 | pooled_temp = tf.concat(pooled_temp, axis=3)
152 | pooled_temp = tf.reshape(pooled_temp, [-1, self.hidden_size])
153 | """3.Highway Network"""
154 | if self.use_highway_flag:
155 | pooled_temp = self.highway(pooled_temp, pooled_temp.get_shape()[1], m, self.highway_layers, 0)
156 | pooled_outputs.append(pooled_temp)
157 | cnn_outputs = tf.stack(pooled_outputs, axis=0)
158 |
159 | """4.LSTM(sentence)"""
160 | # lstm_outputs: [batch_size, max_time, hidden_size]
161 | # cell_state: [batch_size, hidden_size]
162 | with tf.variable_scope("LSTM-Layer-Encoder", initializer=self.initializer_uniform):
163 | lstm_cell = tf.nn.rnn_cell.LSTMCell(self.hidden_size)
164 | lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob = self.dropout_keep_prob)
165 | lstm_outputs, cell_state = tf.nn.dynamic_rnn(lstm_cell, cnn_outputs, dtype = tf.float32)
166 | return cnn_outputs, lstm_outputs, cell_state
167 |
168 | def highway(self, input_, size, mark, layer_size=1, bias=-2.0, f=tf.nn.relu):
169 | # t = sigmoid( W * y + b)
170 | # z = t * g(W * y + b) + (1 - t) * y
171 | # where g is nonlinearity, t is transform gate, and (1 - t) is carry gate.
172 |
173 | def linear(input_, output_size, mark, scope=None):
174 | shape = input_.get_shape().as_list()
175 | if len(shape) != 2:
176 | raise ValueError("Linear is expecting 2D arguments: %s" % str(shape))
177 | if not shape[1]:
178 | raise ValueError("Linear expects shape[1] of arguments: %s" % str(shape))
179 | input_size = shape[1]
180 | with tf.variable_scope(scope or "simplelinear"):
181 | W = tf.get_variable("W_%d" % mark, [output_size, input_size], initializer=self.initializer_uniform, dtype = input_.dtype)
182 | b = tf.get_variable("b_%d" % mark, [output_size], initializer=self.initializer_uniform, dtype = input_.dtype)
183 | return tf.matmul(input_, tf.transpose(W)) + b
184 |
185 | with tf.variable_scope("highway"):
186 | for idx in range(layer_size):
187 | g = f(linear(input_, size, mark, scope="highway_lin_%d" % idx))
188 | t = tf.sigmoid(linear(input_, size, mark, scope="highway_gate_%d" % idx ) + bias)
189 | output = t * g + (1. - t) * input_
190 | input_ = output
191 | return output
192 |
193 | def sigmoid_norm(self, score):
194 | # sigmoid(tanh) --> sigmoid([-1,1]) --> [0.26,0.73] --> [0,1]
195 | with tf.name_scope("sigmoid_norm"):
196 | Min = tf.sigmoid(tf.constant(-1, dtype=tf.float32))
197 | Max = tf.sigmoid(tf.constant(1, dtype=tf.float32))
198 | prob = tf.sigmoid(score)
199 | prob_norm = (prob - Min) / (Max - Min)
200 | return prob_norm
201 |
202 | def lstm_single_step(self, St, At, h_t_minus_1, c_t_minus_1, p_t_minus_1):
203 |
204 | p_t_minus_1 = tf.reshape(p_t_minus_1, [-1, 1])
205 | # Xt = p_t_minus_1 * St
206 | Xt = tf.multiply(p_t_minus_1, St)
207 | # dropout
208 | Xt = tf.nn.dropout(Xt, self.dropout_keep_prob)
209 | # input forget output compute
210 | i_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_i) + tf.matmul(h_t_minus_1, self.U_i) + self.b_i)
211 | f_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_f) + tf.matmul(h_t_minus_1, self.U_f) + self.b_f)
212 | c_t_candidate = tf.nn.tanh(tf.matmul(Xt, self.W_c) + tf.matmul(h_t_minus_1, self.U_c) + self.b_c)
213 | c_t = f_t * c_t_minus_1 + i_t * c_t_candidate
214 | o_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_o) + tf.matmul(h_t_minus_1, self.U_o) + self.b_o)
215 | h_t = o_t * tf.nn.tanh(c_t)
216 | tf.summary.histogram("input:pt*st", Xt)
217 | tf.summary.histogram("attenton_z_value", At)
218 | tf.summary.histogram("hidden_z_value", h_t)
219 | # prob compute
220 | with tf.name_scope("Score_Layer"):
221 | concat_h = tf.concat([At, h_t], axis=1)
222 | tf.summary.histogram("concat", concat_h)
223 | concat_h_dropout = tf.nn.dropout(concat_h, keep_prob=self.dropout_keep_prob)
224 | score = tf.layers.dense(concat_h_dropout, 1, activation=tf.nn.tanh, name="score", reuse=tf.AUTO_REUSE)
225 | p_t = self.sigmoid_norm(score)
226 |
227 | return h_t, c_t, p_t
228 |
229 | def weight_control(self, time_step, p_t):
230 | # curriculum learning control the weight between true labels and those predicted
231 | labels = tf.cast(self.input_y1[:,time_step:time_step+1], dtype=tf.float32)
232 | start = tf.cast(self.cur_step_start, dtype=tf.float32)
233 | end = tf.cast(self.cur_step_end, dtype=tf.float32)
234 | global_step = tf.cast(self.global_step, dtype=tf.float32)
235 | weight = tf.divide(tf.subtract(global_step, start), tf.subtract(end, start))
236 | merge = (1. - weight) * labels + weight * p_t
237 | cond = tf.greater(start, global_step)
238 | p_t_curr = tf.cond(cond, lambda:labels, lambda:merge)
239 | return p_t_curr
240 |
241 | def sentence_extractor(self):
242 | """4.1.1 LSTM(decoder)"""
243 | # decoder input each time: activation (MLP(h_t:At)) * St
244 | # h_t: decoder LSTM output
245 | # At: encoder LSTM output (document level)
246 | # St: encoder CNN output (sentence level)
247 | # probability value: [p_t = activation(MLP(h_t:At)) for h_t in h_t_steps ]
248 | with tf.name_scope("LSTM-Layer-Decoder"):
249 | # initialize
250 | h_t_lstm_list = []
251 | p_t_lstm_list = []
252 | lstm_tuple = self.initial_state
253 | c_t_0 = lstm_tuple[0]
254 | h_t_0 = lstm_tuple[1]
255 | p_t_0 = tf.ones((self.batch_size))
256 | cnn_outputs = tf.split(self.cnn_outputs, self.max_num_sequence, axis=1)
257 | cnn_outputs = [tf.squeeze(i, axis=1) for i in cnn_outputs]
258 | attention_state = tf.split(self.attention_state, self.max_num_sequence, axis=1)
259 | attention_state = [tf.squeeze(i, axis=1) for i in attention_state]
260 | # first step
261 | start_tokens = tf.zeros([self.batch_size], tf.int32) # id for ['GO']
262 | St_0 = tf.nn.embedding_lookup(self.Embedding_, start_tokens)
263 | At_0 = attention_state[0]
264 | h_t, c_t, p_t = self.lstm_single_step(St_0, At_0, h_t_0, c_t_0, p_t_0)
265 | p_t_lstm_list.append(p_t)
266 | tf.summary.histogram("prob_t", p_t)
267 | # next steps
268 | for time_step, merge in enumerate(zip(cnn_outputs[:-1], attention_state[1:])):
269 | St, At = merge[0], merge[1]
270 | if self.is_training:
271 | p_t = tf.cond(self.cur_learning, lambda: self.weight_control(time_step, p_t), lambda: p_t)
272 | h_t, c_t, p_t = self.lstm_single_step(St, At, h_t, c_t, p_t)
273 | p_t_lstm_list.append(p_t)
274 | tf.summary.histogram("sen_t", St)
275 | tf.summary.histogram("prob_t", p_t)
276 | # results
277 | logits = tf.concat(p_t_lstm_list, axis=1)
278 |
279 | return logits
280 |
281 | def word_extractor(self): # TODO
282 | # LSTM inputs: h_t = LSTM(wt-1,h_t-1)
283 | # Attention: h~t = Attention(h_t,h)
284 | logits_list = []
285 | length_list = []
286 | # values_decoder_embedded: [batch_size, document_length]
287 | # inputs_decoder_embedded: [batch_size, max_num_abstract, input_y2_max_length]
288 | attent_decoder_embedded = []
289 | values_decoder_embedded = []
290 | inputs_decoder_embedded = []
291 | initial_state_embedded =[]
292 | encoder_inputs_lengths = []
293 | embedded_values = tf.nn.embedding_lookup(self.Embedding, self.value_decoder_x)
294 | for idx in range(self.batch_size):
295 | c = tf.concat([self.initial_state[0][idx:idx+1] for _ in range(self.max_num_abstract)], axis=0)
296 | h = tf.concat([self.initial_state[1][idx:idx+1] for _ in range(self.max_num_abstract)], axis=0)
297 | embedded_initial_expand = tf.nn.rnn_cell.LSTMStateTuple(c, h)
298 | initial_state_embedded.append(embedded_initial_expand)
299 | embedded_attent_expand = tf.concat([self.attention_state[idx:idx+1] for _ in range(self.max_num_abstract)], axis=0)
300 | attent_decoder_embedded.append(embedded_attent_expand)
301 | embedded_abstracts = tf.nn.embedding_lookup(self.Embedding, self.input_decoder_x[idx:idx+1])
302 | embedded_abstracts_squeezed = tf.squeeze(embedded_abstracts, axis=0)
303 | inputs_decoder_embedded.append(embedded_abstracts_squeezed)
304 | embedded_values_squeezed = embedded_values[idx:idx+1]
305 | #embedded_values_squeezed = tf.squeeze(embedded_values[idx:idx+1], axis=0)
306 | values_decoder_embedded.append(embedded_values_squeezed)
307 | encoder_inputs_length = tf.squeeze(self.input_y2_length[idx:idx+1], axis=0)
308 | encoder_inputs_lengths.append(encoder_inputs_length)
309 |
310 | for attent_embedded, inputs_embedded, values_embedded, initial_state, encoder_inputs_length in zip(attent_decoder_embedded, inputs_decoder_embedded, values_decoder_embedded, initial_state_embedded, encoder_inputs_lengths):
311 |
312 | with tf.variable_scope("attention-word-decoder", reuse=tf.AUTO_REUSE ):
313 | if self.is_training:
314 | attention_state = attent_embedded
315 | document_state = values_embedded
316 | document_length = self.document_length * tf.ones([1,], dtype=tf.int32)
317 | encoder_final_state = initial_state
318 | else:
319 | """4.2 beam search preparation"""
320 | attention_state = tf.contrib.seq2seq.tile_batch(attent_embedded, multiplier=self.beam_width)
321 | document_state = tf.contrib.seq2seq.tile_batch(values_embedded, multiplier=self.beam_width)
322 | encoder_inputs_length = tf.contrib.seq2seq.tile_batch(encoder_inputs_length, multiplier=self.beam_width)
323 | document_length = tf.contrib.seq2seq.tile_batch(self.document_length * tf.ones([1,], dtype=tf.int32), multiplier=self.beam_width)
324 | encoder_final_state = tf.contrib.framework.nest.map_structure(lambda s: tf.contrib.seq2seq.tile_batch(s, self.beam_width), initial_state)
325 | """4.2 Attention(Bahdanau)"""
326 | # building attention cell
327 | lstm_cell = tf.nn.rnn_cell.LSTMCell(self.hidden_size)
328 | lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=self.dropout_keep_prob)
329 | attention_mechanism1 = attention_wrapper.BahdanauAttention(
330 | num_units=self.hidden_size, memory=attention_state, memory_sequence_length=encoder_inputs_length
331 | )
332 | attention_cell = attention_wrapper.AttentionWrapper(
333 | cell=lstm_cell, attention_mechanism=attention_mechanism1, attention_layer_size=self.attention_size, \
334 | # cell_input_fn=(lambda inputs, attention: tf.layers.Dense(self.hidden_size, dtype=tf.float32, name="attention_inputs")(array.ops.concat([inputs, attention],-1))) TODO \
335 | #cell_input_fn=(lambda inputs, attention: tf.squeeze(tf.layers.Dense(self.hidden_size, dtype=tf.float32, name="attention_inputs")(inputs), axis=0)), \
336 | cell_input_fn=(lambda inputs, attention: tf.layers.Dense(self.hidden_size, dtype=tf.float32, name="attention_inputs")(inputs)), \
337 | alignment_history=False, name='Attention_Wrapper' \
338 | )
339 |
340 | batch_size = self.max_num_abstract if self.is_training else self.max_num_abstract * self.beam_width
341 | decoder_initial_state = attention_cell.zero_state(batch_size=(batch_size), dtype=tf.float32).clone(cell_state=encoder_final_state)
342 | #tf.scalar_mul(inputs_embedded, inputs_embedded)
343 | if self.is_training:
344 | helper = tf.contrib.seq2seq.TrainingHelper(inputs=inputs_embedded, sequence_length=encoder_inputs_length, time_major=False, name="training_helper")
345 | training_decoder = tf.contrib.seq2seq.BasicDecoder(cell=attention_cell,helper=helper,initial_state=decoder_initial_state, output_layer=None)
346 | decoder_outputs, _, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder=training_decoder,output_time_major=False,impute_finished=True,maximum_iterations=self.input_y2_max_length)
347 | else:
348 | start_tokens=tf.ones([self.max_num_abstract,], tf.int32) * 2
349 | end_token= 3
350 | inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell=attention_cell,embedding=document_state,start_tokens=start_tokens,end_token=end_token,initial_state=decoder_initial_state,beam_width=self.beam_width,output_layer=None)
351 | decoder_outputs, _, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder=inference_decoder,output_time_major=False,impute_finished=True,maximum_iterations=self.input_y2_max_length)
352 | length_list.append(final_sequence_lengths)
353 |
354 | """4.2 attention * document mat"""
355 | # decoder_outputs: [batch_size, input_y2_max_length, attention_size]
356 | # final_sequence_lengths: [batch_size]
357 | # logits: [batch_size, input_y2_max_length, document_length]
358 | with tf.variable_scope("attention-vocab", reuse=tf.AUTO_REUSE):
359 | attention_mechanism2 =attention_wrapper.BahdanauAttention(
360 | num_units=self.attention_size, memory=document_state, memory_sequence_length=document_length
361 | )
362 | state = tf.constant(True, dtype = tf.bool) # TODO trolling me ...
363 | decoder_outputs = decoder_outputs[0]
364 | list2 = []
365 | for idx in range(self.max_num_abstract):
366 | list1=[]
367 | for step in range(self.input_y2_max_length):
368 | src = decoder_outputs[idx:idx+1,step:step+1,:]
369 | print (src.get_shape)
370 | #print (src.get_shape == (1,1,self.attention_size))
371 | cond = tf.constant((src.get_shape == (1,1,self.attention_size)), tf.bool)
372 | query = tf.cond(cond, lambda:tf.squeeze(src, axis=1), lambda:tf.zeros([1,self.attention_size],tf.float32))
373 | logits, state = attention_mechanism2(query=query, state=state)
374 | list1.append(logits)
375 | logits = tf.stack(list1, axis=1)
376 | list2.append(logits)
377 | logits = tf.concat(list2, axis=0)
378 | logits_list.append(logits)
379 |
380 | if self.is_training:
381 | return logits_list, []
382 | else:
383 | return logits_list, length_list
384 |
385 | def inference(self):
386 | """
387 | compute graph:
388 | 1.Embedding--> 2.CNN(word)-->3.LSTM(sentence) (Document Reader)
389 | 4.1 LSTM + MLP(labeling) (Sentence Extractor)
390 | 4.2 LSTM + Attention(generating) (Word Extractor)
391 | """
392 | self.cnn_outputs, self.attention_state, self.initial_state = self.document_reader()
393 | if self.extract_sentence_flag:
394 | logits = self.sentence_extractor()
395 | return logits
396 | else:
397 | logits, final_sequence_lengths = self.word_extractor()
398 | return logits, final_sequence_lengths
399 |
400 | def loss_sentence(self, l2_lambda = 0.0001):
401 | # multi_class_labels: [batch_size, max_num_sequence]
402 | # logits: [batch_size, max_num_sequence]
403 | # losses: [batch_size, max_num_sequence]
404 | # origin:sigmoid log: max(x, 0) + x * z + log(1 + exp(-x))
405 | # z*-log(x)+(1-z)*-log(1-x)
406 | # z=0 --> -log(1-x)
407 | # z=1 --> -log(x)
408 | with tf.name_scope("loss_sentence"):
409 | logits = tf.convert_to_tensor(self.logits)
410 | labels = tf.cast(self.input_y1, logits.dtype)
411 | zeros = tf.zeros_like(labels, dtype=labels.dtype)
412 | ones = tf.ones_like(logits, dtype=logits.dtype)
413 | cond = ( labels > zeros )
414 | logits_ = tf.where(cond, logits, ones-logits)
415 | logits_log = tf.log(logits_)
416 | losses = -logits_log
417 | losses *= self.mask
418 | l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name]) * l2_lambda
419 | tf.summary.scalar("l2_loss", l2_loss)
420 | loss = tf.reduce_sum(losses, axis=1)
421 | loss = tf.reduce_mean(loss)
422 | tf.summary.scalar("loss", loss)
423 |
424 | return loss+l2_loss
425 |
426 | def loss_word(self, l2_lambda=0.001):
427 | # logits: [batch_size, sequence_length, document_length]
428 | # targets: [batch_size, sequence_length]
429 | # weights: [batch_size, sequence_length]
430 | # loss: scalar
431 | with tf.name_scope("loss_word"):
432 | loss = tf.Variable(0.0, trainable=False, dtype= tf.float32)
433 | for logits, targets, mask in zip(self.logits, self.targets, self.mask_list):
434 | loss += tf.contrib.seq2seq.sequence_loss(logits=logits,targets=targets,weights=mask,average_across_timesteps=True,average_across_batch=True)
435 | #l2_losses = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name]) * l2_lambda
436 | #loss = loss + l2_losses
437 | tf.summary.scalar("loss", loss)
438 | return loss
439 |
440 | def train_frozen(self):
441 | with tf.name_scope("train_op_frozen"):
442 | learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step, self.decay_step, self.decay_rate, staircase=True)
443 | self.learning_rate = learning_rate
444 | optimizer = tf.train.AdamOptimizer(learning_rate,beta1=0.99)
445 | tvars = [tvar for tvar in tf.trainable_variables() if 'embedding' not in tvar.name]
446 | gradients, variables = zip(*optimizer.compute_gradients(self.loss_val, tvars))
447 | gradients, _ = tf.clip_by_global_norm(gradients, self.clip_gradients)
448 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
449 | with tf.control_dependencies(update_ops):
450 | train_op = optimizer.apply_gradients(zip(gradients, variables))
451 | return train_op
452 |
453 | def train(self):
454 | with tf.name_scope("train_op"):
455 | learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step, self.decay_step, self.decay_rate, staircase=True)
456 | self.learning_rate = learning_rate
457 | optimizer = tf.train.AdamOptimizer(learning_rate,beta1=0.99)
458 | gradients, variables = zip(*optimizer.compute_gradients(self.loss_val))
459 | gradients, _ = tf.clip_by_global_norm(gradients, self.clip_gradients)
460 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
461 | with tf.control_dependencies(update_ops):
462 | train_op = optimizer.apply_gradients(zip(gradients, variables))
463 | return train_op
464 |
465 |
--------------------------------------------------------------------------------