├── .gitignore ├── README.md ├── assets ├── demo.png ├── model.png └── model_demo.png ├── data └── test │ ├── article.txt │ └── title.txt ├── loader.py ├── main.py ├── model ├── base.py ├── encoder.py ├── nnlm.py └── ops.py ├── processing.py ├── test.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # data 2 | data 3 | *.json 4 | data/duc 5 | *.pkl 6 | logs 7 | 8 | # trash 9 | .dropbox 10 | 11 | # Created by https://www.gitignore.io/api/python,vim 12 | 13 | ### Python ### 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *,cover 59 | .hypothesis/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | 75 | ### Vim ### 76 | [._]*.s[a-w][a-z] 77 | [._]s[a-w][a-z] 78 | *.un~ 79 | Session.vim 80 | .netrwhist 81 | *~ 82 | 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Attention-Based Summarization 2 | ============================= 3 | 4 | Tensorflow implementation of [A Neural Attention Model for Abstractive Summarization](http://arxiv.org/abs/1509.00685). The original code of author can be found [here](https://github.com/facebook/NAMAS). 5 | 6 | ![model_demo](./assets/model_demo.png) 7 | 8 | 9 | Prerequisites 10 | ------------- 11 | 12 | - Python 2.7 or Python 3.3+ 13 | - [Tensorflow](https://www.tensorflow.org/) 14 | - [Gensim](https://radimrehurek.com/gensim/) 15 | 16 | 17 | Usage 18 | ----- 19 | 20 | To train a model with `duc2013` dataset: 21 | 22 | $ python main.py --dataset duc2013 23 | 24 | To test an existing model: 25 | 26 | $ python main.py --dataset duc2014 --forward_only True 27 | 28 | (This is still in progress and currently have no access to summarization dataset) 29 | 30 | 31 | References 32 | ---------- 33 | 34 | - [EMNLP 2015 slide](http://people.seas.harvard.edu/~srush/emnlp2015_slides.pdf) 35 | 36 | 37 | Author 38 | ------ 39 | 40 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io/) 41 | -------------------------------------------------------------------------------- /assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/neural-summary-tensorflow/12c15eff1eb0e0d2ac6c9d0c96087f0924e18588/assets/demo.png -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/neural-summary-tensorflow/12c15eff1eb0e0d2ac6c9d0c96087f0924e18588/assets/model.png -------------------------------------------------------------------------------- /assets/model_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/neural-summary-tensorflow/12c15eff1eb0e0d2ac6c9d0c96087f0924e18588/assets/model_demo.png -------------------------------------------------------------------------------- /data/test/article.txt: -------------------------------------------------------------------------------- 1 | schizophrenia patients whose medication could n't stop the imaginary voices in their heads gained some relief after researchers repeatedly sent a magnetic field into a small area of their brains . 2 | scientists trying to fathom the mystery of schizophrenia say they have found the strongest evidence to date that the disabling psychiatric disorder is caused by gene abnormalities , according to a researcher at two state universities . 3 | a yale school of medicine study is expanding upon what scientists know about the link between schizophrenia and nicotine addiction . 4 | exploring chaos in a search for order , scientists who study the reality-shattering mental disease schizophrenia are becoming fascinated by the chemical environment of areas of the brain where perception is regulated . 5 | schizophrenia may be one of the most baffling of mental illnesses , but the medical profession and society at large could be making much better use of what we do know about combating the devastating disorder , a major study has found . 6 | cesarean babies may be more susceptible to schizophrenia than children born naturally , say canadian researchers who studied the effects of the operation on rats . 7 | a family history of schizophrenia remains the best predictor of whether a person will develop the illness , but environmental influences like place and season of birth are also significant risk factors , according to the findings of a large epidemiological study by danish researchers . 8 | schizophrenia , the devastating mental illness that afflicts an estimated 2.7 million americans , is rarely diagnosed until it becomes full-blown . 9 | a massachusetts scientist who helped dispel the once widely held belief that schizophrenia came from bad mothering has won an albert lasker award , widely regarded as america 's nobel prize , for his lifetime achievement . 10 | by the time most psychiatrists encounter schizophrenia , its symptoms are already in full flower , but scientists have long surmised that the illness starts much earlier , the demons beginning to nibble at the edges of young people 's lives well before the most flagrant psychosis appears . 11 | -------------------------------------------------------------------------------- /data/test/title.txt: -------------------------------------------------------------------------------- 1 | Magnetic treatment may ease or lessen occurrence of schizophrenic voices. 2 | Evidence shows schizophrenia caused by gene abnormalities of Chromosome 1. 3 | Researchers examining evidence of link between schizophrenia and nicotine addiction. 4 | Scientists focusing on chemical environment of brain to understand schizophrenia. 5 | Schizophrenia study shows disparity between what's known and what's provided to patients. 6 | Researchers find Cesarean babies may be more susceptible to schizophrenia. 7 | Family history, season, and place of birth linked to schizophrenia. 8 | This study shows schizophrenia can be identified before symptoms appear. 9 | Neuroscientist awarded for research linking development of schizophrenia to genetics. 10 | Researchers using newest drugs to detect early signs of schizophrenia. 11 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from gensim import corpora 3 | 4 | def get_dictionary(fname, max_vocabulary_size=None): 5 | with open(fname) as f: 6 | texts = [word for word in f.read().lower().split()] 7 | return corpora.Dictionary([texts], prune_at=max_vocabulary_size) 8 | 9 | class Loader(object): 10 | 11 | def __init__(self, dataset, batch_size, dataset_dir="data", 12 | title_fname="title.txt", article_fname="article.txt"): 13 | self.dataset = dataset 14 | self.batch_size = batch_size 15 | self.dataset_dir = dataset_dir 16 | 17 | title_fname = os.path.join(self.dataset_dir, self.dataset, title_fname) 18 | articles_fname = os.path.join(self.dataset_dir, self.dataset, article_fname) 19 | 20 | dict_fname = os.path.join(self.dataset_dir, self.dataset, "dict") 21 | 22 | if not os.path.exists(dict_fname): 23 | self.article_dict = get_dictionary(title_fname) 24 | 25 | self.build_article_matrices() 26 | 27 | def build_article_matrices(): 28 | pass 29 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from models import NAMAS 6 | from utils import * 7 | 8 | from utils import pp 9 | 10 | flags = tf.app.flags 11 | flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") 12 | flags.DEFINE_integer("word_embed_dim", 650, "The dimension of word embedding matrix [650]") 13 | flags.DEFINE_integer("char_embed_dim", 15, "The dimension of char embedding matrix [15]") 14 | flags.DEFINE_integer("max_word_length", 65, "The maximum length of word [65]") 15 | flags.DEFINE_integer("batch_size", 100, "The size of batch images [100]") 16 | flags.DEFINE_integer("seq_length", 35, "The # of timesteps to unroll for [35]") 17 | flags.DEFINE_float("learning_rate", 1.0, "Learning rate [1.0]") 18 | flags.DEFINE_float("decay", 0.5, "Decay of RMSProp [0.5]") 19 | flags.DEFINE_float("dropout_prob", 0.5, "Probability of dropout layer [0.5]") 20 | flags.DEFINE_string("feature_maps", "[50,100,150,200,200,200,200]", "The # of feature maps in CNN [50,100,150,200,200,200,200]") 21 | flags.DEFINE_string("kernels", "[1,2,3,4,5,6,7]", "The width of CNN kernels [1,2,3,4,5,6,7]") 22 | flags.DEFINE_string("model", "NAMAS", "The type of model to train and test [LSTM, LSTMTDNN]") 23 | flags.DEFINE_string("data_dir", "data", "The name of data directory [data]") 24 | flags.DEFINE_string("dataset", "ptb", "The name of dataset [ptb]") 25 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") 26 | flags.DEFINE_boolean("forward_only", False, "True for forward only, False for training [False]") 27 | flags.DEFINE_boolean("use_char", True, "Use character-level language model [True]") 28 | flags.DEFINE_boolean("use_word", False, "Use word-level language [False]") 29 | FLAGS = flags.FLAGS 30 | 31 | model_dict = { 32 | 'NAMAS': NAMAS, 33 | } 34 | 35 | def main(_): 36 | pp.pprint(flags.FLAGS.__flags) 37 | 38 | if not os.path.exists(FLAGS.checkpoint_dir): 39 | print(" [*] Creating checkpoint directory...") 40 | os.makedirs(FLAGS.checkpoint_dir) 41 | 42 | with tf.Session() as sess: 43 | model = model_dict[FLAGS.model](sess, checkpoint_dir=FLAGS.checkpoint_dir, 44 | seq_length=FLAGS.seq_length, 45 | word_embed_dim=FLAGS.word_embed_dim, 46 | char_embed_dim=FLAGS.char_embed_dim, 47 | feature_maps=eval(FLAGS.feature_maps), 48 | kernels=eval(FLAGS.kernels), 49 | batch_size=FLAGS.batch_size, 50 | dropout_prob=FLAGS.dropout_prob, 51 | max_word_length=FLAGS.max_word_length, 52 | forward_only=FLAGS.forward_only, 53 | dataset_name=FLAGS.dataset, 54 | use_char=FLAGS.use_char, 55 | use_word=FLAGS.use_word, 56 | data_dir=FLAGS.data_dir) 57 | 58 | if not FLAGS.forward_only: 59 | model.run(FLAGS.epoch, FLAGS.learning_rate, FLAGS.decay) 60 | else: 61 | print(" [*] Test loss: %2.6f, perplexity: %2.6f" % (test_loss, np.exp(test_loss))) 62 | 63 | if __name__ == '__main__': 64 | tf.app.run() 65 | -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import tensorflow as tf 4 | 5 | class Model(object): 6 | """Abstract object representing an Reader model.""" 7 | def __init__(self): 8 | self.vocab = None 9 | self.data = None 10 | 11 | def get_model_dir(self): 12 | model_dir = self.dataset 13 | for attr in self._attrs: 14 | if hasattr(self, attr): 15 | model_dir += "_%s:%s" % (attr, getattr(self, attr)) 16 | return model_dir 17 | 18 | def save(self, checkpoint_dir, global_step=None): 19 | self.saver = tf.train.Saver() 20 | 21 | print(" [*] Saving checkpoints...") 22 | model_name = type(self).__name__ or "Reader" 23 | model_dir = self.get_model_dir() 24 | 25 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 26 | if not os.path.exists(checkpoint_dir): 27 | os.makedirs(checkpoint_dir) 28 | self.saver.save(self.sess, 29 | os.path.join(checkpoint_dir, model_name), global_step=global_step) 30 | 31 | def load(self, checkpoint_dir): 32 | self.saver = tf.train.Saver() 33 | 34 | print(" [*] Loading checkpoints...") 35 | model_dir = self.get_model_dir() 36 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 37 | 38 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 39 | if ckpt and ckpt.model_checkpoint_path: 40 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 41 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 42 | print(" [*] Load SUCCESS") 43 | return True 44 | else: 45 | print(" [!] Load failed...") 46 | return False 47 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .ops import * 4 | 5 | class Encoder(object): 6 | """Encoder 7 | """ 8 | def __init__(self, model_name="bow", bow_dim=50, 9 | attention_pool=5, hidden_size=1000, kernel_width=5): 10 | """Initialize the parameters for Encoder 11 | 12 | Args: 13 | model_name: the name of encoder to use [bow] 14 | bow_dim: the size of article embedding [50] 15 | attention_pool: the size of attention model pooling [5] 16 | hidden_size: the size of hidden units in ConvNet [1000] 17 | kernel_width: the size of kernel width in ConvNet [5] 18 | """ 19 | self.model = None 20 | self.bow_dim = bow_dim 21 | self.attention_pool = attention_pool 22 | self.hidden_size = hidden_size 23 | self.kernel_width = kernel_width 24 | 25 | self.build_model(model_name) 26 | 27 | def build_model(self, model_name): 28 | if model_name == None: 29 | self.model = self.build_blank_model(data) 30 | elif model_name == "bow": 31 | self.model = self.build_bow_model(data) 32 | elif model_name == "bow": 33 | self.model = self.build_bow_model(data) 34 | elif model_name == "bow": 35 | self.model = self.build_bow_model(data) 36 | elif model_name == "bow": 37 | self.model = self.build_bow_model(data) 38 | else: 39 | print(" [!] Wrong model name : %s" % model_name) 40 | 41 | def build_blank_model(self): 42 | """Ignores the article layer entirely (acts like LM). 43 | """ 44 | loookup, ignore1, ignore2 = None, None, None 45 | 46 | start = ignore2 47 | mout = tf.constant(0) * start 48 | 49 | def build_bow_model(self, data): 50 | print(" [*] Build Encoder: Bag-of-Words") 51 | loookup, ignore1, ignore2 = None, None, None 52 | bow_embed = tf.get_variable(tf.float32, [len(data), self.bow_dim]) 53 | 54 | start = tf.nn.lookup_embedding(bow_embed, input_) 55 | mout = linear(tf.reduce_mean(tf.transpose(start, [2,3]), 2), self.bow_dim) 56 | 57 | def build_conv_model(self): 58 | loookup, ignore1, ignore2 = None, None, None 59 | print(" [*] Build Encoder: ConvNet") 60 | V2 = len(data.article_data.i2s) 61 | 62 | article_embed = tf.get_variable(tf.float32, [self.hidden_size]) 63 | 64 | # Ignore the context 65 | ignore1, ignore2 = None, None 66 | 67 | start = tf.nn.lookup_embedding(article_embed, input_) 68 | -------------------------------------------------------------------------------- /model/nnlm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .base import Model 4 | 5 | class NNLM(Model): 6 | """Neural Network Language Model""" 7 | 8 | def __init__(self, dictionary, encoder_size, 9 | embedding_dim=50, hidden_dim=100, 10 | window_size=5, encoder_type="deep"): 11 | """Initialize the parameters for Encoder 12 | 13 | Args: 14 | model_name: the name of encoder to use [bow] 15 | bow_dim: the size of article embedding [50] 16 | attention_pool: the size of attention model pooling [5] 17 | hidden_dim: the size of hidden units in ConvNet [1000] 18 | kernel_width: the size of kernel width in ConvNet [5] 19 | """ 20 | self.dictionary = dictionary 21 | self.encoder_size = encoder_size 22 | self.window_size = window_size 23 | self.embedding_dim = embedding_dim 24 | self.hidden_dim = hidden_dim 25 | self.encoder_type = encoder_type 26 | 27 | def train(self, epochs=5, batch_size=64, learning_rate=0.1): 28 | pass 29 | -------------------------------------------------------------------------------- /model/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def conv2d(input_, output_dim, k_h, k_w, 4 | stddev=0.02, name="conv2d"): 5 | with tf.variable_scope(name): 6 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 7 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 8 | conv = tf.nn.conv2d(input_, w, strides=[1, 1, 1, 1], padding='VALID') 9 | return conv 10 | -------------------------------------------------------------------------------- /processing.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #!/bin/python 3 | import re 4 | import os 5 | import sys 6 | import json 7 | from collections import defaultdict 8 | 9 | input_path = sys.argv[1] 10 | try: 11 | is_korean = sys.argv[2] 12 | except: 13 | is_korean = False 14 | 15 | title_path = os.path.join(os.path.dirname(input_path), "title.txt") 16 | article_path = os.path.join(os.path.dirname(input_path), "article.txt") 17 | 18 | def remove_digits(parse): 19 | return re.sub(ur"…", "", re.sub(ur".+ (기자|특파원) = ", "", re.sub(ur'(·|!|\'|`|"|‘|’|“|”|=|ㆍ)', " ", re.sub(r'\d+', '#', parse)))) 20 | 21 | case = defaultdict(int) 22 | 23 | with open(input_path) as f, \ 24 | open(title_path, "w") as t_f, \ 25 | open(article_path, "w") as a_f: 26 | try: 27 | from tqdm import tqdm 28 | items = tqdm(json.loads(f.read())) 29 | except: 30 | items = json.loads(f.read()) 31 | 32 | for item in items: 33 | title = " ".join(remove_digits(item['t']).lower().split()) 34 | article = " ".join(remove_digits(item['a']).lower().split()) 35 | 36 | title = re.sub(r'\[.+\]', '', title).strip() 37 | title = re.sub(r'\(.+\)', '', title).strip() 38 | article = re.sub(r'\[.+\]', '', article).strip() 39 | article = re.sub(r'\(.+\)', '', article).strip() 40 | 41 | if u'\ub2e4.' in article: 42 | article = article.split(u'\ub2e4.', 1)[0] 43 | article = article + u'\ub2e4' 44 | 45 | title_words = title.strip().split() 46 | article_words = article.strip().split() 47 | 48 | # No blanks. 49 | if any((word == "" for word in title_words)) or title_words == []: 50 | case['empty_title'] += 1 51 | continue 52 | 53 | if any((word == "" for word in article_words)) or article_words == []: 54 | case['empty_article'] += 1 55 | continue 56 | 57 | if is_korean: 58 | if article_words[-1][-1] != u"다": 59 | case['not_end_with_da'] += 1 60 | continue 61 | 62 | if article[-2:] != u'\ub2e4.': 63 | continue 64 | 65 | if len(re.findall(r"\w", article)) > 20 or len(re.findall(r"\w", title)) > 10: 66 | case['only english'] += 1 67 | continue 68 | 69 | bad_words = [u'?', u'"', u"-"] 70 | 71 | if any((bad in title.lower() 72 | for bad in bad_words)): 73 | case['bad in title'] += 1 74 | continue 75 | 76 | bad_words = ['======', u'"'] 77 | 78 | if any((bad in article.lower() 79 | for bad in bad_words)): 80 | case['bad in article'] += 1 81 | continue 82 | 83 | # Reasonable lengths 84 | if not (10 < len(article_words) < 100 and 85 | 2 < len(title_words) < 50): 86 | case['reasonable length'] += 1 87 | continue 88 | 89 | # Some word match. 90 | if is_korean: 91 | matches = len(set([w[:2] for w in title_words if len(w) > 1]) & 92 | set([w[:2] for w in article_words if len(w) > 1])) 93 | else: 94 | matches = len(set([w for w in title_words if len(w) > 1]) & 95 | set([w for w in article_words if len(w) > 1])) 96 | if matches < 1: 97 | case['zero match'] += 1 98 | continue 99 | 100 | t_f.write(title.encode('utf-8') + "\n") 101 | a_f.write(article.encode('utf-8') + "\n") 102 | 103 | print case 104 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from loader import Loader 2 | 3 | x = Loader("duc", 64) 4 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pprint 3 | 4 | try: 5 | xrange 6 | except NameError: 7 | xrange = range 8 | 9 | pp = pprint.PrettyPrinter() 10 | --------------------------------------------------------------------------------