├── .gitignore ├── LICENSE ├── README.md ├── demo ├── __init__.py ├── data │ └── fake_or_real_news.csv ├── explore_data.py ├── models │ ├── one-shot-rnn-architecture.json │ ├── one-shot-rnn-config.npy │ ├── one-shot-rnn-weights.h5 │ ├── recursive-rnn-1-architecture.json │ ├── recursive-rnn-1-config.npy │ ├── recursive-rnn-1-weights.h5 │ ├── recursive-rnn-2-architecture.json │ ├── recursive-rnn-2-config.npy │ ├── recursive-rnn-2-weights.h5 │ ├── seq2seq-architecture.json │ ├── seq2seq-config.npy │ ├── seq2seq-glove-architecture.json │ ├── seq2seq-glove-config.npy │ ├── seq2seq-glove-v2-architecture.json │ ├── seq2seq-glove-v2-config.npy │ ├── seq2seq-glove-v2-weights.h5 │ ├── seq2seq-glove-weights.h5 │ └── seq2seq-weights.h5 ├── one_hot_rnn_predict.py ├── one_hot_rnn_train.py ├── recursive_rnn_v1_predict.py ├── recursive_rnn_v1_train.py ├── recursive_rnn_v2_predict.py ├── recursive_rnn_v2_train.py ├── recursive_rnn_v3_predict.py ├── recursive_rnn_v3_train.py ├── reports │ ├── recursive-rnn-1-history.png │ ├── recursive-rnn-2-history.png │ ├── seq2seq-glove-history-v1.png │ ├── seq2seq-glove-history.png │ ├── seq2seq-glove-v2-history.png │ ├── seq2seq-history-v2.png │ └── seq2seq-history.png ├── seq2seq_glove_predict.py ├── seq2seq_glove_train.py ├── seq2seq_glove_v2_predict.py ├── seq2seq_glove_v2_train.py ├── seq2seq_predict.py ├── seq2seq_train.py └── very_large_data │ └── .gitignore ├── keras_text_summarization ├── __init__.py └── library │ ├── __init__.py │ ├── applications │ ├── __init__.py │ └── fake_news_loader.py │ ├── rnn.py │ ├── seq2seq.py │ └── utility │ ├── __init__.py │ ├── device_utils.py │ ├── glove_loader.py │ ├── plot_utils.py │ └── text_utils.py ├── notes ├── ReadMe.md └── evaluation.md ├── requirements.txt ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | .idea/ 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Xianshun Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras-text-summarization 2 | 3 | Text summarization using seq2seq and encoder-decoder recurrent networks in Keras 4 | 5 | # Machine Learning Models 6 | 7 | The follow neural network models are implemented and studied for text summarization: 8 | 9 | ### Seq2Seq 10 | 11 | The seq2seq models encodes the content of an article (encoder input) and one character (decoder input) from the summarized text to predict the next character in the summarized text 12 | 13 | The implementation can be found in [keras_text_summarization/library/seq2seq.py](keras_text_summarization/library/seq2seq.py) 14 | 15 | There are three variants of seq2seq model implemented for the text summarization 16 | * Seq2SeqSummarizer (one hot encoding) 17 | * training: run [demo/seq2seq_train.py](demo/seq2seq_train.py ) 18 | * prediction: demo code is available in [demo/seq2seq_predict.py](demo/seq2seq_predict.py) 19 | * Seq2SeqGloVeSummarizer (GloVe encoding for encoder input) 20 | * training: run [demo/seq2seq_glove_train.py](demo/seq2seq_glove_train.py) 21 | * prediction: demo code is available in [demo/seq2seq_glove_predict.py](demo/seq2seq_glove_predict.py) 22 | * Seq2SeqGloVeSummarizerV2 (GloVe encoding for both encoder input and decoder input) 23 | * training: run [demo/seq2seq_glove_v2_train.py](demo/seq2seq_glove_v2_train.py) 24 | * prediction: demo code is available in [demo/seq2seq_glove_v2_predict.py](demo/seq2seq_glove_v2_predict.py) 25 | 26 | ### Other RNN models 27 | 28 | There are currently 3 other encoder-decoder recurrent models based on some recommendation [here](https://machinelearningmastery.com/encoder-decoder-models-text-summarization-keras/) 29 | 30 | The implementation can be found in [keras_text_summarization/library/rnn.py](keras_text_summarization/library/rnn.py) 31 | 32 | * One-Shot RNN (OneShotRNN in [rnn.py](keras_text_summarization/library/rnn.py)): 33 | The one-shot RNN is a very simple encoder-decoder recurrent network model which encodes the content of an article and decodes the entire content of the summarized text 34 | * training: run [demo/one_hot_rnn_train.py](demo/one_hot_rnn_train.py) 35 | * prediction: run [demo/one_hot_rnn_predict.py](demo/one_hot_rnn_predict.py) 36 | * Recursive RNN 1 (RecursiveRNN1 in [rnn.py](keras_text_summarization/library/rnn.py)): 37 | The recursive RNN 1 takes the artcile content and the current built-up summarized text to predict the next character of the summarized text. 38 | * training: run [demo/recursive_rnn_v1_train.py](demo/recursive_rnn_v1_train.py) 39 | * prediction: run [demo/recursive_rnn_v1_predict.py](demo/recursive_rnn_v1_predict.py) 40 | * Recursive RNN 2 (RecursiveRNN2 in [rnn.py](keras_text_summarization/library/rnn.py)): 41 | The recursive RNN 2 takes the article content and the current built-up summarized text to predict the next character of the summarized text + one layer of LSTM decoder. 42 | * training: run [demo/recursive_rnn_v2_train.py](demo/recursive_rnn_v2_train.py) 43 | * prediction: run [demo/recursive_rnn_v2_predict.py](demo/recursive_rnn_v2_predict.py) 44 | 45 | The trained models are available in the demo/models folder 46 | 47 | # Usage 48 | 49 | The demo below shows how to use seq2seq to do training and prediction, but other models described above also follow 50 | the same process of training and prediction. 51 | 52 | ### Train Deep Learning model 53 | 54 | To train a deep learning model, say Seq2SeqSummarizer, run the following commands: 55 | 56 | ```bash 57 | pip install requirements.txt 58 | 59 | cd demo 60 | python seq2seq_train.py 61 | ``` 62 | 63 | The training code in seq2seq_train.py is quite straightforward and illustrated below: 64 | 65 | ```python 66 | from __future__ import print_function 67 | 68 | import pandas as pd 69 | from sklearn.model_selection import train_test_split 70 | from keras_text_summarization.library.utility.plot_utils import plot_and_save_history 71 | from keras_text_summarization.library.seq2seq import Seq2SeqSummarizer 72 | from keras_text_summarization.library.applications.fake_news_loader import fit_text 73 | import numpy as np 74 | 75 | LOAD_EXISTING_WEIGHTS = True 76 | 77 | np.random.seed(42) 78 | data_dir_path = './data' 79 | report_dir_path = './reports' 80 | model_dir_path = './models' 81 | 82 | print('loading csv file ...') 83 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 84 | 85 | print('extract configuration from input texts ...') 86 | Y = df.title 87 | X = df['text'] 88 | 89 | config = fit_text(X, Y) 90 | 91 | summarizer = Seq2SeqSummarizer(config) 92 | 93 | if LOAD_EXISTING_WEIGHTS: 94 | summarizer.load_weights(weight_file_path=Seq2SeqSummarizer.get_weight_file_path(model_dir_path=model_dir_path)) 95 | 96 | Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.2, random_state=42) 97 | 98 | history = summarizer.fit(Xtrain, Ytrain, Xtest, Ytest, epochs=100) 99 | 100 | history_plot_file_path = report_dir_path + '/' + Seq2SeqSummarizer.model_name + '-history.png' 101 | if LOAD_EXISTING_WEIGHTS: 102 | history_plot_file_path = report_dir_path + '/' + Seq2SeqSummarizer.model_name + '-history-v' + str(summarizer.version) + '.png' 103 | plot_and_save_history(history, summarizer.model_name, history_plot_file_path, metrics={'loss', 'acc'}) 104 | ``` 105 | 106 | After the training is completed, the trained models will be saved as cf-v1-*.* in the video_classifier/demo/models. 107 | 108 | ### Summarization 109 | 110 | To use the trained deep learning model to summarize an article, the following code demo how to do this: 111 | 112 | ```python 113 | 114 | from __future__ import print_function 115 | 116 | import pandas as pd 117 | from keras_text_summarization.library.seq2seq import Seq2SeqSummarizer 118 | import numpy as np 119 | 120 | np.random.seed(42) 121 | data_dir_path = './data' # refers to the demo/data folder 122 | model_dir_path = './models' # refers to the demo/models folder 123 | 124 | print('loading csv file ...') 125 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 126 | X = df['text'] 127 | Y = df.title 128 | 129 | config = np.load(Seq2SeqSummarizer.get_config_file_path(model_dir_path=model_dir_path)).item() 130 | 131 | summarizer = Seq2SeqSummarizer(config) 132 | summarizer.load_weights(weight_file_path=Seq2SeqSummarizer.get_weight_file_path(model_dir_path=model_dir_path)) 133 | 134 | print('start predicting ...') 135 | for i in range(20): 136 | x = X[i] 137 | actual_headline = Y[i] 138 | headline = summarizer.summarize(x) 139 | print('Article: ', x) 140 | print('Generated Headline: ', headline) 141 | print('Original Headline: ', actual_headline) 142 | ``` 143 | 144 | # Configure to run on GPU on Windows 145 | 146 | * Step 1: Change tensorflow to tensorflow-gpu in requirements.txt and install tensorflow-gpu 147 | * Step 2: Download and install the [CUDA® Toolkit 9.0](https://developer.nvidia.com/cuda-90-download-archive) (Please note that 148 | currently CUDA® Toolkit 9.1 is not yet supported by tensorflow, therefore you should download CUDA® Toolkit 9.0) 149 | * Step 3: Download and unzip the [cuDNN 7.0.4 for CUDA@ Toolkit 9.0](https://developer.nvidia.com/cudnn) and add the 150 | bin folder of the unzipped directory to the $PATH of your Windows environment 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/__init__.py -------------------------------------------------------------------------------- /demo/explore_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from sklearn.model_selection import train_test_split 3 | from keras_text_summarization.library.applications.fake_news_loader import fit_text 4 | 5 | 6 | def main(): 7 | data_dir_path = './data' 8 | 9 | # Import `fake_or_real_news.csv` 10 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 11 | 12 | # Inspect shape of `df` 13 | print(df.shape) 14 | 15 | # Print first lines of `df` 16 | print(df.head()) 17 | 18 | # Set index 19 | df = df.set_index("Unnamed: 0") 20 | 21 | # Print first lines of `df` 22 | print(df.head()) 23 | 24 | # Set `y` 25 | Y = df.title 26 | X = df['text'] 27 | 28 | # Drop the `label` column 29 | df.drop("title", axis=1) 30 | 31 | # Make training and test sets 32 | X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=53) 33 | 34 | print('X train: ', X_train.shape) 35 | print('Y train: ', y_train.shape) 36 | 37 | config = fit_text(X, Y) 38 | 39 | print('num_input_tokens: ', config['num_input_tokens']) 40 | print('num_target_tokens: ', config['num_target_tokens']) 41 | print('max_input_seq_length: ', config['max_input_seq_length']) 42 | print('max_target_seq_length: ', config['max_target_seq_length']) 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /demo/models/one-shot-rnn-architecture.json: -------------------------------------------------------------------------------- 1 | {"class_name": "Sequential", "config": [{"class_name": "Embedding", "config": {"name": "embedding_1", "trainable": true, "batch_input_shape": [null, 500], "dtype": "float32", "input_dim": 5002, "output_dim": 128, "embeddings_initializer": {"class_name": "RandomUniform", "config": {"minval": -0.05, "maxval": 0.05, "seed": null}}, "embeddings_regularizer": null, "activity_regularizer": null, "embeddings_constraint": null, "mask_zero": false, "input_length": 500}}, {"class_name": "LSTM", "config": {"name": "lstm_1", "trainable": true, "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "units": 128, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 1}}, {"class_name": "RepeatVector", "config": {"name": "repeat_vector_1", "trainable": true, "n": 37}}, {"class_name": "LSTM", "config": {"name": "lstm_2", "trainable": true, "return_sequences": true, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "units": 128, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 1}}, {"class_name": "TimeDistributed", "config": {"name": "time_distributed_1", "trainable": true, "layer": {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "units": 2001, "activation": "softmax", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}}}], "keras_version": "2.1.2", "backend": "tensorflow"} -------------------------------------------------------------------------------- /demo/models/one-shot-rnn-config.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/models/one-shot-rnn-config.npy -------------------------------------------------------------------------------- /demo/models/one-shot-rnn-weights.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/models/one-shot-rnn-weights.h5 -------------------------------------------------------------------------------- /demo/models/recursive-rnn-1-architecture.json: -------------------------------------------------------------------------------- 1 | {"class_name": "Model", "config": {"name": "model_1", "layers": [{"name": "input_1", "class_name": "InputLayer", "config": {"batch_input_shape": [null, 500], "dtype": "float32", "sparse": false, "name": "input_1"}, "inbound_nodes": []}, {"name": "input_2", "class_name": "InputLayer", "config": {"batch_input_shape": [null, 50], "dtype": "float32", "sparse": false, "name": "input_2"}, "inbound_nodes": []}, {"name": "embedding_1", "class_name": "Embedding", "config": {"name": "embedding_1", "trainable": true, "batch_input_shape": [null, null], "dtype": "float32", "input_dim": 5002, "output_dim": 128, "embeddings_initializer": {"class_name": "RandomUniform", "config": {"minval": -0.05, "maxval": 0.05, "seed": null}}, "embeddings_regularizer": null, "activity_regularizer": null, "embeddings_constraint": null, "mask_zero": false, "input_length": null}, "inbound_nodes": [[["input_1", 0, 0, {}]]]}, {"name": "embedding_2", "class_name": "Embedding", "config": {"name": "embedding_2", "trainable": true, "batch_input_shape": [null, null], "dtype": "float32", "input_dim": 2001, "output_dim": 128, "embeddings_initializer": {"class_name": "RandomUniform", "config": {"minval": -0.05, "maxval": 0.05, "seed": null}}, "embeddings_regularizer": null, "activity_regularizer": null, "embeddings_constraint": null, "mask_zero": false, "input_length": null}, "inbound_nodes": [[["input_2", 0, 0, {}]]]}, {"name": "lstm_1", "class_name": "LSTM", "config": {"name": "lstm_1", "trainable": true, "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "units": 128, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 1}, "inbound_nodes": [[["embedding_1", 0, 0, {}]]]}, {"name": "lstm_2", "class_name": "LSTM", "config": {"name": "lstm_2", "trainable": true, "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "units": 128, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 1}, "inbound_nodes": [[["embedding_2", 0, 0, {}]]]}, {"name": "concatenate_1", "class_name": "Concatenate", "config": {"name": "concatenate_1", "trainable": true, "axis": -1}, "inbound_nodes": [[["lstm_1", 0, 0, {}], ["lstm_2", 0, 0, {}]]]}, {"name": "dense_1", "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "units": 2001, "activation": "softmax", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["concatenate_1", 0, 0, {}]]]}], "input_layers": [["input_1", 0, 0], ["input_2", 0, 0]], "output_layers": [["dense_1", 0, 0]]}, "keras_version": "2.1.2", "backend": "tensorflow"} -------------------------------------------------------------------------------- /demo/models/recursive-rnn-1-config.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/models/recursive-rnn-1-config.npy -------------------------------------------------------------------------------- /demo/models/recursive-rnn-1-weights.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/models/recursive-rnn-1-weights.h5 -------------------------------------------------------------------------------- /demo/models/recursive-rnn-2-architecture.json: -------------------------------------------------------------------------------- 1 | {"class_name": "Model", "config": {"name": "model_1", "layers": [{"name": "input_1", "class_name": "InputLayer", "config": {"batch_input_shape": [null, 500], "dtype": "float32", "sparse": false, "name": "input_1"}, "inbound_nodes": []}, {"name": "input_2", "class_name": "InputLayer", "config": {"batch_input_shape": [null, 50], "dtype": "float32", "sparse": false, "name": "input_2"}, "inbound_nodes": []}, {"name": "embedding_1", "class_name": "Embedding", "config": {"name": "embedding_1", "trainable": true, "batch_input_shape": [null, null], "dtype": "float32", "input_dim": 5002, "output_dim": 128, "embeddings_initializer": {"class_name": "RandomUniform", "config": {"minval": -0.05, "maxval": 0.05, "seed": null}}, "embeddings_regularizer": null, "activity_regularizer": null, "embeddings_constraint": null, "mask_zero": false, "input_length": null}, "inbound_nodes": [[["input_1", 0, 0, {}]]]}, {"name": "embedding_2", "class_name": "Embedding", "config": {"name": "embedding_2", "trainable": true, "batch_input_shape": [null, null], "dtype": "float32", "input_dim": 2001, "output_dim": 128, "embeddings_initializer": {"class_name": "RandomUniform", "config": {"minval": -0.05, "maxval": 0.05, "seed": null}}, "embeddings_regularizer": null, "activity_regularizer": null, "embeddings_constraint": null, "mask_zero": false, "input_length": null}, "inbound_nodes": [[["input_2", 0, 0, {}]]]}, {"name": "lstm_1", "class_name": "LSTM", "config": {"name": "lstm_1", "trainable": true, "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "units": 128, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 1}, "inbound_nodes": [[["embedding_1", 0, 0, {}]]]}, {"name": "lstm_2", "class_name": "LSTM", "config": {"name": "lstm_2", "trainable": true, "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "units": 128, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 1}, "inbound_nodes": [[["embedding_2", 0, 0, {}]]]}, {"name": "repeat_vector_1", "class_name": "RepeatVector", "config": {"name": "repeat_vector_1", "trainable": true, "n": 128}, "inbound_nodes": [[["lstm_1", 0, 0, {}]]]}, {"name": "repeat_vector_2", "class_name": "RepeatVector", "config": {"name": "repeat_vector_2", "trainable": true, "n": 128}, "inbound_nodes": [[["lstm_2", 0, 0, {}]]]}, {"name": "concatenate_1", "class_name": "Concatenate", "config": {"name": "concatenate_1", "trainable": true, "axis": -1}, "inbound_nodes": [[["repeat_vector_1", 0, 0, {}], ["repeat_vector_2", 0, 0, {}]]]}, {"name": "lstm_3", "class_name": "LSTM", "config": {"name": "lstm_3", "trainable": true, "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "units": 128, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 1}, "inbound_nodes": [[["concatenate_1", 0, 0, {}]]]}, {"name": "dense_1", "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "units": 2001, "activation": "softmax", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["lstm_3", 0, 0, {}]]]}], "input_layers": [["input_1", 0, 0], ["input_2", 0, 0]], "output_layers": [["dense_1", 0, 0]]}, "keras_version": "2.1.2", "backend": "tensorflow"} -------------------------------------------------------------------------------- /demo/models/recursive-rnn-2-config.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/models/recursive-rnn-2-config.npy -------------------------------------------------------------------------------- /demo/models/recursive-rnn-2-weights.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/models/recursive-rnn-2-weights.h5 -------------------------------------------------------------------------------- /demo/models/seq2seq-architecture.json: -------------------------------------------------------------------------------- 1 | {"class_name": "Model", "config": {"name": "model_1", "layers": [{"name": "encoder_inputs", "class_name": "InputLayer", "config": {"batch_input_shape": [null, null], "dtype": "float32", "sparse": false, "name": "encoder_inputs"}, "inbound_nodes": []}, {"name": "encoder_embedding", "class_name": "Embedding", "config": {"name": "encoder_embedding", "trainable": true, "batch_input_shape": [null, 500], "dtype": "float32", "input_dim": 5002, "output_dim": 100, "embeddings_initializer": {"class_name": "RandomUniform", "config": {"minval": -0.05, "maxval": 0.05, "seed": null}}, "embeddings_regularizer": null, "activity_regularizer": null, "embeddings_constraint": null, "mask_zero": false, "input_length": 500}, "inbound_nodes": [[["encoder_inputs", 0, 0, {}]]]}, {"name": "decoder_inputs", "class_name": "InputLayer", "config": {"batch_input_shape": [null, null, 2001], "dtype": "float32", "sparse": false, "name": "decoder_inputs"}, "inbound_nodes": []}, {"name": "encoder_lstm", "class_name": "LSTM", "config": {"name": "encoder_lstm", "trainable": true, "return_sequences": false, "return_state": true, "go_backwards": false, "stateful": false, "unroll": false, "units": 100, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 1}, "inbound_nodes": [[["encoder_embedding", 0, 0, {}]]]}, {"name": "decoder_lstm", "class_name": "LSTM", "config": {"name": "decoder_lstm", "trainable": true, "return_sequences": true, "return_state": true, "go_backwards": false, "stateful": false, "unroll": false, "units": 100, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 1}, "inbound_nodes": [[["decoder_inputs", 0, 0, {}], ["encoder_lstm", 0, 1, {}], ["encoder_lstm", 0, 2, {}]]]}, {"name": "decoder_dense", "class_name": "Dense", "config": {"name": "decoder_dense", "trainable": true, "units": 2001, "activation": "softmax", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["decoder_lstm", 0, 0, {}]]]}], "input_layers": [["encoder_inputs", 0, 0], ["decoder_inputs", 0, 0]], "output_layers": [["decoder_dense", 0, 0]]}, "keras_version": "2.1.2", "backend": "tensorflow"} -------------------------------------------------------------------------------- /demo/models/seq2seq-config.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/models/seq2seq-config.npy -------------------------------------------------------------------------------- /demo/models/seq2seq-glove-architecture.json: -------------------------------------------------------------------------------- 1 | {"class_name": "Model", "config": {"name": "model_1", "layers": [{"name": "encoder_inputs", "class_name": "InputLayer", "config": {"batch_input_shape": [null, null, 100], "dtype": "float32", "sparse": false, "name": "encoder_inputs"}, "inbound_nodes": []}, {"name": "decoder_inputs", "class_name": "InputLayer", "config": {"batch_input_shape": [null, null, 2001], "dtype": "float32", "sparse": false, "name": "decoder_inputs"}, "inbound_nodes": []}, {"name": "encoder_lstm", "class_name": "LSTM", "config": {"name": "encoder_lstm", "trainable": true, "return_sequences": false, "return_state": true, "go_backwards": false, "stateful": false, "unroll": false, "units": 100, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 1}, "inbound_nodes": [[["encoder_inputs", 0, 0, {}]]]}, {"name": "decoder_lstm", "class_name": "LSTM", "config": {"name": "decoder_lstm", "trainable": true, "return_sequences": true, "return_state": true, "go_backwards": false, "stateful": false, "unroll": false, "units": 100, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 1}, "inbound_nodes": [[["decoder_inputs", 0, 0, {}], ["encoder_lstm", 0, 1, {}], ["encoder_lstm", 0, 2, {}]]]}, {"name": "decoder_dense", "class_name": "Dense", "config": {"name": "decoder_dense", "trainable": true, "units": 2001, "activation": "softmax", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["decoder_lstm", 0, 0, {}]]]}], "input_layers": [["encoder_inputs", 0, 0], ["decoder_inputs", 0, 0]], "output_layers": [["decoder_dense", 0, 0]]}, "keras_version": "2.1.2", "backend": "tensorflow"} -------------------------------------------------------------------------------- /demo/models/seq2seq-glove-config.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/models/seq2seq-glove-config.npy -------------------------------------------------------------------------------- /demo/models/seq2seq-glove-v2-architecture.json: -------------------------------------------------------------------------------- 1 | {"class_name": "Model", "config": {"name": "model_1", "layers": [{"name": "encoder_inputs", "class_name": "InputLayer", "config": {"batch_input_shape": [null, null, 100], "dtype": "float32", "sparse": false, "name": "encoder_inputs"}, "inbound_nodes": []}, {"name": "decoder_inputs", "class_name": "InputLayer", "config": {"batch_input_shape": [null, null, 100], "dtype": "float32", "sparse": false, "name": "decoder_inputs"}, "inbound_nodes": []}, {"name": "encoder_lstm", "class_name": "LSTM", "config": {"name": "encoder_lstm", "trainable": true, "return_sequences": false, "return_state": true, "go_backwards": false, "stateful": false, "unroll": false, "units": 100, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 1}, "inbound_nodes": [[["encoder_inputs", 0, 0, {}]]]}, {"name": "decoder_lstm", "class_name": "LSTM", "config": {"name": "decoder_lstm", "trainable": true, "return_sequences": true, "return_state": true, "go_backwards": false, "stateful": false, "unroll": false, "units": 100, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 1}, "inbound_nodes": [[["decoder_inputs", 0, 0, {}], ["encoder_lstm", 0, 1, {}], ["encoder_lstm", 0, 2, {}]]]}, {"name": "decoder_dense", "class_name": "Dense", "config": {"name": "decoder_dense", "trainable": true, "units": 2001, "activation": "softmax", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["decoder_lstm", 0, 0, {}]]]}], "input_layers": [["encoder_inputs", 0, 0], ["decoder_inputs", 0, 0]], "output_layers": [["decoder_dense", 0, 0]]}, "keras_version": "2.1.2", "backend": "tensorflow"} -------------------------------------------------------------------------------- /demo/models/seq2seq-glove-v2-config.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/models/seq2seq-glove-v2-config.npy -------------------------------------------------------------------------------- /demo/models/seq2seq-glove-v2-weights.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/models/seq2seq-glove-v2-weights.h5 -------------------------------------------------------------------------------- /demo/models/seq2seq-glove-weights.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/models/seq2seq-glove-weights.h5 -------------------------------------------------------------------------------- /demo/models/seq2seq-weights.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/models/seq2seq-weights.h5 -------------------------------------------------------------------------------- /demo/one_hot_rnn_predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from keras_text_summarization.library.rnn import OneShotRNN 5 | import numpy as np 6 | 7 | 8 | def main(): 9 | np.random.seed(42) 10 | data_dir_path = './data' 11 | model_dir_path = './models' 12 | 13 | print('loading csv file ...') 14 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 15 | # df = df.loc[df.index < 1000] 16 | X = df['text'] 17 | Y = df.title 18 | 19 | config = np.load(OneShotRNN.get_config_file_path(model_dir_path=model_dir_path)).item() 20 | 21 | summarizer = OneShotRNN(config) 22 | summarizer.load_weights(weight_file_path=OneShotRNN.get_weight_file_path(model_dir_path=model_dir_path)) 23 | 24 | print('start predicting ...') 25 | for i in np.random.permutation(np.arange(len(X)))[0:20]: 26 | x = X[i] 27 | actual_headline = Y[i] 28 | headline = summarizer.summarize(x) 29 | # print('Article: ', x) 30 | print('Generated Headline: ', headline) 31 | print('Original Headline: ', actual_headline) 32 | 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /demo/one_hot_rnn_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from sklearn.model_selection import train_test_split 5 | from keras_text_summarization.library.utility.plot_utils import plot_and_save_history 6 | from keras_text_summarization.library.rnn import OneShotRNN 7 | from keras_text_summarization.library.applications.fake_news_loader import fit_text 8 | import numpy as np 9 | 10 | LOAD_EXISTING_WEIGHTS = False 11 | 12 | 13 | def main(): 14 | np.random.seed(42) 15 | data_dir_path = './data' 16 | report_dir_path = './reports' 17 | model_dir_path = './models' 18 | 19 | print('loading csv file ...') 20 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 21 | 22 | print('extract configuration from input texts ...') 23 | # df = df.loc[df.index < 1000] 24 | Y = df.title 25 | X = df['text'] 26 | config = fit_text(X, Y) 27 | 28 | print('configuration extracted from input texts ...') 29 | 30 | summarizer = OneShotRNN(config) 31 | 32 | if LOAD_EXISTING_WEIGHTS: 33 | weight_file_path = OneShotRNN.get_weight_file_path(model_dir_path=model_dir_path) 34 | summarizer.load_weights(weight_file_path=weight_file_path) 35 | 36 | Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.2, random_state=42) 37 | 38 | print('training size: ', len(Xtrain)) 39 | print('testing size: ', len(Xtest)) 40 | 41 | print('start fitting ...') 42 | history = summarizer.fit(Xtrain, Ytrain, Xtest, Ytest, epochs=100, batch_size=20) 43 | 44 | history_plot_file_path = report_dir_path + '/' + OneShotRNN.model_name + '-history.png' 45 | if LOAD_EXISTING_WEIGHTS: 46 | history_plot_file_path = report_dir_path + '/' + OneShotRNN.model_name + '-history-v' + str(summarizer.version) + '.png' 47 | plot_and_save_history(history, summarizer.model_name, history_plot_file_path, metrics={'loss', 'acc'}) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() -------------------------------------------------------------------------------- /demo/recursive_rnn_v1_predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from keras_text_summarization.library.rnn import RecursiveRNN1 5 | import numpy as np 6 | 7 | 8 | def main(): 9 | np.random.seed(42) 10 | data_dir_path = './data' 11 | model_dir_path = './models' 12 | 13 | print('loading csv file ...') 14 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 15 | # df = df.loc[df.index < 1000] 16 | X = df['text'] 17 | Y = df.title 18 | 19 | config = np.load(RecursiveRNN1.get_config_file_path(model_dir_path=model_dir_path)).item() 20 | 21 | summarizer = RecursiveRNN1(config) 22 | summarizer.load_weights(weight_file_path=RecursiveRNN1.get_weight_file_path(model_dir_path=model_dir_path)) 23 | 24 | print('start predicting ...') 25 | for i in np.random.permutation(np.arange(len(X)))[0:20]: 26 | x = X[i] 27 | actual_headline = Y[i] 28 | headline = summarizer.summarize(x) 29 | # print('Article: ', x) 30 | print('Generated Headline: ', headline) 31 | print('Original Headline: ', actual_headline) 32 | 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /demo/recursive_rnn_v1_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from sklearn.model_selection import train_test_split 5 | from keras_text_summarization.library.utility.plot_utils import plot_and_save_history 6 | from keras_text_summarization.library.rnn import RecursiveRNN1 7 | from keras_text_summarization.library.applications.fake_news_loader import fit_text 8 | import numpy as np 9 | 10 | LOAD_EXISTING_WEIGHTS = False 11 | 12 | 13 | def main(): 14 | np.random.seed(42) 15 | data_dir_path = './data' 16 | report_dir_path = './reports' 17 | model_dir_path = './models' 18 | 19 | print('loading csv file ...') 20 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 21 | 22 | # df = df.loc[df.index < 1000] 23 | 24 | print('extract configuration from input texts ...') 25 | Y = df.title 26 | X = df['text'] 27 | config = fit_text(X, Y) 28 | 29 | print('configuration extracted from input texts ...') 30 | 31 | summarizer = RecursiveRNN1(config) 32 | 33 | if LOAD_EXISTING_WEIGHTS: 34 | weight_file_path = RecursiveRNN1.get_weight_file_path(model_dir_path=model_dir_path) 35 | summarizer.load_weights(weight_file_path=weight_file_path) 36 | 37 | Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.2, random_state=42) 38 | 39 | print('demo size: ', len(Xtrain)) 40 | print('testing size: ', len(Xtest)) 41 | 42 | print('start fitting ...') 43 | history = summarizer.fit(Xtrain, Ytrain, Xtest, Ytest, epochs=20) 44 | 45 | history_plot_file_path = report_dir_path + '/' + RecursiveRNN1.model_name + '-history.png' 46 | if LOAD_EXISTING_WEIGHTS: 47 | history_plot_file_path = report_dir_path + '/' + RecursiveRNN1.model_name + '-history-v' + str(summarizer.version) + '.png' 48 | plot_and_save_history(history, summarizer.model_name, history_plot_file_path, metrics={'loss', 'acc'}) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() -------------------------------------------------------------------------------- /demo/recursive_rnn_v2_predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from keras_text_summarization.library.rnn import RecursiveRNN2 5 | import numpy as np 6 | 7 | 8 | def main(): 9 | np.random.seed(42) 10 | data_dir_path = './data' 11 | model_dir_path = './models' 12 | 13 | print('loading csv file ...') 14 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 15 | # df = df.loc[df.index < 1000] 16 | X = df['text'] 17 | Y = df.title 18 | 19 | config = np.load(RecursiveRNN2.get_config_file_path(model_dir_path=model_dir_path)).item() 20 | 21 | summarizer = RecursiveRNN2(config) 22 | summarizer.load_weights(weight_file_path=RecursiveRNN2.get_weight_file_path(model_dir_path=model_dir_path)) 23 | 24 | print('start predicting ...') 25 | for i in np.random.permutation(np.arange(len(X)))[0:20]: 26 | x = X[i] 27 | actual_headline = Y[i] 28 | headline = summarizer.summarize(x) 29 | # print('Article: ', x) 30 | print('Generated Headline: ', headline) 31 | print('Original Headline: ', actual_headline) 32 | 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /demo/recursive_rnn_v2_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from sklearn.model_selection import train_test_split 5 | from keras_text_summarization.library.utility.plot_utils import plot_and_save_history 6 | from keras_text_summarization.library.rnn import RecursiveRNN2 7 | from keras_text_summarization.library.applications.fake_news_loader import fit_text 8 | import numpy as np 9 | 10 | LOAD_EXISTING_WEIGHTS = False 11 | 12 | 13 | def main(): 14 | np.random.seed(42) 15 | data_dir_path = './data' 16 | report_dir_path = './reports' 17 | model_dir_path = './models' 18 | 19 | print('loading csv file ...') 20 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 21 | 22 | # df = df.loc[df.index < 1000] 23 | 24 | print('extract configuration from input texts ...') 25 | Y = df.title 26 | X = df['text'] 27 | config = fit_text(X, Y) 28 | 29 | print('configuration extracted from input texts ...') 30 | 31 | summarizer = RecursiveRNN2(config) 32 | 33 | if LOAD_EXISTING_WEIGHTS: 34 | weight_file_path = RecursiveRNN2.get_weight_file_path(model_dir_path=model_dir_path) 35 | summarizer.load_weights(weight_file_path=weight_file_path) 36 | 37 | Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.2, random_state=42) 38 | 39 | print('demo size: ', len(Xtrain)) 40 | print('testing size: ', len(Xtest)) 41 | 42 | print('start fitting ...') 43 | history = summarizer.fit(Xtrain, Ytrain, Xtest, Ytest, epochs=20, batch_size=256) 44 | 45 | history_plot_file_path = report_dir_path + '/' + RecursiveRNN2.model_name + '-history.png' 46 | if LOAD_EXISTING_WEIGHTS: 47 | history_plot_file_path = report_dir_path + '/' + RecursiveRNN2.model_name + '-history-v' + str(summarizer.version) + '.png' 48 | plot_and_save_history(history, summarizer.model_name, history_plot_file_path, metrics={'loss', 'acc'}) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() -------------------------------------------------------------------------------- /demo/recursive_rnn_v3_predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from keras_text_summarization.library.rnn import RecursiveRNN3 5 | import numpy as np 6 | 7 | 8 | def main(): 9 | np.random.seed(42) 10 | data_dir_path = './data' 11 | model_dir_path = './models' 12 | 13 | print('loading csv file ...') 14 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 15 | # df = df.loc[df.index < 1000] 16 | X = df['text'] 17 | Y = df.title 18 | 19 | config = np.load(RecursiveRNN3.get_config_file_path(model_dir_path=model_dir_path)).item() 20 | 21 | summarizer = RecursiveRNN3(config) 22 | summarizer.load_weights(weight_file_path=RecursiveRNN3.get_weight_file_path(model_dir_path=model_dir_path)) 23 | 24 | print('start predicting ...') 25 | for i in np.random.permutation(np.arange(len(X)))[0:20]: 26 | x = X[i] 27 | actual_headline = Y[i] 28 | headline = summarizer.summarize(x) 29 | # print('Article: ', x) 30 | print('Generated Headline: ', headline) 31 | print('Original Headline: ', actual_headline) 32 | 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /demo/recursive_rnn_v3_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from sklearn.model_selection import train_test_split 5 | from keras_text_summarization.library.utility.plot_utils import plot_and_save_history 6 | from keras_text_summarization.library.rnn import RecursiveRNN3 7 | from keras_text_summarization.library.applications.fake_news_loader import fit_text 8 | import numpy as np 9 | 10 | LOAD_EXISTING_WEIGHTS = False 11 | 12 | 13 | def main(): 14 | np.random.seed(42) 15 | data_dir_path = './data' 16 | report_dir_path = './reports' 17 | model_dir_path = './models' 18 | 19 | print('loading csv file ...') 20 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 21 | 22 | # df = df.loc[df.index < 1000] 23 | 24 | print('extract configuration from input texts ...') 25 | Y = df.title 26 | X = df['text'] 27 | config = fit_text(X, Y) 28 | 29 | print('configuration extracted from input texts ...') 30 | 31 | summarizer = RecursiveRNN3(config) 32 | 33 | if LOAD_EXISTING_WEIGHTS: 34 | weight_file_path = RecursiveRNN3.get_weight_file_path(model_dir_path=model_dir_path) 35 | summarizer.load_weights(weight_file_path=weight_file_path) 36 | 37 | Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.2, random_state=42) 38 | 39 | print('demo size: ', len(Xtrain)) 40 | print('testing size: ', len(Xtest)) 41 | 42 | print('start fitting ...') 43 | history = summarizer.fit(Xtrain, Ytrain, Xtest, Ytest, epochs=20, batch_size=256) 44 | 45 | history_plot_file_path = report_dir_path + '/' + RecursiveRNN3.model_name + '-history.png' 46 | if LOAD_EXISTING_WEIGHTS: 47 | history_plot_file_path = report_dir_path + '/' + RecursiveRNN3.model_name + '-history-v' + str(summarizer.version) + '.png' 48 | plot_and_save_history(history, summarizer.model_name, history_plot_file_path, metrics={'loss', 'acc'}) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() -------------------------------------------------------------------------------- /demo/reports/recursive-rnn-1-history.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/reports/recursive-rnn-1-history.png -------------------------------------------------------------------------------- /demo/reports/recursive-rnn-2-history.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/reports/recursive-rnn-2-history.png -------------------------------------------------------------------------------- /demo/reports/seq2seq-glove-history-v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/reports/seq2seq-glove-history-v1.png -------------------------------------------------------------------------------- /demo/reports/seq2seq-glove-history.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/reports/seq2seq-glove-history.png -------------------------------------------------------------------------------- /demo/reports/seq2seq-glove-v2-history.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/reports/seq2seq-glove-v2-history.png -------------------------------------------------------------------------------- /demo/reports/seq2seq-history-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/reports/seq2seq-history-v2.png -------------------------------------------------------------------------------- /demo/reports/seq2seq-history.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/demo/reports/seq2seq-history.png -------------------------------------------------------------------------------- /demo/seq2seq_glove_predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from keras_text_summarization.library.seq2seq import Seq2SeqGloVeSummarizer 5 | import numpy as np 6 | 7 | 8 | def main(): 9 | np.random.seed(42) 10 | data_dir_path = './data' 11 | very_large_data_dir_path = './very_large_data' 12 | model_dir_path = './models' 13 | 14 | print('loading csv file ...') 15 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 16 | X = df['text'] 17 | Y = df.title 18 | 19 | config = np.load(Seq2SeqGloVeSummarizer.get_config_file_path(model_dir_path=model_dir_path)).item() 20 | 21 | summarizer = Seq2SeqGloVeSummarizer(config) 22 | summarizer.load_glove(very_large_data_dir_path) 23 | summarizer.load_weights(weight_file_path=Seq2SeqGloVeSummarizer.get_weight_file_path(model_dir_path=model_dir_path)) 24 | 25 | print('start predicting ...') 26 | for i in np.random.permutation(np.arange(len(X)))[0:20]: 27 | x = X[i] 28 | actual_headline = Y[i] 29 | headline = summarizer.summarize(x) 30 | 31 | print('Generated Headline: ', headline) 32 | print('Original Headline: ', actual_headline) 33 | # print('Article: ', x[0:100]) 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /demo/seq2seq_glove_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from sklearn.model_selection import train_test_split 5 | from keras_text_summarization.library.utility.plot_utils import plot_and_save_history 6 | from keras_text_summarization.library.seq2seq import Seq2SeqGloVeSummarizer 7 | from keras_text_summarization.library.applications.fake_news_loader import fit_text 8 | import numpy as np 9 | 10 | LOAD_EXISTING_WEIGHTS = False 11 | 12 | 13 | def main(): 14 | np.random.seed(42) 15 | data_dir_path = './data' 16 | very_large_data_dir_path = './very_large_data' 17 | report_dir_path = './reports' 18 | model_dir_path = './models' 19 | 20 | print('loading csv file ...') 21 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 22 | 23 | print('extract configuration from input texts ...') 24 | Y = df.title 25 | X = df['text'] 26 | config = fit_text(X, Y) 27 | 28 | print('configuration extracted from input texts ...') 29 | 30 | summarizer = Seq2SeqGloVeSummarizer(config) 31 | summarizer.load_glove(very_large_data_dir_path) 32 | 33 | if LOAD_EXISTING_WEIGHTS: 34 | summarizer.load_weights(weight_file_path=Seq2SeqGloVeSummarizer.get_weight_file_path(model_dir_path=model_dir_path)) 35 | 36 | Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.2, random_state=42) 37 | 38 | print('training size: ', len(Xtrain)) 39 | print('testing size: ', len(Xtest)) 40 | 41 | print('start fitting ...') 42 | history = summarizer.fit(Xtrain, Ytrain, Xtest, Ytest, epochs=20, batch_size=16) 43 | 44 | history_plot_file_path = report_dir_path + '/' + Seq2SeqGloVeSummarizer.model_name + '-history.png' 45 | if LOAD_EXISTING_WEIGHTS: 46 | history_plot_file_path = report_dir_path + '/' + Seq2SeqGloVeSummarizer.model_name + '-history-v' + str(summarizer.version) + '.png' 47 | plot_and_save_history(history, summarizer.model_name, history_plot_file_path, metrics={'loss', 'acc'}) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() -------------------------------------------------------------------------------- /demo/seq2seq_glove_v2_predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from keras_text_summarization.library.seq2seq import Seq2SeqGloVeSummarizerV2 5 | import numpy as np 6 | 7 | 8 | def main(): 9 | np.random.seed(42) 10 | data_dir_path = './data' 11 | very_large_data_dir_path = './very_large_data' 12 | model_dir_path = './models' 13 | 14 | print('loading csv file ...') 15 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 16 | X = df['text'] 17 | Y = df.title 18 | 19 | config = np.load(Seq2SeqGloVeSummarizerV2.get_config_file_path(model_dir_path=model_dir_path)).item() 20 | 21 | summarizer = Seq2SeqGloVeSummarizerV2(config) 22 | summarizer.load_glove(very_large_data_dir_path) 23 | summarizer.load_weights(weight_file_path=Seq2SeqGloVeSummarizerV2.get_weight_file_path(model_dir_path=model_dir_path)) 24 | 25 | print('start predicting ...') 26 | for i in np.random.permutation(np.arange(len(X)))[0:20]: 27 | x = X[i] 28 | actual_headline = Y[i] 29 | headline = summarizer.summarize(x) 30 | 31 | print('Generated Headline: ', headline) 32 | print('Original Headline: ', actual_headline) 33 | # print('Article: ', x[0:100]) 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /demo/seq2seq_glove_v2_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from sklearn.model_selection import train_test_split 5 | from keras_text_summarization.library.utility.plot_utils import plot_and_save_history 6 | from keras_text_summarization.library.seq2seq import Seq2SeqGloVeSummarizerV2 7 | from keras_text_summarization.library.applications.fake_news_loader import fit_text 8 | import numpy as np 9 | 10 | LOAD_EXISTING_WEIGHTS = False 11 | 12 | 13 | def main(): 14 | np.random.seed(42) 15 | data_dir_path = './data' 16 | very_large_data_dir_path = './very_large_data' 17 | report_dir_path = './reports' 18 | model_dir_path = './models' 19 | 20 | print('loading csv file ...') 21 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 22 | 23 | print('extract configuration from input texts ...') 24 | Y = df.title 25 | X = df['text'] 26 | config = fit_text(X, Y) 27 | 28 | print('configuration extracted from input texts ...') 29 | 30 | summarizer = Seq2SeqGloVeSummarizerV2(config) 31 | summarizer.load_glove(very_large_data_dir_path) 32 | 33 | if LOAD_EXISTING_WEIGHTS: 34 | summarizer.load_weights(weight_file_path=Seq2SeqGloVeSummarizerV2.get_weight_file_path(model_dir_path=model_dir_path)) 35 | 36 | Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.2, random_state=42) 37 | 38 | print('demo size: ', len(Xtrain)) 39 | print('testing size: ', len(Xtest)) 40 | 41 | print('start fitting ...') 42 | history = summarizer.fit(Xtrain, Ytrain, Xtest, Ytest, epochs=20, batch_size=16) 43 | 44 | history_plot_file_path = report_dir_path + '/' + Seq2SeqGloVeSummarizerV2.model_name + '-history.png' 45 | if LOAD_EXISTING_WEIGHTS: 46 | history_plot_file_path = report_dir_path + '/' + Seq2SeqGloVeSummarizerV2.model_name + '-history-v' + str(summarizer.version) + '.png' 47 | plot_and_save_history(history, summarizer.model_name, history_plot_file_path, metrics={'loss', 'acc'}) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() -------------------------------------------------------------------------------- /demo/seq2seq_predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from keras_text_summarization.library.seq2seq import Seq2SeqSummarizer 5 | import numpy as np 6 | 7 | 8 | def main(): 9 | np.random.seed(42) 10 | data_dir_path = './data' 11 | model_dir_path = './models' 12 | 13 | print('loading csv file ...') 14 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 15 | X = df['text'] 16 | Y = df.title 17 | 18 | config = np.load(Seq2SeqSummarizer.get_config_file_path(model_dir_path=model_dir_path)).item() 19 | 20 | summarizer = Seq2SeqSummarizer(config) 21 | summarizer.load_weights(weight_file_path=Seq2SeqSummarizer.get_weight_file_path(model_dir_path=model_dir_path)) 22 | 23 | print('start predicting ...') 24 | for i in np.random.permutation(np.arange(len(X)))[0:20]: 25 | x = X[i] 26 | actual_headline = Y[i] 27 | headline = summarizer.summarize(x) 28 | # print('Article: ', x) 29 | print('Generated Headline: ', headline) 30 | print('Original Headline: ', actual_headline) 31 | 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /demo/seq2seq_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pandas as pd 4 | from sklearn.model_selection import train_test_split 5 | from keras_text_summarization.library.utility.plot_utils import plot_and_save_history 6 | from keras_text_summarization.library.seq2seq import Seq2SeqSummarizer 7 | from keras_text_summarization.library.applications.fake_news_loader import fit_text 8 | import numpy as np 9 | 10 | LOAD_EXISTING_WEIGHTS = False 11 | 12 | 13 | def main(): 14 | np.random.seed(42) 15 | data_dir_path = './data' 16 | report_dir_path = './reports' 17 | model_dir_path = './models' 18 | 19 | print('loading csv file ...') 20 | df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv") 21 | 22 | print('extract configuration from input texts ...') 23 | Y = df.title 24 | X = df['text'] 25 | 26 | config = fit_text(X, Y) 27 | 28 | summarizer = Seq2SeqSummarizer(config) 29 | 30 | if LOAD_EXISTING_WEIGHTS: 31 | summarizer.load_weights(weight_file_path=Seq2SeqSummarizer.get_weight_file_path(model_dir_path=model_dir_path)) 32 | 33 | Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.2, random_state=42) 34 | 35 | print('demo size: ', len(Xtrain)) 36 | print('testing size: ', len(Xtest)) 37 | 38 | print('start fitting ...') 39 | history = summarizer.fit(Xtrain, Ytrain, Xtest, Ytest, epochs=100) 40 | 41 | history_plot_file_path = report_dir_path + '/' + Seq2SeqSummarizer.model_name + '-history.png' 42 | if LOAD_EXISTING_WEIGHTS: 43 | history_plot_file_path = report_dir_path + '/' + Seq2SeqSummarizer.model_name + '-history-v' + str(summarizer.version) + '.png' 44 | plot_and_save_history(history, summarizer.model_name, history_plot_file_path, metrics={'loss', 'acc'}) 45 | 46 | 47 | if __name__ == '__main__': 48 | main() -------------------------------------------------------------------------------- /demo/very_large_data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /keras_text_summarization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/keras_text_summarization/__init__.py -------------------------------------------------------------------------------- /keras_text_summarization/library/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/keras_text_summarization/library/__init__.py -------------------------------------------------------------------------------- /keras_text_summarization/library/applications/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/keras_text_summarization/library/applications/__init__.py -------------------------------------------------------------------------------- /keras_text_summarization/library/applications/fake_news_loader.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | MAX_INPUT_SEQ_LENGTH = 500 4 | MAX_TARGET_SEQ_LENGTH = 50 5 | MAX_INPUT_VOCAB_SIZE = 5000 6 | MAX_TARGET_VOCAB_SIZE = 2000 7 | 8 | 9 | def fit_text(X, Y, input_seq_max_length=None, target_seq_max_length=None): 10 | if input_seq_max_length is None: 11 | input_seq_max_length = MAX_INPUT_SEQ_LENGTH 12 | if target_seq_max_length is None: 13 | target_seq_max_length = MAX_TARGET_SEQ_LENGTH 14 | input_counter = Counter() 15 | target_counter = Counter() 16 | max_input_seq_length = 0 17 | max_target_seq_length = 0 18 | 19 | for line in X: 20 | text = [word.lower() for word in line.split(' ')] 21 | seq_length = len(text) 22 | if seq_length > input_seq_max_length: 23 | text = text[0:input_seq_max_length] 24 | seq_length = len(text) 25 | for word in text: 26 | input_counter[word] += 1 27 | max_input_seq_length = max(max_input_seq_length, seq_length) 28 | 29 | for line in Y: 30 | line2 = 'START ' + line.lower() + ' END' 31 | text = [word for word in line2.split(' ')] 32 | seq_length = len(text) 33 | if seq_length > target_seq_max_length: 34 | text = text[0:target_seq_max_length] 35 | seq_length = len(text) 36 | for word in text: 37 | target_counter[word] += 1 38 | max_target_seq_length = max(max_target_seq_length, seq_length) 39 | 40 | input_word2idx = dict() 41 | for idx, word in enumerate(input_counter.most_common(MAX_INPUT_VOCAB_SIZE)): 42 | input_word2idx[word[0]] = idx + 2 43 | input_word2idx['PAD'] = 0 44 | input_word2idx['UNK'] = 1 45 | input_idx2word = dict([(idx, word) for word, idx in input_word2idx.items()]) 46 | 47 | target_word2idx = dict() 48 | for idx, word in enumerate(target_counter.most_common(MAX_TARGET_VOCAB_SIZE)): 49 | target_word2idx[word[0]] = idx + 1 50 | target_word2idx['UNK'] = 0 51 | 52 | target_idx2word = dict([(idx, word) for word, idx in target_word2idx.items()]) 53 | 54 | num_input_tokens = len(input_word2idx) 55 | num_target_tokens = len(target_word2idx) 56 | 57 | config = dict() 58 | config['input_word2idx'] = input_word2idx 59 | config['input_idx2word'] = input_idx2word 60 | config['target_word2idx'] = target_word2idx 61 | config['target_idx2word'] = target_idx2word 62 | config['num_input_tokens'] = num_input_tokens 63 | config['num_target_tokens'] = num_target_tokens 64 | config['max_input_seq_length'] = max_input_seq_length 65 | config['max_target_seq_length'] = max_target_seq_length 66 | 67 | return config 68 | -------------------------------------------------------------------------------- /keras_text_summarization/library/rnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from keras.models import Model, Sequential 4 | from keras.layers import Embedding, Dense, Input, RepeatVector, TimeDistributed, concatenate, Merge, add, Dropout 5 | from keras.layers.recurrent import LSTM 6 | from keras.preprocessing.sequence import pad_sequences 7 | from keras.callbacks import ModelCheckpoint 8 | import numpy as np 9 | import os 10 | 11 | HIDDEN_UNITS = 100 12 | DEFAULT_BATCH_SIZE = 64 13 | VERBOSE = 1 14 | DEFAULT_EPOCHS = 10 15 | 16 | 17 | class OneShotRNN(object): 18 | model_name = 'one-shot-rnn' 19 | """ 20 | The first alternative model is to generate the entire output sequence in a one-shot manner. 21 | That is, the decoder uses the context vector alone to generate the output sequence. 22 | 23 | This model puts a heavy burden on the decoder. 24 | It is likely that the decoder will not have sufficient context for generating a coherent output sequence as it 25 | must choose the words and their order. 26 | """ 27 | 28 | def __init__(self, config): 29 | self.num_input_tokens = config['num_input_tokens'] 30 | self.max_input_seq_length = config['max_input_seq_length'] 31 | self.num_target_tokens = config['num_target_tokens'] 32 | self.max_target_seq_length = config['max_target_seq_length'] 33 | self.input_word2idx = config['input_word2idx'] 34 | self.input_idx2word = config['input_idx2word'] 35 | self.target_word2idx = config['target_word2idx'] 36 | self.target_idx2word = config['target_idx2word'] 37 | self.config = config 38 | self.version = 0 39 | if 'version' in config: 40 | self.version = config['version'] 41 | 42 | print('max_input_seq_length', self.max_input_seq_length) 43 | print('max_target_seq_length', self.max_target_seq_length) 44 | print('num_input_tokens', self.num_input_tokens) 45 | print('num_target_tokens', self.num_target_tokens) 46 | 47 | # encoder input model 48 | model = Sequential() 49 | model.add(Embedding(output_dim=128, input_dim=self.num_input_tokens, input_length=self.max_input_seq_length)) 50 | 51 | # encoder model 52 | model.add(LSTM(128)) 53 | model.add(RepeatVector(self.max_target_seq_length)) 54 | # decoder model 55 | model.add(LSTM(128, return_sequences=True)) 56 | model.add(TimeDistributed(Dense(self.num_target_tokens, activation='softmax'))) 57 | 58 | model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) 59 | 60 | self.model = model 61 | 62 | def load_weights(self, weight_file_path): 63 | if os.path.exists(weight_file_path): 64 | self.model.load_weights(weight_file_path) 65 | 66 | def transform_input_text(self, texts): 67 | temp = [] 68 | for line in texts: 69 | x = [] 70 | for word in line.lower().split(' '): 71 | wid = 1 72 | if word in self.input_word2idx: 73 | wid = self.input_word2idx[word] 74 | x.append(wid) 75 | if len(x) >= self.max_input_seq_length: 76 | break 77 | temp.append(x) 78 | temp = pad_sequences(temp, maxlen=self.max_input_seq_length) 79 | 80 | print(temp.shape) 81 | return temp 82 | 83 | def transform_target_encoding(self, texts): 84 | temp = [] 85 | for line in texts: 86 | x = [] 87 | line2 = 'START ' + line.lower() + ' END' 88 | for word in line2.split(' '): 89 | x.append(word) 90 | if len(x) >= self.max_target_seq_length: 91 | break 92 | temp.append(x) 93 | 94 | temp = np.array(temp) 95 | print(temp.shape) 96 | return temp 97 | 98 | def generate_batch(self, x_samples, y_samples, batch_size): 99 | num_batches = len(x_samples) // batch_size 100 | while True: 101 | for batchIdx in range(0, num_batches): 102 | start = batchIdx * batch_size 103 | end = (batchIdx + 1) * batch_size 104 | encoder_input_data_batch = pad_sequences(x_samples[start:end], self.max_input_seq_length) 105 | decoder_target_data_batch = np.zeros( 106 | shape=(batch_size, self.max_target_seq_length, self.num_target_tokens)) 107 | for lineIdx, target_words in enumerate(y_samples[start:end]): 108 | for idx, w in enumerate(target_words): 109 | w2idx = 0 # default [UNK] 110 | if w in self.target_word2idx: 111 | w2idx = self.target_word2idx[w] 112 | if w2idx != 0: 113 | decoder_target_data_batch[lineIdx, idx, w2idx] = 1 114 | yield encoder_input_data_batch, decoder_target_data_batch 115 | 116 | @staticmethod 117 | def get_weight_file_path(model_dir_path): 118 | return model_dir_path + '/' + OneShotRNN.model_name + '-weights.h5' 119 | 120 | @staticmethod 121 | def get_config_file_path(model_dir_path): 122 | return model_dir_path + '/' + OneShotRNN.model_name + '-config.npy' 123 | 124 | @staticmethod 125 | def get_architecture_file_path(model_dir_path): 126 | return model_dir_path + '/' + OneShotRNN.model_name + '-architecture.json' 127 | 128 | def fit(self, Xtrain, Ytrain, Xtest, Ytest, epochs=None, model_dir_path=None, batch_size=None): 129 | if epochs is None: 130 | epochs = DEFAULT_EPOCHS 131 | if model_dir_path is None: 132 | model_dir_path = './models' 133 | if batch_size is None: 134 | batch_size = DEFAULT_BATCH_SIZE 135 | 136 | self.version += 1 137 | self.config['version'] = self.version 138 | 139 | config_file_path = OneShotRNN.get_config_file_path(model_dir_path) 140 | weight_file_path = OneShotRNN.get_weight_file_path(model_dir_path) 141 | checkpoint = ModelCheckpoint(weight_file_path) 142 | np.save(config_file_path, self.config) 143 | architecture_file_path = OneShotRNN.get_architecture_file_path(model_dir_path) 144 | open(architecture_file_path, 'w').write(self.model.to_json()) 145 | 146 | Ytrain = self.transform_target_encoding(Ytrain) 147 | Ytest = self.transform_target_encoding(Ytest) 148 | 149 | Xtrain = self.transform_input_text(Xtrain) 150 | Xtest = self.transform_input_text(Xtest) 151 | 152 | train_gen = self.generate_batch(Xtrain, Ytrain, batch_size) 153 | test_gen = self.generate_batch(Xtest, Ytest, batch_size) 154 | 155 | train_num_batches = len(Xtrain) // batch_size 156 | test_num_batches = len(Xtest) // batch_size 157 | 158 | history = self.model.fit_generator(generator=train_gen, steps_per_epoch=train_num_batches, 159 | epochs=epochs, 160 | verbose=VERBOSE, validation_data=test_gen, validation_steps=test_num_batches, 161 | callbacks=[checkpoint]) 162 | self.model.save_weights(weight_file_path) 163 | return history 164 | 165 | def summarize(self, input_text): 166 | input_seq = [] 167 | input_wids = [] 168 | for word in input_text.lower().split(' '): 169 | idx = 1 # default [UNK] 170 | if word in self.input_word2idx: 171 | idx = self.input_word2idx[word] 172 | input_wids.append(idx) 173 | input_seq.append(input_wids) 174 | input_seq = pad_sequences(input_seq, self.max_input_seq_length) 175 | predicted = self.model.predict(input_seq) 176 | predicted_word_idx_list = np.argmax(predicted, axis=1) 177 | predicted_word_list = [self.target_idx2word[wid] for wid in predicted_word_idx_list[0]] 178 | return predicted_word_list 179 | 180 | 181 | class RecursiveRNN1(object): 182 | model_name = 'recursive-rnn-1' 183 | """ 184 | A second alternative model is to develop a model that generates a single word forecast and call it recursively. 185 | 186 | That is, the decoder uses the context vector and the distributed representation of all words generated so far as 187 | input in order to generate the next word. 188 | 189 | A language model can be used to interpret the sequence of words generated so far to provide a second context vector 190 | to combine with the representation of the source document in order to generate the next word in the sequence. 191 | 192 | The summary is built up by recursively calling the model with the previously generated word appended (or, more 193 | specifically, the expected previous word during training). 194 | 195 | The context vectors could be concentrated or added together to provide a broader context for the decoder to 196 | interpret and output the next word. 197 | """ 198 | 199 | def __init__(self, config): 200 | self.num_input_tokens = config['num_input_tokens'] 201 | self.max_input_seq_length = config['max_input_seq_length'] 202 | self.num_target_tokens = config['num_target_tokens'] 203 | self.max_target_seq_length = config['max_target_seq_length'] 204 | self.input_word2idx = config['input_word2idx'] 205 | self.input_idx2word = config['input_idx2word'] 206 | self.target_word2idx = config['target_word2idx'] 207 | self.target_idx2word = config['target_idx2word'] 208 | if 'version' in config: 209 | self.version = config['version'] 210 | else: 211 | self.version = 0 212 | self.config = config 213 | 214 | print('max_input_seq_length', self.max_input_seq_length) 215 | print('max_target_seq_length', self.max_target_seq_length) 216 | print('num_input_tokens', self.num_input_tokens) 217 | print('num_target_tokens', self.num_target_tokens) 218 | 219 | inputs1 = Input(shape=(self.max_input_seq_length,)) 220 | am1 = Embedding(self.num_input_tokens, 128)(inputs1) 221 | am2 = LSTM(128)(am1) 222 | 223 | inputs2 = Input(shape=(self.max_target_seq_length,)) 224 | sm1 = Embedding(self.num_target_tokens, 128)(inputs2) 225 | sm2 = LSTM(128)(sm1) 226 | 227 | decoder1 = concatenate([am2, sm2]) 228 | outputs = Dense(self.num_target_tokens, activation='softmax')(decoder1) 229 | 230 | model = Model(inputs=[inputs1, inputs2], outputs=outputs) 231 | 232 | model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) 233 | self.model = model 234 | 235 | def load_weights(self, weight_file_path): 236 | if os.path.exists(weight_file_path): 237 | self.model.load_weights(weight_file_path) 238 | 239 | def transform_input_text(self, texts): 240 | temp = [] 241 | for line in texts: 242 | x = [] 243 | for word in line.lower().split(' '): 244 | wid = 1 245 | if word in self.input_word2idx: 246 | wid = self.input_word2idx[word] 247 | x.append(wid) 248 | if len(x) >= self.max_input_seq_length: 249 | break 250 | temp.append(x) 251 | temp = pad_sequences(temp, maxlen=self.max_input_seq_length) 252 | 253 | print(temp.shape) 254 | return temp 255 | 256 | def split_target_text(self, texts): 257 | temp = [] 258 | for line in texts: 259 | x = [] 260 | line2 = 'START ' + line.lower() + ' END' 261 | for word in line2.split(' '): 262 | x.append(word) 263 | if len(x)+1 >= self.max_target_seq_length: 264 | x.append('END') 265 | break 266 | temp.append(x) 267 | return temp 268 | 269 | def generate_batch(self, x_samples, y_samples, batch_size): 270 | encoder_input_data_batch = [] 271 | decoder_input_data_batch = [] 272 | decoder_target_data_batch = [] 273 | line_idx = 0 274 | while True: 275 | for recordIdx in range(0, len(x_samples)): 276 | target_words = y_samples[recordIdx] 277 | x = x_samples[recordIdx] 278 | decoder_input_line = [] 279 | 280 | for idx in range(0, len(target_words)-1): 281 | w2idx = 0 # default [UNK] 282 | w = target_words[idx] 283 | if w in self.target_word2idx: 284 | w2idx = self.target_word2idx[w] 285 | decoder_input_line = decoder_input_line + [w2idx] 286 | decoder_target_label = np.zeros(self.num_target_tokens) 287 | w2idx_next = 0 288 | if target_words[idx+1] in self.target_word2idx: 289 | w2idx_next = self.target_word2idx[target_words[idx+1]] 290 | if w2idx_next != 0: 291 | decoder_target_label[w2idx_next] = 1 292 | decoder_input_data_batch.append(decoder_input_line) 293 | encoder_input_data_batch.append(x) 294 | decoder_target_data_batch.append(decoder_target_label) 295 | 296 | line_idx += 1 297 | if line_idx >= batch_size: 298 | yield [pad_sequences(encoder_input_data_batch, self.max_input_seq_length), 299 | pad_sequences(decoder_input_data_batch, 300 | self.max_target_seq_length)], np.array(decoder_target_data_batch) 301 | line_idx = 0 302 | encoder_input_data_batch = [] 303 | decoder_input_data_batch = [] 304 | decoder_target_data_batch = [] 305 | 306 | @staticmethod 307 | def get_weight_file_path(model_dir_path): 308 | return model_dir_path + '/' + RecursiveRNN1.model_name + '-weights.h5' 309 | 310 | @staticmethod 311 | def get_config_file_path(model_dir_path): 312 | return model_dir_path + '/' + RecursiveRNN1.model_name + '-config.npy' 313 | 314 | @staticmethod 315 | def get_architecture_file_path(model_dir_path): 316 | return model_dir_path + '/' + RecursiveRNN1.model_name + '-architecture.json' 317 | 318 | def fit(self, Xtrain, Ytrain, Xtest, Ytest, epochs=None, model_dir_path=None, batch_size=None): 319 | if epochs is None: 320 | epochs = DEFAULT_EPOCHS 321 | if model_dir_path is None: 322 | model_dir_path = './models' 323 | if batch_size is None: 324 | batch_size = DEFAULT_BATCH_SIZE 325 | 326 | self.version += 1 327 | self.config['version'] = self.version 328 | 329 | config_file_path = RecursiveRNN1.get_config_file_path(model_dir_path) 330 | weight_file_path = RecursiveRNN1.get_weight_file_path(model_dir_path) 331 | checkpoint = ModelCheckpoint(weight_file_path) 332 | np.save(config_file_path, self.config) 333 | architecture_file_path = RecursiveRNN1.get_architecture_file_path(model_dir_path) 334 | open(architecture_file_path, 'w').write(self.model.to_json()) 335 | 336 | Ytrain = self.split_target_text(Ytrain) 337 | Ytest = self.split_target_text(Ytest) 338 | 339 | Xtrain = self.transform_input_text(Xtrain) 340 | Xtest = self.transform_input_text(Xtest) 341 | 342 | train_gen = self.generate_batch(Xtrain, Ytrain, batch_size) 343 | test_gen = self.generate_batch(Xtest, Ytest, batch_size) 344 | 345 | total_training_samples = sum([len(target_text)-1 for target_text in Ytrain]) 346 | total_testing_samples = sum([len(target_text)-1 for target_text in Ytest]) 347 | train_num_batches = total_training_samples // batch_size 348 | test_num_batches = total_testing_samples // batch_size 349 | 350 | history = self.model.fit_generator(generator=train_gen, steps_per_epoch=train_num_batches, 351 | epochs=epochs, 352 | verbose=VERBOSE, validation_data=test_gen, validation_steps=test_num_batches, 353 | callbacks=[checkpoint]) 354 | self.model.save_weights(weight_file_path) 355 | return history 356 | 357 | def summarize(self, input_text): 358 | input_seq = [] 359 | input_wids = [] 360 | for word in input_text.lower().split(' '): 361 | idx = 1 # default [UNK] 362 | if word in self.input_word2idx: 363 | idx = self.input_word2idx[word] 364 | input_wids.append(idx) 365 | input_seq.append(input_wids) 366 | input_seq = pad_sequences(input_seq, self.max_input_seq_length) 367 | start_token = self.target_word2idx['START'] 368 | wid_list = [start_token] 369 | sum_input_seq = pad_sequences([wid_list], self.max_target_seq_length) 370 | terminated = False 371 | 372 | target_text = '' 373 | 374 | while not terminated: 375 | output_tokens = self.model.predict([input_seq, sum_input_seq]) 376 | sample_token_idx = np.argmax(output_tokens[0, :]) 377 | sample_word = self.target_idx2word[sample_token_idx] 378 | wid_list = wid_list + [sample_token_idx] 379 | 380 | if sample_word != 'START' and sample_word != 'END': 381 | target_text += ' ' + sample_word 382 | 383 | if sample_word == 'END' or len(wid_list) >= self.max_target_seq_length: 384 | terminated = True 385 | else: 386 | sum_input_seq = pad_sequences([wid_list], self.max_target_seq_length) 387 | return target_text.strip() 388 | 389 | 390 | class RecursiveRNN2(object): 391 | model_name = 'recursive-rnn-2' 392 | """ 393 | In this third alternative, the Encoder generates a context vector representation of the source document. 394 | 395 | This document is fed to the decoder at each step of the generated output sequence. This allows the decoder to build 396 | up the same internal state as was used to generate the words in the output sequence so that it is primed to generate 397 | the next word in the sequence. 398 | 399 | This process is then repeated by calling the model again and again for each word in the output sequence until a 400 | maximum length or end-of-sequence token is generated. 401 | """ 402 | 403 | MAX_DECODER_SEQ_LENGTH = 4 404 | 405 | def __init__(self, config): 406 | self.num_input_tokens = config['num_input_tokens'] 407 | self.max_input_seq_length = config['max_input_seq_length'] 408 | self.num_target_tokens = config['num_target_tokens'] 409 | self.max_target_seq_length = config['max_target_seq_length'] 410 | self.input_word2idx = config['input_word2idx'] 411 | self.input_idx2word = config['input_idx2word'] 412 | self.target_word2idx = config['target_word2idx'] 413 | self.target_idx2word = config['target_idx2word'] 414 | self.config = config 415 | 416 | self.version = 0 417 | if 'version' in config: 418 | self.version = config['version'] 419 | 420 | # article input model 421 | inputs1 = Input(shape=(self.max_input_seq_length,)) 422 | article1 = Embedding(self.num_input_tokens, 128)(inputs1) 423 | article2 = Dropout(0.3)(article1) 424 | 425 | # summary input model 426 | inputs2 = Input(shape=(min(self.num_target_tokens, RecursiveRNN2.MAX_DECODER_SEQ_LENGTH), )) 427 | summ1 = Embedding(self.num_target_tokens, 128)(inputs2) 428 | summ2 = Dropout(0.3)(summ1) 429 | summ3 = LSTM(128)(summ2) 430 | summ4 = RepeatVector(self.max_input_seq_length)(summ3) 431 | 432 | # decoder model 433 | decoder1 = concatenate([article2, summ4]) 434 | decoder2 = LSTM(128)(decoder1) 435 | outputs = Dense(self.num_target_tokens, activation='softmax')(decoder2) 436 | # tie it together [article, summary] [word] 437 | model = Model(inputs=[inputs1, inputs2], outputs=outputs) 438 | model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) 439 | 440 | print(model.summary()) 441 | 442 | self.model = model 443 | 444 | def load_weights(self, weight_file_path): 445 | if os.path.exists(weight_file_path): 446 | print('loading weights from ', weight_file_path) 447 | self.model.load_weights(weight_file_path) 448 | 449 | def transform_input_text(self, texts): 450 | temp = [] 451 | for line in texts: 452 | x = [] 453 | for word in line.lower().split(' '): 454 | wid = 1 455 | if word in self.input_word2idx: 456 | wid = self.input_word2idx[word] 457 | x.append(wid) 458 | if len(x) >= self.max_input_seq_length: 459 | break 460 | temp.append(x) 461 | temp = pad_sequences(temp, maxlen=self.max_input_seq_length) 462 | 463 | print(temp.shape) 464 | return temp 465 | 466 | def split_target_text(self, texts): 467 | temp = [] 468 | for line in texts: 469 | x = [] 470 | line2 = 'START ' + line.lower() + ' END' 471 | for word in line2.split(' '): 472 | x.append(word) 473 | if len(x)+1 >= self.max_target_seq_length: 474 | x.append('END') 475 | break 476 | temp.append(x) 477 | return temp 478 | 479 | def generate_batch(self, x_samples, y_samples, batch_size): 480 | encoder_input_data_batch = [] 481 | decoder_input_data_batch = [] 482 | decoder_target_data_batch = [] 483 | line_idx = 0 484 | while True: 485 | for recordIdx in range(0, len(x_samples)): 486 | target_words = y_samples[recordIdx] 487 | x = x_samples[recordIdx] 488 | decoder_input_line = [] 489 | 490 | for idx in range(0, len(target_words)-1): 491 | w2idx = 0 # default [UNK] 492 | w = target_words[idx] 493 | if w in self.target_word2idx: 494 | w2idx = self.target_word2idx[w] 495 | decoder_input_line = decoder_input_line + [w2idx] 496 | decoder_target_label = np.zeros(self.num_target_tokens) 497 | w2idx_next = 0 498 | if target_words[idx+1] in self.target_word2idx: 499 | w2idx_next = self.target_word2idx[target_words[idx+1]] 500 | if w2idx_next != 0: 501 | decoder_target_label[w2idx_next] = 1 502 | 503 | decoder_input_data_batch.append(decoder_input_line) 504 | encoder_input_data_batch.append(x) 505 | decoder_target_data_batch.append(decoder_target_label) 506 | 507 | line_idx += 1 508 | if line_idx >= batch_size: 509 | yield [pad_sequences(encoder_input_data_batch, self.max_input_seq_length), 510 | pad_sequences(decoder_input_data_batch, 511 | min(self.num_target_tokens, RecursiveRNN2.MAX_DECODER_SEQ_LENGTH))], np.array(decoder_target_data_batch) 512 | line_idx = 0 513 | encoder_input_data_batch = [] 514 | decoder_input_data_batch = [] 515 | decoder_target_data_batch = [] 516 | 517 | @staticmethod 518 | def get_weight_file_path(model_dir_path): 519 | return model_dir_path + '/' + RecursiveRNN2.model_name + '-weights.h5' 520 | 521 | @staticmethod 522 | def get_config_file_path(model_dir_path): 523 | return model_dir_path + '/' + RecursiveRNN2.model_name + '-config.npy' 524 | 525 | @staticmethod 526 | def get_architecture_file_path(model_dir_path): 527 | return model_dir_path + '/' + RecursiveRNN2.model_name + '-architecture.json' 528 | 529 | def fit(self, Xtrain, Ytrain, Xtest, Ytest, epochs=None, model_dir_path=None, batch_size=None): 530 | if epochs is None: 531 | epochs = DEFAULT_EPOCHS 532 | if model_dir_path is None: 533 | model_dir_path = './models' 534 | if batch_size is None: 535 | batch_size = DEFAULT_BATCH_SIZE 536 | 537 | self.version += 1 538 | self.config['version'] = self.version 539 | 540 | config_file_path = RecursiveRNN2.get_config_file_path(model_dir_path) 541 | weight_file_path = RecursiveRNN2.get_weight_file_path(model_dir_path) 542 | checkpoint = ModelCheckpoint(weight_file_path) 543 | np.save(config_file_path, self.config) 544 | architecture_file_path = RecursiveRNN2.get_architecture_file_path(model_dir_path) 545 | open(architecture_file_path, 'w').write(self.model.to_json()) 546 | 547 | Ytrain = self.split_target_text(Ytrain) 548 | Ytest = self.split_target_text(Ytest) 549 | 550 | Xtrain = self.transform_input_text(Xtrain) 551 | Xtest = self.transform_input_text(Xtest) 552 | 553 | train_gen = self.generate_batch(Xtrain, Ytrain, batch_size) 554 | test_gen = self.generate_batch(Xtest, Ytest, batch_size) 555 | 556 | total_training_samples = sum([len(target_text)-1 for target_text in Ytrain]) 557 | total_testing_samples = sum([len(target_text)-1 for target_text in Ytest]) 558 | train_num_batches = total_training_samples // batch_size 559 | test_num_batches = total_testing_samples // batch_size 560 | 561 | history = self.model.fit_generator(generator=train_gen, steps_per_epoch=train_num_batches, 562 | epochs=epochs, 563 | verbose=VERBOSE, validation_data=test_gen, validation_steps=test_num_batches, 564 | callbacks=[checkpoint]) 565 | self.model.save_weights(weight_file_path) 566 | return history 567 | 568 | def summarize(self, input_text): 569 | input_seq = [] 570 | input_wids = [] 571 | for word in input_text.lower().split(' '): 572 | idx = 1 # default [UNK] 573 | if word in self.input_word2idx: 574 | idx = self.input_word2idx[word] 575 | input_wids.append(idx) 576 | input_seq.append(input_wids) 577 | input_seq = pad_sequences(input_seq, self.max_input_seq_length) 578 | start_token = self.target_word2idx['START'] 579 | wid_list = [start_token] 580 | sum_input_seq = pad_sequences([wid_list], min(self.num_target_tokens, RecursiveRNN2.MAX_DECODER_SEQ_LENGTH)) 581 | terminated = False 582 | 583 | target_text = '' 584 | 585 | while not terminated: 586 | output_tokens = self.model.predict([input_seq, sum_input_seq]) 587 | sample_token_idx = np.argmax(output_tokens[0, :]) 588 | sample_word = self.target_idx2word[sample_token_idx] 589 | wid_list = wid_list + [sample_token_idx] 590 | 591 | if sample_word != 'START' and sample_word != 'END': 592 | target_text += ' ' + sample_word 593 | 594 | if sample_word == 'END' or len(wid_list) >= self.max_target_seq_length: 595 | terminated = True 596 | else: 597 | sum_input_seq = pad_sequences([wid_list], min(self.num_target_tokens, RecursiveRNN2.MAX_DECODER_SEQ_LENGTH)) 598 | return target_text.strip() 599 | 600 | 601 | class RecursiveRNN3(object): 602 | model_name = 'recursive-rnn-3' 603 | """ 604 | In this third alternative, the Encoder generates a context vector representation of the source document. 605 | 606 | This document is fed to the decoder at each step of the generated output sequence. This allows the decoder to build 607 | up the same internal state as was used to generate the words in the output sequence so that it is primed to generate 608 | the next word in the sequence. 609 | 610 | This process is then repeated by calling the model again and again for each word in the output sequence until a 611 | maximum length or end-of-sequence token is generated. 612 | """ 613 | 614 | def __init__(self, config): 615 | self.num_input_tokens = config['num_input_tokens'] 616 | self.max_input_seq_length = config['max_input_seq_length'] 617 | self.num_target_tokens = config['num_target_tokens'] 618 | self.max_target_seq_length = config['max_target_seq_length'] 619 | self.input_word2idx = config['input_word2idx'] 620 | self.input_idx2word = config['input_idx2word'] 621 | self.target_word2idx = config['target_word2idx'] 622 | self.target_idx2word = config['target_idx2word'] 623 | self.config = config 624 | 625 | self.version = 0 626 | if 'version' in config: 627 | self.version = config['version'] 628 | 629 | # article input model 630 | inputs1 = Input(shape=(self.max_input_seq_length,)) 631 | article1 = Embedding(self.num_input_tokens, 128)(inputs1) 632 | article2 = LSTM(128)(article1) 633 | article3 = RepeatVector(128)(article2) 634 | # summary input model 635 | inputs2 = Input(shape=(self.max_target_seq_length,)) 636 | summ1 = Embedding(self.num_target_tokens, 128)(inputs2) 637 | summ2 = LSTM(128)(summ1) 638 | summ3 = RepeatVector(128)(summ2) 639 | # decoder model 640 | decoder1 = concatenate([article3, summ3]) 641 | decoder2 = LSTM(128)(decoder1) 642 | outputs = Dense(self.num_target_tokens, activation='softmax')(decoder2) 643 | # tie it together [article, summary] [word] 644 | model = Model(inputs=[inputs1, inputs2], outputs=outputs) 645 | model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) 646 | 647 | print(model.summary()) 648 | 649 | self.model = model 650 | 651 | def load_weights(self, weight_file_path): 652 | if os.path.exists(weight_file_path): 653 | print('loading weights from ', weight_file_path) 654 | self.model.load_weights(weight_file_path) 655 | 656 | def transform_input_text(self, texts): 657 | temp = [] 658 | for line in texts: 659 | x = [] 660 | for word in line.lower().split(' '): 661 | wid = 1 662 | if word in self.input_word2idx: 663 | wid = self.input_word2idx[word] 664 | x.append(wid) 665 | if len(x) >= self.max_input_seq_length: 666 | break 667 | temp.append(x) 668 | temp = pad_sequences(temp, maxlen=self.max_input_seq_length) 669 | 670 | print(temp.shape) 671 | return temp 672 | 673 | def split_target_text(self, texts): 674 | temp = [] 675 | for line in texts: 676 | x = [] 677 | line2 = 'START ' + line.lower() + ' END' 678 | for word in line2.split(' '): 679 | x.append(word) 680 | if len(x)+1 >= self.max_target_seq_length: 681 | x.append('END') 682 | break 683 | temp.append(x) 684 | return temp 685 | 686 | def generate_batch(self, x_samples, y_samples, batch_size): 687 | encoder_input_data_batch = [] 688 | decoder_input_data_batch = [] 689 | decoder_target_data_batch = [] 690 | line_idx = 0 691 | while True: 692 | for recordIdx in range(0, len(x_samples)): 693 | target_words = y_samples[recordIdx] 694 | x = x_samples[recordIdx] 695 | decoder_input_line = [] 696 | 697 | for idx in range(0, len(target_words)-1): 698 | w2idx = 0 # default [UNK] 699 | w = target_words[idx] 700 | if w in self.target_word2idx: 701 | w2idx = self.target_word2idx[w] 702 | decoder_input_line = decoder_input_line + [w2idx] 703 | decoder_target_label = np.zeros(self.num_target_tokens) 704 | w2idx_next = 0 705 | if target_words[idx+1] in self.target_word2idx: 706 | w2idx_next = self.target_word2idx[target_words[idx+1]] 707 | if w2idx_next != 0: 708 | decoder_target_label[w2idx_next] = 1 709 | 710 | decoder_input_data_batch.append(decoder_input_line) 711 | encoder_input_data_batch.append(x) 712 | decoder_target_data_batch.append(decoder_target_label) 713 | 714 | line_idx += 1 715 | if line_idx >= batch_size: 716 | yield [pad_sequences(encoder_input_data_batch, self.max_input_seq_length), 717 | pad_sequences(decoder_input_data_batch, 718 | self.max_target_seq_length)], np.array(decoder_target_data_batch) 719 | line_idx = 0 720 | encoder_input_data_batch = [] 721 | decoder_input_data_batch = [] 722 | decoder_target_data_batch = [] 723 | 724 | @staticmethod 725 | def get_weight_file_path(model_dir_path): 726 | return model_dir_path + '/' + RecursiveRNN2.model_name + '-weights.h5' 727 | 728 | @staticmethod 729 | def get_config_file_path(model_dir_path): 730 | return model_dir_path + '/' + RecursiveRNN2.model_name + '-config.npy' 731 | 732 | @staticmethod 733 | def get_architecture_file_path(model_dir_path): 734 | return model_dir_path + '/' + RecursiveRNN2.model_name + '-architecture.json' 735 | 736 | def fit(self, Xtrain, Ytrain, Xtest, Ytest, epochs=None, model_dir_path=None, batch_size=None): 737 | if epochs is None: 738 | epochs = DEFAULT_EPOCHS 739 | if model_dir_path is None: 740 | model_dir_path = './models' 741 | if batch_size is None: 742 | batch_size = DEFAULT_BATCH_SIZE 743 | 744 | self.version += 1 745 | self.config['version'] = self.version 746 | 747 | config_file_path = RecursiveRNN2.get_config_file_path(model_dir_path) 748 | weight_file_path = RecursiveRNN2.get_weight_file_path(model_dir_path) 749 | checkpoint = ModelCheckpoint(weight_file_path) 750 | np.save(config_file_path, self.config) 751 | architecture_file_path = RecursiveRNN2.get_architecture_file_path(model_dir_path) 752 | open(architecture_file_path, 'w').write(self.model.to_json()) 753 | 754 | Ytrain = self.split_target_text(Ytrain) 755 | Ytest = self.split_target_text(Ytest) 756 | 757 | Xtrain = self.transform_input_text(Xtrain) 758 | Xtest = self.transform_input_text(Xtest) 759 | 760 | train_gen = self.generate_batch(Xtrain, Ytrain, batch_size) 761 | test_gen = self.generate_batch(Xtest, Ytest, batch_size) 762 | 763 | total_training_samples = sum([len(target_text)-1 for target_text in Ytrain]) 764 | total_testing_samples = sum([len(target_text)-1 for target_text in Ytest]) 765 | train_num_batches = total_training_samples // batch_size 766 | test_num_batches = total_testing_samples // batch_size 767 | 768 | history = self.model.fit_generator(generator=train_gen, steps_per_epoch=train_num_batches, 769 | epochs=epochs, 770 | verbose=VERBOSE, validation_data=test_gen, validation_steps=test_num_batches, 771 | callbacks=[checkpoint]) 772 | self.model.save_weights(weight_file_path) 773 | return history 774 | 775 | def summarize(self, input_text): 776 | input_seq = [] 777 | input_wids = [] 778 | for word in input_text.lower().split(' '): 779 | idx = 1 # default [UNK] 780 | if word in self.input_word2idx: 781 | idx = self.input_word2idx[word] 782 | input_wids.append(idx) 783 | input_seq.append(input_wids) 784 | input_seq = pad_sequences(input_seq, self.max_input_seq_length) 785 | start_token = self.target_word2idx['START'] 786 | wid_list = [start_token] 787 | sum_input_seq = pad_sequences([wid_list], self.max_target_seq_length) 788 | terminated = False 789 | 790 | target_text = '' 791 | 792 | while not terminated: 793 | output_tokens = self.model.predict([input_seq, sum_input_seq]) 794 | sample_token_idx = np.argmax(output_tokens[0, :]) 795 | sample_word = self.target_idx2word[sample_token_idx] 796 | wid_list = wid_list + [sample_token_idx] 797 | 798 | if sample_word != 'START' and sample_word != 'END': 799 | target_text += ' ' + sample_word 800 | 801 | if sample_word == 'END' or len(wid_list) >= self.max_target_seq_length: 802 | terminated = True 803 | else: 804 | sum_input_seq = pad_sequences([wid_list], self.max_target_seq_length) 805 | return target_text.strip() 806 | -------------------------------------------------------------------------------- /keras_text_summarization/library/seq2seq.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from keras.models import Model 4 | from keras.layers import Embedding, Dense, Input 5 | from keras.layers.recurrent import LSTM 6 | from keras.preprocessing.sequence import pad_sequences 7 | from keras.callbacks import ModelCheckpoint 8 | from keras_text_summarization.library.utility.glove_loader import load_glove, GLOVE_EMBEDDING_SIZE 9 | import numpy as np 10 | import os 11 | 12 | HIDDEN_UNITS = 100 13 | DEFAULT_BATCH_SIZE = 64 14 | VERBOSE = 1 15 | DEFAULT_EPOCHS = 10 16 | 17 | 18 | class Seq2SeqSummarizer(object): 19 | 20 | model_name = 'seq2seq' 21 | 22 | def __init__(self, config): 23 | self.num_input_tokens = config['num_input_tokens'] 24 | self.max_input_seq_length = config['max_input_seq_length'] 25 | self.num_target_tokens = config['num_target_tokens'] 26 | self.max_target_seq_length = config['max_target_seq_length'] 27 | self.input_word2idx = config['input_word2idx'] 28 | self.input_idx2word = config['input_idx2word'] 29 | self.target_word2idx = config['target_word2idx'] 30 | self.target_idx2word = config['target_idx2word'] 31 | self.config = config 32 | 33 | self.version = 0 34 | if 'version' in config: 35 | self.version = config['version'] 36 | 37 | encoder_inputs = Input(shape=(None,), name='encoder_inputs') 38 | encoder_embedding = Embedding(input_dim=self.num_input_tokens, output_dim=HIDDEN_UNITS, 39 | input_length=self.max_input_seq_length, name='encoder_embedding') 40 | encoder_lstm = LSTM(units=HIDDEN_UNITS, return_state=True, name='encoder_lstm') 41 | encoder_outputs, encoder_state_h, encoder_state_c = encoder_lstm(encoder_embedding(encoder_inputs)) 42 | encoder_states = [encoder_state_h, encoder_state_c] 43 | 44 | decoder_inputs = Input(shape=(None, self.num_target_tokens), name='decoder_inputs') 45 | decoder_lstm = LSTM(units=HIDDEN_UNITS, return_state=True, return_sequences=True, name='decoder_lstm') 46 | decoder_outputs, decoder_state_h, decoder_state_c = decoder_lstm(decoder_inputs, 47 | initial_state=encoder_states) 48 | decoder_dense = Dense(units=self.num_target_tokens, activation='softmax', name='decoder_dense') 49 | decoder_outputs = decoder_dense(decoder_outputs) 50 | 51 | model = Model([encoder_inputs, decoder_inputs], decoder_outputs) 52 | 53 | model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy']) 54 | 55 | self.model = model 56 | 57 | self.encoder_model = Model(encoder_inputs, encoder_states) 58 | 59 | decoder_state_inputs = [Input(shape=(HIDDEN_UNITS,)), Input(shape=(HIDDEN_UNITS,))] 60 | decoder_outputs, state_h, state_c = decoder_lstm(decoder_inputs, initial_state=decoder_state_inputs) 61 | decoder_states = [state_h, state_c] 62 | decoder_outputs = decoder_dense(decoder_outputs) 63 | self.decoder_model = Model([decoder_inputs] + decoder_state_inputs, [decoder_outputs] + decoder_states) 64 | 65 | def load_weights(self, weight_file_path): 66 | if os.path.exists(weight_file_path): 67 | self.model.load_weights(weight_file_path) 68 | 69 | def transform_input_text(self, texts): 70 | temp = [] 71 | for line in texts: 72 | x = [] 73 | for word in line.lower().split(' '): 74 | wid = 1 75 | if word in self.input_word2idx: 76 | wid = self.input_word2idx[word] 77 | x.append(wid) 78 | if len(x) >= self.max_input_seq_length: 79 | break 80 | temp.append(x) 81 | temp = pad_sequences(temp, maxlen=self.max_input_seq_length) 82 | 83 | print(temp.shape) 84 | return temp 85 | 86 | def transform_target_encoding(self, texts): 87 | temp = [] 88 | for line in texts: 89 | x = [] 90 | line2 = 'START ' + line.lower() + ' END' 91 | for word in line2.split(' '): 92 | x.append(word) 93 | if len(x) >= self.max_target_seq_length: 94 | break 95 | temp.append(x) 96 | 97 | temp = np.array(temp) 98 | print(temp.shape) 99 | return temp 100 | 101 | def generate_batch(self, x_samples, y_samples, batch_size): 102 | num_batches = len(x_samples) // batch_size 103 | while True: 104 | for batchIdx in range(0, num_batches): 105 | start = batchIdx * batch_size 106 | end = (batchIdx + 1) * batch_size 107 | encoder_input_data_batch = pad_sequences(x_samples[start:end], self.max_input_seq_length) 108 | decoder_target_data_batch = np.zeros(shape=(batch_size, self.max_target_seq_length, self.num_target_tokens)) 109 | decoder_input_data_batch = np.zeros(shape=(batch_size, self.max_target_seq_length, self.num_target_tokens)) 110 | for lineIdx, target_words in enumerate(y_samples[start:end]): 111 | for idx, w in enumerate(target_words): 112 | w2idx = 0 # default [UNK] 113 | if w in self.target_word2idx: 114 | w2idx = self.target_word2idx[w] 115 | if w2idx != 0: 116 | decoder_input_data_batch[lineIdx, idx, w2idx] = 1 117 | if idx > 0: 118 | decoder_target_data_batch[lineIdx, idx - 1, w2idx] = 1 119 | yield [encoder_input_data_batch, decoder_input_data_batch], decoder_target_data_batch 120 | 121 | @staticmethod 122 | def get_weight_file_path(model_dir_path): 123 | return model_dir_path + '/' + Seq2SeqSummarizer.model_name + '-weights.h5' 124 | 125 | @staticmethod 126 | def get_config_file_path(model_dir_path): 127 | return model_dir_path + '/' + Seq2SeqSummarizer.model_name + '-config.npy' 128 | 129 | @staticmethod 130 | def get_architecture_file_path(model_dir_path): 131 | return model_dir_path + '/' + Seq2SeqSummarizer.model_name + '-architecture.json' 132 | 133 | def fit(self, Xtrain, Ytrain, Xtest, Ytest, epochs=None, batch_size=None, model_dir_path=None): 134 | if epochs is None: 135 | epochs = DEFAULT_EPOCHS 136 | if model_dir_path is None: 137 | model_dir_path = './models' 138 | if batch_size is None: 139 | batch_size = DEFAULT_BATCH_SIZE 140 | 141 | self.version += 1 142 | self.config['version'] = self.version 143 | config_file_path = Seq2SeqSummarizer.get_config_file_path(model_dir_path) 144 | weight_file_path = Seq2SeqSummarizer.get_weight_file_path(model_dir_path) 145 | checkpoint = ModelCheckpoint(weight_file_path) 146 | np.save(config_file_path, self.config) 147 | architecture_file_path = Seq2SeqSummarizer.get_architecture_file_path(model_dir_path) 148 | open(architecture_file_path, 'w').write(self.model.to_json()) 149 | 150 | Ytrain = self.transform_target_encoding(Ytrain) 151 | Ytest = self.transform_target_encoding(Ytest) 152 | 153 | Xtrain = self.transform_input_text(Xtrain) 154 | Xtest = self.transform_input_text(Xtest) 155 | 156 | train_gen = self.generate_batch(Xtrain, Ytrain, batch_size) 157 | test_gen = self.generate_batch(Xtest, Ytest, batch_size) 158 | 159 | train_num_batches = len(Xtrain) // batch_size 160 | test_num_batches = len(Xtest) // batch_size 161 | 162 | history = self.model.fit_generator(generator=train_gen, steps_per_epoch=train_num_batches, 163 | epochs=epochs, 164 | verbose=VERBOSE, validation_data=test_gen, validation_steps=test_num_batches, 165 | callbacks=[checkpoint]) 166 | self.model.save_weights(weight_file_path) 167 | return history 168 | 169 | def summarize(self, input_text): 170 | input_seq = [] 171 | input_wids = [] 172 | for word in input_text.lower().split(' '): 173 | idx = 1 # default [UNK] 174 | if word in self.input_word2idx: 175 | idx = self.input_word2idx[word] 176 | input_wids.append(idx) 177 | input_seq.append(input_wids) 178 | input_seq = pad_sequences(input_seq, self.max_input_seq_length) 179 | states_value = self.encoder_model.predict(input_seq) 180 | target_seq = np.zeros((1, 1, self.num_target_tokens)) 181 | target_seq[0, 0, self.target_word2idx['START']] = 1 182 | target_text = '' 183 | target_text_len = 0 184 | terminated = False 185 | while not terminated: 186 | output_tokens, h, c = self.decoder_model.predict([target_seq] + states_value) 187 | 188 | sample_token_idx = np.argmax(output_tokens[0, -1, :]) 189 | sample_word = self.target_idx2word[sample_token_idx] 190 | target_text_len += 1 191 | 192 | if sample_word != 'START' and sample_word != 'END': 193 | target_text += ' ' + sample_word 194 | 195 | if sample_word == 'END' or target_text_len >= self.max_target_seq_length: 196 | terminated = True 197 | 198 | target_seq = np.zeros((1, 1, self.num_target_tokens)) 199 | target_seq[0, 0, sample_token_idx] = 1 200 | 201 | states_value = [h, c] 202 | return target_text.strip() 203 | 204 | 205 | class Seq2SeqGloVeSummarizer(object): 206 | 207 | model_name = 'seq2seq-glove' 208 | 209 | def __init__(self, config): 210 | self.max_input_seq_length = config['max_input_seq_length'] 211 | self.num_target_tokens = config['num_target_tokens'] 212 | self.max_target_seq_length = config['max_target_seq_length'] 213 | self.target_word2idx = config['target_word2idx'] 214 | self.target_idx2word = config['target_idx2word'] 215 | self.version = 0 216 | if 'version' in config: 217 | self.version = config['version'] 218 | 219 | self.word2em = dict() 220 | if 'unknown_emb' in config: 221 | self.unknown_emb = config['unknown_emb'] 222 | else: 223 | self.unknown_emb = np.random.rand(1, GLOVE_EMBEDDING_SIZE) 224 | config['unknown_emb'] = self.unknown_emb 225 | 226 | self.config = config 227 | 228 | encoder_inputs = Input(shape=(None, GLOVE_EMBEDDING_SIZE), name='encoder_inputs') 229 | encoder_lstm = LSTM(units=HIDDEN_UNITS, return_state=True, name='encoder_lstm') 230 | encoder_outputs, encoder_state_h, encoder_state_c = encoder_lstm(encoder_inputs) 231 | encoder_states = [encoder_state_h, encoder_state_c] 232 | 233 | decoder_inputs = Input(shape=(None, self.num_target_tokens), name='decoder_inputs') 234 | decoder_lstm = LSTM(units=HIDDEN_UNITS, return_state=True, return_sequences=True, name='decoder_lstm') 235 | decoder_outputs, decoder_state_h, decoder_state_c = decoder_lstm(decoder_inputs, 236 | initial_state=encoder_states) 237 | decoder_dense = Dense(units=self.num_target_tokens, activation='softmax', name='decoder_dense') 238 | decoder_outputs = decoder_dense(decoder_outputs) 239 | 240 | model = Model([encoder_inputs, decoder_inputs], decoder_outputs) 241 | 242 | model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy']) 243 | 244 | self.model = model 245 | 246 | self.encoder_model = Model(encoder_inputs, encoder_states) 247 | 248 | decoder_state_inputs = [Input(shape=(HIDDEN_UNITS,)), Input(shape=(HIDDEN_UNITS,))] 249 | decoder_outputs, state_h, state_c = decoder_lstm(decoder_inputs, initial_state=decoder_state_inputs) 250 | decoder_states = [state_h, state_c] 251 | decoder_outputs = decoder_dense(decoder_outputs) 252 | self.decoder_model = Model([decoder_inputs] + decoder_state_inputs, [decoder_outputs] + decoder_states) 253 | 254 | def load_weights(self, weight_file_path): 255 | if os.path.exists(weight_file_path): 256 | self.model.load_weights(weight_file_path) 257 | 258 | def load_glove(self, data_dir_path): 259 | self.word2em = load_glove(data_dir_path) 260 | 261 | def transform_input_text(self, texts): 262 | temp = [] 263 | for line in texts: 264 | x = np.zeros(shape=(self.max_input_seq_length, GLOVE_EMBEDDING_SIZE)) 265 | for idx, word in enumerate(line.lower().split(' ')): 266 | if idx >= self.max_input_seq_length: 267 | break 268 | emb = self.unknown_emb 269 | if word in self.word2em: 270 | emb = self.word2em[word] 271 | x[idx, :] = emb 272 | temp.append(x) 273 | temp = pad_sequences(temp, maxlen=self.max_input_seq_length) 274 | 275 | print(temp.shape) 276 | return temp 277 | 278 | def transform_target_encoding(self, texts): 279 | temp = [] 280 | for line in texts: 281 | x = [] 282 | line2 = 'START ' + line.lower() + ' END' 283 | for word in line2.split(' '): 284 | x.append(word) 285 | if len(x) >= self.max_target_seq_length: 286 | break 287 | temp.append(x) 288 | 289 | temp = np.array(temp) 290 | print(temp.shape) 291 | return temp 292 | 293 | def generate_batch(self, x_samples, y_samples, batch_size): 294 | num_batches = len(x_samples) // batch_size 295 | while True: 296 | for batchIdx in range(0, num_batches): 297 | start = batchIdx * batch_size 298 | end = (batchIdx + 1) * batch_size 299 | encoder_input_data_batch = pad_sequences(x_samples[start:end], self.max_input_seq_length) 300 | decoder_target_data_batch = np.zeros(shape=(batch_size, self.max_target_seq_length, self.num_target_tokens)) 301 | decoder_input_data_batch = np.zeros(shape=(batch_size, self.max_target_seq_length, self.num_target_tokens)) 302 | for lineIdx, target_words in enumerate(y_samples[start:end]): 303 | for idx, w in enumerate(target_words): 304 | w2idx = 0 # default [UNK] 305 | if w in self.target_word2idx: 306 | w2idx = self.target_word2idx[w] 307 | if w2idx != 0: 308 | decoder_input_data_batch[lineIdx, idx, w2idx] = 1 309 | if idx > 0: 310 | decoder_target_data_batch[lineIdx, idx - 1, w2idx] = 1 311 | yield [encoder_input_data_batch, decoder_input_data_batch], decoder_target_data_batch 312 | 313 | @staticmethod 314 | def get_weight_file_path(model_dir_path): 315 | return model_dir_path + '/' + Seq2SeqGloVeSummarizer.model_name + '-weights.h5' 316 | 317 | @staticmethod 318 | def get_config_file_path(model_dir_path): 319 | return model_dir_path + '/' + Seq2SeqGloVeSummarizer.model_name + '-config.npy' 320 | 321 | @staticmethod 322 | def get_architecture_file_path(model_dir_path): 323 | return model_dir_path + '/' + Seq2SeqGloVeSummarizer.model_name + '-architecture.json' 324 | 325 | def fit(self, Xtrain, Ytrain, Xtest, Ytest, epochs=None, batch_size=None, model_dir_path=None): 326 | if epochs is None: 327 | epochs = DEFAULT_EPOCHS 328 | if model_dir_path is None: 329 | model_dir_path = './models' 330 | if batch_size is None: 331 | batch_size = DEFAULT_BATCH_SIZE 332 | 333 | self.version += 1 334 | self.config['version'] = self.version 335 | config_file_path = Seq2SeqGloVeSummarizer.get_config_file_path(model_dir_path) 336 | weight_file_path = Seq2SeqGloVeSummarizer.get_weight_file_path(model_dir_path) 337 | checkpoint = ModelCheckpoint(weight_file_path) 338 | np.save(config_file_path, self.config) 339 | architecture_file_path = Seq2SeqGloVeSummarizer.get_architecture_file_path(model_dir_path) 340 | open(architecture_file_path, 'w').write(self.model.to_json()) 341 | 342 | Ytrain = self.transform_target_encoding(Ytrain) 343 | Ytest = self.transform_target_encoding(Ytest) 344 | 345 | Xtrain = self.transform_input_text(Xtrain) 346 | Xtest = self.transform_input_text(Xtest) 347 | 348 | train_gen = self.generate_batch(Xtrain, Ytrain, batch_size) 349 | test_gen = self.generate_batch(Xtest, Ytest, batch_size) 350 | 351 | train_num_batches = len(Xtrain) // batch_size 352 | test_num_batches = len(Xtest) // batch_size 353 | 354 | history = self.model.fit_generator(generator=train_gen, steps_per_epoch=train_num_batches, 355 | epochs=epochs, 356 | verbose=VERBOSE, validation_data=test_gen, validation_steps=test_num_batches, 357 | callbacks=[checkpoint]) 358 | self.model.save_weights(weight_file_path) 359 | return history 360 | 361 | def summarize(self, input_text): 362 | input_seq = np.zeros(shape=(1, self.max_input_seq_length, GLOVE_EMBEDDING_SIZE)) 363 | for idx, word in enumerate(input_text.lower().split(' ')): 364 | if idx >= self.max_input_seq_length: 365 | break 366 | emb = self.unknown_emb # default [UNK] 367 | if word in self.word2em: 368 | emb = self.word2em[word] 369 | input_seq[0, idx, :] = emb 370 | states_value = self.encoder_model.predict(input_seq) 371 | target_seq = np.zeros((1, 1, self.num_target_tokens)) 372 | target_seq[0, 0, self.target_word2idx['START']] = 1 373 | target_text = '' 374 | target_text_len = 0 375 | terminated = False 376 | while not terminated: 377 | output_tokens, h, c = self.decoder_model.predict([target_seq] + states_value) 378 | 379 | sample_token_idx = np.argmax(output_tokens[0, -1, :]) 380 | sample_word = self.target_idx2word[sample_token_idx] 381 | target_text_len += 1 382 | 383 | if sample_word != 'START' and sample_word != 'END': 384 | target_text += ' ' + sample_word 385 | 386 | if sample_word == 'END' or target_text_len >= self.max_target_seq_length: 387 | terminated = True 388 | 389 | target_seq = np.zeros((1, 1, self.num_target_tokens)) 390 | target_seq[0, 0, sample_token_idx] = 1 391 | 392 | states_value = [h, c] 393 | return target_text.strip() 394 | 395 | 396 | class Seq2SeqGloVeSummarizerV2(object): 397 | 398 | model_name = 'seq2seq-glove-v2' 399 | 400 | def __init__(self, config): 401 | self.max_input_seq_length = config['max_input_seq_length'] 402 | self.num_target_tokens = config['num_target_tokens'] 403 | self.max_target_seq_length = config['max_target_seq_length'] 404 | self.target_word2idx = config['target_word2idx'] 405 | self.target_idx2word = config['target_idx2word'] 406 | self.version = 0 407 | if 'version' in config: 408 | self.version = config['version'] 409 | 410 | self.word2em = dict() 411 | if 'unknown_emb' in config: 412 | self.unknown_emb = config['unknown_emb'] 413 | else: 414 | self.unknown_emb = np.random.rand(1, GLOVE_EMBEDDING_SIZE) 415 | config['unknown_emb'] = self.unknown_emb 416 | 417 | self.config = config 418 | 419 | encoder_inputs = Input(shape=(None, GLOVE_EMBEDDING_SIZE), name='encoder_inputs') 420 | encoder_lstm = LSTM(units=HIDDEN_UNITS, return_state=True, name='encoder_lstm') 421 | encoder_outputs, encoder_state_h, encoder_state_c = encoder_lstm(encoder_inputs) 422 | encoder_states = [encoder_state_h, encoder_state_c] 423 | 424 | decoder_inputs = Input(shape=(None, GLOVE_EMBEDDING_SIZE), name='decoder_inputs') 425 | decoder_lstm = LSTM(units=HIDDEN_UNITS, return_state=True, return_sequences=True, name='decoder_lstm') 426 | decoder_outputs, decoder_state_h, decoder_state_c = decoder_lstm(decoder_inputs, 427 | initial_state=encoder_states) 428 | decoder_dense = Dense(units=self.num_target_tokens, activation='softmax', name='decoder_dense') 429 | decoder_outputs = decoder_dense(decoder_outputs) 430 | 431 | model = Model([encoder_inputs, decoder_inputs], decoder_outputs) 432 | 433 | model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy']) 434 | 435 | self.model = model 436 | 437 | self.encoder_model = Model(encoder_inputs, encoder_states) 438 | 439 | decoder_state_inputs = [Input(shape=(HIDDEN_UNITS,)), Input(shape=(HIDDEN_UNITS,))] 440 | decoder_outputs, state_h, state_c = decoder_lstm(decoder_inputs, initial_state=decoder_state_inputs) 441 | decoder_states = [state_h, state_c] 442 | decoder_outputs = decoder_dense(decoder_outputs) 443 | self.decoder_model = Model([decoder_inputs] + decoder_state_inputs, [decoder_outputs] + decoder_states) 444 | 445 | def load_weights(self, weight_file_path): 446 | if os.path.exists(weight_file_path): 447 | self.model.load_weights(weight_file_path) 448 | 449 | def load_glove(self, data_dir_path): 450 | self.word2em = load_glove(data_dir_path) 451 | 452 | def transform_input_text(self, texts): 453 | temp = [] 454 | for line in texts: 455 | x = np.zeros(shape=(self.max_input_seq_length, GLOVE_EMBEDDING_SIZE)) 456 | for idx, word in enumerate(line.lower().split(' ')): 457 | if idx >= self.max_input_seq_length: 458 | break 459 | emb = self.unknown_emb 460 | if word in self.word2em: 461 | emb = self.word2em[word] 462 | x[idx, :] = emb 463 | temp.append(x) 464 | temp = pad_sequences(temp, maxlen=self.max_input_seq_length) 465 | 466 | print(temp.shape) 467 | return temp 468 | 469 | def transform_target_encoding(self, texts): 470 | temp = [] 471 | for line in texts: 472 | x = [] 473 | line2 = 'start ' + line.lower() + ' end' 474 | for word in line2.split(' '): 475 | x.append(word) 476 | if len(x) >= self.max_target_seq_length: 477 | break 478 | temp.append(x) 479 | 480 | temp = np.array(temp) 481 | print(temp.shape) 482 | return temp 483 | 484 | def generate_batch(self, x_samples, y_samples, batch_size): 485 | num_batches = len(x_samples) // batch_size 486 | while True: 487 | for batchIdx in range(0, num_batches): 488 | start = batchIdx * batch_size 489 | end = (batchIdx + 1) * batch_size 490 | encoder_input_data_batch = pad_sequences(x_samples[start:end], self.max_input_seq_length) 491 | decoder_target_data_batch = np.zeros(shape=(batch_size, self.max_target_seq_length, self.num_target_tokens)) 492 | decoder_input_data_batch = np.zeros(shape=(batch_size, self.max_target_seq_length, GLOVE_EMBEDDING_SIZE)) 493 | for lineIdx, target_words in enumerate(y_samples[start:end]): 494 | for idx, w in enumerate(target_words): 495 | w2idx = 0 # default [UNK] 496 | if w in self.word2em: 497 | emb = self.unknown_emb 498 | decoder_input_data_batch[lineIdx, idx, :] = emb 499 | if w in self.target_word2idx: 500 | w2idx = self.target_word2idx[w] 501 | if w2idx != 0: 502 | if idx > 0: 503 | decoder_target_data_batch[lineIdx, idx - 1, w2idx] = 1 504 | yield [encoder_input_data_batch, decoder_input_data_batch], decoder_target_data_batch 505 | 506 | @staticmethod 507 | def get_weight_file_path(model_dir_path): 508 | return model_dir_path + '/' + Seq2SeqGloVeSummarizerV2.model_name + '-weights.h5' 509 | 510 | @staticmethod 511 | def get_config_file_path(model_dir_path): 512 | return model_dir_path + '/' + Seq2SeqGloVeSummarizerV2.model_name + '-config.npy' 513 | 514 | @staticmethod 515 | def get_architecture_file_path(model_dir_path): 516 | return model_dir_path + '/' + Seq2SeqGloVeSummarizerV2.model_name + '-architecture.json' 517 | 518 | def fit(self, Xtrain, Ytrain, Xtest, Ytest, epochs=None, batch_size=None, model_dir_path=None): 519 | if epochs is None: 520 | epochs = DEFAULT_EPOCHS 521 | if model_dir_path is None: 522 | model_dir_path = './models' 523 | if batch_size is None: 524 | batch_size = DEFAULT_BATCH_SIZE 525 | 526 | self.version += 1 527 | self.config['version'] = self.version 528 | config_file_path = Seq2SeqGloVeSummarizerV2.get_config_file_path(model_dir_path) 529 | weight_file_path = Seq2SeqGloVeSummarizerV2.get_weight_file_path(model_dir_path) 530 | checkpoint = ModelCheckpoint(weight_file_path) 531 | np.save(config_file_path, self.config) 532 | architecture_file_path = Seq2SeqGloVeSummarizerV2.get_architecture_file_path(model_dir_path) 533 | open(architecture_file_path, 'w').write(self.model.to_json()) 534 | 535 | Ytrain = self.transform_target_encoding(Ytrain) 536 | Ytest = self.transform_target_encoding(Ytest) 537 | 538 | Xtrain = self.transform_input_text(Xtrain) 539 | Xtest = self.transform_input_text(Xtest) 540 | 541 | train_gen = self.generate_batch(Xtrain, Ytrain, batch_size) 542 | test_gen = self.generate_batch(Xtest, Ytest, batch_size) 543 | 544 | train_num_batches = len(Xtrain) // batch_size 545 | test_num_batches = len(Xtest) // batch_size 546 | 547 | history = self.model.fit_generator(generator=train_gen, steps_per_epoch=train_num_batches, 548 | epochs=epochs, 549 | verbose=VERBOSE, validation_data=test_gen, validation_steps=test_num_batches, 550 | callbacks=[checkpoint]) 551 | self.model.save_weights(weight_file_path) 552 | return history 553 | 554 | def summarize(self, input_text): 555 | input_seq = np.zeros(shape=(1, self.max_input_seq_length, GLOVE_EMBEDDING_SIZE)) 556 | for idx, word in enumerate(input_text.lower().split(' ')): 557 | if idx >= self.max_input_seq_length: 558 | break 559 | emb = self.unknown_emb # default [UNK] 560 | if word in self.word2em: 561 | emb = self.word2em[word] 562 | input_seq[0, idx, :] = emb 563 | states_value = self.encoder_model.predict(input_seq) 564 | target_seq = np.zeros((1, 1, GLOVE_EMBEDDING_SIZE)) 565 | target_seq[0, 0, :] = self.word2em['start'] 566 | target_text = '' 567 | target_text_len = 0 568 | terminated = False 569 | while not terminated: 570 | output_tokens, h, c = self.decoder_model.predict([target_seq] + states_value) 571 | 572 | sample_token_idx = np.argmax(output_tokens[0, -1, :]) 573 | sample_word = self.target_idx2word[sample_token_idx] 574 | target_text_len += 1 575 | 576 | if sample_word != 'start' and sample_word != 'end': 577 | target_text += ' ' + sample_word 578 | 579 | if sample_word == 'end' or target_text_len >= self.max_target_seq_length: 580 | terminated = True 581 | 582 | if sample_word in self.word2em: 583 | target_seq[0, 0, :] = self.word2em[sample_word] 584 | else: 585 | target_seq[0, 0, :] = self.unknown_emb 586 | 587 | states_value = [h, c] 588 | return target_text.strip() 589 | 590 | 591 | 592 | -------------------------------------------------------------------------------- /keras_text_summarization/library/utility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen0040/keras-text-summarization/11df7c7bf30de8ccd8aecef5a551c136c85f0092/keras_text_summarization/library/utility/__init__.py -------------------------------------------------------------------------------- /keras_text_summarization/library/utility/device_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from keras import backend as K 3 | 4 | 5 | def init_devices(device_type=None): 6 | if device_type is None: 7 | device_type = 'cpu' 8 | 9 | num_cores = 4 10 | 11 | if device_type == 'gpu': 12 | num_GPU = 1 13 | num_CPU = 1 14 | else: 15 | num_CPU = 1 16 | num_GPU = 0 17 | 18 | config = tf.ConfigProto(intra_op_parallelism_threads=num_cores, 19 | inter_op_parallelism_threads=num_cores, allow_soft_placement=True, 20 | device_count={'CPU': num_CPU, 'GPU': num_GPU}) 21 | session = tf.Session(config=config) 22 | K.set_session(session) 23 | -------------------------------------------------------------------------------- /keras_text_summarization/library/utility/glove_loader.py: -------------------------------------------------------------------------------- 1 | import urllib.request 2 | import os 3 | import sys 4 | import zipfile 5 | import numpy as np 6 | 7 | GLOVE_EMBEDDING_SIZE = 100 8 | 9 | 10 | def reporthook(block_num, block_size, total_size): 11 | read_so_far = block_num * block_size 12 | if total_size > 0: 13 | percent = read_so_far * 1e2 / total_size 14 | s = "\r%5.1f%% %*d / %d" % ( 15 | percent, len(str(total_size)), read_so_far, total_size) 16 | sys.stderr.write(s) 17 | if read_so_far >= total_size: # near the end 18 | sys.stderr.write("\n") 19 | else: # total size is unknown 20 | sys.stderr.write("read %d\n" % (read_so_far,)) 21 | 22 | 23 | def download_glove(data_dir_path=None): 24 | if data_dir_path is None: 25 | data_dir_path = 'very_large_data' 26 | glove_model_path = data_dir_path + "/glove.6B." + str(GLOVE_EMBEDDING_SIZE) + "d.txt" 27 | if not os.path.exists(glove_model_path): 28 | 29 | glove_zip = data_dir_path + '/glove.6B.zip' 30 | 31 | if not os.path.exists(data_dir_path): 32 | os.makedirs(data_dir_path) 33 | 34 | if not os.path.exists(glove_zip): 35 | print('glove file does not exist, downloading from internet') 36 | urllib.request.urlretrieve(url='http://nlp.stanford.edu/data/glove.6B.zip', filename=glove_zip, 37 | reporthook=reporthook) 38 | 39 | print('unzipping glove file') 40 | zip_ref = zipfile.ZipFile(glove_zip, 'r') 41 | zip_ref.extractall(data_dir_path) 42 | zip_ref.close() 43 | 44 | 45 | def load_glove(data_dir_path=None): 46 | if data_dir_path is None: 47 | data_dir_path = 'very_large_data' 48 | download_glove(data_dir_path) 49 | _word2em = {} 50 | glove_model_path = data_dir_path + "/glove.6B." + str(GLOVE_EMBEDDING_SIZE) + "d.txt" 51 | file = open(glove_model_path, mode='rt', encoding='utf8') 52 | for line in file: 53 | words = line.strip().split() 54 | word = words[0] 55 | embeds = np.array(words[1:], dtype=np.float32) 56 | _word2em[word] = embeds 57 | file.close() 58 | return _word2em 59 | 60 | 61 | def glove_zero_emb(): 62 | return np.zeros(shape=GLOVE_EMBEDDING_SIZE) 63 | 64 | 65 | class Glove(object): 66 | 67 | word2em = None 68 | 69 | GLOVE_EMBEDDING_SIZE = GLOVE_EMBEDDING_SIZE 70 | 71 | def __init__(self): 72 | self.word2em = load_glove() 73 | -------------------------------------------------------------------------------- /keras_text_summarization/library/utility/plot_utils.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | import itertools 4 | 5 | 6 | def plot_confusion_matrix(cm, classes, 7 | normalize=False, 8 | title='Confusion matrix', 9 | cmap=plt.cm.Blues): 10 | """ 11 | See full source and example: 12 | http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html 13 | 14 | This function prints and plots the confusion matrix. 15 | Normalization can be applied by setting `normalize=True`. 16 | """ 17 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 18 | plt.title(title) 19 | plt.colorbar() 20 | tick_marks = np.arange(len(classes)) 21 | plt.xticks(tick_marks, classes, rotation=45) 22 | plt.yticks(tick_marks, classes) 23 | 24 | if normalize: 25 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 26 | print("Normalized confusion matrix") 27 | else: 28 | print('Confusion matrix, without normalization') 29 | 30 | thresh = cm.max() / 2. 31 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 32 | plt.text(j, i, cm[i, j], 33 | horizontalalignment="center", 34 | color="white" if cm[i, j] > thresh else "black") 35 | 36 | plt.tight_layout() 37 | plt.ylabel('True label') 38 | plt.xlabel('Predicted label') 39 | plt.show() 40 | 41 | 42 | def most_informative_feature_for_binary_classification(vectorizer, classifier, n=100): 43 | """ 44 | See: https://stackoverflow.com/a/26980472 45 | 46 | Identify most important features if given a vectorizer and binary classifier. Set n to the number 47 | of weighted features you would like to show. (Note: current implementation merely prints and does not 48 | return top classes.) 49 | """ 50 | 51 | class_labels = classifier.classes_ 52 | feature_names = vectorizer.get_feature_names() 53 | topn_class1 = sorted(zip(classifier.coef_[0], feature_names))[:n] 54 | topn_class2 = sorted(zip(classifier.coef_[0], feature_names))[-n:] 55 | 56 | for coef, feat in topn_class1: 57 | print(class_labels[0], coef, feat) 58 | 59 | print() 60 | 61 | for coef, feat in reversed(topn_class2): 62 | print(class_labels[1], coef, feat) 63 | 64 | 65 | def plot_history_2win(history): 66 | plt.subplot(211) 67 | plt.title('Accuracy') 68 | plt.plot(history.history['acc'], color='g', label='Train') 69 | plt.plot(history.history['val_acc'], color='b', label='Validation') 70 | plt.legend(loc='best') 71 | 72 | plt.subplot(212) 73 | plt.title('Loss') 74 | plt.plot(history.history['loss'], color='g', label='Train') 75 | plt.plot(history.history['val_loss'], color='b', label='Validation') 76 | plt.legend(loc='best') 77 | 78 | plt.tight_layout() 79 | plt.show() 80 | 81 | 82 | def create_history_plot(history, model_name, metrics=None): 83 | plt.title('Accuracy and Loss (' + model_name + ')') 84 | if metrics is None: 85 | metrics = {'acc', 'loss'} 86 | if 'acc' in metrics: 87 | plt.plot(history.history['acc'], color='g', label='Train Accuracy') 88 | plt.plot(history.history['val_acc'], color='b', label='Validation Accuracy') 89 | if 'loss' in metrics: 90 | plt.plot(history.history['loss'], color='r', label='Train Loss') 91 | plt.plot(history.history['val_loss'], color='m', label='Validation Loss') 92 | plt.legend(loc='best') 93 | 94 | plt.tight_layout() 95 | 96 | 97 | def plot_history(history, model_name): 98 | create_history_plot(history, model_name) 99 | plt.show() 100 | 101 | 102 | def plot_and_save_history(history, model_name, file_path, metrics=None): 103 | if metrics is None: 104 | metrics = {'acc', 'loss'} 105 | create_history_plot(history, model_name, metrics) 106 | plt.savefig(file_path) 107 | -------------------------------------------------------------------------------- /keras_text_summarization/library/utility/text_utils.py: -------------------------------------------------------------------------------- 1 | WHITELIST = 'abcdefghijklmnopqrstuvwxyz1234567890?.,' 2 | 3 | 4 | def in_white_list(_word): 5 | for char in _word: 6 | if char in WHITELIST: 7 | return True 8 | 9 | return False 10 | -------------------------------------------------------------------------------- /notes/ReadMe.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | * https://machinelearningmastery.com/encoder-decoder-models-text-summarization-keras/ -------------------------------------------------------------------------------- /notes/evaluation.md: -------------------------------------------------------------------------------- 1 | # Model Evaluation 2 | 3 | The data source is from [https://github.com/GeorgeMcIntire/fake_real_news_dataset](https://github.com/GeorgeMcIntire/fake_real_news_dataset) 4 | 5 | The training plot below was obtain by running for 100 epochs (note that further improvement can be done by 6 | increasing the number of epochs as the result has not converged yet) 7 | 8 | 9 | ![seq2seq-history](/demo/reports/seq2seq-history.png) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | keras 3 | tensorflow 4 | pandas 5 | numpy 6 | scipy 7 | h5py 8 | matplotlib -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages 2 | from setuptools import setup 3 | 4 | 5 | setup(name='keras_text_summarization', 6 | version='0.0.1', 7 | description='Text Summarization in Keras using Seq2Seq and Recurrent Networks', 8 | author='Xianshun Chen', 9 | author_email='xs0040@gmail.com', 10 | url='https://github.com/chen0040/keras-text-summarization', 11 | download_url='https://github.com/chen0040/keras-text-summarization/tarball/0.0.1', 12 | license='MIT', 13 | install_requires=['Keras'], 14 | packages=find_packages()) 15 | --------------------------------------------------------------------------------